未验证 提交 3a818e82 编写于 作者: J Jason 提交者: GitHub

Merge pull request #367 from Channingss/scatter_nd

matmul support [m,n,k] * [1,k,z]
...@@ -646,7 +646,7 @@ class OpSet9(): ...@@ -646,7 +646,7 @@ class OpSet9():
node.fluid_code.add_layer( node.fluid_code.add_layer(
'squeeze', 'squeeze',
inputs={'input': node, inputs={'input': node,
'axes': [0]}, 'axes': [axis]},
output=node, output=node,
param_attr=None) param_attr=None)
elif axis == 0 and len(indices_shape) > 1: elif axis == 0 and len(indices_shape) > 1:
...@@ -894,8 +894,6 @@ class OpSet9(): ...@@ -894,8 +894,6 @@ class OpSet9():
'this is not supported') 'this is not supported')
if len(value) == 1: if len(value) == 1:
value = value[0] value = value[0]
if dtype.name == 'int64':
dtype = 'int32'
attr = { attr = {
'shape': val_shape.layer_name, 'shape': val_shape.layer_name,
'dtype': string(dtype), 'dtype': string(dtype),
...@@ -1040,12 +1038,16 @@ class OpSet9(): ...@@ -1040,12 +1038,16 @@ class OpSet9():
@print_mapping_info @print_mapping_info
def Concat(self, node): def Concat(self, node):
inputs = [] inputs = []
dtypes = set()
for i in range(len(node.layer.input)): for i in range(len(node.layer.input)):
ipt = self.graph.get_input_node(node, idx=i, copy=True) ipt = self.graph.get_input_node(node, idx=i, copy=True)
if isinstance(ipt, str): if isinstance(ipt, str):
inputs.append(ipt) inputs.append(ipt)
else: else:
inputs.append(ipt.layer_name) inputs.append(ipt.layer_name)
dtypes.add(ipt.dtype)
if len(dtypes) > 1:
assert 'Unspported situation happened, please create issue on https://github.com/PaddlePaddle/X2Paddle/issues.'
axis = node.get_attr('axis') axis = node.get_attr('axis')
attr = {'axis': axis} attr = {'axis': axis}
node.fluid_code.add_layer( node.fluid_code.add_layer(
...@@ -1133,10 +1135,22 @@ class OpSet9(): ...@@ -1133,10 +1135,22 @@ class OpSet9():
def MatMul(self, node): def MatMul(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True)
x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0]
inputs = {"x": val_x, "y": val_y} inputs = {"x": val_x, "y": val_y}
attr = {"name": string(node.layer_name)} if y_shape[0] == 1 and x_shape[-1] != 1:
node.fluid_code.add_layer( y_squeeze = val_y.layer_name + '_squeeze'
"matmul", inputs=inputs, output=node, param_attr=attr) node.fluid_code.add_layer(
"squeeze",
inputs=val_y,
output=y_squeeze,
param_attr={'axes': [0]})
inputs['y'] = y_squeeze
node.fluid_code.add_layer(
"matmul", inputs=inputs, output=node, param_attr=None)
else:
node.fluid_code.add_layer(
"matmul", inputs=inputs, output=node, param_attr=None)
@print_mapping_info @print_mapping_info
def BatchNormalization(self, node): def BatchNormalization(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册