SAGAN(Self Attention Generative Adversarial Network)のSelf Attention機構をざっくり理解した

GANの父であるIan Goodfellowさんが関係しているSelf Attention Generative Adversarial Network(SAGAN)の論文をいつも通り斜め読みした.

SAGAN(Self Attention Generative Adversarial Network)の論文はこちらです.

CVPR2018にて発表されている論文で、CVPR2018で行われた「Intruduction to GAN」においても自身のスライドで触れている.

今回の論文のポイントはざっくり以下らしい.

  • Spectral Normalizationを使用し、DiscriminatorとGenerator双方に適用し学習を安定化
  • Convolutionだけでは局所的なところに注目しがちだが、Self Attentionを導入して画像全体の処理状態を考慮できるようにした.

上記の工夫を入れることで、ACGANやSNGANと比較をしている.

行っているタスクはImageNetの画像を使用して学習して、いわゆるクラスを指定して画像を生成する系のタスク.例えば「サメ」という入力を入れてサメの画像を生成するようなタスク.よくあるやつですね.

これまでの簡単なまとめですが、生成含めた生成画像結果に関しては以下の画像でまとめらています.

スクリーンショット 2018-07-15 0.22.42.png

ian goodefellowの先のスライドから上記の画像を引用しました.

さて、話がそれてしまいましたが、このSAGANにおいて使用されているSelf Attention機構に関してです.

まったく私はAttention機構周りについて知らなかったため、まず機械翻訳系のAttentionについて軽く勉強しました.

語弊を多分に含むと思うので詳しくは調べていただきたいですが、簡単にAttention機構を説明すると「今回の入力を元に得た自分の隠れ層の状態を利用してどういったところに注目して処理すべきか」を識別中に算出している機構を指しています.翻訳の場合は今”ある単語”を生成するときにどこの単語に注目すべきかを全体を見ながら決めます.

Self Attention機構について

論文中では犬の足などといった部分を綺麗に生成するにはConvolutionのような比較的局所的なことに注目したネットワークでは難しいため全体の状況を見ながら決めるべきといったようなことが書かれています.

Convolutionでは局所的な部分を見るのに対して、Self Attention機構は大域的な情報を参照しているとのことです.

コード

とりあえずまずコードを貼っておきます.コードを見たら実際に行っている計算が明白なので.

https://github.com/heykeetae/Self-Attention-GAN/blob/8714a54ba5027d680190791ba3a6bb08f9c9a129/sagan_models.py#L29-L37

        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N) 
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N


        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)

Self Attention機構使う場所

先にSelf Attentionがネットワークのどのような形で使われているかを紹介します.

self attentionの特徴マップoが作られたらself attentionの入力になったxとγをパラメータとして足し合わせるだけです.論文中はいかに書かれています.

スクリーンショット 2018-07-15 0.49.29.png

簡単ですが念のために図で描けば

スクリーンショット 2018-07-15 0.58.32.png

ということで簡単.

あとはSelf Attention機構の中身.

Self Attention機構中身

それでは論文中に出てくるSelf Attention機構の説明図をまず紹介します.

スクリーンショット 2018-07-08 23.48.28

ここで行いたい作戦はこうです.
ある要素に注目した時に他に注目すべきところはどこかを計算するため、色やテクスチャが似ている画素/領域を見つけ出し強調具合を決めAttention Mapを作成します.
それを実現するにあたり、
1. 二つの「クエリf(x)」と「キーg(x)」を用いて各画素が画像全体のどの画素と似ているかをすべて計算します.
2. どの画素とどの画素が似ているか計算できたら、重み(h(x))を掛け合わせてSelf Attention Map(o)を作成します.
3.self attention mapを xに足し合わせます.

まず大前提として、

x : C(入力チャンネル数C)xN(self attentionに入力されるすべての値.行列になっている場合はFlatにする)の入力. x_iと書かれている時は画像全体を1次元の配列としたときのi番目の要素.

スクリーンショット 2018-07-15 0.50.59.png

f(x), g(x), h(x): C・C~ のサイズの学習するパラメータを持っていて、C・Nのサイズの行列を入力にC~ ・ Nのサイズの行列を出力する関数.

スクリーンショット 2018-07-15 0.49.10.png

また絵にすれば以下のようになります.

スクリーンショット 2018-07-15 1.10.21.png

上記のf(x)とg(x)を使って、どの画素とどの画素が似ているのかを計算します.論文でC~をC/8にしているようです.

スクリーンショット 2018-07-15 0.49.10.png

再掲した上記の式はf(x)画素i番目のチャンネル方向のベクトル列とg(x)画素j番目のチャンネル方向のベクトル列の内積をとって類似度を計算. それをi要素で足して1になるように正規化しています.スクリーンショット 2018-07-15 1.24.20.png

上記の計算によって以下のように全画素に対応したそれぞれのAttention Mapが生成されます.

スクリーンショット 2018-07-15 1.26.14.png

赤い点に注目したときのAttention Map、緑の点に注目したときのAttention Map等が可視化されているのが確認できます.色的に近い領域に強く反応しています.

最後にattention mapに重みをかけ合わせて、Self Attention Mapを生成します.このSelf Attention機構を導入した場合、f(x), g(x), h(x)のパラメータを学習しることになります.

スクリーンショット 2018-07-15 1.30.55.png

思った以上にシンプルな仕組みです.上記のoをはじめに話した y = x + γo に代入します.

おわりに

要は、全画素で似たようなチャンネルのベクトルを持つ画素を他でも探しているだけです.ただかなり力技ですね.まぁ性能あがっているようなのでいいですが….以上、色やテクスチャが似ているところが強く反応する画像版Self Attentionの斜め読みでした.

About the author

Comments

  1. ものすごい参考になりました。特に論文中の図の 1×1 conv がWと対応しているのが最初分からなかったので救われました。

    ちなみに、実は記事の最後の図の h(x) のshapeは(N, C~)で、βと掛け合わせた結果 (N, C~)に、そしてそれを(C~, C)のshapeであるW_vと掛け合わせることで o になり、これがxと同じshapeである (N, C) になるのではと思ったのですがいかがでしょうか。

コメントを残す