提交 9fe56c11 编写于 作者: C Channingss

add scatter_nd

上级 cc9d332e
...@@ -350,6 +350,7 @@ class ONNXGraph(Graph): ...@@ -350,6 +350,7 @@ class ONNXGraph(Graph):
node.out_shapes.append(value_info['shape']) node.out_shapes.append(value_info['shape'])
else: else:
node.out_shapes.append([]) node.out_shapes.append([])
print(layer.name, node.out_shapes)
class ONNXDecoder(object): class ONNXDecoder(object):
......
...@@ -53,14 +53,21 @@ class ONNXOpMapper(OpMapper): ...@@ -53,14 +53,21 @@ class ONNXOpMapper(OpMapper):
def op_checker(self): def op_checker(self):
unsupported_ops = set() unsupported_ops = set()
contain_ops = set()
for node_name in self.graph.topo_sort: for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
op = node.layer_type op = node.layer_type
contain_ops.add(op)
if not hasattr(self.opset, op) and \ if not hasattr(self.opset, op) and \
op not in self.opset.default_op_mapping and \ op not in self.opset.default_op_mapping and \
op not in custom_layers and \ op not in custom_layers and \
op not in self.opset.elementwise_ops: op not in self.opset.elementwise_ops:
unsupported_ops.add(op) unsupported_ops.add(op)
print("There are {} ops need converted , list as below".format(
len(contain_ops)))
for op in contain_ops:
print(op)
if len(unsupported_ops) == 0: if len(unsupported_ops) == 0:
return True return True
else: else:
......
...@@ -597,12 +597,35 @@ class OpSet9(): ...@@ -597,12 +597,35 @@ class OpSet9():
#assert len( #assert len(
# indices_shape) <= 2, "Gather op don't support dim of indice >2 " # indices_shape) <= 2, "Gather op don't support dim of indice >2 "
if axis == 0 and len(indices_shape) <= 1: if axis == 0 and len(indices_shape) <= 1:
node.fluid_code.add_layer( if len(val_x.out_shapes[0]) <= 1:
'gather', node.fluid_code.add_layer(
inputs={'input': val_x, 'gather',
'index': indices}, inputs={'input': val_x,
output=node, 'index': indices},
param_attr=None) output=node,
param_attr=None)
elif len(val_x.out_shapes[0]) > 1:
if len(indices_shape) == 0:
gather_ = node.layer_name + '_1'
node.fluid_code.add_layer(
'gather',
inputs={'input': val_x,
'index': indices},
output=gather_,
param_attr=None)
node.fluid_code.add_layer(
'squeeze',
inputs={'input': gather_,
'axes': [0]},
output=node,
param_attr=None)
else:
node.fluid_code.add_layer(
'gather',
inputs={'input': val_x,
'index': indices},
output=node,
param_attr=None)
elif axis > 0 and len(indices_shape) <= 1: elif axis > 0 and len(indices_shape) <= 1:
perm = list(range(len(val_x.out_shapes[0]))) perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:] perm = [axis] + perm[:axis] + perm[axis + 1:]
...@@ -621,6 +644,13 @@ class OpSet9(): ...@@ -621,6 +644,13 @@ class OpSet9():
param_attr=None) param_attr=None)
node.fluid_code.add_layer( node.fluid_code.add_layer(
'transpose', inputs=node, output=node, param_attr=attr_trans) 'transpose', inputs=node, output=node, param_attr=attr_trans)
if len(indices_shape) < 1:
node.fluid_code.add_layer(
'squeeze',
inputs={'input': node,
'axes': [0]},
output=node,
param_attr=None)
elif axis == 0 and len(indices_shape) > 1: elif axis == 0 and len(indices_shape) > 1:
if val_x.out_shapes[0] is not None and isinstance( if val_x.out_shapes[0] is not None and isinstance(
val_x, ONNXGraphDataNode): val_x, ONNXGraphDataNode):
...@@ -701,6 +731,89 @@ class OpSet9(): ...@@ -701,6 +731,89 @@ class OpSet9():
output=node, output=node,
param_attr={'shape': reshaped_shape}) param_attr={'shape': reshaped_shape})
@print_mapping_info
def ScatterND(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
indices = self.graph.get_input_node(node, idx=1, copy=True)
updates = self.graph.get_input_node(node, idx=2, copy=True)
if len(indices.out_shapes[0]) == 1:
node.fluid_code.add_layer(
'scatter',
inputs={'input': val_x,
'index': indices,
'updates': updates},
output=node,
param_attr=None)
else:
input_inner_indices = node.layer_name + '_input_inner_indices'
print('val_x shape:', val_x.out_shapes[0])
print('indices shape:', indices.out_shapes[0])
print('updates shape:', updates.out_shapes[0])
node.fluid_code.add_layer(
'scatter_nd',
inputs={
'shape': val_x.out_shapes[0],
'index': indices,
'updates': updates
},
output=input_inner_indices,
param_attr=None)
constant_minus_one = node.layer_name + '_constant_minus_one'
node.fluid_code.add_layer(
'fill_constant',
inputs=None,
output=constant_minus_one,
param_attr={
'shape': updates.out_shapes[0],
'dtype': string(updates.dtype),
'value': -1
})
indices_mask = node.layer_name + '_indices_mask'
node.fluid_code.add_layer(
'scatter_nd',
inputs={
'shape': val_x.out_shapes[0],
'index': indices,
'updates': constant_minus_one
},
output=indices_mask,
param_attr=None)
constant_1 = node.layer_name + '_constant_1'
node.fluid_code.add_layer(
'fill_constant',
inputs=None,
output=constant_1,
param_attr={
'shape': val_x.out_shapes[0],
'dtype': string(val_x.dtype),
'value': 1
})
input_out_indices_mask = node.layer_name + '_input_out_indices_mask'
node.fluid_code.add_layer(
"elementwise_add",
inputs={"x": indices_mask,
"y": constant_1},
output=input_out_indices_mask,
param_attr=None)
input_out_indices = node.layer_name + '_input_out_indices'
node.fluid_code.add_layer(
"elementwise_mul",
inputs={"x": val_x,
"y": input_out_indices_mask},
output=input_out_indices,
param_attr=None)
node.fluid_code.add_layer(
"elementwise_add",
inputs={"x": input_inner_indices,
"y": input_out_indices},
output=node,
param_attr=None)
@print_mapping_info @print_mapping_info
def Range(self, node): def Range(self, node):
val_start = self.graph.get_input_node(node, idx=0, copy=True) val_start = self.graph.get_input_node(node, idx=0, copy=True)
...@@ -828,6 +941,13 @@ class OpSet9(): ...@@ -828,6 +941,13 @@ class OpSet9():
inputs={'x': val_x}, inputs={'x': val_x},
output=node, output=node,
param_attr={'shape': shape_value.tolist()}) param_attr={'shape': shape_value.tolist()})
elif len(node.out_shapes[0]) > 0:
node.fluid_code.add_layer(
'reshape',
inputs={'x': val_x,
'shape': node.out_shapes[0]},
output=node,
param_attr=attr)
elif val_shape.dtype == 'int64': elif val_shape.dtype == 'int64':
val_shape_cast = val_shape.layer_name + '_cast' val_shape_cast = val_shape.layer_name + '_cast'
node.fluid_code.add_layer( node.fluid_code.add_layer(
...@@ -879,6 +999,34 @@ class OpSet9(): ...@@ -879,6 +999,34 @@ class OpSet9():
node.fluid_code.add_layer( node.fluid_code.add_layer(
'cast', inputs=val_input, output=node, param_attr=attr) 'cast', inputs=val_input, output=node, param_attr=attr)
@print_mapping_info
def Cast(self, node):
val_input = self.graph.get_input_node(node, idx=0, copy=True)
val_output = self.graph.get_node(node.layer.output[0], copy=True)
dtype = node.get_attr('to')
if not isinstance(dtype, np.dtype):
dtype = TENSOR_TYPE_TO_NP_TYPE[dtype]
output_dtype = val_output.dtype
if output_dtype:
assert dtype == output_dtype, 'dtype of to unmatches output'
attr = {'dtype': string(dtype)}
node.fluid_code.add_layer(
'cast', inputs=val_input, output=node, param_attr=attr)
@print_mapping_info
def Not(self, node):
val_input = self.graph.get_input_node(node, idx=0, copy=True)
node.fluid_code.add_layer('logical_not', inputs=val_input, output=node)
val_output = self.graph.get_node(node.layer.output[0], copy=True)
node.fluid_code.add_layer(
'cast',
inputs=node,
output=node,
param_attr={'dtype': string('int64')})
@print_mapping_info @print_mapping_info
def AveragePool(self, node): def AveragePool(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册