提交 8c82b81f 编写于 作者: W wjj19950828

resolve conflict

......@@ -16,6 +16,7 @@ from six import text_type as _text_type
from x2paddle import program
import argparse
import sys
import logging
def arg_parser():
......@@ -137,12 +138,12 @@ def tf2paddle(model_path,
import tensorflow as tf
version = tf.__version__
if version >= '2.0.0' or version < '1.0.0':
print(
logging.info(
"[ERROR] 1.0.0<=tensorflow<2.0.0 is required, and v1.14.0 is recommended"
)
return
except:
print(
logging.info(
"[ERROR] Tensorflow is not installed, use \"pip install tensorflow\"."
)
return
......@@ -150,7 +151,7 @@ def tf2paddle(model_path,
from x2paddle.decoder.tf_decoder import TFDecoder
from x2paddle.op_mapper.tf2paddle.tf_op_mapper import TFOpMapper
print("Now translating model from tensorflow to paddle.")
logging.info("Now translating model from tensorflow to paddle.")
model = TFDecoder(model_path, define_input_shape=define_input_shape)
mapper = TFOpMapper(model)
mapper.paddle_graph.build()
......@@ -178,15 +179,15 @@ def caffe2paddle(proto_file,
or (int(ver_part[0]) > 3):
version_satisfy = True
assert version_satisfy, '[ERROR] google.protobuf >= 3.6.0 is required'
print("Now translating model from caffe to paddle.")
logging.info("Now translating model from caffe to paddle.")
model = CaffeDecoder(proto_file, weight_file, caffe_proto)
mapper = CaffeOpMapper(model)
mapper.paddle_graph.build()
print("Model optimizing ...")
logging.info("Model optimizing ...")
from x2paddle.optimizer.optimizer import GraphOptimizer
graph_opt = GraphOptimizer(source_frame="caffe")
graph_opt.optimize(mapper.paddle_graph)
print("Model optimized.")
logging.info("Model optimized.")
mapper.paddle_graph.gen_model(save_dir)
if convert_to_lite:
convert2lite(save_dir, lite_valid_places, lite_model_type)
......@@ -204,12 +205,13 @@ def onnx2paddle(model_path,
v0, v1, v2 = version.split('.')
version_sum = int(v0) * 100 + int(v1) * 10 + int(v2)
if version_sum < 160:
print("[ERROR] onnx>=1.6.0 is required")
logging.info("[ERROR] onnx>=1.6.0 is required")
return
except:
print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".")
logging.info(
"[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".")
return
print("Now translating model from onnx to paddle.")
logging.info("Now translating model from onnx to paddle.")
from x2paddle.decoder.onnx_decoder import ONNXDecoder
from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper
......@@ -233,17 +235,24 @@ def pytorch2paddle(module,
try:
import torch
version = torch.__version__
ver_part = version.split('.')
print(ver_part)
if int(ver_part[1]) < 5:
print("[ERROR] pytorch>=1.5.0 is required")
v0, v1, v2 = version.split('.')
# Avoid the situation where the version is equal to 1.7.0+cu101
if '+' in v2:
v2 = v2.split('+')[0]
version_sum = int(v0) * 100 + int(v1) * 10 + int(v2)
if version_sum < 150:
logging.info(
"[ERROR] pytorch>=1.5.0 is required, 1.6.0 is the most recommended"
)
return
if version_sum > 160:
logging.info("[WARNING] pytorch==1.6.0 is recommended")
except:
print(
"[ERROR] Pytorch is not installed, use \"pip install torch==1.5.0 torchvision\"."
logging.info(
"[ERROR] Pytorch is not installed, use \"pip install torch==1.6.0 torchvision\"."
)
return
print("Now translating model from pytorch to paddle.")
logging.info("Now translating model from pytorch to paddle.")
from x2paddle.decoder.pytorch_decoder import ScriptDecoder, TraceDecoder
from x2paddle.op_mapper.pytorch2paddle.pytorch_op_mapper import PyTorchOpMapper
......@@ -254,11 +263,11 @@ def pytorch2paddle(module,
model = ScriptDecoder(module, input_examples)
mapper = PyTorchOpMapper(model)
mapper.paddle_graph.build()
print("Model optimizing ...")
logging.info("Model optimizing ...")
from x2paddle.optimizer.optimizer import GraphOptimizer
graph_opt = GraphOptimizer(source_frame="pytorch", jit_type=jit_type)
graph_opt.optimize(mapper.paddle_graph)
print("Model optimized.")
logging.info("Model optimized.")
mapper.paddle_graph.gen_model(
save_dir, jit_type=jit_type, enable_code_optim=enable_code_optim)
if convert_to_lite:
......@@ -266,10 +275,12 @@ def pytorch2paddle(module,
def main():
logging.basicConfig(level=logging.INFO)
if len(sys.argv) < 2:
print("Use \"x2paddle -h\" to print the help information")
print("For more information, please follow our github repo below:)")
print("\nGithub: https://github.com/PaddlePaddle/X2Paddle.git\n")
logging.info("Use \"x2paddle -h\" to print the help information")
logging.info(
"For more information, please follow our github repo below:)")
logging.info("\nGithub: https://github.com/PaddlePaddle/X2Paddle.git\n")
return
parser = arg_parser()
......@@ -277,8 +288,8 @@ def main():
if args.version:
import x2paddle
print("x2paddle-{} with python>=3.5, paddlepaddle>=1.6.0\n".format(
x2paddle.__version__))
logging.info("x2paddle-{} with python>=3.5, paddlepaddle>=1.6.0\n".
format(x2paddle.__version__))
return
if not args.convert_torch_project:
......@@ -289,18 +300,19 @@ def main():
import platform
v0, v1, v2 = platform.python_version().split('.')
if not (int(v0) >= 3 and int(v1) >= 5):
print("[ERROR] python>=3.5 is required")
logging.info("[ERROR] python>=3.5 is required")
return
import paddle
v0, v1, v2 = paddle.__version__.split('.')
print("paddle.__version__ = {}".format(paddle.__version__))
logging.info("paddle.__version__ = {}".format(paddle.__version__))
if v0 == '0' and v1 == '0' and v2 == '0':
print("[WARNING] You are use develop version of paddlepaddle")
logging.info(
"[WARNING] You are use develop version of paddlepaddle")
elif int(v0) != 2 or int(v1) < 0:
print("[ERROR] paddlepaddle>=2.0.0 is required")
logging.info("[ERROR] paddlepaddle>=2.0.0 is required")
return
except:
print(
logging.info(
"[ERROR] paddlepaddle not installed, use \"pip install paddlepaddle\""
)
......@@ -341,7 +353,7 @@ def main():
lite_valid_places=args.lite_valid_places,
lite_model_type=args.lite_model_type)
elif args.framework == "paddle2onnx":
print(
logging.info(
"Paddle to ONNX tool has been migrated to the new github: https://github.com/PaddlePaddle/paddle2onnx"
)
......
......@@ -388,7 +388,7 @@ class PaddleGraph(object):
gen_codes(
[
"paddle.disable_static()",
"params = paddle.load('{}')".format(
"params = paddle.load(r'{}')".format(
osp.join(osp.abspath(code_dir), "model.pdparams")),
"model = {}()".format(self.name),
"model.set_dict(params, use_structured_name={})".format(
......
......@@ -16,7 +16,6 @@ from x2paddle.core.graph import GraphNode, Graph
from x2paddle.decoder.onnx_shape_inference import SymbolicShapeInference
from onnx.checker import ValidationError
from onnx.checker import check_model
from onnx.utils import polish_model
from onnx import helper, shape_inference
from onnx.helper import get_attribute_value, make_attribute
from onnx.shape_inference import infer_shapes
......
......@@ -518,6 +518,8 @@ class OpSet9():
if pads is not None:
is_pads_attr = True
mode = node.get_attr('mode', 'constant')
if mode in ["edge"]:
mode = "replicate"
value = node.get_attr('value', 0.)
data_shape = val_x.out_shapes[0]
output_shape = node.out_shapes[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册