699の名前——機械学習における「Hello World」

このデータセットを開いたとき、ふと頭をよぎった疑問がある。class列の2と4は、それぞれ「良性」と「悪性」を意味する——だが、その二つの数字の背後に座っているのは、誰なのか。

699人の女性たち

ウィスコンシン乳がんデータセット(Wisconsin Breast Cancer Dataset)は、機械学習における「Hello World」と言える存在だ。699件のサンプル、各サンプルには10個の細胞特徴量がある。clump_thickness(塊の厚さ)、uniform_cell_size(細胞サイズの均一性)、uniform_cell_shape(細胞形状の均一性)、marginal_adhesion(辺縁接着性)、single_epithelial_size(単上皮細胞サイズ)、bare_nuclei(裸核)、bland_chromatin(淡明クロマチン)、normal_nucleoli(正常核小体)、mitoses(核分裂)。

これらの特徴量は抽象的に聞こえるが、病理医が顕微鏡越しに実際に見てきたものだ。

細胞の層数を数え、細胞の直径を測り、細胞の縁が滑らかかギザギザかを観察し、細胞核が何色に染まっているかを見る。1から10までの数字のひとつひとつが、人の目と人の脳による判断——経験の蓄積であり、訓練の痕跡なのだ。

データの中の一行一行は、かつて診察室に座り、生検の結果を待っていた、ひとりの女性だった。

細胞の言葉

病理医が細胞を見るまなざしは、私たちが人間を見るときとよく似ている。

形を見る——丸みを帯びていれば良性、異形であれば悪性。境界を見る——境界が明瞭なら通常は良性、癒着していれば悪性であることが多い。染色を見る——均一に染まっていれば比較的正常、クロマチンが濃縮して黒ずんでいれば懸念材料だ。

これは神秘でもなんでもない。数千例の症例が訓練した直感である。

そして機械学習がやっていることは、端的に言えばこうだ。医師の経験を迂回し、アルゴリズムによってその「直感」を数値化し、ルール化する。SVM(サポートベクターマシン)は高次元空間で最適な分割線を見つけ、良性サンプルと悪性サンプルを分離する。KNN(K近傍法)は「類は友を呼ぶ」——未知のサンプルに最も近い5つのサンプルのクラスを見て、多数決でどのクラスに属するかを決める。

KNNのK=5という値は、経験的にこの数値が概ね良好な結果を示すことがわかっているからだ。小さすぎるとノイズに影響されやすく、大きすぎると境界がぼやける。

SVMのカーネル関数は、細胞の特徴量を高次元空間に写像する。その空間では、線形分離できなかったデータが分離可能になる。

10分割交差検証

全データをそのまま訓練に使ってテストしたら、何が起きるか。

モデルは訓練データを「暗記」しただけで、本当の法則を学んでいないかもしれない。試験前に答えを丸暗記して、初見の問題で手も足も出なくなるようなものだ。

10分割交差検証(10-fold cross validation)が解決するのは、まさにこの問題である。データをランダムに10分割し、順番に9つを訓練セット、1つをテストセットとして、10回の実験を行い、最後に平均精度を取る。

結果はこうだ。KNN 96.6%、SVM 96.0%。一見すると大差ない——しかし実際には——

KNNの標準偏差は2.9%、SVMは3.3%だった。KNNのほうが安定している。

これは何を意味するか。KNNは10回の実験におけるばらつきが小さく、SVMは時に良く、時に悪くという振れ幅がある。実運用においては、安定性のほうが、たまに出る高スコアよりも重要だ。

Precision、Recall、そして命の重み

しかし、正解率(accuracy)だけが指標ではない。医療の現場では、誤診と見逃しでは代償が異なる。

Precision(適合率):悪性と予測したサンプルのうち、実際に悪性である割合。

Recall(再現率):すべての悪性サンプルのうち、正しく検出された割合。

F1スコアはこの両者の調和平均である。

SVMのレポートにおけるクラス2(良性):

SVMのレポートにおけるクラス4(悪性):

がん検診においては、Recallのほうがより重要だ。一人のがん患者を見逃すことは、一人の良性腫瘍を誤診することより危険である——後者は追加検査で済むが、前者は治療のタイミングを逃しかねない。

3%の精度誤差は、699件のサンプルにおいて約21人に相当する。21人の女性である。診察室に座り、「追加検査をお勧めします」と書かれたレポートを受け取った、21人の女性だ。

技術は確率を提示できる。しかし、彼女たちに代わって、その待ち時間の重みを引き受けることはできない。

アルゴリズムは本物の細胞を見たことがない。ただ、699人の女性の経験の中に、ひとつの参照点を見つけただけだ。


付録

データ前処理

import numpy as np
import pandas as pd
from sklearn import preprocessing, model_selection
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.metrics import classification_report, accuracy_score

# データセットの読み込み
names = ['id', 'clump_thickness', 'uniform_cell_size', 'uniform_cell_shape',
       'marginal_adhesion', 'single_epithelial_size', 'bare_nuclei',
       'bland_chromatin', 'normal_nucleoli', 'mitoses', 'class']
df = pd.read_csv('data.csv', names=names)

# 欠損データの置換
df.replace('?', -99999, inplace=True)
# 無関係な特徴量の削除
df.drop(['id'], axis=1, inplace=True)

データ分割

X = np.array(df.drop(['class'], 1))
y = np.array(df['class'])

X_train, X_test, y_train, y_test = model_selection.train_test_split(
    X, y, test_size=0.2)

モデル訓練と交差検証

seed = 8
scoring = 'accuracy'

models = []
models.append(('KNN', KNeighborsClassifier(n_neighbors = 5)))

# scikit-learn 0.22 で SVC のデフォルトパラメータが変更されました
# gamma パラメータのデフォルト値が 'auto' から 'scale' に変更されました
# models.append(('SVM', SVC()))
models.append(('SVM', SVC(gamma='auto')))

results = []
names = []

for name, model in models:
    # kfold = model_selection.KFold(n_splits=10, random_state = seed)
    # seed を指定するにはパラメータ shuffle = True が必要です
    kfold = model_selection.KFold(n_splits=10, shuffle = True, random_state = seed)
    cv_results = model_selection.cross_val_score(model, X_train, y_train, cv=kfold, scoring=scoring)
    results.append(cv_results)
    names.append(name)
    msg = "%s: %f (%f)" % (name, cv_results.mean(), cv_results.std())
    print(msg)

出力:

KNN: 0.966039 (0.029270)
SVM: 0.960649 (0.032726)

モデル予測と評価

for name, model in models:
    model.fit(X_train, y_train)
    predictions = model.predict(X_test)
    print(name)
    print(accuracy_score(y_test, predictions))
    print(classification_report(y_test, predictions))

SVM 出力:

0.9642857142857143
              precision    recall  f1-score   support

           2       1.00      0.95      0.97        95
           4       0.90      1.00      0.95        45

    accuracy                           0.96       140
   macro avg       0.95      0.97      0.96       140
weighted avg       0.97      0.96      0.96       140