From 259c603e8882bad7730f9c8e35718148a1f69418 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Tue, 22 Oct 2019 13:09:45 +0000 Subject: [PATCH] fix code --- x2paddle/convert.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 678e66c..0dcda08 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -98,12 +98,12 @@ def tf2paddle(model_path, print("Now translating model from tensorflow to paddle.") model = TFDecoder(model_path, define_input_shape=define_input_shape) - mapper = TFOpMapperNHWC(model) - optimizer = TFOptimizer(mapper) - optimizer.delete_redundance_code() - optimizer.strip_graph() -# optimizer.merge_activation() -# optimizer.merge_bias() + mapper = TFOpMapperNHWC(model) + optimizer = TFOptimizer(mapper) + optimizer.delete_redundance_code() + optimizer.strip_graph() + # optimizer.merge_activation() + # optimizer.merge_bias() mapper.save_inference_model(save_dir) @@ -169,22 +169,26 @@ def main(): x2paddle.__version__)) return - try: - import paddle - v0, v1, v2 = paddle.__version__.split('.') - if v0 == 0 and v1 == 0 and v2 == 0: - print("You have installed paddlepaddle-dev? We're not sure it's working for x2paddle!" - elif int(v0) != 1 or int(v1) < 6: - print("paddlepaddle>=1.6.1 is required") - return - except: + try: + import paddle + v0, v1, v2 = paddle.__version__.split('.') + if int(v0) == 0 and int(v1) == 0 and int(v2) == 0: + print( + "You have installed paddlepaddle-dev? We're not sure it's working for x2paddle!" + ) + print( + "==================paddlepaddle>=1.6.1 is strongly recommended=================" + ) + elif int(v0) != 1 or int(v1) < 6: + print("paddlepaddle>=1.6.1 is required") + return + except: print("paddlepaddle not installed, use \"pip install paddlepaddle\"") - return + return assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)" assert args.save_dir is not None, "--save_dir is not defined" - if args.framework == "tensorflow": assert args.model is not None, "--model should be defined while translating tensorflow model" without_data_format_optimization = False -- GitLab