Re:ゼロから始めるML生活

どちらかといえばエミリア派です

PyTorch Lightningでcross validationを書こうと思ったら失敗した話

前にこんな感じのことをつぶやいていました。

なんかいいやり方はないものかと考えつつ、一旦ここで書いてあるやり方で書いてみようと思いました。 やってみたらなんかうまく行かなかったので、多分なんかおかしいですが、自分の備忘録として残しておきます。

ざっくりコードリーディング

下記のコードを参考にしています。

github.com

dataclass

trainとvalidationのデータセットの分割はdataclassで定義しているようです。

@dataclass
class MNISTKFoldDataModule(BaseKFoldDataModule):

    train_dataset: Optional[Dataset] = None
    test_dataset: Optional[Dataset] = None
    train_fold: Optional[Dataset] = None
    val_fold: Optional[Dataset] = None

    def prepare_data(self) -> None:
        # download the data.
        MNIST(DATASETS_PATH, download=True, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))]))

    def setup(self, stage: str) -> None:
        # load the data
        dataset = MNIST(DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))]))
        self.train_dataset, self.test_dataset = random_split(dataset, [50000, 10000])

    def setup_folds(self, num_folds: int) -> None:
        self.num_folds = num_folds
        self.splits = [split for split in KFold(num_folds).split(range(len(self.train_dataset)))]

    def setup_fold_index(self, fold_index: int) -> None:
        train_indices, val_indices = self.splits[fold_index]
        self.train_fold = Subset(self.train_dataset, train_indices)
        self.val_fold = Subset(self.train_dataset, val_indices)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train_fold)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.val_fold)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(self.test_dataset)

    def __post_init__(cls):
        super().__init__()

今回関係するのは、setup_foldssetup_folds_indicesですかね。

  • setup_foldsでvalidation用のデータのindexを決める
  • そのindexを使用してsetup_fold_indexでtrainとvalidationのdatasetを分割・作成

って流れを想定しているようです。

複数モデルの混ぜ合わせ

最終的な推論時に、fold毎の結果をまとめるためのクラスとして、EnsembleVotingModelが用意されています。

class EnsembleVotingModel(LightningModule):
    def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str]) -> None:
        super().__init__()
        # Create `num_folds` models with their associated fold weights
        self.models = torch.nn.ModuleList([model_cls.load_from_checkpoint(p) for p in checkpoint_paths])
        self.test_acc = Accuracy()

    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        # Compute the averaged predictions over the `num_folds` models.
        logits = torch.stack([m(batch[0]) for m in self.models]).mean(0)
        loss = F.nll_loss(logits, batch[1])
        self.test_acc(logits, batch[1])
        self.log("test_acc", self.test_acc)
        self.log("test_loss", loss)

test_stepで、各モデルの推論結果の平均を取るようになっていますね。 その後はaccを計算しているだけなのであんま関係なさそう。

Loop Module

CVを切る際にはTrainerをそのままは使用できない(そのままでは書けない)っぽいのでLoop Moduleを継承して記述していますね。

ループは、Lightningの中核をなすデフォルトの勾配降下最適化ループを、上級ユーザが別の最適化パラダイムに置き換えることを可能にします。 Loops — PyTorch Lightning 1.8.6 documentation (DeepLで翻訳)

pytorch-lightning.readthedocs.io

