From 9bbc701b76281ac6d8143cdc98b60668f27f879d Mon Sep 17 00:00:00 2001 From: channingss Date: Thu, 8 Aug 2019 10:15:59 +0800 Subject: [PATCH] complete GEMM op --- x2paddle/op_mapper/onnx_op_mapper.py | 34 ++++++++++++---------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/x2paddle/op_mapper/onnx_op_mapper.py b/x2paddle/op_mapper/onnx_op_mapper.py index 26ae32e..a298cca 100644 --- a/x2paddle/op_mapper/onnx_op_mapper.py +++ b/x2paddle/op_mapper/onnx_op_mapper.py @@ -464,6 +464,7 @@ class ONNXOpMapper(OpMapper): inputs=matmul_inputs, output=val_mm, param_attr=attr_matmul) + if beta != 0: if beta == 1.: add_inputs = {"x": val_mm, "y": val_c} @@ -473,7 +474,19 @@ class ONNXOpMapper(OpMapper): output=node, param_attr=attr) else: - pass + var_beta = node.layer_name + '_beta' + matmul_beta_inputs = {"x": val_c, "y": var_beta} + node.fluid_code.add_layer("Constant", + inputs=matmul_beta_inputs, + output=var_beta, + param_attr={'value': beta}) + + add_inputs = {"x": val_mm, "y": var_beta} + attr = {"name": string(node.layer_name)} + node.fluid_code.add_layer("elementwise_add", + inputs=add_inputs, + output=node, + param_attr=attr) def Add(self, node): val_x = self.graph.get_node(node.layer.input[0], copy=True) @@ -504,25 +517,6 @@ class ONNXOpMapper(OpMapper): output=node, param_attr=attr) - def LRN(self, node): - val_x = self.graph.get_node(node.layer.input[0], copy=True) - size = node.get_attr('size') # required - alpha = node.get_attr('alpha', 0.0001) # optional - beta = node.get_attr('beta', 0.75) # optional - bias = node.get_attr('bias', 1.0) # optional - - attr = { - "n": max(1, size), - "k": bias, - "alpha": alpha, - 'beta': beta, - "name": string(node.layer_name) - } - node.fluid_code.add_layer("lrn", - inputs=val_x, - output=node, - param_attr=attr) - def BatchNormalization(self, node): val_x = self.graph.get_node(node.layer.input[0], copy=True) val_scale = self.graph.get_node(node.layer.input[1], copy=True) -- GitLab