提交 3d549eb8 编写于 作者: J jiangjiajun

adapt for dev paddle

上级 0497123e
......@@ -98,29 +98,12 @@ def tf2paddle(model_path,
print("Now translating model from tensorflow to paddle.")
model = TFDecoder(model_path, define_input_shape=define_input_shape)
if not without_data_format_optimization:
mapper = TFOpMapper(model)
optimizer = TFOptimizer(mapper)
# neccesary optimization
optimizer.delete_redundance_code()
# optimizer below is experimental
optimizer.optimize_elementwise_op()
optimizer.merge_activation()
optimizer.merge_bias()
optimizer.optimize_sub_graph()
# optimizer.merge_batch_norm()
# optimizer.merge_prelu()
else:
mapper = TFOpMapperNHWC(model)
optimizer = TFOptimizer(mapper)
optimizer.delete_redundance_code()
optimizer.strip_graph()
optimizer.merge_activation()
optimizer.merge_bias()
optimizer.make_nchw_input_output()
optimizer.remove_transpose()
# optimizer.merge_activation()
# optimizer.merge_bias()
mapper.save_inference_model(save_dir)
......@@ -189,14 +172,6 @@ def main():
assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)"
assert args.save_dir is not None, "--save_dir is not defined"
try:
import paddle
v0, v1, v2 = paddle.__version__.split('.')
if int(v0) != 1 or int(v1) < 5:
print("paddlepaddle>=1.5.0 is required")
return
except:
print("paddlepaddle not installed, use \"pip install paddlepaddle\"")
if args.framework == "tensorflow":
assert args.model is not None, "--model should be defined while translating tensorflow model"
......
......@@ -80,6 +80,11 @@ class Layer(object):
param_attr = collections.OrderedDict(self.param_attr)
for key, value in param_attr.items():
if isinstance(value, GraphNode):
value_name = value.layer_name
if hasattr(value, "index"):
value_name += "[{}]".format(value.index)
value = value_name
if '\n' in str(value):
value = string(str(value).replace('\n', ','))
layer_code = layer_code + key + "={}, ".format(value)
......
......@@ -389,26 +389,10 @@ class TFDecoder(object):
compare01 = (results[0] == results[1])
compare12 = (results[1] == results[2])
if compare01.all() and compare12.all():
return results[0].tolist()
if (compare01 == compare12).all():
index = numpy.argwhere(compare01 == False).flatten()
if index.shape[0] != 1:
raise Exception("There's not only one unstable dimension")
results[0][index[0]] = -1
index = numpy.argwhere(results[0] < 0).flatten()
if index.shape[0] > 2:
print("Warning: More than two dimension less than zero")
if index.shape[0] == 2 and out_shape is not None:
if out_shape[index[1]] > 0:
results[0][index[1]] = out_shape[index[1]]
else:
results[0][index[0]] = out_shape[index[0]]
return results[0].tolist()
else:
raise Exception("Couldn't infer a stable shape shape tensor value")
compare = compare01 & compare12
index = numpy.argwhere(compare==False).flatten()
results[0][index] = -1
return results[0].tolist()
def infer_tensor_shape(self, graph_node):
if hasattr(graph_node, "index"):
......@@ -436,11 +420,7 @@ class TFDecoder(object):
if compare01.all() and compare12.all():
return shape[0].tolist()
if (compare01 == compare12).all():
index = numpy.argwhere(compare01 == False).flatten()
if index.shape[0] != 1:
raise Exception("There's not only one unstable dimension")
if index[0] != 0:
raise Exception("Batch size not in the first dimension")
shapes[0][0] = -1
return shapes[0].tolist()
compare = compare01 & compare12
index = numpy.argwhere(compare==False).flatten()
shapes[0][index] = -1
return shapes[0].tolist()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册