Re:ゼロから始めるML生活

どちらかといえばエミリア派です

PyTorch Metric Learningの使い方を眺める

良質なEmbeddingを作成したくなることがあって、Deep Metric Learningを試してみることにしました。 やってみたら意外と使い方にハマったので、備忘の意味で記録していこうと思います。

Deep Metric Learning

よく深層距離学習とも言われる代物です。

深層距離学習(Deep Metric Learning)とは、サンプル間の距離(metric)または類似度(similarity)に基づいてクラスごとに分離されるよう、入力データを特徴量空間への変換を学習させる手法です。 深層距離学習(Deep Metric Learning)の基礎から紹介 - OPTiM TECH BLOG

超ざっくり言うと「近くにあってほしいものを近くに、遠くに位置してほしいものを遠くにしたEmbeddingを得るための変換」を学習する技術です。

例えば、MNISTのようなラベル付きの画像があったときに、その画像には今Embeddingが与えられてない状況があったとします。 今やりたいことは、同じラベルの画像を近くに、異なるラベルの画像を遠くに配置するように画像からEmbeddingを得ることです。

これがうまくできると、その変換器を使うことで未知の画像がどのラベルに属するかを分類することができたりするので、非常に便利です。

要点

この技術、上手にできれば非常に便利なんですが、学習が非常に難しく、バカみたいにやると思ったEmbeddingを得ることができません。

この技術の要点は、サンプリング方法とLoss関数です。 よく使われるのはTripletLossとかArcFaceと呼ばれるLossで、これらを使用することで上手に学習できたりします。

PyTorch Metric Learning

PyTorchでMetric Learningを行う拡張ライブラリとしてPyTorch Metric Learningがあります。

kevinmusgrave.github.io

今回はこちらを使って簡単Metric Learningをやってみたいと思います。

Modules

PyTorch Metric Learningでは、9つの主要モジュールが提供されており、これらを使用することで効率的にMetric Learningを記述できるようになるらしいです。

https://kevinmusgrave.github.io/pytorch-metric-learning/ より引用

  • Distances
  • Losses
  • Miners
  • Reducers
  • Regularizers
  • Samplers
  • Trainers
  • Testers
  • Accuracy Calculators

詳細はドキュメントを御覧ください。

やってみる

理論の紹介とツールの紹介はだいたいこんなもんにして、早速使ってみます。

github.com

sampleコードがあって助かりますね。

MNIST using TripletMarginLoss

とりあえず使えるようにしたいので、一番上の例からやっていきます。

colab.research.google.com

主な箇所にメモを付けながら読んでいきます。

ネットワークの定義はここですね。最終的に128の長さのembeddingの出力をするらしい。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

次に、学習の定義をしています。 mining_funcという関数が使われていますが、これがMinerですね。 Minerはdataloaderで呼んできたバッチから最適なサンプルを選択するモジュールです。

mining_funcによってindices_tupleが作られています。 Lossは今回TripletMarginLossを使用しており、そこで使用されるようですね。

### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ###
def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        embeddings = model(data)
        indices_tuple = mining_func(embeddings, labels)
        loss = loss_func(embeddings, labels, indices_tuple)
        loss.backward()
        optimizer.step()
        if batch_idx % 20 == 0:
            print(
                "Epoch {} Iteration {}: Loss = {}, Number of mined triplets = {}".format(
                    epoch, batch_idx, loss, mining_func.num_triplets
                )
            )

推論時に使用する関数を定義しています。 この辺は必要になったらちゃんと読めば良さそう。

### convenient function from pytorch-metric-learning ###
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester()
    return tester.get_all_embeddings(dataset, model)


### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###
def test(train_set, test_set, model, accuracy_calculator):
    train_embeddings, train_labels = get_all_embeddings(train_set, model)
    test_embeddings, test_labels = get_all_embeddings(test_set, model)
    train_labels = train_labels.squeeze(1)
    test_labels = test_labels.squeeze(1)
    print("Computing accuracy")
    accuracies = accuracy_calculator.get_accuracy(
        test_embeddings, train_embeddings, test_labels, train_labels, False
    )
    print("Test set accuracy (Precision@1) = {}".format(accuracies["precision_at_1"]))

あとは使用するモジュールの呼び出しをしているくらいですかね。

### pytorch-metric-learning stuff ###
distance = distances.CosineSimilarity()
reducer = reducers.ThresholdReducer(low=0)
loss_func = losses.TripletMarginLoss(margin=0.2, distance=distance, reducer=reducer)
mining_func = miners.TripletMarginMiner(
    margin=0.2, distance=distance, type_of_triplets="semihard"
)
accuracy_calculator = AccuracyCalculator(include=("precision_at_1",), k=1)

その他は普通のPyTorchのコードと変わらない感じで書かれています。

最後に、ここで使用されているモジュールがどういう役割を担っているのかを確認します。 (そこまでわかれば自分でいじれそうなので)

TripletMarginMiner

PyTorch Metric Learningでは下記のように処理が行われていくことが想定されています。

Your Data --> Sampler --> Miner --> Loss --> Reducer --> Final loss value

Minerでは、学習に使用するデータセットの選択を行います。 このデータセットの作り方は分類の難易度で種類が分けられており、TripletMarginMinerでは

  • all
  • hard
  • semi-hard
  • easy

の4種類から選択することになります。 詳しくはこちらの記事がわかりやすいです。

tech-blog.optim.co.jp

TripletMarginLoss

TripletMarginLossでは、

  • anchor
  • positive
  • negative

の3種類のサンプルが発生します。 これらについて、

(anchor, positive)の距離(d_{ap})を近く、(anchor, negative)の距離(d_{an})を遠くにするようにLossが計算されます。

\displaystyle{
L_{triplet} = [ d_{ap} - d_{an} + m]_{+}
}

ThresholdReducer

Reducerではサンプル毎に計算されたLossを、backwardできるように1つの値に集約します。

ThresholdReducerでは、指定された範囲に収まる損失だけを平均することで外れ値の影響を除外します。

AccuracyCalculator

AccuracyCalculatorでは、k-meansとk-NNにを基づいて精度が算出されます。 このときの、計算条件を指定する役割をAccuracyCalculatorが担っているようです。

参考文献

下記の文献を参考にさせていただきました。

感想

Metric Learningが必要になったので大急ぎで勉強してみたメモ書きでした。 自分でなんか書くのはまた今度。