未验证 提交 799ee3ee 编写于 作者: J Jason 提交者: GitHub

Merge pull request #69 from jiangjiajun/develop

little modify
......@@ -38,9 +38,3 @@ def run_net(param_dir="./"):
param_dir,
fluid.default_main_program(),
predicate=if_exist)
fluid.io.save_inference_model(dirname='inference_model',
feeded_var_names=[i.name for i in inputs],
target_vars=outputs,
executor=exe,
params_filename="__params__")
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
......@@ -181,24 +182,6 @@ class TFGraph(Graph):
self.identity_map[node_name] = input_node.layer_name
# node = self.get_node(node_name)
# # Remind: Only 1 input for Identity node
# input_node = self.get_node(node.inputs[0])
#
# # remove identity node from graph
# self.identity_map[node_name] = input_node.layer_name
# idx = input_node.outputs.index(node_name)
# del input_node.outputs[idx]
#
# output_names = node.outputs
# for output_name in output_names:
# output_node = self.get_node(output_name)
# idx = output_node.inputs.index(node_name)
# output_node.inputs[idx] = input_node.layer_name
#
# idx = self.topo_sort.index(node_name)
# del self.topo_sort[idx]
if node_name in self.output_nodes:
idx = self.output_nodes.index(node_name)
self.output_nodes[idx] = input_node.layer_name
......@@ -227,10 +210,6 @@ class TFDecoder(object):
self.sess.graph.as_default()
tf.import_graph_def(graph_def, name='', input_map=input_map)
# for node in graph_def.node:
# print(node.name, node.op, node.input)
self.sess.run(tf.global_variables_initializer())
self.tf_graph = TFGraph(
......@@ -264,13 +243,17 @@ class TFDecoder(object):
if need_define_shape > 0:
if need_define_shape == 1:
print(
"\nUnknown shape for input tensor[tensor name: \"{}\"]".
format(layer.name))
print("无法获取到输入结点\"{}\"的shape".format(layer.name))
print("Unknown shape for input tensor[tensor name: \"{}\"]".
format(layer.name))
else:
print(
"输入结点\"{}\"的shape为{},但我们现仅支持batch维为不定长,所以需要你重新设定shape".
format(layer.name, shape))
print(
"\nShape[now is {}] for input tensor[tensor name: \"{}\"] not support yet"
.format(shape, layer.name))
print("需要你手动在下面输入对应这个输入结点的shape:)")
print(
"Use your keyboard type the shape of input tensor below :)")
......@@ -293,13 +276,11 @@ class TFDecoder(object):
layer.name))
input_map["{}:0".format(layer.name)] = x2paddle_input
shape[shape.index(None)] = -1
# self.input_example_data["x2paddle_{}".format(layer.name)] = numpy.random.random_sample(shape).astype(dtype)
self.input_info["x2paddle_{}".format(layer.name)] = (shape,
dtype)
else:
value = graph_node.layer.attr["shape"].shape
shape = [dim.size for dim in value.dim]
# self.input_example_data[graph_node.layer_name] = numpy.random.random_sample(shape).astype(dtype)
self.input_info[graph_node.layer_name] = (shape, dtype)
return input_map
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册