我正在嘗試在TensorFlow 2.0中用GP實現WGAN。要計算梯度損失,您需要計算與輸入圖像相關的預測的梯度。現在,為了使它更易于處理,它不是計算相對于所有輸入圖像的預測梯度,而是沿著原始和假數據點的線計算插值數據點,并將其用作輸入。為了實現這一點,我首先開發(fā)了一個函數,它將進行一些預測并返回相對于某些輸入圖像的梯度。首先,我想過這樣做,但它在急切模式下不起作用。因此,我現在正試圖使用.compute_gradientstf.keras.backend.gradientsGradientTape以下是我用來測試內容的代碼:from tensorflow.keras import backend as Kfrom tensorflow.keras.layers import *from tensorflow.keras.models import *import tensorflow as tfimport numpy as np# Comes from Generative Deep Learning by David Fosterclass RandomWeightedAverage(tf.keras.layers.Layer): def __init__(self, batch_size): super().__init__() self.batch_size = batch_size """Provides a (random) weighted average between real and generated image samples""" def call(self, inputs): alpha = K.random_uniform((self.batch_size, 1, 1, 1)) return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])# Dummy criticdef make_critic(): critic = Sequential() inputShape = (28, 28, 1) critic.add(Conv2D(32, (5, 5), padding="same", strides=(2, 2), input_shape=inputShape)) critic.add(LeakyReLU(alpha=0.2)) critic.add(Conv2D(64, (5, 5), padding="same", strides=(2, 2))) critic.add(LeakyReLU(alpha=0.2)) critic.add(Flatten()) critic.add(Dense(512)) critic.add(LeakyReLU(alpha=0.2)) critic.add(Dropout(0.3)) critic.add(Dense(1)) return critic# Gather dataset((X_train, _), (X_test, _)) = tf.keras.datasets.fashion_mnist.load_data()X_train = X_train.reshape(-1, 28, 28, 1)X_test = X_test.reshape(-1, 28, 28, 1)# Note that I am using test images as fake images for testing purposesinterpolated_img = RandomWeightedAverage(32)([X_train[0:32].astype("float"), X_test[32:64].astype("float")])# Compute gradients of the predictions with respect to the interpolated imagescritic = make_critic()with tf.GradientTape() as tape: y_pred = critic(interpolated_img)gradients = tape.gradient(y_pred, interpolated_img)漸變即將成為 。我在這里錯過了什么嗎?None
1 回答

開滿天機
TA貢獻1786條經驗 獲得超13個贊
相對于某些張量的預測梯度...我在這里錯過了什么嗎?
是的。您需要一個 :tape.watch(interpolated_img)
with tf.GradientTape() as tape:
tape.watch(interpolated_img)
y_pred = critic(interpolated_img)
GradientTape需要存儲正向傳遞的中間值來計算梯度。通常,您需要漸變 WRT 變量。因此,它不會保留從張量開始的計算痕跡,可能是為了節(jié)省內存。
如果你想要一個漸變WRT一個張量,你需要明確地告訴.tape
添加回答
舉報
0/150
提交
取消