提交 9bbc701b 编写于 作者: C channingss

complete GEMM op

上级 fed47521
...@@ -464,6 +464,7 @@ class ONNXOpMapper(OpMapper): ...@@ -464,6 +464,7 @@ class ONNXOpMapper(OpMapper):
inputs=matmul_inputs, inputs=matmul_inputs,
output=val_mm, output=val_mm,
param_attr=attr_matmul) param_attr=attr_matmul)
if beta != 0: if beta != 0:
if beta == 1.: if beta == 1.:
add_inputs = {"x": val_mm, "y": val_c} add_inputs = {"x": val_mm, "y": val_c}
...@@ -473,7 +474,19 @@ class ONNXOpMapper(OpMapper): ...@@ -473,7 +474,19 @@ class ONNXOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
else: else:
pass var_beta = node.layer_name + '_beta'
matmul_beta_inputs = {"x": val_c, "y": var_beta}
node.fluid_code.add_layer("Constant",
inputs=matmul_beta_inputs,
output=var_beta,
param_attr={'value': beta})
add_inputs = {"x": val_mm, "y": var_beta}
attr = {"name": string(node.layer_name)}
node.fluid_code.add_layer("elementwise_add",
inputs=add_inputs,
output=node,
param_attr=attr)
def Add(self, node): def Add(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_node(node.layer.input[0], copy=True)
...@@ -504,25 +517,6 @@ class ONNXOpMapper(OpMapper): ...@@ -504,25 +517,6 @@ class ONNXOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
def LRN(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True)
size = node.get_attr('size') # required
alpha = node.get_attr('alpha', 0.0001) # optional
beta = node.get_attr('beta', 0.75) # optional
bias = node.get_attr('bias', 1.0) # optional
attr = {
"n": max(1, size),
"k": bias,
"alpha": alpha,
'beta': beta,
"name": string(node.layer_name)
}
node.fluid_code.add_layer("lrn",
inputs=val_x,
output=node,
param_attr=attr)
def BatchNormalization(self, node): def BatchNormalization(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_node(node.layer.input[0], copy=True)
val_scale = self.graph.get_node(node.layer.input[1], copy=True) val_scale = self.graph.get_node(node.layer.input[1], copy=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册