Re:ゼロから始めるML生活

どちらかといえばエミリア派です

【論文メモ:DINO】Emerging Properties in Self-Supervised Vision Transformers

タイトルの論文を読んでみたので、内容に関する雑なメモです。

tl;dr;

  • NLPでTransformerが成功している背景にあるのは、教師なし学習による事前学習の手法がうまくいったことが一因であると考えられる
  • 一方で画像分野ではTransformerは苦戦している感じがあり、事前学習がポイントになると考えられる
  • DINOはラベルなしの画像データによる事前学習の手法で、これによりラベルなしのデータによるTransformerの事前学習がうまくいって全体の性能が良くなる

論文

arxiv.org

著者

Mathilde Caron, Hugo Touvron, Ishan Misra, Hervé Jégou, Julien Mairal, Piotr Bojanowski, Armand Joulin

背景

近年、画像認識の界隈では畳み込み演算を駆使したCNNの対抗馬としてTransformerが注目されています。

もともとTransformerはNLPの世界で誕生したもので、強力な事前学習モデルからの転移学習によって多様なタスクで高い成果を達成できることが特徴となっています。 しかし、画像分野でVision Transformerが注目されているものの、NLPの世界ほどTransformerとして目立った功績が達成できていないのが実際です。

その主な原因として、自己教師あり学習(self-supervised learning)の難しさにあると考えられます。 Transformerは一般に大量の教師データが必要になることが多い一方、ラベル付き画像データを十分な量用意することは困難であるケースも多いかと思います。 BERTではテキスト中の単語をマスクし、その単語を予測する自己教師あり学習を事前学習として大量に行うことで、高い成果を達成しています。

つまり、NLPの世界で言うMasked Language Modelingで行われているような事前学習を、画像でやったらうまくいくのでは?というのがモチベーションとしてあるみたいです。

目的・アプローチ

目的

  • 画像認識の分野におけるTransformerの有効な自己教師あり学習の手法の考案

アプローチ

  • DINO (self-distillation with no labels)
    • ラベルのないMean Teacherの自己蒸留

DINO

特徴

DINO (self-distillation with no labels)と言っているだけあって、基本的な蒸留の考え方をラベルなしのデータに応用するという方針をとっています。 そのため、ラベル付きのデータなしに事前学習を行うことができます。

構造

DINOの構造は下記のようになっています。

f:id:nogawanogawa:20210801222801j:plain:w400
Emerging Properties in Self-Supervised Vision Transformersより

教師・生徒ネットワークの二種類が使用されます。 これら2つは内部的には同じ構造のネットワークを使用しています。

Augmentation

教師と生徒のネットワークは構造は同じですが、入力される画像は異なっています。 教師ネットワークに対してはglobal(例:元画像の50%以上を使用して切り出したもの)のみ、生徒ネットワークにはglobalに加えてlocal(例:元画像の50%未満を使用して切り出したもの)を加えて入力します。

これにより、教師と生徒のネットワークで別々の学習データが入力されることになります。

学習

ネットワーク全体を使用して、教師・生徒ネットワークの出力が同じになることを目的関数とします。


\displaystyle{
min_{\theta_{s}} H (P_t(x), P_s(x))
}

ここでHはクロスエントロピーを表しています。

これとAugmentationを考慮すると、


