正则化-机器学习
-
正则化 regularization
虚线是真实函数g(x) 圆点是加入的噪声:
不应用正则化的实现:
import numpy as np import matplotlib.pyplot as plt#真正的函数 def g(x):return 0.1 * (x**3 + x**2 + x)#加入噪声 train_x = np.linspace(-2,2,8) train_y = g(train_x) + np.random.randn(train_x.size)*0.05#绘图确认 x = np.linspace(-2,2,100) plt.plot(train_x,train_y,'o') plt.plot(x,g(x),linestyle='dashed') plt.ylim(-1,2) plt.show()#标准化 mu = train_x.mean() sigma = train_x.std() def standardize(x):return (x-mu)/sigmatrain_z = standardize(train_x)#创建训练数据的矩阵 def to_matrix(x):return np.vstack([np.ones(x.size),x,x**2,x**3,x**4,x**5,x**6,x**7,x**8,x**9,x**10,]).TX = to_matrix(train_z)#参数初始化 theta = np.random.randn(X.shape[1])#预测函数 def f(x):return np.dot(x,theta)#不应用正则化的实现#目标函数 def E(x,y):return 0.5*np.sum((y-f(x))**2)#学习率 ETA = 1e-4#误差 diff = 1#重复学习 error = E(X,train_y) while diff>1e-6:theta = theta - ETA*np.dot(f(X)-train_y,X)current_error = E(X,train_y)diff = error - current_errorerror = current_error#对结果绘图 z = standardize(x) plt.plot(train_z,train_y,'o') plt.plot(z,f(to_matrix(z))) plt.show()
过拟合的曲线:
应用正则化的实现:
import numpy as np import matplotlib.pyplot as plt#真正的函数 def g(x):return 0.1 * (x**3 + x**2 + x)#加入噪声 train_x = np.linspace(-2,2,8) train_y = g(train_x) + np.random.randn(train_x.size)*0.05x = np.linspace(-2,2,100)#标准化 mu = train_x.mean() sigma = train_x.std() def standardize(x):return (x-mu)/sigmatrain_z = standardize(train_x)#创建训练数据的矩阵 def to_matrix(x):return np.vstack([np.ones(x.size),x,x**2,x**3,x**4,x**5,x**6,x**7,x**8,x**9,x**10,]).TX = to_matrix(train_z)#参数初始化 theta = np.random.randn(X.shape[1])#预测函数 def f(x):return np.dot(x,theta)#应用正则化的实现 #正则化常量 LAMBDA = 1#目标函数 def E(x,y):return 0.5*np.sum((y-f(x))**2)#学习率 ETA = 1e-4#误差 diff = 1#重复学习 error = E(X,train_y) while diff>1e-6:#正则化项 偏置项不适合用正则化 所以为0reg_term = LAMBDA*np.hstack([0,theta[1:]])#应用正则化项 更新参数theta = theta - ETA*(np.dot(f(X)-train_y,X) + reg_term)current_error = E(X,train_y)diff = error - current_errorerror = current_error#对结果绘图 z = standardize(x) plt.plot(train_z,train_y,'o') plt.plot(z,f(to_matrix(z))) plt.show()