前にこんな感じのことをつぶやいていました。
pytorch lightningでcross validation書くときって、こんな感じに書かなきゃダメなの?もっとシンプルに書けるもんなのか?(全然わかってない顔)https://t.co/R3OzpZC0lq
— 野川の側 (@nogawanogawa) 2022年11月13日
なんかいいやり方はないものかと考えつつ、一旦ここで書いてあるやり方で書いてみようと思いました。 やってみたらなんかうまく行かなかったので、多分なんかおかしいですが、自分の備忘録として残しておきます。
ざっくりコードリーディング
下記のコードを参考にしています。
dataclass
trainとvalidationのデータセットの分割はdataclassで定義しているようです。
@dataclass class MNISTKFoldDataModule(BaseKFoldDataModule): train_dataset: Optional[Dataset] = None test_dataset: Optional[Dataset] = None train_fold: Optional[Dataset] = None val_fold: Optional[Dataset] = None def prepare_data(self) -> None: # download the data. MNIST(DATASETS_PATH, download=True, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) def setup(self, stage: str) -> None: # load the data dataset = MNIST(DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) self.train_dataset, self.test_dataset = random_split(dataset, [50000, 10000]) def setup_folds(self, num_folds: int) -> None: self.num_folds = num_folds self.splits = [split for split in KFold(num_folds).split(range(len(self.train_dataset)))] def setup_fold_index(self, fold_index: int) -> None: train_indices, val_indices = self.splits[fold_index] self.train_fold = Subset(self.train_dataset, train_indices) self.val_fold = Subset(self.train_dataset, val_indices) def train_dataloader(self) -> DataLoader: return DataLoader(self.train_fold) def val_dataloader(self) -> DataLoader: return DataLoader(self.val_fold) def test_dataloader(self) -> DataLoader: return DataLoader(self.test_dataset) def __post_init__(cls): super().__init__()
今回関係するのは、setup_folds
とsetup_folds_indices
ですかね。
setup_folds
でvalidation用のデータのindexを決める- そのindexを使用して
setup_fold_index
でtrainとvalidationのdatasetを分割・作成
って流れを想定しているようです。
複数モデルの混ぜ合わせ
最終的な推論時に、fold毎の結果をまとめるためのクラスとして、EnsembleVotingModel
が用意されています。
class EnsembleVotingModel(LightningModule): def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str]) -> None: super().__init__() # Create `num_folds` models with their associated fold weights self.models = torch.nn.ModuleList([model_cls.load_from_checkpoint(p) for p in checkpoint_paths]) self.test_acc = Accuracy() def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: # Compute the averaged predictions over the `num_folds` models. logits = torch.stack([m(batch[0]) for m in self.models]).mean(0) loss = F.nll_loss(logits, batch[1]) self.test_acc(logits, batch[1]) self.log("test_acc", self.test_acc) self.log("test_loss", loss)
test_step
で、各モデルの推論結果の平均を取るようになっていますね。
その後はaccを計算しているだけなのであんま関係なさそう。
Loop Module
CVを切る際にはTrainerをそのままは使用できない(そのままでは書けない)っぽいのでLoop Moduleを継承して記述していますね。
ループは、Lightningの中核をなすデフォルトの勾配降下最適化ループを、上級ユーザが別の最適化パラダイムに置き換えることを可能にします。 Loops — PyTorch Lightning 1.8.6 documentation (DeepLで翻訳)
pytorch-lightning.readthedocs.io
class KFoldLoop(Loop): def __init__(self, num_folds: int, export_path: str) -> None: super().__init__() self.num_folds = num_folds self.current_fold: int = 0 self.export_path = export_path @property def done(self) -> bool: return self.current_fold >= self.num_folds def connect(self, fit_loop: FitLoop) -> None: self.fit_loop = fit_loop def reset(self) -> None: """Nothing to reset in this loop.""" def on_run_start(self, *args: Any, **kwargs: Any) -> None: """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the model.""" assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) self.trainer.datamodule.setup_folds(self.num_folds) self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict()) def on_advance_start(self, *args: Any, **kwargs: Any) -> None: """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance.""" print(f"STARTING FOLD {self.current_fold}") assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) self.trainer.datamodule.setup_fold_index(self.current_fold) def advance(self, *args: Any, **kwargs: Any) -> None: """Used to the run a fitting and testing on the current hold.""" self._reset_fitting() # requires to reset the tracking stage. self.fit_loop.run() self._reset_testing() # requires to reset the tracking stage. # the test loop normally expects the model to be the pure LightningModule, but since we are running the # test loop during fitting, we need to temporarily unpack the wrapped module wrapped_model = self.trainer.strategy.model self.trainer.strategy.model = self.trainer.strategy.lightning_module self.trainer.test_loop.run() self.trainer.strategy.model = wrapped_model self.current_fold += 1 # increment fold tracking number. def on_advance_end(self) -> None: """Used to save the weights of the current fold and reset the LightningModule and its optimizers.""" self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt")) # restore the original weights + optimizers and schedulers. self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict) self.trainer.strategy.setup_optimizers(self.trainer) self.replace(fit_loop=FitLoop) def on_run_end(self) -> None: """Used to compute the performance of the ensemble model on the test set.""" checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)] voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths) voting_model.trainer = self.trainer # This requires to connect the new model and move it the right device. self.trainer.strategy.connect(voting_model) self.trainer.strategy.model_to_device() self.trainer.test_loop.run() def on_save_checkpoint(self) -> Dict[str, int]: return {"current_fold": self.current_fold} def on_load_checkpoint(self, state_dict: Dict) -> None: self.current_fold = state_dict["current_fold"] def _reset_fitting(self) -> None: self.trainer.reset_train_dataloader() self.trainer.reset_val_dataloader() self.trainer.state.fn = TrainerFn.FITTING self.trainer.training = True def _reset_testing(self) -> None: self.trainer.reset_test_dataloader() self.trainer.state.fn = TrainerFn.TESTING self.trainer.testing = True def __getattr__(self, key) -> Any: # requires to be overridden as attributes of the wrapped loop are being accessed. if key not in self.__dict__: return getattr(self.fit_loop, key) return self.__dict__[key] def __setstate__(self, state: Dict[str, Any]) -> None: self.__dict__.update(state)
Loop内部のrunのメソッドはこんな感じの順番でメソッドを呼んでいくらしいです。
# class Loop: # # # # def run(self, ...): # # self.reset(...) # # self.on_run_start(...) # # # # while not self.done: # # self.on_advance_start(...) # # self.advance(...) # # self.on_advance_end(...) # # # # return self.on_run_end(...)
ざっくり主要なメソッドの説明を書いてくと、こんな感じですかね。
- 上に書いてあるメインのメソッド
- reset : Loopのstateのリセット(今回は何もしていない)
- on_run_start : 学習を開始する前準備(ここでfold毎の準備)
- done : cvの終了判定
- on_advance_start : Foldの学習の実行前の設定(対象のfold番号を設定してあげてる)
- advance : Foldの学習の実行
- self.on_advance_end : Foldの学習の終了時の処理(モデルの保存、もろもろstateをリセット)
- on_run_end : 全Foldのモデルを使って推論
- その他
- connect : cvを行う関係上で必要になる内部ループの開始(Loops — PyTorch Lightning 1.8.6 documentation)
connect
は外かfit_loopを書き換えるために用意してるみたいですね。
if __name__ == "__main__": seed_everything(42) model = LitImageClassifier() datamodule = MNISTKFoldDataModule() trainer = Trainer( max_epochs=10, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, num_sanity_val_steps=0, devices=2, accelerator="cpu", strategy="ddp", ) internal_fit_loop = trainer.fit_loop trainer.fit_loop = KFoldLoop(5, export_path="./") trainer.fit_loop.connect(internal_fit_loop) trainer.fit(model, datamodule)
やってみる
雰囲気はなんとなくわかった気がするので、上のサンプルコードを適当にいじって自分でもやってみようと思います。 この辺のコード書くのは久しぶりだったので、この辺見て思い出しながら書いてました。
ログ(ここをクリック)
INFO:lightning_lite.utilities.seed:Global seed set to 42 Loaded pretrained weights for efficientnet-b7 INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] INFO:pytorch_lightning.callbacks.model_summary: | Name | Type | Params ------------------------------------------------- 0 | model | EfficientNet | 63.9 M 1 | accuracy | MulticlassAccuracy | 0 2 | criterion | CrossEntropyLoss | 0 ------------------------------------------------- 63.9 M Trainable params 0 Non-trainable params 63.9 M Total params 255.506 Total estimated model params size (MB) STARTING FOLD 0 /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py:1595: PossibleUserWarning: The number of training batches (8) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch. rank_zero_warn( Epoch 1: 100% 16/16 [00:11<00:00, 1.42it/s, loss=2.62, v_num=77] INFO:pytorch_lightning.callbacks.early_stopping:Metric val_acc improved. New best score: 0.277 INFO:pytorch_lightning.callbacks.early_stopping:Metric val_acc improved by 0.031 >= min_delta = 0.0. New best score: 0.308 INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached. Testing DataLoader 0: 100% 8/8 [00:01<00:00, 6.91it/s] ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Test metric DataLoader 0 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── test_acc 0.23529411852359772 test_loss 24.689242482185364 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── STARTING FOLD 1 Epoch 1: 100% 16/16 [00:11<00:00, 1.38it/s, loss=2.54, v_num=77] INFO:pytorch_lightning.callbacks.early_stopping:Metric val_acc improved. New best score: 0.354 INFO:pytorch_lightning.callbacks.early_stopping:Metric val_acc improved by 0.031 >= min_delta = 0.0. New best score: 0.385 INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached. Testing DataLoader 0: 100% 8/8 [00:00<00:00, 9.65it/s] ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Test metric DataLoader 0 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── test_acc 0.23529411852359772 test_loss 26.890554547309875 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── STARTING FOLD 2 Epoch 1: 100% 16/16 [00:12<00:00, 1.28it/s, loss=2.96, v_num=77] INFO:pytorch_lightning.callbacks.early_stopping:Metric val_acc improved. New best score: 0.169 INFO:pytorch_lightning.callbacks.early_stopping:Metric val_acc improved by 0.031 >= min_delta = 0.0. New best score: 0.200 INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached. Testing DataLoader 0: 100% 8/8 [00:00<00:00, 9.85it/s] ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Test metric DataLoader 0 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── test_acc 0.11764705926179886 test_loss 28.032548666000366 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── STARTING FOLD 3 Epoch 1: 100% 16/16 [00:11<00:00, 1.33it/s, loss=3.22, v_num=77] INFO:pytorch_lightning.callbacks.early_stopping:Metric val_acc improved. New best score: 0.062 INFO:pytorch_lightning.callbacks.early_stopping:Metric val_acc improved by 0.092 >= min_delta = 0.0. New best score: 0.154 INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached. Testing DataLoader 0: 100% 8/8 [00:00<00:00, 10.21it/s] ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Test metric DataLoader 0 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── test_acc 0.05882352963089943 test_loss 28.37675142288208 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── STARTING FOLD 4 Epoch 1: 100% 16/16 [00:11<00:00, 1.37it/s, loss=3.4, v_num=77] INFO:pytorch_lightning.callbacks.early_stopping:Metric val_acc improved. New best score: 0.077 INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached. Testing DataLoader 0: 100% 8/8 [00:00<00:00, 10.11it/s] ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Test metric DataLoader 0 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── test_acc 0.0 test_loss 28.572951912879944 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Loaded pretrained weights for efficientnet-b7 Loaded pretrained weights for efficientnet-b7 Loaded pretrained weights for efficientnet-b7 Loaded pretrained weights for efficientnet-b7 Loaded pretrained weights for efficientnet-b7 Testing DataLoader 0: 100% 8/8 [00:02<00:00, 3.18it/s] ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Test metric DataLoader 0 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── test_acc 0.0
実行ログを見てみると、
- accがだんだん低下している
- 多分Optimizerかなんかがfoldごとに初期化できてない
- 最後のAccが0
- metrics(acc)がちゃんと計算できてない
ってのがわかったので、なんかおかしいです。 ここまでやってうまくいかなかったので、PyTorchそのままで書けって神のお告げだと思い、断念しました。
参考文献
下記の文献を参考にさせていただきました。
- lightning/kfold.py at master · Lightning-AI/lightning · GitHub
- Loops — PyTorch Lightning 1.8.6 documentation
感想
書くだけ書いてみて、いい感じに記述できなかったので、失敗したまま放置しました。 いい感じに書ける方いましたら、誰か教えて下さい。
書き方が悪いのか、それともそもそもそれを想定していないのかはわかりませんが、個人的にはこれだったら生PyTorch書いたほうがシンプルな気がしたので、しばらくはそっちでやろうという気持ちになったのでもう良いです。