提交 35c60c0f 编写于 作者: C Channingss

merge paddle/develop

......@@ -58,7 +58,7 @@ x2paddle --framework=paddle2onnx --model=paddle_infer_model_dir --save_dir=onnx_
|--save_dir | 指定转换后的模型保存目录路径 |
|--model | 当framework为tensorflow/onnx时,该参数指定tensorflow的pb模型文件或onnx模型路径 |
|--caffe_proto | **[可选]** 由caffe.proto编译成caffe_pb2.py文件的存放路径,当存在自定义Layer时使用,默认为None |
|--without_data_format_optimization | **[可选]** For TensorFlow, 当指定该参数时,关闭NHWC->NCHW的优化,见[文档Q2](FAQ.md) |
|--without_data_format_optimization | **[可选]** For TensorFlow, 当指定该参数为False时,打开NHWC->NCHW的优化,见[文档Q2](FAQ.md),默认为True|
|--define_input_shape | **[可选]** For TensorFlow, 当指定该参数时,强制用户输入每个Placeholder的shape,见[文档Q2](FAQ.md) |
|--params_merge | **[可选]** 当指定该参数时,转换完成后,inference_model中的所有模型参数将合并保存为一个文件__params__ |
|--onnx_opset | **[可选]** 当framework为paddle2onnx时,该参数可设置转换为ONNX的OpSet版本,目前支持9、10、11,默认为10 |
......
# X2Paddle支持OP列表
> 目前X2Paddle支持50+的TensorFlow OP,30+的Caffe Layer,覆盖了大部分CV分类模型常用的操作。我们在如下列表中给出了目前X2Paddle支持的全部OP。
> 目前X2Paddle支持70+的TensorFlow OP,30+的Caffe Layer,覆盖了大部分CV分类模型常用的操作。我们在如下列表中给出了目前X2Paddle支持的全部OP。
**注:** 目前,部分OP暂未支持,如您在转换过程中出现OP不支持的情况,可自行添加或反馈给我们。欢迎通过[ISSUE反馈](https://github.com/PaddlePaddle/X2Paddle/issues/new)的方式告知我们(模型名,代码实现或模型获取方式),我们会及时跟进:)
......@@ -21,6 +21,10 @@
| 45 | Softmax | 46 | Range | 47 | ConcatV2 | 48 | MirrorPad |
| 49 | Identity | 50 | GreaterEqual | 51 | StopGradient | 52 | Minimum |
| 53 | RadnomUniform | 54 | Fill | 55 | Floor | 56 | DepthToSpace |
| 57 | Sqrt | 58 | Softplus | 59 | Erf | 60 | AddV2 |
| 61 | LessEqual | 62 | BatchMatMul | 63 | BatchMatMulV2 | 64 | ExpandDims |
| 65 | BatchToSpaceND | 66 | SpaceToBatchND | 67 | OneHot | 68 | Pow |
| 69 | All | 70 | GatherV2 | 71 | IteratorV2 | | |
## Caffe
......
......@@ -66,8 +66,8 @@ def arg_parser():
parser.add_argument(
"--without_data_format_optimization",
"-wo",
action="store_true",
default=False,
type=_text_type,
default="True",
help="tf model conversion without data format optimization")
parser.add_argument(
"--define_input_shape",
......@@ -93,7 +93,7 @@ def arg_parser():
def tf2paddle(model_path,
save_dir,
without_data_format_optimization=False,
without_data_format_optimization,
define_input_shape=False,
params_merge=False):
# check tensorflow installation and version
......@@ -240,11 +240,12 @@ def main():
if args.framework == "tensorflow":
assert args.model is not None, "--model should be defined while translating tensorflow model"
without_data_format_optimization = False
assert args.without_data_format_optimization in [
"True", "False"
], "--the param without_data_format_optimization should be defined True or False"
define_input_shape = False
params_merge = False
if args.without_data_format_optimization:
without_data_format_optimization = True
without_data_format_optimization = True if args.without_data_format_optimization == "True" else False
if args.define_input_shape:
define_input_shape = True
if args.params_merge:
......
......@@ -17,7 +17,7 @@ def normalize_layer(inputs,
scale_param = fluid.layers.create_parameter(
shape=[1] if channel_shared else [1, 1, 1, input_shape[0][1]],
dtype=input.dtype,
attr=name + '_scale')
attr=fluid.ParamAttr(name=name + '_scale'))
scale_param = fluid.layers.reshape(x=scale_param, \
shape=[1] if channel_shared else [input_shape[0][1]])
out = fluid.layers.elementwise_mul(
......
......@@ -43,6 +43,21 @@ def _const_weight_or_none(node, necessary=False):
return None
def _is_static_shape(shape):
negtive_dims = 0
error_dims = 0
for dim in shape:
if dim < 0:
negtive_dims += 1
if dim < -1:
error_dims += 1
if negtive_dims > 1:
return False
if error_dims > 0:
return False
return True
def _get_same_padding(in_size, kernel_size, stride):
new_size = int(math.ceil(in_size * 1.0 / stride))
pad_size = (new_size - 1) * stride + kernel_size - in_size
......@@ -231,42 +246,9 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
val_y_shape = val_y.out_shapes[0]
val_x_shape = val_x.out_shapes[0]
if len(val_x_shape) < len(val_y_shape):
val_x, val_y = val_y, val_x
val_y_shape, val_x_shape = val_x_shape, val_y_shape
str_y_shape = ','.join(str(e) for e in val_y_shape)
str_x_shape = ','.join(str(e) for e in val_x_shape)
slice_idx = 0
if str_y_shape not in str_x_shape:
for dim in val_y_shape:
if dim == 1:
slice_idx += 1
else:
break
attr = {"name": string(node.layer_name)}
if slice_idx < len(val_y_shape) and slice_idx > 0:
val_y_reshaped = val_y_shape[slice_idx:]
var_y_reshaped = val_y.layer_name + '_reshaped'
attr_reshaped = {
'shape': val_y_reshaped,
'name': string(var_y_reshaped)
}
node.fluid_code.add_layer(
'reshape',
inputs=val_y,
output=var_y_reshaped,
param_attr=attr_reshaped)
inputs = {'x': val_x, 'y': var_y_reshaped}
node.fluid_code.add_layer(
op_type, inputs=inputs, output=node, param_attr=attr)
else:
inputs = {'x': val_x, 'y': val_y}
node.fluid_code.add_layer(
op_type, inputs=inputs, output=node, param_attr=attr)
op_type, inputs=inputs, output=node, param_attr=None)
@print_mapping_info
def place_holder(self, node):
......@@ -478,6 +460,19 @@ class OpSet9():
inputs=val_x,
output=node,
param_attr={'shape': [1]})
else:
if str(val_x.dtype) == 'bool':
val_x_cast = val_x.layer_name + '_cast'
node.fluid_code.add_layer(
'cast',
inputs=val_x,
output=val_x_cast,
param_attr={'dtype': string('int64')})
node.fluid_code.add_layer(
'unsqueeze',
inputs=val_x_cast,
output=node,
param_attr=attr)
else:
node.fluid_code.add_layer(
'unsqueeze', inputs=val_x, output=node, param_attr=attr)
......@@ -590,6 +585,29 @@ class OpSet9():
#assert len(
# indices_shape) <= 2, "Gather op don't support dim of indice >2 "
if axis == 0 and len(indices_shape) <= 1:
if len(val_x.out_shapes[0]) <= 1:
node.fluid_code.add_layer(
'gather',
inputs={'input': val_x,
'index': indices},
output=node,
param_attr=None)
elif len(val_x.out_shapes[0]) > 1:
if len(indices_shape) == 0:
gather_ = node.layer_name + '_1'
node.fluid_code.add_layer(
'gather',
inputs={'input': val_x,
'index': indices},
output=gather_,
param_attr=None)
node.fluid_code.add_layer(
'squeeze',
inputs={'input': gather_,
'axes': [0]},
output=node,
param_attr=None)
else:
node.fluid_code.add_layer(
'gather',
inputs={'input': val_x,
......@@ -614,6 +632,13 @@ class OpSet9():
param_attr=None)
node.fluid_code.add_layer(
'transpose', inputs=node, output=node, param_attr=attr_trans)
if len(indices_shape) < 1:
node.fluid_code.add_layer(
'squeeze',
inputs={'input': node,
'axes': [axis]},
output=node,
param_attr=None)
elif axis == 0 and len(indices_shape) > 1:
if val_x.out_shapes[0] is not None and isinstance(
val_x, ONNXGraphDataNode):
......@@ -702,6 +727,86 @@ class OpSet9():
output=node,
param_attr={'shape': reshaped_shape})
@print_mapping_info
def ScatterND(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
indices = self.graph.get_input_node(node, idx=1, copy=True)
updates = self.graph.get_input_node(node, idx=2, copy=True)
if len(indices.out_shapes[0]) == 1:
node.fluid_code.add_layer(
'scatter',
inputs={'input': val_x,
'index': indices,
'updates': updates},
output=node,
param_attr=None)
else:
input_inner_indices = node.layer_name + '_input_inner_indices'
node.fluid_code.add_layer(
'scatter_nd',
inputs={
'shape': val_x.out_shapes[0],
'index': indices,
'updates': updates
},
output=input_inner_indices,
param_attr=None)
constant_minus_one = node.layer_name + '_constant_minus_one'
node.fluid_code.add_layer(
'fill_constant',
inputs=None,
output=constant_minus_one,
param_attr={
'shape': updates.out_shapes[0],
'dtype': string(updates.dtype),
'value': -1
})
indices_mask = node.layer_name + '_indices_mask'
node.fluid_code.add_layer(
'scatter_nd',
inputs={
'shape': val_x.out_shapes[0],
'index': indices,
'updates': constant_minus_one
},
output=indices_mask,
param_attr=None)
constant_1 = node.layer_name + '_constant_1'
node.fluid_code.add_layer(
'fill_constant',
inputs=None,
output=constant_1,
param_attr={
'shape': val_x.out_shapes[0],
'dtype': string(val_x.dtype),
'value': 1
})
input_out_indices_mask = node.layer_name + '_input_out_indices_mask'
node.fluid_code.add_layer(
"elementwise_add",
inputs={"x": indices_mask,
"y": constant_1},
output=input_out_indices_mask,
param_attr=None)
input_out_indices = node.layer_name + '_input_out_indices'
node.fluid_code.add_layer(
"elementwise_mul",
inputs={"x": val_x,
"y": input_out_indices_mask},
output=input_out_indices,
param_attr=None)
node.fluid_code.add_layer(
"elementwise_add",
inputs={"x": input_inner_indices,
"y": input_out_indices},
output=node,
param_attr=None)
@print_mapping_info
def Range(self, node):
val_start = self.graph.get_input_node(node, idx=0, copy=True)
......@@ -791,8 +896,6 @@ class OpSet9():
'this is not supported')
if len(value) == 1:
value = value[0]
if dtype.name == 'int64':
dtype = 'int32'
attr = {
'shape': val_shape.layer_name,
'dtype': string(dtype),
......@@ -833,6 +936,14 @@ class OpSet9():
inputs={'x': val_x},
output=node,
param_attr={'shape': shape_value.tolist()})
elif len(node.out_shapes[0]) > 0 and _is_static_shape(node.out_shapes[
0]):
node.fluid_code.add_layer(
'reshape',
inputs={'x': val_x,
'shape': node.out_shapes[0]},
output=node,
param_attr=attr)
elif val_shape.dtype == 'int64':
val_shape_cast = val_shape.layer_name + '_cast'
node.fluid_code.add_layer(
......@@ -884,6 +995,11 @@ class OpSet9():
node.fluid_code.add_layer(
'cast', inputs=val_input, output=node, param_attr=attr)
@print_mapping_info
def Not(self, node):
val_input = self.graph.get_input_node(node, idx=0, copy=True)
node.fluid_code.add_layer('logical_not', inputs=val_input, output=node)
@print_mapping_info
def AveragePool(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
......@@ -924,12 +1040,16 @@ class OpSet9():
@print_mapping_info
def Concat(self, node):
inputs = []
dtypes = set()
for i in range(len(node.layer.input)):
ipt = self.graph.get_input_node(node, idx=i, copy=True)
if isinstance(ipt, str):
inputs.append(ipt)
else:
inputs.append(ipt.layer_name)
dtypes.add(ipt.dtype)
if len(dtypes) > 1:
assert 'Unspported situation happened, please create issue on https://github.com/PaddlePaddle/X2Paddle/issues.'
axis = node.get_attr('axis')
attr = {'axis': axis}
node.fluid_code.add_layer(
......@@ -1017,10 +1137,22 @@ class OpSet9():
def MatMul(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0]
inputs = {"x": val_x, "y": val_y}
attr = {"name": string(node.layer_name)}
if y_shape[0] == 1 and x_shape[-1] != 1:
y_squeeze = val_y.layer_name + '_squeeze'
node.fluid_code.add_layer(
"squeeze",
inputs=val_y,
output=y_squeeze,
param_attr={'axes': [0]})
inputs['y'] = y_squeeze
node.fluid_code.add_layer(
"matmul", inputs=inputs, output=node, param_attr=None)
else:
node.fluid_code.add_layer(
"matmul", inputs=inputs, output=node, param_attr=attr)
"matmul", inputs=inputs, output=node, param_attr=None)
@print_mapping_info
def BatchNormalization(self, node):
......
......@@ -93,16 +93,13 @@ class OpSet11(OpSet10):
else:
coordinate_transformation_mode = 'half_pixel'
roi_name = self.get_name(op.type, 'roi')
roi_node = self.make_constant_node(roi_name, onnx_pb.TensorProto.FLOAT,
[1, 1, 1, 1, 1, 1, 1, 1])
if ('OutSize' in input_names and len(op.input('OutSize')) > 0) or (
'SizeTensor' in input_names and
len(op.input('SizeTensor')) > 0):
node_list = list()
roi_node = self.make_constant_node(
self.get_name(op.type, 'roi'), onnx_pb.TensorProto.FLOAT,
[1, 1, 1, 1, 1, 1, 1, 1])
roi_name = self.get_name(op.type, 'roi')
roi_node = self.make_constant_node(
roi_name, onnx_pb.TensorProto.FLOAT, [1, 1, 1, 1, 1, 1, 1, 1])
empty_name = self.get_name(op.type, 'empty')
empty_tensor = helper.make_tensor(
empty_name,
......@@ -168,7 +165,7 @@ class OpSet11(OpSet10):
elif 'Scale' in input_names and len(op.input('Scale')) > 0:
node = helper.make_node(
'Resize',
inputs=[op.input('X')[0], op.input('Scale')[0]],
inputs=[op.input('X')[0], roi_name, op.input('Scale')[0]],
outputs=op.output('Out'),
mode='linear',
coordinate_transformation_mode=coordinate_transformation_mode)
......@@ -180,10 +177,6 @@ class OpSet11(OpSet10):
scale_node = self.make_constant_node(scale_name,
onnx_pb.TensorProto.FLOAT,
[1, 1, scale, scale])
roi_name = self.get_name(op.type, 'roi')
roi_node = self.make_constant_node(roi_name,
onnx_pb.TensorProto.FLOAT,
[1, 1, 1, 1, 1, 1, 1, 1])
node = helper.make_node(
'Resize',
inputs=[op.input('X')[0], roi_name, scale_name],
......@@ -194,7 +187,7 @@ class OpSet11(OpSet10):
return [scale_node, roi_node, node]
else:
raise Exception("Unexpected situation happend")
return node
return [roi_node, node]
def nearest_interp(self, op, block):
input_names = op.input_names
......@@ -204,17 +197,20 @@ class OpSet11(OpSet10):
coordinate_transformation_mode = 'align_corners'
else:
coordinate_transformation_mode = 'asymmetric'
roi_name = self.get_name(op.type, 'roi')
roi_node = self.make_constant_node(roi_name, onnx_pb.TensorProto.FLOAT,
[1, 1, 1, 1, 1, 1, 1, 1])
if 'OutSize' in input_names and len(op.input('OutSize')) > 0:
node = helper.make_node(
'Resize',
inputs=[op.input('X')[0], '', op.input('OutSize')[0]],
inputs=[op.input('X')[0], roi_name, op.input('OutSize')[0]],
outputs=op.output('Out'),
mode='nearest',
coordinate_transformation_mode=coordinate_transformation_mode)
elif 'Scale' in input_names and len(op.input('Scale')) > 0:
node = helper.make_node(
'Resize',
inputs=[op.input('X')[0], op.input('Scale')[0]],
inputs=[op.input('X')[0], roi_name, op.input('Scale')[0]],
outputs=op.output('Out'),
mode='nearest',
coordinate_transformation_mode=coordinate_transformation_mode)
......@@ -226,10 +222,6 @@ class OpSet11(OpSet10):
scale_node = self.make_constant_node(scale_name,
onnx_pb.TensorProto.FLOAT,
[1, 1, scale, scale])
roi_name = self.get_name(op.type, 'roi')
roi_node = self.make_constant_node(roi_name,
onnx_pb.TensorProto.FLOAT,
[1, 1, 1, 1, 1, 1, 1, 1])
node = helper.make_node(
'Resize',
inputs=[op.input('X')[0], roi_name, scale_name],
......@@ -240,7 +232,7 @@ class OpSet11(OpSet10):
return [scale_node, roi_node, node]
else:
raise Exception("Unexpected situation happend")
return node
return [roi_node, node]
def hard_swish(self, op, block):
min_name = self.get_name(op.type, 'min')
......
......@@ -174,14 +174,15 @@ class OpSet9(object):
inputs=[op.input('X')[0], temp_value],
outputs=op.output('Out'))
return [shape_node, y_node, node]
elif len(x_shape) == len(y_shape):
elif axis == -1 or axis == (len(x_shape) - 1
) or len(x_shape) == len(y_shape):
node = helper.make_node(
'Add',
inputs=[op.input('X')[0], op.input('Y')[0]],
outputs=op.output('Out'))
return node
else:
raise Excpetion("Unexpected situation happend in elementwise_add")
raise Exception("Unexpected situation happend in elementwise_add")
def elementwise_sub(self, op, block):
axis = op.attr('axis')
......@@ -203,14 +204,15 @@ class OpSet9(object):
inputs=[op.input('X')[0], temp_value],
outputs=op.output('Out'))
return [shape_node, y_node, node]
elif len(x_shape) == len(y_shape):
elif axis == -1 or axis == (len(x_shape) - 1
) or len(x_shape) == len(y_shape):
node = helper.make_node(
'Sub',
inputs=[op.input('X')[0], op.input('Y')[0]],
outputs=op.output('Out'))
return node
else:
raise Excpetion("Unexpected situation happend in elementwise_sub")
raise Exception("Unexpected situation happend in elementwise_sub")
def pool2d(self, op, block):
pool_type = {
......@@ -398,6 +400,11 @@ class OpSet9(object):
axis=op.attr('axis'))
return node
def sum(self, op, block):
node = helper.make_node(
'Sum', inputs=op.input('X'), outputs=op.output('Out'))
return node
def depthwise_conv2d(self, op, block):
return self.conv2d(op, block)
......@@ -560,7 +567,7 @@ class OpSet9(object):
input_shape = block.vars[op.input('X')[0]].shape
if op.attr('align_corners') or op.attr('align_mode') == 0:
raise Exception(
"Resize in onnx(opset<=10) only support coordinate_transformation_mode: 'asymmetric', Try converting with --onnx_opest 11"
"Resize in onnx(opset<=10) only support coordinate_transformation_mode: 'asymmetric', Try converting with --onnx_opset 11"
)
if ('OutSize' in input_names and len(op.input('OutSize')) > 0) or (
'SizeTensor' in input_names and
......@@ -666,14 +673,82 @@ class OpSet9(object):
input_names = op.input_names
if op.attr('align_corners'):
raise Exception(
"Resize in onnx(opset<=10) only support coordinate_transformation_mode: 'asymmetric', Try converting with --onnx_opest 11"
"Resize in onnx(opset<=10) only support coordinate_transformation_mode: 'asymmetric', Try converting with --onnx_opset 11"
)
if 'OutSize' in input_names and len(op.input('OutSize')) > 0:
node = helper.make_node(
node_list = list()
shape_name0 = self.get_name(op.type, 'shape')
shape_node0 = helper.make_node(
'Shape', inputs=op.input('X'), outputs=[shape_name0])
starts_name = self.get_name(op.type, 'slice.starts')
starts_node = self.make_constant_node(
starts_name, onnx_pb.TensorProto.INT64, [0])
ends_name = self.get_name(op.type, 'slice.ends')
ends_node = self.make_constant_node(ends_name,
onnx_pb.TensorProto.INT64, [2])
shape_name1 = self.get_name(op.type, 'shape')
shape_node1 = helper.make_node(
'Slice',
inputs=[shape_name0, starts_name, ends_name],
outputs=[shape_name1])
node_list.extend([shape_node0, starts_node, ends_node, shape_node1])
if 'OutSize' in input_names and len(op.input('OutSize')) > 0:
cast_shape_name = self.get_name(op.type, "shape.cast")
cast_shape_node = helper.make_node(
'Cast',
inputs=op.input('OutSize'),
outputs=[cast_shape_name],
to=onnx_pb.TensorProto.INT64)
node_list.append(cast_shape_node)
else:
concat_shape_name = self.get_name(
op.type, op.output('Out')[0] + "shape.concat")
concat_shape_node = helper.make_node(
"Concat",
inputs=op.input('SizeTensor'),
outputs=[concat_shape_name],
axis=0)
cast_shape_name = self.get_name(op.type, "shape.cast")
cast_shape_node = helper.make_node(
'Cast',
inputs=[concat_shape_name],
outputs=[cast_shape_name],
to=onnx_pb.TensorProto.INT64)
node_list.extend([concat_shape_node, cast_shape_node])
shape_name2 = self.get_name(op.type, "shape.concat")
shape_node2 = helper.make_node(
'Concat',
inputs=[shape_name1, cast_shape_name],
outputs=[shape_name2],
axis=0)
node_list.append(shape_node2)
cast_shape_name2 = self.get_name(op.type, "shape.cast")
cast_shape_node2 = helper.make_node(
'Cast',
inputs=[shape_name2],
outputs=[cast_shape_name2],
to=onnx_pb.TensorProto.FLOAT)
node_list.append(cast_shape_node2)
cast_shape_name0 = self.get_name(op.type, "shape.cast")
cast_shape_node0 = helper.make_node(
'Cast',
inputs=[shape_name0],
outputs=[cast_shape_name0],
to=onnx_pb.TensorProto.FLOAT)
node_list.append(cast_shape_node0)
outputs_h_w_scales = op.output('Out')[0] + "@out_hw_scales"
node_h_w_scales = helper.make_node(
'Div',
inputs=[cast_shape_name2, cast_shape_name0],
outputs=[outputs_h_w_scales])
node_list.append(node_h_w_scales)
result_node = helper.make_node(
'Resize',
inputs=[op.input('X')[0], op.input('OutSize')[0]],
inputs=[op.input('X')[0], outputs_h_w_scales],
outputs=op.output('Out'),
mode='nearest')
mode='linear')
node_list.extend([result_node])
return node_list
elif 'Scale' in input_names and len(op.input('Scale')) > 0:
node = helper.make_node(
'Resize',
......@@ -758,14 +833,15 @@ class OpSet9(object):
inputs=[op.input('X')[0], temp_value],
outputs=op.output('Out'))
return [shape_node, y_node, node]
elif len(x_shape) == len(y_shape):
elif axis == -1 or axis == (len(x_shape) - 1
) or len(x_shape) == len(y_shape):
node = helper.make_node(
'Mul',
inputs=[op.input('X')[0], op.input('Y')[0]],
outputs=op.output('Out'))
return node
else:
raise Excpetion("Unexpected situation happend in elementwise_add")
raise Exception("Unexpected situation happend in elementwise_mul")
return node
def feed(self, op, block):
......@@ -794,6 +870,14 @@ class OpSet9(object):
axes=op.attr('axes'))
return node
def cast(self, op, block):
node = helper.make_node(
'Cast',
inputs=op.input('X'),
outputs=op.output('Out'),
to=self.paddle_onnx_dtype_map[op.attr('out_dtype')])
return node
def arg_max(self, op, block):
node = helper.make_node(
'ArgMax',
......
......@@ -299,6 +299,10 @@ class TFOpMapperNHWC(OpMapper):
data_format = node.get_attr("data_format").decode()
pad_mode = node.get_attr("padding").decode()
channel_first = data_format == "NCHW"
if data_format == "NHWC":
n, h, w, c = input.out_shapes[0]
else:
n, c, h, w = input.out_shapes[0]
if kernel.layer_type == 'Const':
kernel_value = kernel.value
......@@ -329,10 +333,15 @@ class TFOpMapperNHWC(OpMapper):
"dilation": dilations[2:4],
"padding": string(pad_mode)
}
if hasattr(node, 'dilation') and attr['dilation'] == [1, 1]:
if len(node.dilation) == 1:
attr['dilation'] = [1, node.dilation[0]]
if c == -1:
reshape_attr = {"shape": [0, k_size[2], 0, 0]}
node.fluid_code.add_layer(
"reshape", inputs=input, output=input, param_attr=reshape_attr)
node.fluid_code.add_layer(
"conv2d", inputs=input, output=node, param_attr=attr)
if not channel_first:
......@@ -748,11 +757,12 @@ class TFOpMapperNHWC(OpMapper):
self.add_omit_nodes(begin.layer_name, node.layer_name)
begin = begin.value.tolist()
else:
begin = begin
shape = begin.out_shapes[0]
attr = {"shape": shape}
node.fluid_code.add_layer(
"reshape", inputs=begin, output=begin, param_attr=attr)
begin = self.decoder.infer_tensor(begin).tolist()
# shape = begin.out_shapes[0]
# attr = {"shape": shape}
# node.fluid_code.add_layer(
# "reshape", inputs=begin, output=begin, param_attr=attr)
if size.layer_type == "Const":
self.add_omit_nodes(size.layer_name, node.layer_name)
size = size.value.tolist()
......@@ -1058,13 +1068,25 @@ class TFOpMapperNHWC(OpMapper):
axis = axis.value.tolist()
assert axis == 0, "Only support axis=0 in GatherV2 OP"
attr = {'overwrite': False}
embeddings_shape = embeddings.out_shapes[0][-1]
reshape_list = list()
reshape_name = index.layer_name
if len(index.out_shapes[0]) != 1:
reshape_list = index.out_shapes[0]
reshape_attr = {"shape": [-1]}
reshape_name = "{}_reshape".format(index.layer_name)
node.fluid_code.add_layer(
"reshape", inputs=index, output=index, param_attr=reshape_attr)
inputs = {'input': embeddings, 'index': index}
"reshape",
inputs=index,
output=reshape_name,
param_attr=reshape_attr)
inputs = {'input': embeddings, 'index': reshape_name}
node.fluid_code.add_layer(
"gather", inputs=inputs, output=node, param_attr=attr)
if len(index.out_shapes[0]) != 1:
reshape_attr = {"shape": reshape_list + [embeddings_shape]}
node.fluid_code.add_layer(
"reshape", inputs=node, output=node, param_attr=reshape_attr)
def OneShotIterator(self, node):
return self.Placeholder(node)
......
......@@ -863,6 +863,9 @@ class TFOptimizer(object):
weight = numpy.expand_dims(weight, 2)
weight = numpy.expand_dims(weight, 3)
self.op_mapper.weights[in_nodes3[0].layer_name] = weight
# fix bug in Paddle1.8.3 and may change in next version.
# self.op_mapper.weights[in_nodes3[0].layer_name +
# '_1'] = weight.reshape(1, -1)
in_nodes3[0].fluid_code.layers[0].param_attr["shape"] = [
1, in_shape[-1], 1, 1
]
......
# X2Paddle模型测试库
> 目前X2Paddle支持50+的TensorFlow OP,40+的Caffe Layer,覆盖了大部分CV分类模型常用的操作。我们在如下模型列表中测试了X2Paddle的转换。
> 目前X2Paddle支持70+的TensorFlow OP,40+的Caffe Layer,覆盖了大部分CV分类模型常用的操作。我们在如下模型列表中测试了X2Paddle的转换。
**注:** 受限于不同框架的差异,部分模型可能会存在目前无法转换的情况,如TensorFlow中包含控制流的模型,NLP模型等。对于CV常见的模型,如若您发现无法转换或转换失败,存在较大diff等问题,欢迎通过[ISSUE反馈](https://github.com/PaddlePaddle/X2Paddle/issues/new)的方式告知我们(模型名,代码实现或模型获取方式),我们会及时跟进:)
......@@ -20,10 +20,13 @@
| ResNet_V1_101 | [code](https://github.com/tensorflow/models/tree/master/research/slim/nets) |-|
| ResNet_V2_101 | [code](https://github.com/tensorflow/models/tree/master/research/slim/nets) |-|
| UNet | [code1](https://github.com/jakeret/tf_unet )/[code2](https://github.com/lyatdawn/Unet-Tensorflow) |-|
|MTCNN | [code](https://github.com/AITTSMD/MTCNN-Tensorflow) |-|
|YOLO-V3| [code](https://github.com/YunYang1994/tensorflow-yolov3) | 转换需要关闭NHWC->NCHW的优化,见[文档Q2](FAQ.md) |
| FALSR | [code](https://github.com/xiaomi-automl/FALSR) | - |
| DCSCN | [code](https://modelzoo.co/model/dcscn-super-resolution) | - |
| MTCNN | [code](https://github.com/AITTSMD/MTCNN-Tensorflow) |-|
| YOLO-V3| [code](https://github.com/YunYang1994/tensorflow-yolov3) | 转换需要关闭NHWC->NCHW的优化,见[文档Q2](FAQ.md) |
| FALSR | [code](https://github.com/xiaomi-automl/FALSR) | 需使用参数without_data_format_optimization |
| DCSCN | [code](https://modelzoo.co/model/dcscn-super-resolution) | 需使用参数without_data_format_optimization |
| Bert(albert) | [code](https://github.com/google-research/albert#pre-trained-models) | 需使用参数without_data_format_optimization |
| Bert(chinese_L-12_H-768_A-12) | [code](https://github.com/google-research/bert#pre-trained-models) | 需使用参数without_data_format_optimization |
| Bert(multi_cased_L-12_H-768_A-12) | [code](https://github.com/google-research/bert#pre-trained-models) | 需使用参数without_data_format_optimization |
## Caffe
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册