提交 6eca1a44 编写于 作者: W wjj19950828

Merge remote-tracking branch 'upstream/develop' into Fixed_Flatten

...@@ -192,6 +192,14 @@ def tf2paddle(model_path, ...@@ -192,6 +192,14 @@ def tf2paddle(model_path,
ConverterCheck( ConverterCheck(
task="TensorFlow", time_info=time_info, task="TensorFlow", time_info=time_info,
lite_state="Success").start() lite_state="Success").start()
# for convert survey
logging.info("================================================")
logging.info("")
logging.info(
"Model Convertd! Fill this survey to help X2Paddle better, https://iwenjuan.baidu.com/?code=npyd51 "
)
logging.info("")
logging.info("================================================")
def caffe2paddle(proto_file, def caffe2paddle(proto_file,
...@@ -240,6 +248,14 @@ def caffe2paddle(proto_file, ...@@ -240,6 +248,14 @@ def caffe2paddle(proto_file,
if not disable_feedback: if not disable_feedback:
ConverterCheck( ConverterCheck(
task="Caffe", time_info=time_info, lite_state="Success").start() task="Caffe", time_info=time_info, lite_state="Success").start()
# for convert survey
logging.info("================================================")
logging.info("")
logging.info(
"Model Convertd! Fill this survey to help X2Paddle better, https://iwenjuan.baidu.com/?code=npyd51 "
)
logging.info("")
logging.info("================================================")
def onnx2paddle(model_path, def onnx2paddle(model_path,
...@@ -293,6 +309,14 @@ def onnx2paddle(model_path, ...@@ -293,6 +309,14 @@ def onnx2paddle(model_path,
if not disable_feedback: if not disable_feedback:
ConverterCheck( ConverterCheck(
task="ONNX", time_info=time_info, lite_state="Success").start() task="ONNX", time_info=time_info, lite_state="Success").start()
# for convert survey
logging.info("================================================")
logging.info("")
logging.info(
"Model Convertd! Fill this survey to help X2Paddle better, https://iwenjuan.baidu.com/?code=npyd51 "
)
logging.info("")
logging.info("================================================")
def pytorch2paddle(module, def pytorch2paddle(module,
...@@ -364,6 +388,14 @@ def pytorch2paddle(module, ...@@ -364,6 +388,14 @@ def pytorch2paddle(module,
ConverterCheck( ConverterCheck(
task="PyTorch", time_info=time_info, task="PyTorch", time_info=time_info,
lite_state="Success").start() lite_state="Success").start()
# for convert survey
logging.info("================================================")
logging.info("")
logging.info(
"Model Convertd! Fill this survey to help X2Paddle better, https://iwenjuan.baidu.com/?code=npyd51 "
)
logging.info("")
logging.info("================================================")
def main(): def main():
......
...@@ -122,6 +122,7 @@ class OpSet9(): ...@@ -122,6 +122,7 @@ class OpSet9():
'Mul': 'paddle.multiply', 'Mul': 'paddle.multiply',
'Pow': 'paddle.pow', 'Pow': 'paddle.pow',
'Less': 'paddle.less_than', 'Less': 'paddle.less_than',
'LessOrEqual': 'paddle.less_equal',
} }
directly_map_ops = { directly_map_ops = {
...@@ -742,27 +743,21 @@ class OpSet9(): ...@@ -742,27 +743,21 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes') axes = node.get_attr('axes')
if axes is None: if axes is None:
axes = self.graph.get_input_node(node, idx=1, copy=True) axes_node = self.graph.get_input_node(node, idx=1, copy=True)
axes = _const_weight_or_none(axes_node, necessary=True)
# deal with scalar(0D) tensor
if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0: if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0:
if node.name:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.reshape', 'paddle.reshape',
inputs={"x": val_x.name}, inputs={"x": val_x.name},
outputs=[node.name], outputs=[node.name],
shape=[1]) shape=[1])
else: else:
if isinstance(axes, list) or isinstance(axes, tuple):
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.unsqueeze', 'paddle.unsqueeze',
inputs={"x": val_x.name}, inputs={"x": val_x.name},
axis=axes, axis=axes,
outputs=[node.name]) outputs=[node.name])
else:
self.paddle_graph.add_layer(
'paddle.unsqueeze',
inputs={"x": val_x.name,
"axis": axes.name},
outputs=[node.name])
@print_mapping_info @print_mapping_info
def Shrink(self, node): def Shrink(self, node):
...@@ -897,72 +892,31 @@ class OpSet9(): ...@@ -897,72 +892,31 @@ class OpSet9():
def Gather(self, node): def Gather(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
indices = self.graph.get_input_node(node, idx=1, copy=True) indices = self.graph.get_input_node(node, idx=1, copy=True)
indices_values = _const_weight_or_none(indices, necessary=True)
if isinstance(indices_values, np.ndarray):
indices_values = indices_values.tolist()
indices_shape = indices.out_shapes[0] indices_shape = indices.out_shapes[0]
val_x_shape = val_x.out_shapes[0]
axis = node.get_attr('axis', 0) axis = node.get_attr('axis', 0)
#assert len( if len(indices_shape) == 1 or \
# indices_shape) <= 2, "Gather op don't support dim of indice >2 " (indices_values is not None and isinstance(indices_values, int)) or \
if axis == 0 and len(indices_shape) <= 1: (indices_values is not None and len(indices_values) == 1):
if len(val_x.out_shapes[0]) <= 1:
self.paddle_graph.add_layer(
'paddle.gather',
inputs={'x': val_x.name,
'index': indices.name},
outputs=[node.name])
elif len(val_x.out_shapes[0]) > 1:
if len(indices_shape) == 0:
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": indices.name},
outputs=[indices.name],
shape=[-1, ])
gather_ = node.name + '_1'
self.paddle_graph.add_layer(
'paddle.gather',
inputs={'x': val_x.name,
'index': indices.name},
outputs=[gather_])
self.paddle_graph.add_layer(
'paddle.squeeze',
inputs={'x': gather_},
outputs=[node.name],
axis=[0])
else:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.gather', 'paddle.gather',
inputs={'x': val_x.name, inputs={'x': val_x.name,
'index': indices.name}, 'index': indices.name},
outputs=[node.name])
elif axis > 0 and len(indices_shape) <= 1:
perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:]
name_trans = val_x.name + '_trans'
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": val_x.name},
outputs=[name_trans],
perm=perm)
self.paddle_graph.add_layer(
'paddle.gather',
inputs={'x': name_trans,
'index': indices.name},
outputs=[node.name])
new_perm = [0] * len(perm)
for i in range(len(perm)):
new_perm[perm[i]] = i
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": node.name},
outputs=[node.name], outputs=[node.name],
perm=new_perm) axis=axis)
if len(indices_shape) < 1: # deal with indice is scalar(0D) Tensor
if isinstance(indices_values, int) and len(val_x_shape) > 1:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.squeeze', 'paddle.squeeze',
inputs={'x': node.name}, inputs={'x': node.name},
outputs=[node.name], outputs=[node.name],
axis=[axis]) axis=[axis])
elif axis == 0 and len(indices_shape) > 1: else:
if val_x.out_shapes[0] is not None and isinstance( # if val_x is DataNode, convert gather to embedding
val_x, ONNXGraphDataNode): if axis == 0 and isinstance(val_x, ONNXGraphDataNode):
indices_cast = indices.name + '_cast' indices_cast = indices.name + '_cast'
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.cast', 'paddle.cast',
...@@ -977,79 +931,68 @@ class OpSet9(): ...@@ -977,79 +931,68 @@ class OpSet9():
'paddle.nn.Embedding', 'paddle.nn.Embedding',
inputs={"x": indices_cast}, inputs={"x": indices_cast},
outputs=layer_outputs, outputs=layer_outputs,
num_embeddings=val_x.out_shapes[0][0], num_embeddings=val_x_shape[0],
embedding_dim=val_x.out_shapes[0][1]) embedding_dim=val_x_shape[1])
else: else:
from functools import reduce
reshape_shape = reduce(lambda x, y: x * y, indices_shape)
indices_reshape = indices.name + '_shape'
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.reshape', 'paddle.reshape',
inputs={"x": indices.name}, inputs={"x": indices.name},
outputs=[indices_reshape], outputs=[indices.name + "_reshape"],
shape=[reshape_shape, ]) shape=[-1])
gather_1d = node.name + '_1D'
perm = list(range(len(val_x.out_shapes[0])))
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.gather', 'paddle.gather',
inputs={'x': val_x.name, inputs={
'index': indices_reshape}, 'x': val_x.name,
outputs=[node.name]) 'index': indices.name + "_reshape"
val_x_shape = val_x.out_shapes[0] },
reshaped_shape = [] outputs=[gather_1d],
for i in perm: axis=axis)
reshaped_shape.append(indices_shape[i]) # if shape is known
for i in val_x_shape[:axis] + val_x_shape[axis + 1:]: if len(indices_shape) != 0 and len(val_x_shape) != 0:
reshaped_shape.append(i)
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.reshape', 'paddle.reshape',
inputs={"x": node.name}, inputs={'x': gather_1d},
outputs=[node.name], outputs=[node.name],
shape=reshaped_shape) shape=val_x_shape[:axis] + indices_shape +
elif axis > 0 and len(indices_shape) > 1: val_x_shape[axis + 1:])
from functools import reduce else:
reshape_shape = reduce(lambda x, y: x * y, indices_shape) all_shape_name = list()
indices_reshape = indices.name + '_shape'
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.reshape', kernel="paddle.shape",
inputs={"x": indices.name}, inputs={"input": val_x.name},
outputs=[indices_reshape], outputs=[val_x.name + "_shape"])
shape=[reshape_shape, ])
perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:]
name_trans = val_x.name + '_transpose'
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.transpose', kernel="paddle.shape",
inputs={"x": val_x.name}, inputs={"input": indices.name},
outputs=[name_trans], outputs=[indices.name + "_shape"])
perm=perm)
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.gather', "paddle.slice",
inputs={'x': name_trans, inputs={"input": val_x.name + "_shape"},
'index': indices_reshape}, outputs=[val_x.name + "_shape_slice_start"],
outputs=[node.name]) axes=[0],
input_transpose = node.name + '_transpose' starts=[0],
new_perm = [0] * len(perm) ends=[axis])
for i in range(len(perm)): all_shape_name.append(val_x.name + "_shape_slice_start")
new_perm[perm[i]] = i all_shape_name.append(indices.name + "_shape")
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.transpose', "paddle.slice",
inputs={"x": node.name}, inputs={"input": val_x.name + "_shape"},
outputs=[input_transpose], outputs=[val_x.name + "_shape_slice_end"],
perm=new_perm) axes=[0],
perm = new_perm starts=[axis + 1],
val_x_shape = val_x.out_shapes[0] ends=[2147483647])
reshaped_shape = [] all_shape_name.append(val_x.name + "_shape_slice_end")
for i in perm: self.paddle_graph.add_layer(
reshaped_shape.append(indices_shape[i]) 'paddle.concat',
for i in val_x_shape[:axis] + val_x_shape[axis + 1:]: inputs={"x": all_shape_name},
reshaped_shape.append(i) outputs=[node.name + "_all_shape"],
axis=0)
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.reshape', 'paddle.reshape',
inputs={"x": input_transpose}, inputs={'x': gather_1d},
outputs=[node.name], outputs=[node.name],
shape=reshaped_shape) shape=node.name + "_all_shape")
@print_mapping_info @print_mapping_info
def ScatterND(self, node): def ScatterND(self, node):
...@@ -1260,16 +1203,6 @@ class OpSet9(): ...@@ -1260,16 +1203,6 @@ class OpSet9():
outputs=[node.name], outputs=[node.name],
**layer_attrs) **layer_attrs)
@print_mapping_info
def GatherND(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
self.paddle_graph.add_layer(
"paddle.gather_nd",
inputs={"x": val_x.name,
"index": val_y.name},
outputs=[node.name])
@print_mapping_info @print_mapping_info
def Clip(self, node): def Clip(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
...@@ -1425,16 +1358,6 @@ class OpSet9(): ...@@ -1425,16 +1358,6 @@ class OpSet9():
"y": val_y.name}, "y": val_y.name},
outputs=[node.name]) outputs=[node.name])
@print_mapping_info
def GatherND(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
self.paddle_graph.add_layer(
"paddle.gather_nd",
inputs={"x": val_x.name,
"index": val_y.name},
outputs=[node.name])
@print_mapping_info @print_mapping_info
def And(self, node): def And(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
...@@ -1732,7 +1655,8 @@ class OpSet9(): ...@@ -1732,7 +1655,8 @@ class OpSet9():
x_shape = val_x.out_shapes[0] x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0] y_shape = val_y.out_shapes[0]
inputs_dict = {"x": val_x.name, "y": val_y.name} inputs_dict = {"x": val_x.name, "y": val_y.name}
if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1: if len(y_shape) != 0 and y_shape[0] == 1 and len(
x_shape) != 0 and x_shape[-1] != 1 and x_shape[0] != 1:
y_squeeze = val_y.name + '_squeeze' y_squeeze = val_y.name + '_squeeze'
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.squeeze", "paddle.squeeze",
......
...@@ -118,7 +118,9 @@ class TraceFcFuser(FuseBase): ...@@ -118,7 +118,9 @@ class TraceFcFuser(FuseBase):
(1, 0)) (1, 0))
self.rm_params.add(weight_name) self.rm_params.add(weight_name)
bias_numpy = parameters[bias_name] bias_numpy = parameters[bias_name]
parameters["{}.bias".format(linear_name)] = np.squeeze(bias_numpy) if len(bias_numpy.shape) == 2:
bias_numpy = np.squeeze(bias_numpy)
parameters["{}.bias".format(linear_name)] = bias_numpy
self.rm_params.add(bias_name) self.rm_params.add(bias_name)
new_layer = PaddleLayer( new_layer = PaddleLayer(
layers_id[0], layers_id[0],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册