未验证 提交 e9f4c95b 编写于 作者: S SunAhong1993 提交者: GitHub

Merge pull request #1 from PaddlePaddle/develop

00
......@@ -90,11 +90,13 @@ def tf2paddle(model_path,
version = tf.__version__
if version >= '2.0.0' or version < '1.0.0':
print(
"1.0.0<=tensorflow<2.0.0 is required, and v1.14.0 is recommended"
"[ERROR] 1.0.0<=tensorflow<2.0.0 is required, and v1.14.0 is recommended"
)
return
except:
print("Tensorflow is not installed, use \"pip install tensorflow\".")
print(
"[ERROR] Tensorflow is not installed, use \"pip install tensorflow\"."
)
return
from x2paddle.decoder.tf_decoder import TFDecoder
......@@ -140,7 +142,7 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False):
if (int(ver_part[0]) == 3 and int(ver_part[1]) >= 6) \
or (int(ver_part[0]) > 3):
version_satisfy = True
assert version_satisfy, 'google.protobuf >= 3.6.0 is required'
assert version_satisfy, '[ERROR] google.protobuf >= 3.6.0 is required'
print("Now translating model from caffe to paddle.")
model = CaffeDecoder(proto, weight, caffe_proto)
mapper = CaffeOpMapper(model)
......@@ -156,10 +158,10 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
import onnx
version = onnx.version.version
if version != '1.6.0':
print("onnx==1.6.0 is required")
print("[ERROR] onnx==1.6.0 is required")
return
except:
print("onnx is not installed, use \"pip install onnx==1.6.0\".")
print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".")
return
print("Now translating model from onnx to paddle.")
......@@ -199,21 +201,26 @@ def main():
import onnxruntime as rt
version = rt.__version__
if version != '1.0.0':
print("onnxruntime==1.0.0 is required")
print("[ERROR] onnxruntime==1.0.0 is required")
return
except:
print(
"onnxruntime is not installed, use \"pip install onnxruntime==1.0.0\"."
"[ERROR] onnxruntime is not installed, use \"pip install onnxruntime==1.0.0\"."
)
try:
import paddle
v0, v1, v2 = paddle.__version__.split('.')
if int(v0) != 1 or int(v1) < 6:
print("paddlepaddle>=1.6.0 is required")
print("paddle.__version__ = {}".format(paddle.__version__))
if v0 == '0' and v1 == '0' and v2 == '0':
print("[WARNING] You are use develop version of paddlepaddle")
elif int(v0) != 1 or int(v1) < 6:
print("[ERROR] paddlepaddle>=1.6.0 is required")
return
except:
print("paddlepaddle not installed, use \"pip install paddlepaddle\"")
print(
"[ERROR] paddlepaddle not installed, use \"pip install paddlepaddle\""
)
if args.framework == "tensorflow":
assert args.model is not None, "--model should be defined while translating tensorflow model"
......
......@@ -36,6 +36,8 @@ class Layer(object):
if self.is_custom_layer:
layer_code = layer_code + self.op + "("
elif self.op == "=":
layer_code = layer_code
else:
layer_code = layer_code + "fluid.layers." + self.op + "("
......@@ -70,11 +72,15 @@ class Layer(object):
elif isinstance(self.inputs, GraphNode):
if hasattr(self.inputs, "index"):
layer_code += (self.inputs.layer_name +
"[{}]".format(self.inputs.index) + ", ")
"[{}]".format(self.inputs.index))
else:
layer_code += (self.inputs.layer_name + ", ")
layer_code += (self.inputs.layer_name)
if self.op != "=":
layer_code += ", "
elif isinstance(self.inputs, six.string_types):
layer_code += (self.inputs + ", ")
layer_code += (self.inputs)
if self.op != "=":
layer_code += ", "
else:
raise Exception("Unknown type of inputs.")
......@@ -85,7 +91,9 @@ class Layer(object):
layer_code = layer_code + key + "={}, ".format(value)
layer_code = layer_code.strip(", ")
return layer_code + ")"
if self.op != "=":
layer_code += ")"
return layer_code
class FluidCode(object):
......
......@@ -171,6 +171,14 @@ class CaffeGraph(Graph):
self.input2layers(input_layers)
self.transform_input_layers(layers, input_layers)
layers = input_layers + layers
for layer in layers:
if hasattr(layer, 'name'):
name = getattr(layer, 'name')
setattr(layer, 'name', name.replace('/', '_').replace('-', '_'))
for i, name in enumerate(layer.bottom):
layer.bottom[i] = name.replace('/', '_').replace('-', '_')
for i, name in enumerate(layer.top):
layer.top[i] = name.replace('/', '_').replace('-', '_')
top_layer = {}
for layer in layers:
......@@ -232,10 +240,12 @@ class CaffeDecoder(object):
def load_using_pb(self):
data = self.resolver.NetParameter()
data.MergeFromString(open(self.model_path, 'rb').read())
pair = lambda layer: (layer.name, self.normalize_pb_data(layer))
layers = data.layers or data.layer
for layer in layers:
setattr(layer, 'name',
layer.name.replace('/', '_').replace('-', '_'))
pair = lambda layer: (layer.name, self.normalize_pb_data(layer))
self.params = [pair(layer) for layer in layers if layer.blobs]
def normalize_pb_data(self, layer):
......@@ -246,14 +256,13 @@ class CaffeDecoder(object):
if layer.type == 'PReLU':
c_o, c_i, h, w = map(int, [1] + \
list(dims) + [1]* (3 - len(dims)))
elif layer.type == 'Normalize':
elif layer.type == 'Normalize' and len(dims) == 4:
data = np.asarray(list(blob.data), dtype=np.float32)
transformed.append(data)
continue
else:
c_o, c_i, h, w = map(int, [1] * (4 - len(dims)) \
+ list(dims))
c_o, c_i, h, w = map(int,
[1] * (4 - len(dims)) + list(dims))
else:
c_o = blob.num
c_i = blob.channels
......
......@@ -48,6 +48,9 @@ class TFGraphNode(GraphNode):
@property
def out_shapes(self):
if self.layer_type == "OneShotIterator":
values = self.layer.attr["output_shapes"].list.shape
else:
values = self.layer.attr["_output_shapes"].list.shape
out_shapes = list()
for value in values:
......@@ -62,6 +65,8 @@ class TFGraphNode(GraphNode):
dtype = self.layer.attr[k].type
if dtype > 0:
break
if dtype == 0:
dtype = self.layer.attr['output_types'].list.type[0]
if dtype not in self.dtype_map:
raise Exception("Dtype[{}] not in dtype_map".format(dtype))
return self.dtype_map[dtype]
......@@ -136,6 +141,7 @@ class TFGraph(Graph):
# tensorflow graph optimize
self._remove_isolated_node()
self._optimize_dialiation_conv()
self._remove_identity_node()
self._remove_cast_node()
......@@ -175,6 +181,34 @@ class TFGraph(Graph):
idx = self.topo_sort.index(node_name)
del self.topo_sort[idx]
def _optimize_dialiation_conv(self):
for name in list(self.node_map.keys()):
node = self.node_map[name]
if node.layer_type == "SpaceToBatchND":
is_dilation = True
out_node0 = self.node_map[node.outputs[0]]
if out_node0.layer_type != 'ExpandDims':
is_dilation = False
continue
out_node1 = self.node_map[out_node0.outputs[0]]
if out_node1.layer_type != 'Conv2D':
is_dilation = False
continue
out_node2 = self.node_map[out_node1.outputs[0]]
if out_node2.layer_type != 'Squeeze':
is_dilation = False
continue
out_node3 = self.node_map[out_node2.outputs[0]]
if out_node3.layer_type != 'BatchToSpaceND':
is_dilation = False
continue
if is_dilation:
node.skip = True
out_node3.skip = True
block_shape = self.node_map[node.inputs[1]]
out_node1.dilation = block_shape.value.tolist()
def _remove_isolated_node(self):
# delete isolated nodes
isolated_nodes = list()
......@@ -197,7 +231,7 @@ class TFGraph(Graph):
def _remove_identity_node(self):
identity_ops = [
'Identity', 'StopGradient', 'Switch', 'Merge',
'PlaceholderWithDefault'
'PlaceholderWithDefault', 'IteratorGetNext'
]
identity_node = list()
for node_name, node in self.node_map.items():
......@@ -288,7 +322,7 @@ class TFDecoder(object):
graph_def = cp.deepcopy(graph_def)
input_map = dict()
for layer in graph_def.node:
if layer.op != "Placeholder":
if layer.op != "Placeholder" and layer.op != "OneShotIterator":
continue
graph_node = TFGraphNode(layer)
dtype = graph_node.layer.attr['dtype'].type
......@@ -306,6 +340,14 @@ class TFDecoder(object):
if shape.count(-1) > 1:
need_define_shape = 2
if need_define_shape == 1:
try:
shape = graph_node.out_shapes[0]
if len(shape) > 0 and shape.count(-1) < 2:
need_define_shape = 0
except:
pass
if need_define_shape > 0:
shape = None
if graph_node.get_attr("shape"):
......
......@@ -12,7 +12,6 @@ def detectionoutput_layer(inputs,
share_location=True,
keep_top_k=100,
confidence_threshold=0.1,
num_classes=2,
input_shape=None,
name=None):
nms_param_str = nms_param
......@@ -37,9 +36,9 @@ def detectionoutput_layer(inputs,
pb = fluid.layers.reshape(x=pb, shape=[-1, 4])
pbv = fluid.layers.reshape(x=pbv, shape=[-1, 4])
mbox_loc = inputs[0]
mbox_loc = fluid.layers.reshape(x=mbox_loc, shape=[0, -1, 4])
mbox_loc = fluid.layers.reshape(x=mbox_loc, shape=[-1, pb.shape[0], 4])
mbox_conf_flatten = fluid.layers.reshape(x=mbox_conf_flatten,
shape=[0, -1, num_classes])
shape=[0, pb.shape[0], -1])
default = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
fields = ['eta', 'top_k', 'nms_threshold']
......
......@@ -467,7 +467,7 @@ class CaffeOpMapper(OpMapper):
def Concat(self, node):
assert len(
node.inputs
) > 1, 'The count of Concat node\'s input is not more than 1.'
) >= 1, 'The count of Concat node\'s input is not more than 1.'
inputs = []
for i in range(len(node.inputs)):
input = self.graph.get_bottom_node(node, idx=i, copy=True)
......@@ -797,21 +797,21 @@ class CaffeOpMapper(OpMapper):
input = self.graph.get_bottom_node(node, idx=0, copy=True)
example = self.graph.get_bottom_node(node, idx=1, copy=True)
params = node.layer.crop_param
axis = parmas.axis
axis = params.axis
input_shape = node.input_shape[0]
if axis < 0:
axis += len(input_shape)
offset_real = [0] * len(input_shape)
if hasattr(params, offset):
if hasattr(params, "offset") and len(params.offset) > 0:
offset = list(params.offset)
assert (len(input_shape) - axis) == len(
offset), "invalid offset[%s] in crop layer" % (str(offset))
offset_real = [0] * axis + offset
attr = {'offsets': offset_real, 'name': string(node.layer_name)}
attr = {'offsets': list(offset_real), 'name': string(node.layer_name)}
node.fluid_code.add_layer("crop",
inputs={
'x': input,
'y': example
'shape': node.input_shape[1]
},
output=node,
param_attr=attr)
......
......@@ -293,12 +293,15 @@ def shape_reshape(layer, input_shape):
explicit_count *= count(l)
for i in range(len(copy_axes)):
explicit_count *= outshape[start_axis + copy_axes[i]]
outshape[start_axis + inferred_axis] = -1
outshape[0] = 0
else:
outshape[0] = -1
assert input_count % explicit_count == 0, "[Reshape]botom count[%d] "\
"must be divisible by product of the specified dimensions[%d] "\
% (input_count, explicit_count)
outshape[start_axis + inferred_axis] = int(input_count / explicit_count)
output_count = count(outshape)
assert output_count == input_count, "[Reshape]output count[%d] must match input count[%d]" % (
output_count, input_count)
outshape[0] = -1
return [outshape]
......@@ -342,10 +345,9 @@ def shape_flatten(layer, input_shape):
output_shape = inshape[0:start_axis]
if len(inshape[start_axis:end_axis]) != 0:
flat_sz = reduce(lambda a, b: a * b, inshape[start_axis:end_axis])
flat_sz = -1
output_shape[0] = 0
output_shape += [flat_sz]
output_shape += inshape[end_axis:len(inshape)]
output_shape[0] = -1
return [output_shape]
......
......@@ -32,11 +32,12 @@ default_op_mapping = {
dict(),
dict(
min=(_np.asarray([255, 255, 127, 255],
dtype=_np.uint8).view(_np.float32)),
dtype=_np.uint8).view(_np.float32)[0]),
max=(_np.asarray([255, 255, 127, 127],
dtype=_np.uint8).view(_np.float32)),
dtype=_np.uint8).view(_np.float32)[0]),
)
],
'Erf': ['erf', ['X'], ['Out']],
'Ceil': ['ceil', ['X'], ['Out']],
'ReduceMean': [
'reduce_mean', ['X'], ['Out'],
......
......@@ -373,7 +373,6 @@ class ONNXOpMapper(OpMapper):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_scales = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
out_shape = val_y.out_shapes[0]
if out_shape is not None:
assert len(out_shape) == 4, 'only 4-D Tensor as X and Y supported'
......@@ -383,7 +382,6 @@ class ONNXOpMapper(OpMapper):
if isinstance(val_scales, ONNXGraphNode):
scales, _, _ = self.get_dynamic_shape(val_scales.layer_name)
attr = {'name': string(node.layer_name)}
use_scales = True
if scales is not None:
......@@ -708,8 +706,8 @@ class ONNXOpMapper(OpMapper):
self.omit_nodes.append(starts.layer_name)
self.omit_nodes.append(ends.layer_name)
starts = _const_weight_or_none(starts)
ends = _const_weight_or_none(ends)
starts = _const_weight_or_none(starts).copy()
ends = _const_weight_or_none(ends).copy()
else:
starts = node.get_attr('starts')
ends = node.get_attr('ends')
......
......@@ -85,7 +85,7 @@ class TFOpMapper(OpMapper):
not_placeholder = list()
for name in self.graph.input_nodes:
if self.graph.get_node(name).layer_type != "Placeholder":
if self.graph.get_node(name).layer_type != "Placeholder" and self.graph.get_node(name).layer_type != "OneShotIterator":
not_placeholder.append(name)
for name in not_placeholder:
idx = self.graph.input_nodes.index(name)
......@@ -287,6 +287,9 @@ class TFOpMapper(OpMapper):
output=node,
param_attr=attr)
def OneShotIterator(self, node):
return self.Placeholder(node)
def Const(self, node):
shape = node.out_shapes[0]
dtype = node.dtype
......@@ -492,6 +495,9 @@ class TFOpMapper(OpMapper):
output=node,
param_attr=attr)
def FusedBatchNormV3(self, node):
return self.FusedBatchNorm(node)
def DepthwiseConv2dNative(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True)
......@@ -712,7 +718,7 @@ class TFOpMapper(OpMapper):
if input.tf_data_format == "NHWC":
if len(input.out_shapes[0]) == 4:
expand_times = [expand_times[i] for i in [0, 3, 1, 2]]
elif len(input.out_shape[0]) == 3:
elif len(input.out_shapes[0]) == 3:
expand_times = [expand_times[i] for i in [2, 0, 1]]
for i in range(len(expand_times)):
if expand_times[i] < 0:
......@@ -812,7 +818,7 @@ class TFOpMapper(OpMapper):
node.fluid_code.add_layer("range",
inputs=inputs,
output=node,
param_attr=None)
param_attr=attr)
def Mean(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
......
......@@ -40,6 +40,7 @@ class TFOpMapperNHWC(OpMapper):
'Sigmoid': ['sigmoid'],
'Exp': ['exp'],
'Rsqrt': ['rsqrt'],
'Sqrt': ['sqrt'],
'swish_f32': ['swish'],
'Tanh': ['tanh'],
'LeakyRelu': ['leaky_relu', {
......@@ -48,6 +49,7 @@ class TFOpMapperNHWC(OpMapper):
}
elementwise_ops = {
'Add': 'elementwise_add',
'AddV2': 'elementwise_add',
'RealDiv': 'elementwise_div',
'Sub': 'elementwise_sub',
'Maximum': 'elementwise_max',
......@@ -90,10 +92,12 @@ class TFOpMapperNHWC(OpMapper):
if len(unsupported_ops) > 0:
continue
func = getattr(self, op)
try:
func(node)
except:
unsupported_ops.add(op)
else:
unsupported_ops.add(op)
continue
if len(unsupported_ops) > 0:
print("========= {} OPs are not supported yet ===========".format(
len(unsupported_ops)))
......@@ -342,7 +346,6 @@ class TFOpMapperNHWC(OpMapper):
def Conv2D(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True)
assert kernel.layer_type == "Const", "Kernel of Conv2D should be Const"
self.add_omit_nodes(kernel.layer_name, node.layer_name)
in_shape = input.out_shapes[0]
......@@ -358,8 +361,12 @@ class TFOpMapperNHWC(OpMapper):
pad_mode = node.get_attr("padding").decode()
channel_first = data_format == "NCHW"
if kernel.layer_type == 'Const':
kernel_value = kernel.value
else:
kernel_value = self.decoder.infer_tensor(kernel)
self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
kernel.value, (3, 2, 0, 1))
kernel_value, (3, 2, 0, 1))
if not channel_first:
in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
......@@ -381,6 +388,11 @@ class TFOpMapperNHWC(OpMapper):
"dilation": dilations[2:4],
"padding": string(pad_mode)
}
if hasattr(node, 'dilation') and attr['dilation'] == [1, 1]:
if len(node.dilation) == 1:
attr['dilation'] = [1, node.dilation[0]]
node.fluid_code.add_layer("conv2d",
inputs=input,
output=node,
......@@ -732,13 +744,12 @@ class TFOpMapperNHWC(OpMapper):
"start": start,
"end": limit,
"step": delta,
"dtype": string(dtype)
}
attr = {"dtype": string(node.dtype)}
node.fluid_code.add_layer("range",
inputs=inputs,
output=node,
param_attr=None)
param_attr=attr)
def Mean(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
......@@ -1135,3 +1146,39 @@ class TFOpMapperNHWC(OpMapper):
inputs=inputs,
output=node,
param_attr=None)
def ExpandDims(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True)
y = self.graph.get_node(node.layer.input[1], copy=True)
if y.layer_type == 'Const':
dim = y.value.tolist()
else:
dim = self.decoder.infer_tensor(y)
self.add_omit_nodes(y.layer_name, node.layer_name)
attr = {'axes': [dim]}
node.fluid_code.add_layer("unsqueeze",
inputs=x,
output=node,
param_attr=attr)
def BatchToSpaceND(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True)
y = self.graph.get_node(node.layer.input[1], copy=True)
if hasattr(node, 'skip') and node.skip:
node.fluid_code.add_layer("=",
inputs=x,
output=node,
param_attr=None)
else:
raise Exception("BatchToSpaceND is not supported")
def SpaceToBatchND(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True)
y = self.graph.get_node(node.layer.input[1], copy=True)
if hasattr(node, 'skip') and node.skip:
node.fluid_code.add_layer("=",
inputs=x,
output=node,
param_attr=None)
else:
raise Exception("SpaceToBatchND is not supported")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册