2 回答
TA貢獻(xiàn)1805條經(jīng)驗(yàn) 獲得超10個(gè)贊
每個(gè)正在尋找我的問題的答案的人,請(qǐng)看下面。
注意:我想您已經(jīng)將模型保存在其中checkpoint_dir并希望以圖形模式獲取此模型,以便您可以將其另存為.pb文件。
model = ContextExtractor()
predictions = model(images, training=False)
checkpoint = tf.train.Checkpoint(model=model, global_step=tf.train.get_or_create_global_step())
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
status.assert_consumed()
with tf.Session() as sess:
status.initialize_or_restore(sess) # this is the main line for loading
# Actually, I don't know it is necessary to pass one batch for creating graph or not
img_batch = get_image(...)
ans = sess.run(predictions, feed_dict={images: img_batch})
frozen_graph = freeze_session(sess, output_names=[out.op.name for out in model.outputs])
# save your model
tf.train.write_graph(frozen_graph, "where/to/save", "tf_model.pb", as_text=False)
TA貢獻(xiàn)1810條經(jīng)驗(yàn) 獲得超4個(gè)贊
你應(yīng)該得到會(huì)話:
tf.keras.backend.get_session()
然后凍結(jié)模型,例如這里做的https://www.dlology.com/blog/how-to-convert-trained-keras-model-to-tensorflow-and-make-prediction/
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a pruned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
pruned so subgraphs that are not necessary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
# Graph -> GraphDef ProtoBuf
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
frozen_graph = freeze_session(K.get_session(),
output_names=[out.op.name for out in model.outputs])
然后將模型另存為.pb(也顯示在鏈接中):
tf.train.write_graph(frozen_graph, "model", "tf_model.pb", as_text=False)
如果這太麻煩,請(qǐng)嘗試將 keras 模型另存為.h5(HDF5 類型文件),然后按照提供的鏈接中的說明進(jìn)行操作。
從張量流文檔:
編寫兼容代碼 為 Eager Execution 編寫的相同代碼也將在圖執(zhí)行期間構(gòu)建圖。為此,只需在未啟用 Eager Execution 的新 Python 會(huì)話中運(yùn)行相同的代碼即可。
同樣來自同一頁面:
為了保存和加載模型,tf.train.Checkpoint 存儲(chǔ)對(duì)象的內(nèi)部狀態(tài),而不需要隱藏變量。要記錄模型、優(yōu)化器和全局步驟的狀態(tài),請(qǐng)將它們傳遞給 tf.train.Checkpoint:
checkpoint_dir = tempfile.mkdtemp()
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
root = tf.train.Checkpoint(optimizer=optimizer,
model=model,
optimizer_step=tf.train.get_or_create_global_step())
root.save(checkpoint_prefix)
root.restore(tf.train.latest_checkpoint(checkpoint_dir))
我向您推薦本頁的最后一部分:https : //www.tensorflow.org/guide/eager
希望這可以幫助。
添加回答
舉報(bào)
