提交 e164cf67 编写于 作者: C Channingss

fix bug

上级 fc5c67c4
......@@ -556,14 +556,6 @@ class PaddleOpMapper(object):
'SizeTensor' in input_names and
len(op.input('SizeTensor')) > 0):
node_list = list()
empty_name = self.get_name(op.type, 'empty')
empty_tensor = helper.make_tensor(
empty_name,
onnx_pb.TensorProto.FLOAT, (0, ),
np.array([]).astype('float32'),
raw=False)
empty_node = helper.make_node(
'Constant', [], outputs=[empty_name], value=empty_tensor)
shape_name0 = self.get_name(op.type, 'shape')
shape_node0 = helper.make_node(
'Shape', inputs=op.input('X'), outputs=[shape_name0])
......@@ -578,8 +570,7 @@ class PaddleOpMapper(object):
'Slice',
inputs=[shape_name0, starts_name, ends_name],
outputs=[shape_name1])
node_list.extend(
[empty_node, shape_node0, starts_node, ends_node, shape_node1])
node_list.extend([shape_node0, starts_node, ends_node, shape_node1])
if 'OutSize' in input_names and len(op.input('OutSize')) > 0:
cast_shape_name = self.get_name(op.type, "shape.cast")
cast_shape_node = helper.make_node(
......@@ -617,7 +608,7 @@ class PaddleOpMapper(object):
outputs=[name_h_w],
value=helper.make_tensor(
name=name_h_w,
data_type=onnx_pb.TensorProto.FLOAT,
data_type=onnx_pb.TensorProto.INT64,
dims=[2],
vals=[height, width],
raw=False))
......@@ -629,10 +620,24 @@ class PaddleOpMapper(object):
outputs=[shape_name4],
axis=0)
node_list.append(shape_node4)
cast_shape_name3 = self.get_name(op.type, "shape.cast")
cast_shape_node3 = helper.make_node(
'Cast',
inputs=[shape_name3],
outputs=[cast_shape_name3],
to=onnx_pb.TensorProto.FLOAT)
node_list.append(cast_shape_node3)
cast_shape_name4 = self.get_name(op.type, "shape.cast")
cast_shape_node4 = helper.make_node(
'Cast',
inputs=[shape_name4],
outputs=[cast_shape_name4],
to=onnx_pb.TensorProto.FLOAT)
node_list.append(cast_shape_node4)
outputs_h_w_scales = op.output('Out')[0] + "@out_hw_scales"
node_h_w_scales = helper.make_node(
'Div',
inputs=[shape_name3, shape_name4],
inputs=[cast_shape_name3, cast_shape_name4],
outputs=[outputs_h_w_scales])
node_list.append(node_h_w_scales)
result_node = helper.make_node(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册