タイトルの論文を読んでみたので、内容に関する雑なメモです。
tl;dr;
- NLPでTransformerが成功している背景にあるのは、教師なし学習による事前学習の手法がうまくいったことが一因であると考えられる
- 一方で画像分野ではTransformerは苦戦している感じがあり、事前学習がポイントになると考えられる
- DINOはラベルなしの画像データによる事前学習の手法で、これによりラベルなしのデータによるTransformerの事前学習がうまくいって全体の性能が良くなる
論文
著者
Mathilde Caron, Hugo Touvron, Ishan Misra, Hervé Jégou, Julien Mairal, Piotr Bojanowski, Armand Joulin
背景
近年、画像認識の界隈では畳み込み演算を駆使したCNNの対抗馬としてTransformerが注目されています。
もともとTransformerはNLPの世界で誕生したもので、強力な事前学習モデルからの転移学習によって多様なタスクで高い成果を達成できることが特徴となっています。 しかし、画像分野でVision Transformerが注目されているものの、NLPの世界ほどTransformerとして目立った功績が達成できていないのが実際です。
その主な原因として、自己教師あり学習(self-supervised learning)の難しさにあると考えられます。 Transformerは一般に大量の教師データが必要になることが多い一方、ラベル付き画像データを十分な量用意することは困難であるケースも多いかと思います。 BERTではテキスト中の単語をマスクし、その単語を予測する自己教師あり学習を事前学習として大量に行うことで、高い成果を達成しています。
つまり、NLPの世界で言うMasked Language Modelingで行われているような事前学習を、画像でやったらうまくいくのでは?というのがモチベーションとしてあるみたいです。
目的・アプローチ
目的
- 画像認識の分野におけるTransformerの有効な自己教師あり学習の手法の考案
アプローチ
- DINO (self-distillation with no labels)
- ラベルのないMean Teacherの自己蒸留
DINO
特徴
DINO (self-distillation with no labels)と言っているだけあって、基本的な蒸留の考え方をラベルなしのデータに応用するという方針をとっています。 そのため、ラベル付きのデータなしに事前学習を行うことができます。
構造
DINOの構造は下記のようになっています。
教師・生徒ネットワークの二種類が使用されます。 これら2つは内部的には同じ構造のネットワークを使用しています。
Augmentation
教師と生徒のネットワークは構造は同じですが、入力される画像は異なっています。 教師ネットワークに対してはglobal(例:元画像の50%以上を使用して切り出したもの)のみ、生徒ネットワークにはglobalに加えてlocal(例:元画像の50%未満を使用して切り出したもの)を加えて入力します。
これにより、教師と生徒のネットワークで別々の学習データが入力されることになります。
学習
ネットワーク全体を使用して、教師・生徒ネットワークの出力が同じになることを目的関数とします。
ここでHはクロスエントロピーを表しています。
これとAugmentationを考慮すると、
のような形になります。(がグローバル画像)
これにより、ラベルが無くても学習を行うことができるんですね。
生徒ネットワークの出力は
のようになります。 解釈としては、クラス毎の確率の出力をexpを用いることで、分布をより際立たせたもの(sharpening)を出力となっているようです。
さらに、教師側の出力は指数移動平均を用いて更新されるようです。 教師の出力には、Cで表されるcentering項を使用して補正されていきます。
def H(t, s): t = t.detach() # stop gradient s = softmax(s / tps, dim=1) t = softmax((t - C) / tpt, dim=1) # center + sharpen return - (t * log(s)).sum(dim=1).mean()
このCは下記のようになっています。
これは出力が1つだけ高い値を取り続けるといったようにモデルが学習してしまうのを避けるために適応されます。
このような工夫により、Augmentationと合わせて教師と生徒で異なる出力がなされることになり、そのクロスエントロピーによってロスを計算して学習していきます。
学習は生徒は通常通りバックプロパゲーションによって更新されますが、教師側の出力についてはセウトのパラメータから一定割合の荷重がかけられてネットワークが更新されていきます。
gt.params = l*gt.params + (1-l)*gs.params
こんな感じに、教師側のネットワークは生徒側のネットワークによって更新されるようになっています。
評価
例の如く、評価は駆け足で確認。
他のSSLとの比較
上半分では同じネットワークを使用して、事前学習だけ変えた手法を比較しています。 先行研究とResNet50で同程度、ViTを使用するとLinear classifier、K-NNどちらも先行研究より上回ります。
また、下半分より、一般的にはパラメータの数を大きくしたときに高い性能が出ますが、DINOではパラメータの数が少ないもでるでも高い性能が達成できています。
画像検索
Oxford and Paris image retrieval datasetsを使用して検索性能を評価すると、Mean Average Precision(mAP)で確認したそうです。 すると、教師あり学習より、DINOで学習したときのほうがk-NNでの検索性能が良かったそうです。
コピー検出
出力のベクトルに対してコサイン類似度を取ることで、コピー検出を行った結果、これも先行研究より優位に判定ができたとのことです。
Attention
ちゃんとアテンションができているように見受けられ、教師ありで学習したときより正しくアテンションできているようになっています。
転移学習
普通に教師あり学習したときより、DINOを使ったモデルから転移学習したほうが精度が上がっています。
参考文献
この記事を書くにあたって下記を参考にさせていただきました。
使う際にはこちらにライブラリがあったりします。
こんな感じでつかえるんですね…
import torch from vit_pytorch import ViT, Dino model = ViT( image_size = 256, patch_size = 32, num_classes = 1000, dim = 1024, depth = 6, heads = 8, mlp_dim = 2048 ) learner = Dino( model, image_size = 256, hidden_layer = 'to_latent', # hidden layer name or index, from which to extract the embedding projection_hidden_size = 256, # projector network hidden dimension projection_layers = 4, # number of layers in projection network num_classes_K = 65336, # output logits dimensions (referenced as K in paper) student_temp = 0.9, # student temperature teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok ) opt = torch.optim.Adam(learner.parameters(), lr = 3e-4) def sample_unlabelled_images(): return torch.randn(20, 3, 256, 256) for _ in range(100): images = sample_unlabelled_images() loss = learner(images) opt.zero_grad() loss.backward() opt.step() learner.update_moving_average() # update moving average of teacher encoder and teacher centers # save your improved network torch.save(model.state_dict(), './pretrained-net.pt')
感想
atmaCupで使われていたので、勉強してみた次第です。
それにしても、すごいですね。 学習済みモデルを使えないような状況であったり、画像検索などでEmbeddingを使用する際にも有効そうですね。
こんな機会じゃないと絶対やらないところなので、良い勉強になりました。