2.多项式回归详解:从原理推导到实战预测

1 前言

在上一章节,我们学习了如何使用线性回归对数据进行预测,但发现一元线性回归并不能很好的反应某些数据。

因此本章将要继续学习更高级的内容,让我们拥有回归预测不那么线性的数据,也就是多项式回归。

所谓多项式,指的是由变量和系数常量通过有限次加减乘除以及自然幂次的乘方运算得到的表达式,是整式的一种。

当未知数只有一个时,就被成为一元多项式,也就是上一章节所介绍的内容,比如x2x+4x^2-x+4

而未知数不止一个的多项式就称为多元多项式,例如x33xy2zyz+1x^3-3xy^2z-yz+1就是一个三元多项式。

2 认识多项式

为了更加直观的演示,首先这里提供一组数据,然后将其画出来:

from matplotlib import pyplot as plt

x = [4, 7, 13, 26, 31, 41, 59, 62, 71, 78]
y = [21, 35, 49, 54, 43, 30, 31, 48, 66, 74]

plt.scatter(x, y)
plt.show()

效果如下:

image.png

很明显可以看出来,这组数据具有明显的波动性,如果用前面学过的内容、用直线去拟合这组数据肯定效果不好。

因此下面我们来尝试使用多项式来拟合数据,首先尝试用二次多项式。

一个标准的一元高阶多项式函数如下:

y(x,w)=w0+w1x+w2x2++wmxm=j=0mwjxjy(x, \mathbf{w}) = w_0 + w_1 x + w_2 x^2 + \cdots + w_m x^m = \sum_{j=0}^{m} w_j x^j

其中m代表多项式的阶数,xjx^j表示x的j次幂,w代表x的系数。

因此当我们尝试用上面的一元高阶多项式去拟合数据时,就需要确定多项式系数w与阶数m,这是多项式的两个基本要素。

由于阶数我们可以实现尝试确定、比如这里设定阶数为2,那么就只需要求解系数w即可:

y(x,w)=w0+w1x+w1x2y(x, \mathbf{w}) = w_0 + w_1 x + w_1 x^2

此时,我们又回到了前面的内容,依旧是自定义损失函数:

def func(p, x):
    # 根据公式,定义 2 次多项式函数
    w0, w1, w2 = p
    x = np.asarray(x)
    f = w0 + w1 * x + w2 * x**2
    return f


def err_func(p, x, y):
    # 残差函数(观测值与拟合值之间的差距)
    f=func(p,x)
    return f-y

然后调用求解最佳拟合参数的函数即可:

from scipy.optimize import least_squares
#...
res=least_squares(err_func,[0,0,0],args=(x,y))

这里使用了scipy库内部的least_squares函数,通过最小二乘法求解最佳拟合参数,而无需自己去手动实现。

如果不存在该库可能需要自己安装一下:

uv pip install scipy

除非你需要更多自定义行为,否则只需要上面三个参数即可:损失函数、初始系数值、训练数据。

这里为了简单,直接设置了三个0作为初始值,该值并不影响结果,但它会影响最终最终多项式的次数,所以这里涉及三个,代表求解2次多项式(0、1、2次共三个系数)。

下面是一个完整的实例程序:

import numpy as np
from matplotlib import pyplot as plt
from scipy.optimize import least_squares
x = [4, 7, 13, 26, 31, 41, 59, 62, 71, 78]
y = [21, 35, 49, 54, 43, 30, 31, 48, 66, 74]

def func(p, x):
    w0, w1, w2 = p
    x = np.asarray(x)
    f = w0 + w1 * x + w2 * x**2
作者:余识
全部文章:0
会员文章:0
总阅读量:0
c/c++pythonrustJavaScriptwindowslinux