未验证 提交 41cc9f14 编写于 作者: J Jason 提交者: GitHub

Merge pull request #182 from jiangjiajun/develop-1.6

Develop 1.6
...@@ -98,29 +98,12 @@ def tf2paddle(model_path, ...@@ -98,29 +98,12 @@ def tf2paddle(model_path,
print("Now translating model from tensorflow to paddle.") print("Now translating model from tensorflow to paddle.")
model = TFDecoder(model_path, define_input_shape=define_input_shape) model = TFDecoder(model_path, define_input_shape=define_input_shape)
if not without_data_format_optimization:
mapper = TFOpMapper(model)
optimizer = TFOptimizer(mapper)
# neccesary optimization
optimizer.delete_redundance_code()
# optimizer below is experimental
optimizer.optimize_elementwise_op()
optimizer.merge_activation()
optimizer.merge_bias()
optimizer.optimize_sub_graph()
# optimizer.merge_batch_norm()
# optimizer.merge_prelu()
else:
mapper = TFOpMapperNHWC(model) mapper = TFOpMapperNHWC(model)
optimizer = TFOptimizer(mapper) optimizer = TFOptimizer(mapper)
optimizer.delete_redundance_code() optimizer.delete_redundance_code()
optimizer.strip_graph() optimizer.strip_graph()
optimizer.merge_activation() # optimizer.merge_activation()
optimizer.merge_bias() # optimizer.merge_bias()
optimizer.make_nchw_input_output()
optimizer.remove_transpose()
mapper.save_inference_model(save_dir) mapper.save_inference_model(save_dir)
...@@ -182,21 +165,29 @@ def main(): ...@@ -182,21 +165,29 @@ def main():
if args.version: if args.version:
import x2paddle import x2paddle
print("x2paddle-{} with python>=3.5, paddlepaddle>=1.5.0\n".format( print("x2paddle-{} with python>=3.5, paddlepaddle>=1.6.1\n".format(
x2paddle.__version__)) x2paddle.__version__))
return return
assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)"
assert args.save_dir is not None, "--save_dir is not defined"
try: try:
import paddle import paddle
v0, v1, v2 = paddle.__version__.split('.') v0, v1, v2 = paddle.__version__.split('.')
if int(v0) != 1 or int(v1) < 5: if int(v0) == 0 and int(v1) == 0 and int(v2) == 0:
print("paddlepaddle>=1.5.0 is required") print(
"You have installed paddlepaddle-dev? We're not sure it's working for x2paddle!"
)
print(
"==================paddlepaddle>=1.6.1 is strongly recommended================="
)
elif int(v0) != 1 or int(v1) < 6:
print("paddlepaddle>=1.6.1 is required")
return return
except: except:
print("paddlepaddle not installed, use \"pip install paddlepaddle\"") print("paddlepaddle not installed, use \"pip install paddlepaddle\"")
return
assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)"
assert args.save_dir is not None, "--save_dir is not defined"
if args.framework == "tensorflow": if args.framework == "tensorflow":
assert args.model is not None, "--model should be defined while translating tensorflow model" assert args.model is not None, "--model should be defined while translating tensorflow model"
......
...@@ -80,6 +80,11 @@ class Layer(object): ...@@ -80,6 +80,11 @@ class Layer(object):
param_attr = collections.OrderedDict(self.param_attr) param_attr = collections.OrderedDict(self.param_attr)
for key, value in param_attr.items(): for key, value in param_attr.items():
if isinstance(value, GraphNode):
value_name = value.layer_name
if hasattr(value, "index"):
value_name += "[{}]".format(value.index)
value = value_name
if '\n' in str(value): if '\n' in str(value):
value = string(str(value).replace('\n', ',')) value = string(str(value).replace('\n', ','))
layer_code = layer_code + key + "={}, ".format(value) layer_code = layer_code + key + "={}, ".format(value)
......
...@@ -389,27 +389,11 @@ class TFDecoder(object): ...@@ -389,27 +389,11 @@ class TFDecoder(object):
compare01 = (results[0] == results[1]) compare01 = (results[0] == results[1])
compare12 = (results[1] == results[2]) compare12 = (results[1] == results[2])
if compare01.all() and compare12.all(): compare = compare01 & compare12
index = numpy.argwhere(compare==False).flatten()
results[0][index] = -1
return results[0].tolist() return results[0].tolist()
if (compare01 == compare12).all():
index = numpy.argwhere(compare01 == False).flatten()
if index.shape[0] != 1:
raise Exception("There's not only one unstable dimension")
results[0][index[0]] = -1
index = numpy.argwhere(results[0] < 0).flatten()
if index.shape[0] > 2:
print("Warning: More than two dimension less than zero")
if index.shape[0] == 2 and out_shape is not None:
if out_shape[index[1]] > 0:
results[0][index[1]] = out_shape[index[1]]
else:
results[0][index[0]] = out_shape[index[0]]
return results[0].tolist()
else:
raise Exception("Couldn't infer a stable shape shape tensor value")
def infer_tensor_shape(self, graph_node): def infer_tensor_shape(self, graph_node):
if hasattr(graph_node, "index"): if hasattr(graph_node, "index"):
tensor_name = graph_node.layer.name + ":{}".format(graph_node.index) tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
...@@ -436,11 +420,7 @@ class TFDecoder(object): ...@@ -436,11 +420,7 @@ class TFDecoder(object):
if compare01.all() and compare12.all(): if compare01.all() and compare12.all():
return shape[0].tolist() return shape[0].tolist()
if (compare01 == compare12).all(): compare = compare01 & compare12
index = numpy.argwhere(compare01 == False).flatten() index = numpy.argwhere(compare==False).flatten()
if index.shape[0] != 1: shapes[0][index] = -1
raise Exception("There's not only one unstable dimension")
if index[0] != 0:
raise Exception("Batch size not in the first dimension")
shapes[0][0] = -1
return shapes[0].tolist() return shapes[0].tolist()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册