提交 9ab7285a 编写于 作者: C Channingss

[Paddle2ONNX] add swish op

上级 05bb5a0c
...@@ -75,12 +75,6 @@ def arg_parser(): ...@@ -75,12 +75,6 @@ def arg_parser():
action="store_true", action="store_true",
default=False, default=False,
help="define input shape for tf model") help="define input shape for tf model")
parser.add_argument(
"--onnx_opset",
"-oo",
type=int,
default=10,
help="when paddle2onnx set onnx opset version to export")
parser.add_argument( parser.add_argument(
"--params_merge", "--params_merge",
"-pm", "-pm",
...@@ -192,12 +186,12 @@ def onnx2paddle(model_path, save_dir, params_merge=False): ...@@ -192,12 +186,12 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
print("Paddle model and code generated.") print("Paddle model and code generated.")
def paddle2onnx(model_path, save_dir, opset): def paddle2onnx(model_path, save_dir):
from x2paddle.decoder.paddle_decoder import PaddleDecoder from x2paddle.decoder.paddle_decoder import PaddleDecoder
from x2paddle.op_mapper.paddle_op_mapper import PaddleOpMapper from x2paddle.op_mapper.paddle_op_mapper import PaddleOpMapper
model = PaddleDecoder(model_path, '__model__', '__params__') model = PaddleDecoder(model_path, '__model__', '__params__')
mapper = PaddleOpMapper() mapper = PaddleOpMapper()
mapper.convert(model.program, save_dir, opset) mapper.convert(model.program, save_dir)
def main(): def main():
...@@ -264,7 +258,7 @@ def main(): ...@@ -264,7 +258,7 @@ def main():
elif args.framework == "paddle2onnx": elif args.framework == "paddle2onnx":
assert args.model is not None, "--model should be defined while translating paddle model to onnx" assert args.model is not None, "--model should be defined while translating paddle model to onnx"
paddle2onnx(args.model, args.save_dir, args.onnx_opset) paddle2onnx(args.model, args.save_dir)
else: else:
raise Exception( raise Exception(
......
...@@ -37,7 +37,7 @@ class PaddleOpMapper(object): ...@@ -37,7 +37,7 @@ class PaddleOpMapper(object):
self.name_counter = dict() self.name_counter = dict()
def convert(self, program, save_dir, opset=10): def convert(self, program, save_dir):
weight_nodes = self.convert_weights(program) weight_nodes = self.convert_weights(program)
op_nodes = list() op_nodes = list()
input_nodes = list() input_nodes = list()
...@@ -80,9 +80,7 @@ class PaddleOpMapper(object): ...@@ -80,9 +80,7 @@ class PaddleOpMapper(object):
initializer=[], initializer=[],
inputs=input_nodes, inputs=input_nodes,
outputs=output_nodes) outputs=output_nodes)
opset_imports = [helper.make_opsetid("", opset)] model = helper.make_model(graph, producer_name='X2Paddle')
model = helper.make_model(
graph, producer_name='X2Paddle', opset_imports=opset_imports)
onnx.checker.check_model(model) onnx.checker.check_model(model)
if not os.path.isdir(save_dir): if not os.path.isdir(save_dir):
...@@ -184,6 +182,41 @@ class PaddleOpMapper(object): ...@@ -184,6 +182,41 @@ class PaddleOpMapper(object):
alpha=op.attr('alpha')) alpha=op.attr('alpha'))
return node return node
def swish(self, op, block):
"""
The activation swish, y = x / (1 + exp(-beta * x))
"""
beta = op.attr('beta')
beta_name = self.get_name(op.type, 'beta')
beta_node = onnx.helper.make_node(
'Constant',
name=beta_name,
inputs=[],
outputs=[beta_name],
value=onnx.helper.make_tensor(
name=beta_name,
data_type=onnx.TensorProto.FLOAT,
dims=(),
vals=[beta]))
beta_x_name = self.get_name(op.type, 'beta_x')
beta_x_node = onnx.helper.make_node(
'Mul',
name=beta_x_name,
inputs=[op.input('X')[0], beta_name],
outputs=[beta_x_name])
sigmoid_name = self.get_name(op.type, 'sigmoid')
sigmoid_node = onnx.helper.make_node(
'Sigmoid',
name=sigmoid_name,
inputs=[beta_x_name],
outputs=[sigmoid_name])
swish_node = onnx.helper.make_node(
'Mul',
inputs=[op.input('X')[0], sigmoid_name],
outputs=op.output('Out'))
return [beta_node, beta_x_node, sigmoid_node, swish_node]
def elementwise_add(self, op, block): def elementwise_add(self, op, block):
axis = op.attr('axis') axis = op.attr('axis')
x_shape = block.var(op.input('X')[0]).shape x_shape = block.var(op.input('X')[0]).shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册