mnistの1クラスだけ学習させたVariational AutoEncoderの中間層の可視化

今更だけどVariational AutoEncoder…. 相変わらずよくわからない…

でもとりあえず可視化だけしてみる練習(TensorFlowの練習もかねて). 解説記事ではないので悪しからず.

 

以下の「visualize results」にgif , 「コード」に コード起きました. TensorFlowベースです.

https://urusulambda.github.io/products/vaeae.html

 

複数種類の数字を混ぜたmnistをそのままVariational AutoEncoderに流し込んでいる解説をよく見かける.ただ、これはあまりVAEのイメージが湧きにくい気がしている.(勝手な想像だけど)恐らくそうしているのは、あのVariationalAutoEncoderならではの滑らかな数字の変化を見せたいから持ってくるのだろうと思っている.

私のような確率弱者は、ガウス分布になるように学習をさせますって言っている割にどう結果を見てもガウス分布に沿って点在しているようには見えないじゃん!って感じ、最初困惑するのではないのだろうか.

なのでここでは1つの数字だけ3だけをVAEで学習するようにして可視化します.

 

3の数字だけ学習

3だけを訓練画像にして3だけをテスト画像にし、中間層(隠れ層)の次元数をプロットしやすいように2にした.そしてプロットしたものが以下.

(https://urusulambda.github.io/products/vaeae.html に学習過程のgif)

3digits.png

この画像では確かに歪ではあるものの(0, 0)を中心にガウス分布上に広がるようになっている.

理解が正しければ、
中央の(0, 0)近辺に置かれる画像が最も頻繁に現れやすいものになっているはずで、
周りに行くほど頻度が少ない3が割り当てられるはず???

Variational AutoEncoderは画像を上記のようなどこかに割り当てるときに中央には同様な結果に変換がされやすいものをそうでない周辺には同じような結果には変換されないものがくるという認識.だから分布はガウス分布に沿ってかつ中央の方の3はありそうな3となって周辺に行くほど汚い字になっている.

確かに周辺は頻度が少ないというか歪な3が多い気はする.左上の方なんて大分にょろにょろした3ばかり並んでる.

3の数字だけ学習してそうじゃないのも入れてみる

ってことで3の画像だけで訓練することには変わりないけど、3じゃない数字も入力したときにどうなるかもみる

svae.png

うーん、微妙.

確かに3はガウス分布を保ててているし、「0」「1」「9」「7」は外側に飛ばされている.

一方で「2」や「6」が大分中心に現れてしまっている.学習回数が足りない、チューニングされていないということも原因かもしれない.

もうちょい学習させてみてどうなるかとか次元数を増やしてPCAなどでプロットしたらどうなるかなどもいずれ確認したい.

 

About the author

コメントを残す