class KFoldLoop(Loop):
    def __init__(self, num_folds: int, export_path: str) -> None:
        super().__init__()
        self.num_folds = num_folds
        self.current_fold: int = 0
        self.export_path = export_path

    @property
    def done(self) -> bool:
        return self.current_fold >= self.num_folds

    def connect(self, fit_loop: FitLoop) -> None:
        self.fit_loop = fit_loop

    def reset(self) -> None:
        """Nothing to reset in this loop."""

    def on_run_start(self, *args: Any, **kwargs: Any) -> None:
        """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the
        model."""
        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
        self.trainer.datamodule.setup_folds(self.num_folds)
        self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict())

    def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
        """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance."""
        print(f"STARTING FOLD {self.current_fold}")
        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
        self.trainer.datamodule.setup_fold_index(self.current_fold)

    def advance(self, *args: Any, **kwargs: Any) -> None:
        """Used to the run a fitting and testing on the current hold."""
        self._reset_fitting()  # requires to reset the tracking stage.
        self.fit_loop.run()

        self._reset_testing()  # requires to reset the tracking stage.

        # the test loop normally expects the model to be the pure LightningModule, but since we are running the
        # test loop during fitting, we need to temporarily unpack the wrapped module
        wrapped_model = self.trainer.strategy.model
        self.trainer.strategy.model = self.trainer.strategy.lightning_module
        self.trainer.test_loop.run()
        self.trainer.strategy.model = wrapped_model
        self.current_fold += 1  # increment fold tracking number.

    def on_advance_end(self) -> None:
        """Used to save the weights of the current fold and reset the LightningModule and its optimizers."""
        self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt"))
        # restore the original weights + optimizers and schedulers.
        self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict)
        self.trainer.strategy.setup_optimizers(self.trainer)
        self.replace(fit_loop=FitLoop)

    def on_run_end(self) -> None:
        """Used to compute the performance of the ensemble model on the test set."""
        checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)]
        voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths)
        voting_model.trainer = self.trainer
        # This requires to connect the new model and move it the right device.
        self.trainer.strategy.connect(voting_model)
        self.trainer.strategy.model_to_device()
        self.trainer.test_loop.run()

    def on_save_checkpoint(self) -> Dict[str, int]:
        return {"current_fold": self.current_fold}

    def on_load_checkpoint(self, state_dict: Dict) -> None:
        self.current_fold = state_dict["current_fold"]

    def _reset_fitting(self) -> None:
        self.trainer.reset_train_dataloader()
        self.trainer.reset_val_dataloader()
        self.trainer.state.fn = TrainerFn.FITTING
        self.trainer.training = True

    def _reset_testing(self) -> None:
        self.trainer.reset_test_dataloader()
        self.trainer.state.fn = TrainerFn.TESTING
        self.trainer.testing = True

    def __getattr__(self, key) -> Any:
        # requires to be overridden as attributes of the wrapped loop are being accessed.
        if key not in self.__dict__:
            return getattr(self.fit_loop, key)
        return self.__dict__[key]

    def __setstate__(self, state: Dict[str, Any]) -> None:
        self.__dict__.update(state)

Loop内部のrunのメソッドはこんな感じの順番でメソッドを呼んでいくらしいです。

# class Loop:                                                                               #
#                                                                                           #
#   def run(self, ...):                                                                     #
#       self.reset(...)                                                                     #
#       self.on_run_start(...)                                                              #
#                                                                                           #
#       while not self.done:                                                                #
#           self.on_advance_start(...)                                                      #
#           self.advance(...)                                                               #
#           self.on_advance_end(...)                                                        #
#                                                                                           #
#       return self.on_run_end(...)     

ざっくり主要なメソッドの説明を書いてくと、こんな感じですかね。

  • 上に書いてあるメインのメソッド
    • reset : Loopのstateのリセット(今回は何もしていない)
    • on_run_start : 学習を開始する前準備(ここでfold毎の準備)
    • done : cvの終了判定
    • on_advance_start : Foldの学習の実行前の設定(対象のfold番号を設定してあげてる)
    • advance : Foldの学習の実行
    • self.on_advance_end : Foldの学習の終了時の処理(モデルの保存、もろもろstateをリセット)
    • on_run_end : 全Foldのモデルを使って推論
  • その他

connect は外かfit_loopを書き換えるために用意してるみたいですね。

if __name__ == "__main__":
    seed_everything(42)
    model = LitImageClassifier()
    datamodule = MNISTKFoldDataModule()
    trainer = Trainer(
        max_epochs=10,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        num_sanity_val_steps=0,
        devices=2,
        accelerator="cpu",
        strategy="ddp",
    )
    internal_fit_loop = trainer.fit_loop
    trainer.fit_loop = KFoldLoop(5, export_path="./")
    trainer.fit_loop.connect(internal_fit_loop)
    trainer.fit(model, datamodule)

やってみる

雰囲気はなんとなくわかった気がするので、上のサンプルコードを適当にいじって自分でもやってみようと思います。 この辺のコード書くのは久しぶりだったので、この辺見て思い出しながら書いてました。

www.nogawanogawa.com

colab.research.google.com

ログ(ここをクリック)

