1 回答

TA貢獻1827條經(jīng)驗 獲得超8個贊
imshow
可用于繪制二維函數(shù)。x 和 y 方向首先使用 eg 在 1D 中創(chuàng)建np.linspace
,然后通過 合并到 2D np.meshgrid
。Numpy 的魔法允許編寫簡單的表達式,這些表達式在幕后立即對整個網(wǎng)格進行操作。
from matplotlib import pyplot as plt
import matplotlib
import numpy as np
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 4), gridspec_kw={'hspace': 0.05})
distance, velocity = np.meshgrid(np.linspace(0, 1, 50), np.linspace(0, 1, 50))
reward1 = 1 - distance ** 0.4
reward1[distance < 0.1] = -1
reward1[(distance < 0.1) & (velocity < 0.1)] = 1
im1 = ax1.imshow(reward1, origin='bottom', extent=[0, 1, 0, 1], vmin=-1, vmax=1, cmap='bwr',
interpolation='nearest')
ax1.set_xlabel('distance')
ax1.set_ylabel('velocity')
plt.colorbar(im1, ax=ax1, shrink=0.9)
dist_reward = 1 - distance ** 0.4
vel_discount = (1 - np.maximum(velocity, 0.1)) ** (1 / np.maximum(distance, 0.1))
reward2 = vel_discount * dist_reward
reward2[distance < 0.1] = -1
reward2[(distance < 0.1) & (velocity < 0.1)] = 1
im2 = ax2.imshow(reward2, origin='bottom', extent=[0, 1, 0, 1], vmin=-1, vmax=1, cmap='bwr',
interpolation='nearest')
ax2.set_xlabel('distance')
ax2.set_ylabel('velocity')
plt.colorbar(im2, ax=ax2, shrink=0.9)
plt.tight_layout()
plt.show()
添加回答
舉報