Re:ゼロから始めるML生活

ミスよりグズを嫌え

【論文メモ:StarGAN】StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation

論文

https://arxiv.org/abs/1711.09020

著者

Yunjey Choi, Minje Choi, Munyoung Kim, Jung-Woo Ha, Sunghun Kim, Jaegul Choo

背景

近年のimage-to-image translationの進歩がめざましい一方で、拡張性に乏しい問題がある。 特に、2種類以上の変換にはペアごとに独立したモデルが必要だったため、多対多の変換は難しいという問題がある。

目的とアプローチ

目的

  • image-to-image translationにおける拡張性の向上

アプローチ

  • StarGAN
    • 1つのGenerator-Discriminatorでマルチドメインに対応

f:id:nogawanogawa:20181120151835j:plain:w350

提案手法

StarGAN

StarGANの学習の概念図を下記に示す。

f:id:nogawanogawa:20181120151826j:plain:w600

StartGANのDiscriminatorは、教師画像とGeneratorが生成した偽画像の二種類を入力とし、①画像が本物かどうかと②ドメインがただしいかどうかを判定する。 Generatorは入力としてもととなる画像と変換先のドメインを表すベクトルを入力とし、該当ドメインの画像を出力する。 さらに、Generatorによって生成された画像を元画像のドメインに向かって再度Generatorで変換し、元画像と生成画像との差分も考慮する。

Loss

StarGANでの損失関数を下記に示す。

L_{D}=-L_{adv} + \lambda_{cls}L_{cls}^{r}

L_{G}=L_{adv} + \lambda_{cls}L_{cls}^{f} + \lambda_{rvc}L_{rvc}

Adversarial Loss

Discriminatorの入力が教師データか生成した画像かを判定するLossを下記に示す。

L_{adv} = \mathbb{E}_{x}[logD_{src}(x)] + \mathbb{E}_{x, c}[log(1-D_{src}(G(x, c)))]

先行研究のWGANを参考に、 Adversarial Lossを下記のように改良したものを使用する。

L_{adv} = \mathbb{E}_{x}[D_{src}(x)] - \mathbb{E}_{x, c}[ D_{src}(x, c)] - \lambda_{gp} \mathbb{E}_{\hat{x}}[(||\nabla _{\hat{x}}D_{src}(\hat{x})|| _{2}-1)^{2} ]

Domain Classification Loss

Discriminarの入力画像が正しいドメインに含まれているかを表すLossについて、教師データを使用してDiscriminatorが正しいドメインを判定するためのLossについて下記に示す。

 L_{cls}^{r} = \mathbb{E}_{x, c'}[-logD_{cls}(c'|x)]

一方で、Generatorが正しいドメインに向けて画像を生成できているかを判定するLossを下記に示す。

 L_{cls}^{f} = \mathbb{E} _{x, c'}[-logD_{cls}(c|G(x, c))]

Reconstruction Loss

入力をGeneratorによって2回変換することでもとのドメインに戻した際の入力画像との差分を表すLossを下記に示す。

 L_{rec} = \mathbb{E}_{x, c, c'}[||x - G(G(x, c), c')||]

Mask Vector

異なるデータセットを取り扱う際には、ラベルの形式も問題になる。 例えば、CelebAでは髪の色や性別のラベルが付いているが、RaFDでは"嬉しさ"や"怒り"などの感情のラベルがついており、それぞれ異なる形式のラベルとなっている。

StarGANではこのラベルの形式の不整合に対応するために、Mask Vectorを下記のように定義する。

c = [c_1, \cdots, c_n, m ]

c_iはデータセットごとに定義されたラベルを意味し、mはどのc_iを使用するかを表すone-hot vecotorとなっている。 また、使用しないラベルについては0-Fillする。

f:id:nogawanogawa:20181120222119j:plain

これによって、使用するデータセットに関するドメイン情報についてのみ学習し、別ドメインの情報を無視した学習が可能になる。

評価

Setup

  • Baseline Models
    • DIAT
    • CycleGAN
    • IcGAN
  • Dataset
    • CelebA
    • RaFD
  • ネットワーク f:id:nogawanogawa:20181120213050j:plain

f:id:nogawanogawa:20181120213101j:plain

  • Optimizer

    • Adam
      • β1 = 0.5
      • β2 = 0.999
  • data flip

    • 0.5
  • batch size
    • 16
  • learning rate
    • 0.0001 -> 0 (linear decay)

CelebA

CelebAに関する評価結果を下記に示す。

f:id:nogawanogawa:20181120173141j:plain

クロスドメイン(1,2段目)のものと比較して、提案手法が精度が高い画像が生成できている。 原因の一つとには、クロスドメインでは過学習によりドメイン変換に失敗していることが挙げられる。

マルチドメイン(3段目)と比較すると、表情等の保存ができている。

定量的に評価するために、Amazon Mechanical Turk (AMT)に関して評価する。 これは、比較対象の4種類の手法の画像をシャッフルし、それを被験者に見せてどれが最も正しく変換されているかを選んでもらった統計値を表す。

クロスドメインの値を下記に示す。

f:id:nogawanogawa:20181120173229j:plain:w400

StarGANでは他の手法よりもAMTの値が高く、先行研究より優れた結果が得られている。

マルチドメインの値を下記に示す。

f:id:nogawanogawa:20181120173238j:plain:w400

クロスドメインでは性別に関する変換がDIATと同程度であるが、マルチドメインでの評価で性別と年齢を同時に変換する際には、提案手法の複数ドメインを考慮した変換の有効性が顕著に現れている。

RaFD

次に、RaFDで表情の変換に関する評価を行う。 生成された画像を下記に示す。

f:id:nogawanogawa:20181120173153j:plain

先行研究と遜色ない画像が提案手法で実現されていることがわかる。 この時、学習データとしてDIATやCycleGANでは各ドメインごとに500枚の画像を使用したが、提案手法では合計4,000枚の画像を使用し、それでも同程度の変換が可能であることがわかる。

次に、ResNetによるドメイン分類のエラー率に関する評価結果を下記に示す。

f:id:nogawanogawa:20181120173247j:plain:w400

使用したResNetでは99.55%の分類の精度が確認されている。 提案手法はエラー率が最低になっており、先行研究より高精度の画像変換がなされていることがわかる。

CelebA+RaFD

最後に複数のデータセットを使用した際の評価を行う。

単体のデータセットで学習した場合と、複数のデータセットを使用して学習した場合の生成画像を下記に示す。

f:id:nogawanogawa:20181120173203j:plain

単体のデータセットで学習した場合のほうが、やや画像がぼやけている。これは、学習時にCelebAを使用したかどうかが影響していると考えられる。

mask vectorの妥当性について検討する。 下記にmask vectorの有無で分けた画像を示す。

f:id:nogawanogawa:20181120173215j:plain

上段が正しくmask vectorを適用した場合、下段が不適切に使用した場合を示す。 上段ではyoung、下段ではoldの特徴が付加されており、正しくmask vectorにより、別ドメインの特徴が反映されていることがわかる。

結論

マルチドメインでのImage-to-Image変換を可能にするためにStarGANを提案した。 StarGANにより、先行研究よりも優れた画質の画像を生成が可能になった。 またシンプルなマスクベクトルを使用することで、異なるラベルのデータセットに対しても提案手法が適用可能であることを示した。

感想

なんとなくのイメージは

StarGAN = DiscoGAN (or CycleGAN) + CGAN

ですね。これをLINEは何に使うんでしょうね。気になります。