未验证 提交 9b8cd312 编写于 作者: J Jason 提交者: GitHub

Merge pull request #328 from Channingss/scope

support custom Scope
...@@ -195,9 +195,14 @@ def onnx2paddle(model_path, save_dir, params_merge=False): ...@@ -195,9 +195,14 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
def paddle2onnx(model_path, save_dir, opset_version=10): def paddle2onnx(model_path, save_dir, opset_version=10):
from x2paddle.decoder.paddle_decoder import PaddleDecoder from x2paddle.decoder.paddle_decoder import PaddleDecoder
from x2paddle.op_mapper.paddle2onnx.paddle_op_mapper import PaddleOpMapper from x2paddle.op_mapper.paddle2onnx.paddle_op_mapper import PaddleOpMapper
import paddle.fluid as fluid
model = PaddleDecoder(model_path, '__model__', '__params__') model = PaddleDecoder(model_path, '__model__', '__params__')
mapper = PaddleOpMapper() mapper = PaddleOpMapper()
mapper.convert(model.program, save_dir, opset_number=opset_version) mapper.convert(
model.program,
save_dir,
scope=fluid.global_scope(),
opset_version=opset_version)
def main(): def main():
...@@ -264,7 +269,7 @@ def main(): ...@@ -264,7 +269,7 @@ def main():
elif args.framework == "paddle2onnx": elif args.framework == "paddle2onnx":
assert args.model is not None, "--model should be defined while translating paddle model to onnx" assert args.model is not None, "--model should be defined while translating paddle model to onnx"
paddle2onnx(args.model, args.save_dir, args.onnx_opset) paddle2onnx(args.model, args.save_dir, opset_version=args.onnx_opset)
else: else:
raise Exception( raise Exception(
......
...@@ -59,7 +59,7 @@ class OpSet9(object): ...@@ -59,7 +59,7 @@ class OpSet9(object):
'Constant', inputs=[], outputs=[name], value=tensor) 'Constant', inputs=[], outputs=[name], value=tensor)
return node return node
def convert_weights(self, program): def convert_weights(self, program, scope=None):
var_names = program.global_block().vars var_names = program.global_block().vars
nodes = list() nodes = list()
for name in var_names: for name in var_names:
...@@ -68,7 +68,7 @@ class OpSet9(object): ...@@ -68,7 +68,7 @@ class OpSet9(object):
continue continue
if not var.persistable: if not var.persistable:
continue continue
weight = np.array(fluid.global_scope().find_var(name).get_tensor()) weight = np.array(scope.find_var(name).get_tensor())
tensor = helper.make_tensor( tensor = helper.make_tensor(
name=name, name=name,
dims=var.shape, dims=var.shape,
......
...@@ -33,9 +33,9 @@ class PaddleOpMapper(object): ...@@ -33,9 +33,9 @@ class PaddleOpMapper(object):
self.name_counter = dict() self.name_counter = dict()
self.op_set = None self.op_set = None
def convert(self, program, save_dir, opset_number=10): def convert(self, program, save_dir, scope=None, opset_version=10):
self.op_set = self.create_opset(opset_number) self.op_set = self.create_opset(opset_version)
weight_nodes = self.op_set.convert_weights(program) weight_nodes = self.op_set.convert_weights(program, scope=scope)
op_nodes = list() op_nodes = list()
input_nodes = list() input_nodes = list()
output_nodes = list() output_nodes = list()
...@@ -77,7 +77,7 @@ class PaddleOpMapper(object): ...@@ -77,7 +77,7 @@ class PaddleOpMapper(object):
initializer=[], initializer=[],
inputs=input_nodes, inputs=input_nodes,
outputs=output_nodes) outputs=output_nodes)
opset_imports = [helper.make_opsetid("", opset_number)] opset_imports = [helper.make_opsetid("", opset_version)]
model = helper.make_model( model = helper.make_model(
graph, producer_name='X2Paddle', opset_imports=opset_imports) graph, producer_name='X2Paddle', opset_imports=opset_imports)
onnx.checker.check_model(model) onnx.checker.check_model(model)
...@@ -89,20 +89,20 @@ class PaddleOpMapper(object): ...@@ -89,20 +89,20 @@ class PaddleOpMapper(object):
print("\nTranslated model saved in {}".format( print("\nTranslated model saved in {}".format(
os.path.join(save_dir, 'x2paddle_model.onnx'))) os.path.join(save_dir, 'x2paddle_model.onnx')))
def create_opset(self, opset_number): def create_opset(self, opset_version=10):
run_opset = self.default_opset run_opset = self.default_opset
opset = '' opset = ''
if opset_number in self.support_opsets: if opset_version in self.support_opsets:
run_opset = opset_number run_opset = opset_version
else: else:
for support_opset_number in self.support_opsets: for support_opset_version in self.support_opsets:
if support_opset_number < opset_number: if support_opset_version < opset_version:
run_opset = support_opset_number run_opset = support_opset_version
else: else:
break break
print( print(
'Now, onnx2paddle support convert onnx model opset_verison {},' 'Now, onnx2paddle support convert onnx model opset_verison {},'
'opset_verison of your onnx model is {}, automatically treated as op_set: {}.' 'opset_verison of your onnx model is {}, automatically treated as op_set: {}.'
.format(self.support_opsets, opset_number, run_opset)) .format(self.support_opsets, opset_version, run_opset))
opset = 'OpSet' + str(run_opset) opset = 'OpSet' + str(run_opset)
return eval(opset)() return eval(opset)()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册