Chainerで簡単なクラス分類をしてみる
Chainerを試してみるために簡単なサンプルプログラムを動かしてみたいと思います。
まず必要なライブラリをインポートします。
import numpy as np import chainer from chainer import cuda, Function, gradient_check, Variable from chainer import optimizers, serializers, utils from chainer import Link, Chain, ChainList import chainer.functions as F import chainer.links as L
それから、学習用データを読み込みます。今回はsklearnからアヤメのデータを使って4入力(がくの長さ、幅と茎の長さ、幅)からアヤメの種類(setosa, versicolor, virginica)を分類できるように試してみます。
# Set data from sklearn import datasets iris = datasets.load_iris() X = iris.data.astype(np.float32) Y = iris.target.astype(np.float32) N = Y.size Y2 = np.zeros(3 * N).reshape(N,3).astype(np.float32) for i in range(N): Y2[i,np.int(Y[i])] = 1.0 index = np.arange(N) xtrain = X[index[index % 2 != 0],:] ytrain = Y2[index[index % 2 != 0],:] xtest = X[index[index % 2 == 0],:] yans = Y[index[index % 2 == 0]]
モデルの定義として4入力、3出力で中間層はなしのものを定義しています。入力に対して出力が簡単に結びつくような今回のケースでは中間層は必要なさそうに思います。多クラスの分類としてはソフトマックス関数を使用し誤差関数として二乗誤差を使用しています。ソフトマックスを利用した多クラス分類等では交差エントロピーを使うように思っていたのですが、この辺りはちゃんと学習して使いこなせるようにしたいと思います。
# Define model class IrisRogi(Chain): def __init__(self): super(IrisRogi, self).__init__( # 入力4軸(がくの長さ、幅と茎の長さ、幅) 出力3クラス分類(irisの種類setosa, versicolor, virginica) l1=L.Linear(4,3), ) def __call__(self,x,y): # 順伝番した結果に対してmean_squared_errorで二乗誤差求めている return F.mean_squared_error(self.fwd(x), y) def fwd(self,x): # 多クラス分類(irisの種類)のための各ユニットの出力としてソフトマックスを使用する return F.softmax(self.l1(x)) # Initialize model model = IrisRogi() # パラメータの最適化で比較的高速に良い値を出すadamを使用する optimizer = optimizers.Adam() optimizer.setup(model)
それから学習を行います。
# Learn n = int(index.size / 2) bs = 25 for j in range(5000): sffindx = np.random.permutation(n) accum_loss = None for i in range(0, n, bs): x = Variable(xtrain[sffindx[i:(i+bs) if (i+bs) < n else n]]) y = Variable(ytrain[sffindx[i:(i+bs) if (i+bs) < n else n]]) model.zerograds() # 勾配を初期化 loss = model(x,y) # 順方向に計算し誤差を算出 loss.backward() # 逆伝番で勾配の向きを計算 optimizer.update() # 逆伝番で得た勾配からパラメータを更新する
学習で得たパラメータを利用しテストをしてみます。
# Test xt = Variable(xtest, volatile='on') yy = model.fwd(xt) ans = yy.data nrow, ncol = ans.shape ok = 0 for i in range(nrow): cls = np.argmax(ans[i,:]) # print( ans[i,:], cls) if cls == yans[i]: ok += 1 print (ok, "/", nrow, " = ", (ok * 1.0)/nrow)
自分の環境で動かしてみたところ、73/75の精度で正解していたのでちゃんと動いてることは確認できます。