[翻訳] Vision Transformerの視覚的解説 (A Visual Guide to Vision Transformers)
(discuss.pytorch.kr)-
ℹ️ xguruさんが紹介してくださったVisual Transformersビジュアルガイドの記事を読み、著者であり Data Scientist かつ Software Engineer であるDennis Turpの許可のもと、彼が執筆したVision Transformer(ViT)の視覚的解説(A Visual Guide to Vision Transformers)を翻訳しました。
-
Vision Transformer(ViT)は、CV(Computer Vision)分野にTransformerを適用し、物体検出や画像分類などの分野で優れた性能を示すモデルです。特に、画像から特徴(feature)を抽出するVisual Encoderとして多く使われています。
-
原文の説明は簡潔なため、理解が難しい場合に備えて、理解を助けるための注釈を一部追加しました。
Vision Transformer(ViT)の視覚的解説
この記事は、画像分類タスクで最先端(SotA, State-of-the-Art)の性能を示すディープラーニングモデル、Vision Transformers(ViTs)についての視覚的な解説です。Vision Transformerは、もともと自然言語処理(NLP)向けに設計されたTransformerアーキテクチャを画像データに適用したものです。この記事では、スクロールしながらデータの流れを理解しやすくする可視化と簡単な説明を通じて、Vision Transformerの動作方式を理解できるようにします。 (:pytorch::kr:: ここではスクロールによる説明が難しいため、画像キャプチャで代用します。原文もあわせて参照するとよいでしょう。)
> This is a visual guide to Vision Transformers (ViTs), a class of deep learning models that have achieved state-of-the-art performance on image classification tasks. Vision Transformers apply the transformer architecture, originally designed for natural language processing (NLP), to image data. This guide will walk you through the key components of Vision Transformers in a scroll story format, using visualizations and simple explanations to help you understand how these models work and how the flow of the data through the model looks like.
0. データを見てみる / Lets start with the data
一般的な畳み込みニューラルネットワーク(CNN)と同様に、Vision Transformerも教師あり学習(Supervised Learning)方式で学習します。つまり、画像とそれに対応するラベル(label)で構成されたデータセットでモデルを学習します。
> Like normal convolutional neural networks, vision transformers are trained in a supervised manner. This means that the model is trained on a dataset of images and their corresponding labels.
1. 1つのデータだけに注目する / Focus on one data point
Vision Transformerが内部でどのように動作するのかを理解するため、まずは1つのデータ(バッチサイズ1)だけに注目してみましょう。そして次の問いを考えてみます。Transformerにこのデータを入力するには、どのように準備(前処理)すればよいのでしょうか。
> To get a better understanding of what happens inside a vision transformer lets focus on a single data point (batch size of 1). And lets ask the question: How is this data point prepared in order to be consumed by a transformer?
2. ラベルはいったん脇に置く / Forget the label for the moment
ラベルについては後でもっと関係の深い形で見ていきます。今は画像1枚だけを残して見ていきましょう。
> The label will become more relevant later. For now the only thing that we are left with is a single image.
3. 画像をパッチに分割する / Create patches of the image
画像全体を同じサイズのパッチ(p x p)画像に分割し、Transformer内部で使えるように準備します。
> To prepare the image for the use inside the transformer we divide the image into equally sized patches of size p x p.
4. 画像パッチを平坦化する / Flatting of the image patches
パッチを p' = p² x c サイズのベクトルへ平坦化(flatten)します。ここで p はパッチの一辺の長さ、c はチャネル数です。 (:pytorch::kr:: たとえばRGB画像の場合、チャネル数は3です。)
> The patches are now flattened into vectors of dimension p'= p²*c where p is the size of the patch and c is the number of channels.
5. パッチから埋め込みを作る / Creating patch embeddings
前段で画像パッチから作成したベクトルを、線形変換によってエンコードします。こうして作られたパッチ埋め込みベクトル(Patch Embedding Vector) は、固定サイズ d を持ちます。
> These image patch vectors are now encoded using a linear transformation. The resulting Patch Embedding Vector has a fixed size d.
6. すべてのパッチを埋め込む / Embedding all patches
画像パッチをすべて固定サイズのベクトルに埋め込むと、n x d サイズの配列が得られます。ここで n は画像パッチの数、d は1つのパッチ埋め込みのサイズです。
> Now that we have embedded our image patches into vectors of fixed size, we are left with an array of size n x d where n is the the number of image patches and d is the size of the patch embedding
7. 分類トークン(CLS)を追加する / Appending a classification token
モデルを効果的に学習させるために、パッチ埋め込みに加えて分類トークン(CLS token)と呼ばれるベクトルを追加します。このベクトルはニューラルネットワークを通じて学習可能なパラメータであり、ランダムに初期化されます。なお、CLSトークンは1つだけで、すべてのデータに同じベクトルを追加します。(:pytorch::kr:: ここまで行うと、n個のパッチ埋め込みにCLSトークンを加えて (n+1) 個となり、各埋め込みサイズ d に対して (n+1) x d を持つことになります。)
> In order for us to effectively train our model we extend the array of patch embeddings by an additional vector called classification token (cls token). This vector is a learnable parameter of the network and is randomly initialized. Note: We only have one cls token and we append the same vector for all data points.
8. 位置埋め込みベクトルを追加する / Add positional embedding Vectors
これまでの パッチ埋め込み には個別の位置情報がありません。すべてのパッチ埋め込みに、学習可能でランダムに初期化された 位置埋め込みベクトル(Positional Embedding Vector) を加えることでこの問題を解決します。また、先ほど追加した 分類トークン(CLS token) にもこの位置ベクトルを追加します。(:pytorch::kr:: Transformerでは Positional Encoding の値を「加え」ます。したがってベクトルの大きさは変化しません。)
> Currently our patch embeddings have no positional information associated with them. We remedy that by adding a learnable randomly initialized positional embedding vector to all our patch embeddings. We also add a such a positional embedding vector to our classification token.
9. Transformerに入力する / Transformer Input
位置埋め込みベクトルを追加すると、(n+1) x d サイズの配列が残ります。この配列をTransformerへの入力として与え、これについては次のステップでさらに詳しく説明します。
> After the positional embedding vectors have been added we are left with an array of size (n+1) x d. This will be our input for the transformer which will be explained in greater detail in the next steps.
10.1. Transformer: QKVを作る / QKV Creation
Transformer入力のパッチ埋め込みベクトルは、複数の大きなベクトルへ線形に埋め込まれます。これらの新しいベクトルは、同じサイズの3つの部分に分割されます。つまり、それぞれ Q はクエリ(Query)ベクトル、K はキー(Key)ベクトル、V は値(Value)ベクトル です。これらのベクトルはそれぞれ (n+1) 個ずつ得られます。
> Our transformer input patch embedding vectors are linearly embedded into multiple large vectors. These new vectors are than separated into three equal sized parts. The Q - Query Vector, the K - Key Vector and the V - Value Vector . We will have (n+1) of a all of those vectors.
10.2. Transformer: アテンションスコアを計算する / Attention Score Calculation
まずアテンションスコア A を計算するために、すべてのクエリベクトル Q にすべてのキーベクトル K を掛けます。
> To calculate our attention scores A we will now multiply all of our query vectors Q with all of our key vectors K.
10.3. Transformer: アテンションスコア行列 / Attention Score Matrix
こうして得られたアテンションスコア行列 A の各行の合計が 1 になるように、すべての行に softmax 関数を適用します。
> Now that we have the attention score matrix A we apply a softmax function to every row such that every row sums up to 1.
10.4. Transformer: 集約されたコンテキスト情報を計算する / Aggregated Contextual Information Calculation
最初のパッチ埋め込みベクトルに対する 集約されたコンテキスト情報(aggregated contextual information) を計算するために、アテンション行列の 1行目 に注目します。ここで 値ベクトル V に対する重みとしてその要素を使用し、最初の画像パッチ埋め込みに対する 集約されたコンテキスト情報ベクトル(aggregated vector) を生成します。
> To calculate the aggregated contextual information for the first patch embedding vector. We focus on the first row of the attention matrix. And use the entires as weights for our Value Vectors V. The result is our aggregated contextual information vector for the first image patch embedding.
10.5. Transformer: すべてのパッチについて集約されたコンテキスト情報を求める / Aggregated Contextual Information for every patch
アテンションスコア行列の他の行についても同じ処理を繰り返し、N+1 個の集約されたコンテキスト情報ベクトルを得ます。つまり、すべてのパッチごとに1つ (=N個) + 分類トークン(CLS Token)に対して1つ (=1) です。ここまでで最初のアテンションヘッド(Attention Head)が完了します。
> Now we repeat this process for every row of our attention score matrix and the result will be N+1 aggregated contextual information vectors. One for every patch + one for the classification token. This steps concludes our first Attention Head.
10.6. Transformer: マルチヘッドアテンション / Multi-Head Attention
(Transformer の) マルチヘッドアテンションを扱っているため、別の QKV に対して 10.1 から 10.5 まで の全プロセスを繰り返します。上の図では説明のために2つのヘッド בלבדを仮定していますが、一般的に ViT はさらに多くのヘッドを持ちます。最終的に、複数の集約されたコンテキスト情報ベクトル(Multiple Aggregated Contextual Information Vectors)が得られます。
> Now because we are dealing multi head attention we repeat the entire process from step 10.1 - 10-5 again with a different QKV mapping. For our explanatory setup we assume 2 Heads but typically a VIT has many more. In the end this results in multiple Aggregated contextual information vectors.
10.7. Transformer: 最後のアテンション層のステップ / Last Attention Layer Step
このようにして生成した複数のヘッドを積み重ねたあと、パッチ埋め込みと同じサイズの d 次元ベクトルにマッピングします。
> These heads are stacked together and are mapped to vectors of size d which was the same size as our patch embeddings had.
10.8. Transformer: アテンション層の結果を得る / Attention Layer Result
これで前のステップからアテンション層が完成し、入力時に使っていたものと まったく同じサイズ の埋め込みが得られました。
> The previous step concluded the attention layer and we are left with the same amount of embeddings of exactly the same size as we used as input.
10.9. Transformer: 残差接続を行う / Residual connections
Transformer では 残差接続(Residual Connection) が多用されます。これは単に、前の層の入力を現在の層の出力に加えることを意味します。ここでも残差接続を行います。
> Transformers make heavy use of residual connections which simply means adding the input of the previous layer to the output the current layer. This is also something that we will do now.
10.10. Transformer: 残差接続の結果を得る / Residual connection Result
このような残差接続により、(同じサイズ d のベクトル同士を加えることで)同じサイズのベクトルが生成されます。
> The addition results in vectors of the same size.
10.11. Transformer: フィードフォワードネットワークに通す / Feed Forward Network
ここまでの結果(output)を、非線形活性化関数を持つフィードフォワード人工ニューラルネットワークに通します。
> Now these outputs are feed through a feed forward neural network with non linear activation functions
10.12. Transformer: 最終結果を得る / Final Result
Transformer にはここまでの演算のあとに、さらに別の残差接続がありますが、ここでは説明を簡潔にするため省略し、Transformer レイヤーの演算を締めくくります。最終的に Transformer は入力と同じサイズの出力を生成します。
> After the transformer step there is another residual connections which we will skip here for brevity. And so the last step concluded the transformer layer. In the end the transformer produced outputs of the same size as input.
11. Transformer の演算を繰り返す / Repeat Transformers
ここまで進めた 10.1 から 10.12 まで の Transformer の一連の演算全体を、複数回繰り返します。ここでは 6 回を例にしています。
> Repeat the entire transformer calculation Steps 10.1 - Steps 10.12 for the Transformer several times e.g. 6 times.
12. 分類トークンの出力を確認する / Identify Classification token output
最後のステップは、分類トークン(CLS token)の出力を確認することです。このベクトルは、Vision Transformer の最後の段階で使われます。
> Last step is to identify the classification token output. This vector will be used in the final step of our Vision Transformer journey.
13. 最終ステップ: 分類確率を予測する / Final Step: Predicting classification probabilities
最終段階では、この分類出力トークンを別の全結合(fully-connected)人工ニューラルネットワークに通し、入力画像の分類確率(classification probabilities)を予測します。
> In the final and last step we use this classification output token and another fully connected neural network to predict the classification probabilities of our input image.
14. Vision Transformer を学習する / Training of the Vision Transformer
前の段階で予測した分類確率(class probabilities)と正解ラベル(true class label)を比較する標準的なクロスエントロピー損失関数(Cross-Entropy Loss Function)を使って、Vision Transformer を学習します。モデルは、逆伝播(backpropagation)と勾配降下法(gradient descent)を用いて損失関数を最小化する方向にパラメータを更新しながら学習します。
> We train the Vision Transformer using a standard cross-entropy loss function, which compares the predicted class probabilities with the true class labels. The model is trained using backpropagation and gradient descent, updating the model parameters to minimize the loss function.
結論 / Conclusion
ここまで、視覚的な解説を通じて、データ準備からモデル学習まで Vision Transformer の主要な構成要素を見てきました。この解説が、Vision Transformer がどのように動作し、画像分類にどのように使われるのかを理解する助けになれば幸いです。
> In this visual guide, we have walked through the key components of Vision Transformers, from the data preparation to the training of the model. We hope this guide has helped you understand how Vision Transformers work and how they can be used to classify images.
Vision Transformer をさらによく理解できるように、小さな Colab Notebook も用意しました。'Blogpost' のコメントもご覧ください。このコードは @lucidrains による優れた VIT Pytorch implementation から取られたものなので、彼の仕事もぜひ確認してみてください。
> I prepared this little Colab Notebook to help you understand the Vision Transformer even better. Please have look for the 'Blogpost' comment. The code was taken from @lucidrains great VIT Pytorch implementation be sure to checkout his work.
ご質問やフィードバックがありましたら、いつでもお気軽にご連絡ください。お読みいただきありがとうございました! (著者の GitHub、X(Twitter)、Threads、LinkedIn)
> If you have any questions or feedback, please feel free to reach out to me. Thank you for reading!
謝辞 / Acknowledgements
- @lucidrains の VIT PyTorch 実装
- すべての画像は Wikipedia から取得しており、CC BY-SA 4.0 ライセンス に基づいて利用を許諾されています。
> * VIT Pytorch implementation
> * All images have been taken from Wikipedia and are licensed under the Creative Commons Attribution-Share Alike 4.0 International license.
さらに読む
原文
https://blog.mdturp.ch/posts/…
要約記事
https://ja.news.hada.io/topic?id=14370
Vision Transformer 論文
https://arxiv.org/abs/2010.11929v2
PR12 の Vision Transformer 論文レビュー動画
https://www.youtube.com/watch?v=D72_Cn-XV1g
Google Research の Vision Transformer リポジトリ
https://github.com/google-research/vision_transformer
PapersWithCode にまとまっている Vision Transformer 関連の論文やコードなど
https://paperswithcode.com/method/vision-transformer
⚠️広告⚠️: :pytorch:PyTorch 日本ユーザーコミュニティ がまとめたこの記事は役に立ちましたか? 会員登録すると主要な記事をメールでお届けします! (基本は Weekly ですが、Daily への変更も可能です。)
1件のコメント
役立つ資料を作成してくださってありがとうございます。^