1 回答

TA貢獻1852條經驗 獲得超7個贊
對我來說,你的代碼看起來完全正確!至少算法是正確的。我已經更改了您的代碼以用于numpy
快速計算而不是純Python。另外,我還配置了一些參數,例如改變了動量和學習率,也實現了MSE
。
然后我用來matplotlib
畫情節(jié)動畫。最后,在動畫上,看起來您的回歸實際上試圖將曲線擬合到數據。盡管在最后一次擬合迭代中它sin(x)
看起來像線性近似,但仍然盡可能接近二次曲線的數據點。但對于for?in來說,它看起來像是理想的近似(它從迭代周圍開始擬合)。x
[0; 2 * pi]
sin(x)
x
[0; pi]
12-th
i-th
動畫幀只是用 進行回歸dErr = 0.7 ** (i + 15)
。
我的動畫運行腳本有點慢,但是如果您save
像這樣添加參數python script.py save
,它將渲染/保存以line.gif
繪制繪圖動畫。如果您在沒有參數的情況下運行腳本,它將在您的 PC 屏幕上實時繪制/擬合動畫。
完整的代碼在圖形之后,代碼需要通過運行一次安裝一些Python模塊python -m pip install numpy matplotlib
。
接下來是sin(x)
在x
:(0, pi)
接下來是sin(x)
在x
:(0, 2 * pi)
接下來是abs(x)
在x
:(-1, 1)
# Needs: python -m pip install numpy matplotlib
import math, sys
import numpy as np, matplotlib.pyplot as plt, matplotlib.animation as animation
from matplotlib.animation import FuncAnimation
x_range = (0., math.pi, 0.1) # (xmin, xmax, xstep)
y_range = (-0.2, 1.2) # (ymin, ymax)
num_iterations = 50
def f(x):
? ? return np.sin(x)
def derr(iteration):
? ? return 0.7 ** (iteration + 15)
? ??
? ??
def MSE(a, b):
? ? return (np.abs(np.array(a) - np.array(b)) ** 2).mean()
def quadraticRegression(*, x, data, dErr):
? ? x, data = np.array(x), np.array(data)
? ? assert x.size == data.size, (x.size, data.size)
? ? a = 1 #Starting values
? ? b = 1
? ? c = 1
? ? a_momentum = 0.1 #Momentum to counter steady state error
? ? b_momentum = 0.1
? ? c_momentum = 0.1
? ? estimate = a*x**2 + b*x + c #Estimate curve
? ? error = MSE(data, estimate) #Get errors 'n stuff
? ? errorOld = 0.
? ? lr = 10. ** -4 #learning rate
? ? while abs(error - errorOld) > dErr:
? ? ? ? #Fit a (dE/da)
? ? ? ? deda = np.sum(2*x**2 * (a*x**2 + b*x + c - data))/len(data)
? ? ? ? correction = deda*lr
? ? ? ? a_momentum = (a_momentum)*0.99 + correction*0.1 #0.99 is to slow down momentum when correction speed changes
? ? ? ? a = a - correction - a_momentum
? ??
? ? ? ? #fit b (dE/db)
? ? ? ? dedb = np.sum(2*x*(a*x**2 + b*x + c - data))/len(data)
? ? ? ? correction = dedb*lr
? ? ? ? b_momentum = (b_momentum)*0.99 + correction*0.1
? ? ? ? b = b - correction - b_momentum
? ? ? ? #fit c (dE/dc)
? ? ? ? dedc = np.sum(2*(a*x**2 + b*x + c - data))/len(data)
? ? ? ? correction = dedc*lr
? ? ? ? c_momentum = (c_momentum)*0.99 + correction*0.1
? ? ? ? c = c - correction - c_momentum
? ? ? ? #Update model and find errors
? ? ? ? estimate = a*x**2 +b*x + c
? ? ? ? errorOld = error
? ? ? ? #print(error)
? ? ? ? error = MSE(data, estimate)
? ? return a, b, c, error
? ? ? ??
fig, ax = plt.subplots()
fig.set_tight_layout(True)
x = np.arange(x_range[0], x_range[1], x_range[2])
#ax.scatter(x, x + np.random.normal(0, 3.0, len(x)))
line0, line1 = None, None
do_save = len(sys.argv) > 1 and sys.argv[1] == 'save'
def g(x, derr):
? ? a, b, c, error = quadraticRegression(x = x, data = f(x), dErr = derr)
? ? return a * x ** 2 + b * x + c
? ??
def dummy(x):
? ? return np.ones_like(x, dtype = np.float64) * 100.
def update(i):
? ? global line0, line1
? ??
? ? de = derr(i)
? ??
? ? if line0 is None:
? ? ? ? assert line1 is None
? ? ? ? line0, = ax.plot(x, f(x), 'r-', linewidth=2)
? ? ? ? line1, = ax.plot(x, g(x, de), 'r-', linewidth=2, color = 'blue')
? ? ? ? ax.set_ylim(y_range[0], y_range[1])
? ? ? ??
? ? if do_save:
? ? ? ? sys.stdout.write(str(i) + ' ')
? ? ? ? sys.stdout.flush()
? ? label = 'iter {0} derr {1}'.format(i, round(de, math.ceil(-math.log(de) / math.log(10)) + 2))
? ? line1.set_ydata(g(x, de))
? ? ax.set_xlabel(label)
? ? return line1, ax
if __name__ == '__main__':
? ? anim = FuncAnimation(fig, update, frames = np.arange(0, num_iterations), interval = 200)
? ? if do_save:
? ? ? ? anim.save('line.gif', dpi = 200, writer = 'imagemagick')
? ? else:
? ? ? ? plt.show()
添加回答
舉報