提交 08770385 编写于 作者: C Channingss

matmul support [m,n,k] * [1,k,z]

上级 29b8e2c7
......@@ -643,7 +643,7 @@ class OpSet9():
node.fluid_code.add_layer(
'squeeze',
inputs={'input': node,
'axes': [0]},
'axes': [axis]},
output=node,
param_attr=None)
elif axis == 0 and len(indices_shape) > 1:
......@@ -1132,10 +1132,22 @@ class OpSet9():
def MatMul(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0]
inputs = {"x": val_x, "y": val_y}
attr = {"name": string(node.layer_name)}
node.fluid_code.add_layer(
"matmul", inputs=inputs, output=node, param_attr=attr)
if y_shape[0] == 1 and x_shape[-1] != 1:
y_squeeze = val_y.layer_name + '_squeeze'
node.fluid_code.add_layer(
"squeeze",
inputs=val_y,
output=y_squeeze,
param_attr={'axes': [0]})
inputs['y'] = y_squeeze
node.fluid_code.add_layer(
"matmul", inputs=inputs, output=node, param_attr=None)
else:
node.fluid_code.add_layer(
"matmul", inputs=inputs, output=node, param_attr=None)
@print_mapping_info
def BatchNormalization(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册