StratifiedGroupKFoldの実装を読んだ
有名なアルゴリズムの実装を読む シリーズの第二回。 今回はscikit-learnのStratifiedGroupKFoldの元のなったKaggle Notebookの実装を読んだ。といっても43行のとても短いコードである。
表記
- n_fold: フォールド数
- n_groups: ユニークなグループ数
- n_labels: ユニークなラベル数
処理の概要
まず、StratifiedGroupKFoldは、以下の制約を満たすようなFold分割アルゴリズムである。
- 各foldに割り当てられたラベルの分布が全体のラベルの分布になるべく一致する
- 各foldに固有のグループが割り当てられるようにする。すなわち、異なるfoldには同一のグループに所属するサンプルが存在しない
これを実現するために、以下の処理を行っている。
- グループごとにラベルの頻度分布を計算する(n_groups x n_labels)
- 各グループについて、ラベルの頻度のばらつき(具体的には標準偏差)が大きいものから順にfold割り当て処理を行う
fold割り当て処理
- 各グループのラベルの頻度分布を試しに各foldに割り当ててみる
- fold割り当て後の評価値が最小になるfoldにそのグループを割り当てる
評価値は以下のように定義している。
- 「各foldのラベル頻度と元のラベル頻度の比」の標準偏差の平均値
コストの大きいものから、評価値(=元の分布からの隔たり)が小さくなるように割り当てていくので、貪欲法のアルゴリズムの一種である。
計算量はO(n_fold * n_groups * n_labels)
である。
scikit-learnの実装
scikit-learnの実装はこの実装をより汎用的に使えるようにするための処理が追加されている。 少数グループが割り当てられなかったfoldがあった場合に警告を出したり、乱数でshuffleかどうかを切り替える機能などが追加されている。
なお、shuffleした後で結局コストの大きいものから順に処理しているので、shuffleの効果はのグループ内でラベル頻度の標準偏差が一致するものの処理順序が変わる程度の効果しかない。