INFO:lightning_lite.utilities.seed:Global seed set to 42
Loaded pretrained weights for efficientnet-b7
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type               | Params
-------------------------------------------------
0 | model     | EfficientNet       | 63.9 M
1 | accuracy  | MulticlassAccuracy | 0     
2 | criterion | CrossEntropyLoss   | 0     
-------------------------------------------------
63.9 M    Trainable params
0         Non-trainable params
63.9 M    Total params
255.506   Total estimated model params size (MB)
STARTING FOLD 0
/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py:1595: PossibleUserWarning: The number of training batches (8) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
Epoch 1: 100%
16/16 [00:11<00:00, 1.42it/s, loss=2.62, v_num=77]
INFO:pytorch_lightning.callbacks.early_stopping:Metric val_acc improved. New best score: 0.277
INFO:pytorch_lightning.callbacks.early_stopping:Metric val_acc improved by 0.031 >= min_delta = 0.0. New best score: 0.308
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.
Testing DataLoader 0: 100%
8/8 [00:01<00:00, 6.91it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.23529411852359772
        test_loss           24.689242482185364
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
STARTING FOLD 1
Epoch 1: 100%
16/16 [00:11<00:00, 1.38it/s, loss=2.54, v_num=77]
INFO:pytorch_lightning.callbacks.early_stopping:Metric val_acc improved. New best score: 0.354
INFO:pytorch_lightning.callbacks.early_stopping:Metric val_acc improved by 0.031 >= min_delta = 0.0. New best score: 0.385
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.
Testing DataLoader 0: 100%
8/8 [00:00<00:00, 9.65it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.23529411852359772
        test_loss           26.890554547309875
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
STARTING FOLD 2
Epoch 1: 100%
16/16 [00:12<00:00, 1.28it/s, loss=2.96, v_num=77]
INFO:pytorch_lightning.callbacks.early_stopping:Metric val_acc improved. New best score: 0.169
INFO:pytorch_lightning.callbacks.early_stopping:Metric val_acc improved by 0.031 >= min_delta = 0.0. New best score: 0.200
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.
Testing DataLoader 0: 100%
8/8 [00:00<00:00, 9.85it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.11764705926179886
        test_loss           28.032548666000366
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
STARTING FOLD 3
Epoch 1: 100%
16/16 [00:11<00:00, 1.33it/s, loss=3.22, v_num=77]
INFO:pytorch_lightning.callbacks.early_stopping:Metric val_acc improved. New best score: 0.062
INFO:pytorch_lightning.callbacks.early_stopping:Metric val_acc improved by 0.092 >= min_delta = 0.0. New best score: 0.154
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.
Testing DataLoader 0: 100%
8/8 [00:00<00:00, 10.21it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.05882352963089943
        test_loss            28.37675142288208
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
STARTING FOLD 4
Epoch 1: 100%
16/16 [00:11<00:00, 1.37it/s, loss=3.4, v_num=77]
INFO:pytorch_lightning.callbacks.early_stopping:Metric val_acc improved. New best score: 0.077
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.
Testing DataLoader 0: 100%
8/8 [00:00<00:00, 10.11it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc                    0.0
        test_loss           28.572951912879944
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Loaded pretrained weights for efficientnet-b7
Loaded pretrained weights for efficientnet-b7
Loaded pretrained weights for efficientnet-b7
Loaded pretrained weights for efficientnet-b7
Loaded pretrained weights for efficientnet-b7
Testing DataLoader 0: 100%
8/8 [00:02<00:00, 3.18it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc                    0.0

実行ログを見てみると、

  • accがだんだん低下している
    • 多分Optimizerかなんかがfoldごとに初期化できてない
  • 最後のAccが0
    • metrics(acc)がちゃんと計算できてない

ってのがわかったので、なんかおかしいです。 ここまでやってうまくいかなかったので、PyTorchそのままで書けって神のお告げだと思い、断念しました。

参考文献

下記の文献を参考にさせていただきました。

感想

書くだけ書いてみて、いい感じに記述できなかったので、失敗したまま放置しました。 いい感じに書ける方いましたら、誰か教えて下さい。

書き方が悪いのか、それともそもそもそれを想定していないのかはわかりませんが、個人的にはこれだったら生PyTorch書いたほうがシンプルな気がしたので、しばらくはそっちでやろうという気持ちになったのでもう良いです。