博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
sklearn 决策树
阅读量:6480 次
发布时间:2019-06-23

本文共 5272 字,大约阅读时间需要 17 分钟。

import pandas as pdimport numpy as np复制代码
data = pd.read_csv('./tt/train.csv')复制代码
data.columns复制代码
Index(['PassengerId', 'Survived', 'Pclass', 'Name', 'Sex', 'Age', 'SibSp',       'Parch', 'Ticket', 'Fare', 'Cabin', 'Embarked'],      dtype='object')复制代码
data = data[['Survived', 'Pclass','Sex', 'Age', 'SibSp',       'Parch', 'Fare', 'Embarked']]复制代码
data['Age'] = data['Age'].fillna(data['Age'].mean())复制代码
data.fillna(0, inplace=True)复制代码
data['Sex'] =[1 if x=='male' else 0 for x in data.Sex]复制代码
data['p1'] = np.array(data['Pclass'] == 1).astype(np.int32)data['p2'] = np.array(data['Pclass'] == 2).astype(np.int32)data['p3'] = np.array(data['Pclass'] == 3).astype(np.int32)复制代码
del data['Pclass']复制代码
data.Embarked.unique()复制代码
array(['S', 'C', 'Q', 0], dtype=object)复制代码
data['e1'] = np.array(data['Embarked'] == 'S').astype(np.int32)data['e2'] = np.array(data['Embarked'] == 'C').astype(np.int32)data['e3'] = np.array(data['Embarked'] == 'Q').astype(np.int32)复制代码
del data['Embarked']复制代码
data.values.dtype复制代码
dtype('float64')复制代码
data.columns复制代码
Index(['Survived', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'p1', 'p2', 'p3',       'e1', 'e2', 'e3'],      dtype='object')复制代码
data_train = data[['Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'p1', 'p2', 'p3',       'e1', 'e2', 'e3']].values复制代码
data_target = data['Survived'].values.reshape(len(data),1)复制代码
np.shape(data_train),np.shape(data_target)复制代码
((891, 11), (891, 1))复制代码
from sklearn.model_selection import train_test_split复制代码
x_train, x_test, y_train, y_test = train_test_split(data_train, data_target, test_size = 0.2)复制代码
x_train.shape, x_test.shape复制代码
((712, 11), (179, 11))复制代码
from sklearn.tree import DecisionTreeClassifier复制代码
model = DecisionTreeClassifier()复制代码
model.fit(x_train, y_train)复制代码
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,            max_features=None, max_leaf_nodes=None,            min_impurity_decrease=0.0, min_impurity_split=None,            min_samples_leaf=1, min_samples_split=2,            min_weight_fraction_leaf=0.0, presort=False, random_state=None,            splitter='best')复制代码
model.score(x_test, y_test)复制代码
0.7932960893854749复制代码
model.score(x_train, y_train)复制代码
0.9845505617977528复制代码
def m_score(depth):    model = DecisionTreeClassifier(max_depth=depth)    model.fit(x_train, y_train)    train_score = model.score(x_train, y_train)    test_score = model.score(x_test, y_test)    return train_score, test_score    复制代码
depths = range(2, 15)复制代码
scores = [m_score(depth) for depth in depths]复制代码
scores复制代码
[(0.7921348314606742, 0.7932960893854749), (0.8258426966292135, 0.8547486033519553), (0.8412921348314607, 0.8491620111731844), (0.8469101123595506, 0.8435754189944135), (0.8595505617977528, 0.8435754189944135), (0.8623595505617978, 0.8435754189944135), (0.8735955056179775, 0.8435754189944135), (0.8876404494382022, 0.8435754189944135), (0.9087078651685393, 0.8603351955307262), (0.9199438202247191, 0.8491620111731844), (0.9311797752808989, 0.8268156424581006), (0.9382022471910112, 0.8156424581005587), (0.9452247191011236, 0.8379888268156425)]复制代码
train_s = [s[0] for s in scores]test_s = [s[1] for s in scores]复制代码
import matplotlib.pyplot as plt复制代码
plt.plot(train_s)plt.plot(test_s)复制代码
[
]复制代码

def m_score(value):    model = DecisionTreeClassifier(min_impurity_split=value)    model.fit(x_train, y_train)    train_score = model.score(x_train, y_train)    test_scroe = model.score(x_test, y_test)    return train_score, test_scroe复制代码
values = np.linspace(0, 0.5, 50)复制代码
scores = [m_score(value) for value in values]复制代码
train_s = [s[0] for s in scores]test_s = [s[1] for s in scores]复制代码
best_index = np.argmax(test_s)复制代码
best_score = test_s[best_index]best_value = values[best_index]复制代码
best_score, best_value复制代码
(0.8659217877094972, 0.19387755102040816)复制代码
plt.plot(train_s)plt.plot(test_s)复制代码
[
]复制代码

from sklearn.model_selection import GridSearchCV复制代码
values = np.linspace(0, 0.5, 50)depths = range(2, 15)复制代码
param_grid = {
'max_depth':depths, 'min_impurity_split':values}复制代码
model = GridSearchCV(DecisionTreeClassifier(), param_grid, cv = 5)复制代码
model.fit(data_train, data_target)复制代码
GridSearchCV(cv=5, error_score='raise',       estimator=DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,            max_features=None, max_leaf_nodes=None,            min_impurity_decrease=0.0, min_impurity_split=None,            min_samples_leaf=1, min_samples_split=2,            min_weight_fraction_leaf=0.0, presort=False, random_state=None,            splitter='best'),       fit_params=None, iid=True, n_jobs=1,       param_grid={'max_depth': range(2, 15), 'min_impurity_split': array([0.     , 0.0102 , 0.02041, 0.03061, 0.04082, 0.05102, 0.06122,       0.07143, 0.08163, 0.09184, 0.10204, 0.11224, 0.12245, 0.13265,       0.14286, 0.15306, 0.16327, 0.17347, 0.18367, 0.19388, 0.20408,       0.21429, 0.22449, 0.23...16, 0.41837,       0.42857, 0.43878, 0.44898, 0.45918, 0.46939, 0.47959, 0.4898 ,       0.5    ])},       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',       scoring=None, verbose=0)复制代码
model.best_params_复制代码
{'max_depth': 9, 'min_impurity_split': 0.21428571428571427}复制代码
model.best_score_复制代码
0.8316498316498316复制代码

转载地址:http://tjzuo.baihongyu.com/

你可能感兴趣的文章
收集的QCon 北京(Beijing) 2010 PPT 及总结
查看>>
Qt 让QLabel自适应text的大小,并且自动换行(转)
查看>>
PostgreSQL学习手册(十六) SQL语言函数
查看>>
网络编程——第一篇 基础之进程线程
查看>>
9.png 技巧
查看>>
hdu 4715(打表)
查看>>
java J2EE学习入门
查看>>
Linux系统信息查看命令大全
查看>>
为什么项目的jar包会和tomcat的jar包冲突?
查看>>
这些.NET开源项目你知道吗?.NET平台开源文档与报表处理组件集合(三)
查看>>
linux ps top 命令 VSZ,RSS,TTY,STAT, VIRT,RES,SHR,DATA的含义【转】
查看>>
程序员接私活记(转)
查看>>
eclipse如何修改dynamic web module version
查看>>
Have You Ever Wondered About the Difference Between NOT NULL and DEFAULT?
查看>>
自己定义ImageView,实现点击之后算出点击的是身体的哪个部位
查看>>
颜色模式中8位,16位,24位,32位色彩是什么意思?会有什么区别?计算机颜色格式( 8位 16位 24位 32位色)【转】...
查看>>
js 类属性的命名 尽量避免关键字
查看>>
【新产品发布】迷你 STLINK / V2,可支持ST公司全系列的 STM8 / STM32
查看>>
strcpy()、memcpy()、memmove()、memset()的内部实现
查看>>
Project Euler 001-006 解法总结
查看>>