提交 259c603e 编写于 作者: J jiangjiajun

fix code

上级 79911873
...@@ -98,12 +98,12 @@ def tf2paddle(model_path, ...@@ -98,12 +98,12 @@ def tf2paddle(model_path,
print("Now translating model from tensorflow to paddle.") print("Now translating model from tensorflow to paddle.")
model = TFDecoder(model_path, define_input_shape=define_input_shape) model = TFDecoder(model_path, define_input_shape=define_input_shape)
mapper = TFOpMapperNHWC(model) mapper = TFOpMapperNHWC(model)
optimizer = TFOptimizer(mapper) optimizer = TFOptimizer(mapper)
optimizer.delete_redundance_code() optimizer.delete_redundance_code()
optimizer.strip_graph() optimizer.strip_graph()
# optimizer.merge_activation() # optimizer.merge_activation()
# optimizer.merge_bias() # optimizer.merge_bias()
mapper.save_inference_model(save_dir) mapper.save_inference_model(save_dir)
...@@ -169,22 +169,26 @@ def main(): ...@@ -169,22 +169,26 @@ def main():
x2paddle.__version__)) x2paddle.__version__))
return return
try: try:
import paddle import paddle
v0, v1, v2 = paddle.__version__.split('.') v0, v1, v2 = paddle.__version__.split('.')
if v0 == 0 and v1 == 0 and v2 == 0: if int(v0) == 0 and int(v1) == 0 and int(v2) == 0:
print("You have installed paddlepaddle-dev? We're not sure it's working for x2paddle!" print(
elif int(v0) != 1 or int(v1) < 6: "You have installed paddlepaddle-dev? We're not sure it's working for x2paddle!"
print("paddlepaddle>=1.6.1 is required") )
return print(
except: "==================paddlepaddle>=1.6.1 is strongly recommended================="
)
elif int(v0) != 1 or int(v1) < 6:
print("paddlepaddle>=1.6.1 is required")
return
except:
print("paddlepaddle not installed, use \"pip install paddlepaddle\"") print("paddlepaddle not installed, use \"pip install paddlepaddle\"")
return return
assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)" assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)"
assert args.save_dir is not None, "--save_dir is not defined" assert args.save_dir is not None, "--save_dir is not defined"
if args.framework == "tensorflow": if args.framework == "tensorflow":
assert args.model is not None, "--model should be defined while translating tensorflow model" assert args.model is not None, "--model should be defined while translating tensorflow model"
without_data_format_optimization = False without_data_format_optimization = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册