使用高斯朴素贝叶斯算法对鸢尾花数据集进行分类
高斯朴素贝叶斯算法通常用于特征变量是连续变量,符合高素分布的情况。
使用高斯朴素贝叶斯算法对鸢尾花数据集进行分类
"""
使用高斯贝叶斯堆鸢尾花进行分类
"""
#导入需要的库
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score
#导入数据
x,y = load_iris().data,load_iris().target
#划分数据集
x_train,x_test,y_train,y_test = train_test_split(x,y,random_state=1, test_size=50)
#定义和训练模型
model = GaussianNB()
model.fit(x_train,y_train)
#模型评估
pred = model.predict(x_test)
print("测试集数据的预测标签为",pred)
print("测试集数据的真实标签为",y_test)
print("测试集共有%d条数据,其中预测错误的数据有%d条,预测准确率为%.2f"%(x_test.shape[0],(pred!=y_test).sum(),
accuracy_score(y_test,pred)))
输出的结果为:
测试集数据的预测标签为 [0 1 1 0 2 2 2 0 0 2 1 0 2 1 1 0 1 1 0 0 1 1 2 0 2 1 0 0 1 2 1 2 1 2 2 0 1
0 1 2 2 0 1 2 1 2 0 0 0 1]
测试集数据的真实标签为 [0 1 1 0 2 1 2 0 0 2 1 0 2 1 1 0 1 1 0 0 1 1 1 0 2 1 0 0 1 2 1 2 1 2 2 0 1
0 1 2 2 0 2 2 1 2 0 0 0 1]
测试集共有50条数据,其中预测错误的数据有3条,预测准确率为0.94