7.scipy Interpolate


title: "X-Y轴分为15*15的网格" date: 2026-05-24T14:11:20Z

1import numpy as np
2import pylab as pl
3from scipy import interpolate
1import matplotlib as mpl
2mpl.rcParams['font.sans-serif'] = ['SimHei']

插值-interpolate

一维插值

WARNING

高次interp1d()插值的运算量很大,因此对于点数较多的数据,建议使用后面介绍的UnivariateSpline()

 1#%fig=`interp1d`的各阶插值
 2from scipy import interpolate
 3
 4x = np.linspace(0, 10, 11)
 5y = np.sin(x)
 6
 7xnew = np.linspace(0, 10, 101)
 8pl.plot(x, y, 'ro')
 9for kind in ['nearest', 'zero', 'slinear', 'quadratic']:
10    f = interpolate.interp1d(x, y, kind=kind)  #❶
11    ynew = f(xnew)  #❷
12    pl.plot(xnew, ynew, label=str(kind))
13
14pl.legend(loc='lower right')
<matplotlib.legend.Legend at 0x274ccb5f6a0>

png

外推和Spline拟合

 1#%fig=使用UnivariateSpline进行插值:外推(上),数据拟合(下)
 2x1 = np.linspace(0, 10, 20)
 3y1 = np.sin(x1)
 4sx1 = np.linspace(0, 12, 100)
 5sy1 = interpolate.UnivariateSpline(x1, y1, s=0)(sx1)  #❶
 6
 7x2 = np.linspace(0, 20, 200)
 8y2 = np.sin(x2) + np.random.standard_normal(len(x2)) * 0.2
 9sx2 = np.linspace(0, 20, 2000)
10spline2 = interpolate.UnivariateSpline(x2, y2, s=8)  #❷
11sy2 = spline2(sx2)
12
13pl.figure(figsize=(8, 5))
14pl.subplot(211)
15pl.plot(x1, y1, ".", label=u"数据点")
16pl.plot(sx1, sy1, label=u"spline曲线")
17pl.legend()
18
19pl.subplot(212)
20pl.plot(x2, y2, ".", label=u"数据点")
21pl.plot(sx2, sy2, linewidth=2, label=u"spline曲线")
22pl.plot(x2, np.sin(x2), label=u"无噪声曲线")
23pl.legend()
<matplotlib.legend.Legend at 0x274bb27f588>

png

1print(np.array_str(spline2.roots(), precision=3))
[ 0.053  3.151  6.36   9.386 12.603 15.619 18.929]
 1#%fig=计算Spline与水平线的交点
 2def roots_at(self, v):  #❶
 3    coeff = self.get_coeffs()
 4    coeff -= v
 5    try:
 6        root = self.roots()
 7        return root
 8    finally:
 9        coeff += v
10
11
12interpolate.UnivariateSpline.roots_at = roots_at  #❷
13
14pl.plot(sx2, sy2, linewidth=2, label=u"spline曲线")
15
16ax = pl.gca()
17for level in [0.5, 0.75, -0.5, -0.75]:
18    ax.axhline(level, ls=":", color="k")
19    xr = spline2.roots_at(level)  #❸
20    pl.plot(xr, spline2(xr), "ro")

png

参数插值

 1#%fig=使用参数插值连接二维平面上的点
 2x = [
 3    4.913, 4.913, 4.918, 4.938, 4.955, 4.949, 4.911, 4.848, 4.864, 4.893,
 4    4.935, 4.981, 5.01, 5.021
 5]
 6
 7y = [
 8    5.2785, 5.2875, 5.291, 5.289, 5.28, 5.26, 5.245, 5.245, 5.2615, 5.278,
 9    5.2775, 5.261, 5.245, 5.241
10]
11
12pl.plot(x, y, "o")
13
14for s in (0, 1e-4):
15    tck, t = interpolate.splprep([x, y], s=s)  #❶
16    xi, yi = interpolate.splev(np.linspace(t[0], t[-1], 200), tck)  #❷
17    pl.plot(xi, yi, lw=2, label=u"s=%g" % s)
18
19pl.legend()
<matplotlib.legend.Legend at 0x274ccd64780>

png

单调插值

 1import numpy as np
 2import matplotlib.pyplot as plt
 3from scipy import interpolate
 4
 5x = np.arange(0, 2 * np.pi + np.pi / 4, 2 * np.pi / 8)
 6y = np.sin(x)
 7tck = interpolate.splrep(x, y, s=0)
 8xnew = np.arange(0, 2 * np.pi, np.pi / 50)
 9ynew = interpolate.splev(xnew, tck, der=0)
10
11plt.figure()
12plt.plot(x, y, 'x', xnew, ynew, xnew, np.sin(xnew), x, y, 'b')
13plt.legend(['Linear', 'Cubic Spline', 'True'])
14plt.axis([-0.05, 6.33, -1.05, 1.05])
15plt.title('三次样条插值')
16plt.show()

png

