我在csv文件中的数据:
cucumber,green,15,4 tomato,red,7,7 carrots,Orange,13,3 onion,White,8,8 potatoes,Gray,8,6 apple,Red,7,6 apple,Yellow,6,5 coconut,Brown,25,20 orange,Orange,7,7 banana,Yellow,16,4 lemon,Yellow,5,4 watermelon,Green,30,25 cherries,Black,2,2
我想预测一个Friut!
import csv from sklearn import tree x = [] y = [] lst = [] with open('F5-ML-TEST.csv', 'r') as csvfile: data = csv.reader(csvfile) for line in data: lst.append(line[1]) lst.append(line[2]) lst.append(line[3]) x.append(lst) y.append(line[0]) lst = [] print('x ----- >', x) print('y ----- >', y) clf = tree.DecisionTreeClassifier() clf = clf.fit(x, y) new_data = [["red", 7, 7], ["yellow", 5, 6]] answer = clf.predict(new_data) print('answer[0]====== >', answer[0]) print('answer[1]====== >', answer[1])
因此,您需要做的是将字符串数据编码为数字特征。我在这里复制您的输入:
import pandas as pd
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
from sklearn.tree import DecisionTreeClassifier
df = pd.read_clipboard(header=None, sep=',')
0 1 2 3
0 cucumber green 15 4
1 tomato red 7 7
2 carrots Orange 13 3
3 onion White 8 8
4 potatoes Gray 8 6
5 apple Red 7 6
6 apple Yellow 6 5
您需要对“颜色”列进行编码:
ohe = OneHotEncoder(sparse=False)
colors = ohe.fit_transform(df.iloc[:, 1].values.reshape(-1, 1))
现在看起来是这样的,每种颜色都是一列:
array([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], ...
然后,您需要将其与其他已经是数字的列连接起来:
inputs = np.concatenate([df.iloc[:, 2:].values, colors], axis=1)
现在,你需要把你的目标(水果)变成数字:
oe = OrdinalEncoder()
targets = oe.fit_transform(df.iloc[:, 0].values.reshape(-1, 1))
现在,它们看起来是这样的:
array([[ 5.],
[10.],
[ 2.],
[ 7.],
[ 9.],
[ 0.], ...
然后,您可以匹配您的决策树:
clf = DecisionTreeClassifier()
clf = clf.fit(inputs, targets)
现在你甚至可以预测新的数据:
new_data = [["red", 7, 7], ["Yellow", 5, 6]]
new_data = np.concatenate([[i[1:] for i in new_data],
ohe.transform([[i[0]] for i in new_data])], axis=1)
answer = clf.predict(new_data)
oe.categories_[0][answer.astype(int)]
Out[88]: array(['tomato', 'apple'], dtype=object)