我正在 TensorFlow 中開(kāi)展一個(gè) VAE 項(xiàng)目,其中編碼器/解碼器網(wǎng)絡(luò)內(nèi)置于函數(shù)中。這個(gè)想法是能夠保存,然后加載訓(xùn)練好的模型并使用編碼器功能進(jìn)行采樣?;謴?fù)模型后,我無(wú)法運(yùn)行解碼器功能并將恢復(fù)的訓(xùn)練變量返回給我,出現(xiàn)“未初始化值”錯(cuò)誤。我認(rèn)為這是因?yàn)樵摵瘮?shù)要么創(chuàng)建一個(gè)新函數(shù),要么覆蓋現(xiàn)有函數(shù),要么以其他方式。但我無(wú)法弄清楚如何解決這個(gè)問(wèn)題。這是一些代碼:class VAE(object): def __init__(self, restore=True): self.session = tf.Session() if restore: self.restore_model() self.build_decoder = tf.make_template('decoder', self._build_decoder)@staticmethoddef _build_decoder(z, output_size=768, hidden_size=200, hidden_activation=tf.nn.elu, output_activation=tf.nn.sigmoid): x = tf.layers.dense(z, hidden_size, activation=hidden_activation) x = tf.layers.dense(x, hidden_size, activation=hidden_activation) logits = tf.layers.dense(x, output_size, activation=output_activation) return distributions.Independent(distributions.Bernoulli(logits), 2)def sample_decoder(self, n_samples): prior = self.build_prior(self.latent_dim) samples = self.build_decoder(prior.sample(n_samples), self.input_size).mean() return self.session.run([samples])def restore_model(self): print("Restoring") self.saver = tf.train.import_meta_graph(os.path.join(self.save_dir, "turbolearn.meta")) self.saver.restore(self.sess, tf.train.latest_checkpoint(self.save_dir)) self._restored = True想跑 samples = vae.sample_decoder(5)在我的訓(xùn)練程序中,我運(yùn)行: if self.checkpoint: self.saver.save(self.session, os.path.join(self.save_dir, "myvae"), write_meta_graph=True)更新根據(jù)下面的建議答案,我更改了恢復(fù)方法self.saver = tf.train.Saver()self.saver.restore(self.session, tf.train.latest_checkpoint(self.save_dir))但是現(xiàn)在在創(chuàng)建 Saver() 對(duì)象時(shí)出現(xiàn)值錯(cuò)誤:ValueError: No variables to save
在 TensorFlow 中保存和恢復(fù)函數(shù)
三國(guó)紛爭(zhēng)
2021-08-14 17:20:22