1 回答

TA貢獻(xiàn)1799條經(jīng)驗(yàn) 獲得超8個(gè)贊
使用該input_map參數(shù),我成功地將一個(gè)僅解碼 jpg 圖像的新圖形映射到我的原始圖形的輸入(此處:)node.name='image_tensor:0'。只需確保重命名name_scope解碼器圖的 的(此處:)decoder。之后,您可以使用 tensorflow SavedModelBuilder 保存新的連接圖。
這是一個(gè)物體檢測(cè)網(wǎng)絡(luò)的例子:
import tensorflow as tf
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
# The export path contains the name and the version of the model
model = 'path/to/model.pb'
export_path = './output/dir/'
sigs = {}
with tf.gfile.FastGFile(model, 'rb') as f:
with tf.name_scope('decoder'):
image_str_tensor = tf.placeholder(tf.string, shape=[None], name= 'encoded_image_string_tensor')
# The CloudML Prediction API always "feeds" the Tensorflow graph with
# dynamic batch sizes e.g. (?,). decode_jpeg only processes scalar
# strings because it cannot guarantee a batch of images would have
# the same output size. We use tf.map_fn to give decode_jpeg a scalar
# string from dynamic batches.
def decode_and_resize(image_str_tensor):
"""Decodes jpeg string, resizes it and returns a uint8 tensor."""
image = tf.image.decode_jpeg(image_str_tensor, channels=3)
# do additional image manipulation here (like resize etc...)
image = tf.cast(image, dtype=tf.uint8)
return image
image = tf.map_fn(decode_and_resize, image_str_tensor, back_prop=False, dtype=tf.uint8)
with tf.name_scope('net'):
# load .pb file
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# concatenate decoder graph and original graph
tf.import_graph_def(graph_def, name="", input_map={'image_tensor:0':image})
g = tf.get_default_graph()
with tf.Session() as sess:
# load graph into session and save to new .pb file
# define model input
inp = g.get_tensor_by_name('decoder/encoded_image_string_tensor:0')
# define model outputs
num_detections = g.get_tensor_by_name('num_detections:0')
detection_scores = g.get_tensor_by_name('detection_scores:0')
detection_boxes = g.get_tensor_by_name('detection_boxes:0')
out = {'num_detections': num_detections, 'detection_scores': detection_scores, 'detection_boxes': detection_boxes}
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
tensor_info_inputs = {
'inputs': tf.saved_model.utils.build_tensor_info(inp)}
tensor_info_outputs = {}
for k, v in out.items():
tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(v)
# assign detection signature for tensorflow serving
detection_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs=tensor_info_inputs,
outputs=tensor_info_outputs,
method_name=signature_constants.PREDICT_METHOD_NAME))
# "build" graph
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'detection_signature':
detection_signature,
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
detection_signature,
},
main_op=tf.tables_initializer()
)
# save graph
builder.save()
另外:如果您難以找到正確的輸入和輸出節(jié)點(diǎn),您可以運(yùn)行它來顯示圖形:
graph_op = g.get_operations()
for i in graph_op:
print(i.node_def)
添加回答
舉報(bào)