提交 0b130dd7 编写于 作者: C Channingss

add op:swish,floor,uniform_random_batch_size_like

上级 bec2dcb7
...@@ -405,6 +405,23 @@ class OpSet9(object): ...@@ -405,6 +405,23 @@ class OpSet9(object):
'Sum', inputs=op.input('X'), outputs=op.output('Out')) 'Sum', inputs=op.input('X'), outputs=op.output('Out'))
return node return node
def floor(self, op, block):
node = helper.make_node(
'Floor', inputs=op.input('X'), outputs=op.output('Out'))
return node
def uniform_random_batch_size_like(self, op, block):
node = helper.make_node(
'RandomUniformLike',
inputs=op.input('Input'),
outputs=op.output('Out'),
#shape=op.attr('shape'),
high=op.attr('max'),
dtype=self.paddle_onnx_dtype_map[op.attr('dtype')],
low=op.attr('min'),
seed=float(op.attr('seed')), )
return node
def depthwise_conv2d(self, op, block): def depthwise_conv2d(self, op, block):
return self.conv2d(op, block) return self.conv2d(op, block)
...@@ -784,6 +801,38 @@ class OpSet9(object): ...@@ -784,6 +801,38 @@ class OpSet9(object):
beta=offset) beta=offset)
return node return node
def swish(self, op, block):
beta = op.attr('beta')
beta_name = self.get_name(op.type, 'beta')
beta_node = onnx.helper.make_node(
'Constant',
name=beta_name,
inputs=[],
outputs=[beta_name],
value=onnx.helper.make_tensor(
name=beta_name,
data_type=onnx.TensorProto.FLOAT,
dims=(),
vals=[beta]))
beta_x_name = self.get_name(op.type, 'beta_x')
beta_x_node = onnx.helper.make_node(
'Mul',
name=beta_x_name,
inputs=[op.input('X')[0], beta_name],
outputs=[beta_x_name])
sigmoid_name = self.get_name(op.type, 'sigmoid')
sigmoid_node = onnx.helper.make_node(
'Sigmoid',
name=sigmoid_name,
inputs=[beta_x_name],
outputs=[sigmoid_name])
swish_node = onnx.helper.make_node(
'Mul',
inputs=[op.input('X')[0], sigmoid_name],
outputs=op.output('Out'))
return [beta_node, beta_x_node, sigmoid_node, swish_node]
def hard_swish(self, op, block): def hard_swish(self, op, block):
scale_name = self.get_name(op.type, 'scale') scale_name = self.get_name(op.type, 'scale')
offset_name = self.get_name(op.type, 'offset') offset_name = self.get_name(op.type, 'offset')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册