\displaystyle{
min_{\theta_{s}} \sum_{x \in {x^{g}_{1}, x^{g}_{2}}}  \sum_{x' \in V,  x' \neq x}   H (P_t(x), P_s(x'))
}

のような形になります。(x^{g}がグローバル画像) これにより、ラベルが無くても学習を行うことができるんですね。

生徒ネットワークの出力は


\displaystyle{
P_s(x)^{(i)} = \frac{exp(g_{\theta s} (x) ^{(i)} / \tau _{s})}{\sum^{K}_{k=1} exp(g_{\theta s} (x) ^{(k)} / \tau _{s})}
}

のようになります。 解釈としては、クラス毎の確率の出力をexpを用いることで、分布をより際立たせたもの(sharpening)を出力となっているようです。

さらに、教師側の出力は指数移動平均を用いて更新されるようです。 教師の出力には、Cで表されるcentering項を使用して補正されていきます。

def H(t, s):
    t = t.detach() # stop gradient
    s = softmax(s / tps, dim=1)
    t = softmax((t - C) / tpt, dim=1) # center + sharpen
    return - (t * log(s)).sum(dim=1).mean()

このCは下記のようになっています。


\displaystyle{
c \leftarrow mc + (1-m) \frac{1}{B} \sum^{B}_{i=1} g_{\theta t}(x_i)
}

これは出力が1つだけ高い値を取り続けるといったようにモデルが学習してしまうのを避けるために適応されます。

このような工夫により、Augmentationと合わせて教師と生徒で異なる出力がなされることになり、そのクロスエントロピーによってロスを計算して学習していきます。

学習は生徒は通常通りバックプロパゲーションによって更新されますが、教師側の出力についてはセウトのパラメータから一定割合の荷重がかけられてネットワークが更新されていきます。

gt.params = l*gt.params + (1-l)*gs.params

こんな感じに、教師側のネットワークは生徒側のネットワークによって更新されるようになっています。

評価

例の如く、評価は駆け足で確認。

他のSSLとの比較

f:id:nogawanogawa:20210804231755j:plain
Emerging Properties in Self-Supervised Vision Transformersより

上半分では同じネットワークを使用して、事前学習だけ変えた手法を比較しています。 先行研究とResNet50で同程度、ViTを使用するとLinear classifier、K-NNどちらも先行研究より上回ります。

また、下半分より、一般的にはパラメータの数を大きくしたときに高い性能が出ますが、DINOではパラメータの数が少ないもでるでも高い性能が達成できています。

画像検索

f:id:nogawanogawa:20210804231829j:plain
Emerging Properties in Self-Supervised Vision Transformersより

Oxford and Paris image retrieval datasetsを使用して検索性能を評価すると、Mean Average Precision(mAP)で確認したそうです。 すると、教師あり学習より、DINOで学習したときのほうがk-NNでの検索性能が良かったそうです。

コピー検出

f:id:nogawanogawa:20210804231918j:plain
Emerging Properties in Self-Supervised Vision Transformersより

出力のベクトルに対してコサイン類似度を取ることで、コピー検出を行った結果、これも先行研究より優位に判定ができたとのことです。

Attention

f:id:nogawanogawa:20210804232502j:plainf:id:nogawanogawa:20210804232505j:plain
Emerging Properties in Self-Supervised Vision Transformersより

ちゃんとアテンションができているように見受けられ、教師ありで学習したときより正しくアテンションできているようになっています。

転移学習

f:id:nogawanogawa:20210804232532j:plain
Emerging Properties in Self-Supervised Vision Transformersより

普通に教師あり学習したときより、DINOを使ったモデルから転移学習したほうが精度が上がっています。

参考文献

この記事を書くにあたって下記を参考にさせていただきました。

oumpy.github.io

www.youtube.com

使う際にはこちらにライブラリがあったりします。

github.com

こんな感じでつかえるんですね…

import torch
from vit_pytorch import ViT, Dino

model = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

learner = Dino(
    model,
    image_size = 256,
    hidden_layer = 'to_latent',        # hidden layer name or index, from which to extract the embedding
    projection_hidden_size = 256,      # projector network hidden dimension
    projection_layers = 4,             # number of layers in projection network
    num_classes_K = 65336,             # output logits dimensions (referenced as K in paper)
    student_temp = 0.9,                # student temperature
    teacher_temp = 0.04,               # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
    local_upper_crop_scale = 0.4,      # upper bound for local crop - 0.4 was recommended in the paper 
    global_lower_crop_scale = 0.5,     # lower bound for global crop - 0.5 was recommended in the paper
    moving_average_decay = 0.9,        # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
    center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
)

opt = torch.optim.Adam(learner.parameters(), lr = 3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average() # update moving average of teacher encoder and teacher centers

# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')

感想

atmaCupで使われていたので、勉強してみた次第です。

それにしても、すごいですね。 学習済みモデルを使えないような状況であったり、画像検索などでEmbeddingを使用する際にも有効そうですね。

こんな機会じゃないと絶対やらないところなので、良い勉強になりました。