- 基于T-R-train.csv数据建立线性回归模型,计算模型的r2分数,可视化模型预测结果
- 加入多项式(2次、5次),建立回归模型
- 计算多项式回归模型对预测数据进行预测的r2分数,判断哪个模型预测更准确
import pandas as pd
import numpy as np
data_train = pd.read_csv('T-R-train.csv')

X_train = data_train.loc[:,'T']
y_train = data_train.loc[:,'rate']%matplotlib inline
from matplotlib import pyplot as plt
fig1 = plt.figure(figsize=(5,5))
plt.scatter(X_train,y_train)
plt.title('raw data')
plt.xlabel('temperature')
plt.ylabel('rate')
plt.show()

from sklearn.linear_model import LinearRegression
lr1 = LinearRegression()
X_train = np.array(X_train).reshape(-1,1)
lr1.fit(X_train,y_train)
data_test = pd.read_csv('T-R-test.csv')
X_test = data_test.loc[:,'T']
y_test = data_test.loc[:,'rate']
y_train_predict = lr1.predict(X_train)
y_test_predict = lr1.predict(X_test)X_test = np.array(X_test).reshape(-1,1) from sklearn.metrics import r2_score
r2_train = r2_score(y_train,y_train_predict)
r2_test = r2_score(y_test,y_test_predict)
print('training r2',r2_train)
print('test r2',r2_test)

X_range = np.linspace(40,90,300).reshape(-1,1)
y_range_predict = lr1.predict(X_range)fig2 = plt.figure(figsize=(5,5))
plt.plot(X_range,y_range_predict)
plt.scatter(X_train,y_train)
plt.title('prediction data')
plt.xlabel('temperature')
plt.ylabel('rate')
plt.show()

多项式模型
from sklearn.preprocessing import PolynomialFeatures
poly2 = PolynomialFeatures(degree=2) X_2_train = poly2.fit_transform(X_train)
X_2_test = poly2.transform(X_test)
print(X_2_train)

lr2 = LinearRegression()
lr2.fig(X_2_train,y_train)y_2_tain_predict = lr2.predict(X_2_train)
y_2_test_predict = lr2.predict(X_2_test)
r2_2_train = r2_score(y_train,y_2_train_predict)
r2_2_test = r2_score(y_test,y_2_test_predict)
print('training r2_2',r2_2_tain)
print('test r2_2',r2_2_test)

X_2_range = np.linspace(40,90,300).reshape(-1,1)
X_2_test = poly2.transform(X_2_range)
y_2_range_predict = lr2.predict(X_2_range)fig3 = plt.figure(figsize=(5,5))
plt.plot(X_range,y_2_range_predict)
plt.scatter(X_train,y_train)
plt.scatter(X_test,y_test)plt.title('prediction data')
plt.xlabel('temperature')
plt.ylabel('rate')
plt.show()
