xgboostでKaggleの自転車需要予測をやってみた

Pocket

xgboostでの予測

はじめに

こんにちは。はんぺんです。

最近、xgboostを使う機会があったので、それについてまとめようと思います。

使うデータはKaggleにあるBike Sharing Demandです。

コードはGitにあげています。

環境は以下の通りです

  • macOS High Sierra 10.13.6
  • Python 3.6
  • Anaconda 5.6

コードの解説

Pythonで書きました。

少しずつ解説していきます。

インポート

必要なパッケージをインポートします。

データ取得

CSVからでデータを読み込んでデータの加工をします。

ここで、データおログ変換することでデータの分散が抑えられるので変換します。

データ分割

目的変数を決めてデータを分割します。

ここでは、train-validationとtestにデータをわけます。

パラメータチューニング

max_depth & learning_rate

グリッドサーチの範囲設定を行います。

今回はまずmax_depthとlearning_rateをグリッドサーチで探索します。探索の範囲は,max_depthは3~10、learning_rateは0.5~0.4くらいで私はいつもやっています。

刻み幅は時間との兼ね合いもありますが、時間があるのなら0.2とかもっと細かくしてもよいです。

この二つを決めればとりあえずOKみたいなところある思ってます。(違ったらごめんなさい)

グリッドサーチのための関数を定義します。

クロスバリデーションはcv=3(データを3つに分ける)で行います。

また、sklearnのGridSearchの仕様でrmseではなくmseしかつかえないのでそれを使います。

ルートついてるかついてないかの違いなので、影響ないと考えています。

グリッドサーチを行って結果を保存します。

データやPCの性能によって早さは異なりますが今回は私のMacbookproで1分ほどで終わりました。

この結果から、max_depth=7, learning_rate=0.1が最適なパラメータであることがわかります。

なんとなく好きなので3Dでもプロットしてみましたが、ヒートマップの方が見やすそうです。

n_estimators

次にn_estimatorsを探索します。

これは何個の木を生成して回帰をするかというパラメータです。基本的に多くするほど学習エラーは下がりますが、ある値でバリデーションエラーは悪化して行く傾向にあります。

経験からすると70 〜130くらいになることが多いです。(今回は175と多めでした)

train-errorとvalidation-errorの可視化結果が以下のようになります。

これではよくわかりませんが、拡大すると、

175が最適な値であることがわかります。

以上のクロスバリデーションの結果より、n_estimators=175を採用します。

学習

パラメータの探索結果を元に学習器を生成し、訓練させます。
モデルの保存です。

予測

訓練させた学習器を用いて予測を行います。

そしてtestデータに実際の値と予測値のカラムを追加して、評価を行なっていきます。

誤差率計算

Kaggleで使われている評価指標を使っていきます。

可視化

Feature importance

Xgboostには説明変数の中でどれが大事だったかを示す、Feature importanceというものがあります。

これを見ながらデータの特徴を考えて説明変数を追加して行くというのが一般的です。

散布図

実際の値と予測値がどれくらいあっているかを直感的にわかるようにグラフ化ました。

X軸に実績値、Y軸に予測値を取っています。

y=xの赤い線に点が乗っているほど性能の高いモデルということが言えます。

さいごに

いかがでしたでしょうか。

今回はxgboostでKaggleの自電車需要のデータを題材にして説明を行いました。

xgboostは非常に強力な機械学習のパッケージで、Kaggleのみならず業務データの分析においても非常に高い精度が出ることが多く、使えて損はないと思います。

私ももっと使いこなせるように頑張ろうと思います。

ありがとうございました。

ソースコードは以下においてあります。

GitHub

スポンサーリンク














コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください