提交 32e91854 编写于 作者: C Channingss

update

上级 8f9bd9b6
......@@ -53,21 +53,14 @@ class ONNXOpMapper(OpMapper):
def op_checker(self):
unsupported_ops = set()
contain_ops = set()
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
op = node.layer_type
contain_ops.add(op)
if not hasattr(self.opset, op) and \
op not in self.opset.default_op_mapping and \
op not in custom_layers and \
op not in self.opset.elementwise_ops:
unsupported_ops.add(op)
print("There are {} ops need converted , list as below".format(
len(contain_ops)))
for op in contain_ops:
print(op)
if len(unsupported_ops) == 0:
return True
else:
......
......@@ -46,7 +46,7 @@ def _is_static_shape(shape):
for dim in shape:
if dim < 0:
negtive_dims += 1
if dim != -1:
if dim < -1:
error_dims += 1
if negtive_dims > 1:
return False
......@@ -513,8 +513,21 @@ class OpSet9():
output=node,
param_attr={'shape': [1]})
else:
node.fluid_code.add_layer(
'unsqueeze', inputs=val_x, output=node, param_attr=attr)
if str(val_x.dtype) == 'bool':
val_x_cast = val_x.layer_name + '_cast'
node.fluid_code.add_layer(
'cast',
inputs=val_x,
output=val_x_cast,
param_attr={'dtype': string('int64')})
node.fluid_code.add_layer(
'unsqueeze',
inputs=val_x_cast,
output=node,
param_attr=attr)
else:
node.fluid_code.add_layer(
'unsqueeze', inputs=val_x, output=node, param_attr=attr)
@print_mapping_info
def Shrink(self, node):
......@@ -783,9 +796,6 @@ class OpSet9():
param_attr=None)
else:
input_inner_indices = node.layer_name + '_input_inner_indices'
print('val_x shape:', val_x.out_shapes[0])
print('indices shape:', indices.out_shapes[0])
print('updates shape:', updates.out_shapes[0])
node.fluid_code.add_layer(
'scatter_nd',
inputs={
......@@ -1037,28 +1047,11 @@ class OpSet9():
node.fluid_code.add_layer(
'cast', inputs=val_input, output=node, param_attr=attr)
@print_mapping_info
def Cast(self, node):
val_input = self.graph.get_input_node(node, idx=0, copy=True)
val_output = self.graph.get_node(node.layer.output[0], copy=True)
dtype = node.get_attr('to')
if not isinstance(dtype, np.dtype):
dtype = TENSOR_TYPE_TO_NP_TYPE[dtype]
output_dtype = val_output.dtype
if output_dtype:
assert dtype == output_dtype, 'dtype of to unmatches output'
attr = {'dtype': string(dtype)}
node.fluid_code.add_layer(
'cast', inputs=val_input, output=node, param_attr=attr)
@print_mapping_info
def Not(self, node):
val_input = self.graph.get_input_node(node, idx=0, copy=True)
node.fluid_code.add_layer('logical_not', inputs=val_input, output=node)
val_output = self.graph.get_node(node.layer.output[0], copy=True)
node.fluid_code.add_layer(
'cast',
inputs=node,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册