多维插值

 1#%fig=使用interp2d类进行二维插值
 2def func(x, y):  #❶
 3    return (x + y) * np.exp(-5.0 * (x**2 + y**2))
 4
 5
 6# X-Y轴分为15*15的网格
 7y, x = np.mgrid[-1:1:15j, -1:1:15j]  #❷
 8fvals = func(x, y)  # 计算每个网格点上的函数值
 9
10# 二维插值
11newfunc = interpolate.interp2d(x, y, fvals, kind='cubic')  #❸
12
13# 计算100*100的网格上的插值
14xnew = np.linspace(-1, 1, 100)
15ynew = np.linspace(-1, 1, 100)
16fnew = newfunc(xnew, ynew)  #❹
17#%hide
18pl.subplot(121)
19pl.imshow(
20    fvals,
21    extent=[-1, 1, -1, 1],
22    cmap=pl.cm.jet,
23    interpolation='nearest',
24    origin="lower")
25pl.title("fvals")
26pl.subplot(122)
27pl.imshow(
28    fnew,
29    extent=[-1, 1, -1, 1],
30    cmap=pl.cm.jet,
31    interpolation='nearest',
32    origin="lower")
33pl.title("fnew")
34pl.show()

png

griddata

WARNING

griddata()使用欧几里得距离计算插值。如果K维空间中每个维度的取值范围相差较大,则应先将数据正规化,然后使用griddata()进行插值运算。

 1#%fig=使用gridata进行二维插值
 2# 计算随机N个点的坐标,以及这些点对应的函数值
 3N = 200
 4np.random.seed(42)
 5x = np.random.uniform(-1, 1, N)
 6y = np.random.uniform(-1, 1, N)
 7z = func(x, y)
 8
 9yg, xg = np.mgrid[-1:1:100j, -1:1:100j]
10xi = np.c_[xg.ravel(), yg.ravel()]
11
12methods = 'nearest', 'linear', 'cubic'
13
14zgs = [
15    interpolate.griddata((x, y), z, xi, method=method).reshape(100, 100)
16    for method in methods
17]
18#%hide
19fig, axes = pl.subplots(1, 3, figsize=(11.5, 3.5))
20
21for ax, method, zg in zip(axes, methods, zgs):
22    ax.imshow(
23        zg,
24        extent=[-1, 1, -1, 1],
25        cmap=pl.cm.jet,
26        interpolation='nearest',
27        origin="lower")
28    ax.set_xlabel(method)
29    ax.scatter(x, y, c=z)

png

径向基函数插值

 1#%fig=一维RBF插值
 2from scipy.interpolate import Rbf
 3
 4x1 = np.array([-1, 0, 2.0, 1.0])
 5y1 = np.array([1.0, 0.3, -0.5, 0.8])
 6
 7funcs = ['multiquadric', 'gaussian', 'linear']
 8nx = np.linspace(-3, 4, 100)
 9rbfs = [Rbf(x1, y1, function=fname) for fname in funcs]  #❶
10rbf_ys = [rbf(nx) for rbf in rbfs]  #❷
11#%hide
12pl.plot(x1, y1, "o")
13for fname, ny in zip(funcs, rbf_ys):
14    pl.plot(nx, ny, label=fname, lw=2)
15
16pl.ylim(-1.0, 1.5)
17pl.legend()
<matplotlib.legend.Legend at 0x274caacec88>

png

1for fname, rbf in zip(funcs, rbfs):
2    print (fname, rbf.nodes)
multiquadric [-0.88822885  2.17654513  1.42877511 -2.67919021]
gaussian [ 1.00321945 -0.02345964 -0.65441716  0.91375159]
linear [-0.26666667  0.6         0.73333333 -0.9       ]
 1#%fig=二维径向基函数插值
 2rbfs = [Rbf(x, y, z, function=fname) for fname in funcs]
 3rbf_zg = [rbf(xg, yg).reshape(xg.shape) for rbf in rbfs]
 4#%hide
 5fig, axes = pl.subplots(1, 3, figsize=(11.5, 3.5))
 6for ax, fname, zg in zip(axes, funcs, rbf_zg):
 7    ax.imshow(
 8        zg,
 9        extent=[-1, 1, -1, 1],
10        cmap=pl.cm.jet,
11        interpolation='nearest',
12        origin="lower")
13    ax.set_xlabel(fname)
14    ax.scatter(x, y, c=z)

png

 1#%fig=`epsilon`参数指定径向基函数中数据点的作用范围
 2epsilons = 0.1, 0.15, 0.3
 3rbfs = [Rbf(x, y, z, function="gaussian", epsilon=eps) for eps in epsilons]
 4zgs = [rbf(xg, yg).reshape(xg.shape) for rbf in rbfs]
 5#%hide
 6fig, axes = pl.subplots(1, 3, figsize=(11.5, 3.5))
 7for ax, eps, zg in zip(axes, epsilons, zgs):
 8    ax.imshow(
 9        zg,
10        extent=[-1, 1, -1, 1],
11        cmap=pl.cm.jet,
12        interpolation='nearest',
13        origin="lower")
14    ax.set_xlabel("eps=%g" % eps)
15    ax.scatter(x, y, c=z)

png