未验证 提交 01173a3b 编写于 作者: M mamingjie-China 提交者: GitHub

Merge pull request #1 from PaddlePaddle/develop

update
...@@ -211,7 +211,10 @@ def main(): ...@@ -211,7 +211,10 @@ def main():
try: try:
import paddle import paddle
v0, v1, v2 = paddle.__version__.split('.') v0, v1, v2 = paddle.__version__.split('.')
if int(v0) != 1 or int(v1) < 6: print("paddle.__version__ = {}".format(paddle.__version__))
if v0 == '0' and v1 == '0' and v2 == '0':
print("[WARNING] You are use develop version of paddlepaddle")
elif int(v0) != 1 or int(v1) < 6:
print("[ERROR] paddlepaddle>=1.6.0 is required") print("[ERROR] paddlepaddle>=1.6.0 is required")
return return
except: except:
......
...@@ -171,6 +171,14 @@ class CaffeGraph(Graph): ...@@ -171,6 +171,14 @@ class CaffeGraph(Graph):
self.input2layers(input_layers) self.input2layers(input_layers)
self.transform_input_layers(layers, input_layers) self.transform_input_layers(layers, input_layers)
layers = input_layers + layers layers = input_layers + layers
for layer in layers:
if hasattr(layer, 'name'):
name = getattr(layer, 'name')
setattr(layer, 'name', name.replace('/', '_').replace('-', '_'))
for i, name in enumerate(layer.bottom):
layer.bottom[i] = name.replace('/', '_').replace('-', '_')
for i, name in enumerate(layer.top):
layer.top[i] = name.replace('/', '_').replace('-', '_')
top_layer = {} top_layer = {}
for layer in layers: for layer in layers:
...@@ -232,10 +240,12 @@ class CaffeDecoder(object): ...@@ -232,10 +240,12 @@ class CaffeDecoder(object):
def load_using_pb(self): def load_using_pb(self):
data = self.resolver.NetParameter() data = self.resolver.NetParameter()
data.MergeFromString(open(self.model_path, 'rb').read()) data.MergeFromString(open(self.model_path, 'rb').read())
pair = lambda layer: (layer.name, self.normalize_pb_data(layer))
layers = data.layers or data.layer layers = data.layers or data.layer
for layer in layers:
setattr(layer, 'name',
layer.name.replace('/', '_').replace('-', '_'))
pair = lambda layer: (layer.name, self.normalize_pb_data(layer))
self.params = [pair(layer) for layer in layers if layer.blobs] self.params = [pair(layer) for layer in layers if layer.blobs]
def normalize_pb_data(self, layer): def normalize_pb_data(self, layer):
...@@ -246,14 +256,13 @@ class CaffeDecoder(object): ...@@ -246,14 +256,13 @@ class CaffeDecoder(object):
if layer.type == 'PReLU': if layer.type == 'PReLU':
c_o, c_i, h, w = map(int, [1] + \ c_o, c_i, h, w = map(int, [1] + \
list(dims) + [1]* (3 - len(dims))) list(dims) + [1]* (3 - len(dims)))
elif layer.type == 'Normalize': elif layer.type == 'Normalize' and len(dims) == 4:
data = np.asarray(list(blob.data), dtype=np.float32) data = np.asarray(list(blob.data), dtype=np.float32)
transformed.append(data) transformed.append(data)
continue continue
else: else:
c_o, c_i, h, w = map(int, [1] * (4 - len(dims)) \ c_o, c_i, h, w = map(int,
+ list(dims)) [1] * (4 - len(dims)) + list(dims))
else: else:
c_o = blob.num c_o = blob.num
c_i = blob.channels c_i = blob.channels
......
...@@ -48,7 +48,10 @@ class TFGraphNode(GraphNode): ...@@ -48,7 +48,10 @@ class TFGraphNode(GraphNode):
@property @property
def out_shapes(self): def out_shapes(self):
values = self.layer.attr["_output_shapes"].list.shape if self.layer_type == "OneShotIterator":
values = self.layer.attr["output_shapes"].list.shape
else:
values = self.layer.attr["_output_shapes"].list.shape
out_shapes = list() out_shapes = list()
for value in values: for value in values:
shape = [dim.size for dim in value.dim] shape = [dim.size for dim in value.dim]
...@@ -62,6 +65,8 @@ class TFGraphNode(GraphNode): ...@@ -62,6 +65,8 @@ class TFGraphNode(GraphNode):
dtype = self.layer.attr[k].type dtype = self.layer.attr[k].type
if dtype > 0: if dtype > 0:
break break
if dtype == 0:
dtype = self.layer.attr['output_types'].list.type[0]
if dtype not in self.dtype_map: if dtype not in self.dtype_map:
raise Exception("Dtype[{}] not in dtype_map".format(dtype)) raise Exception("Dtype[{}] not in dtype_map".format(dtype))
return self.dtype_map[dtype] return self.dtype_map[dtype]
...@@ -226,7 +231,7 @@ class TFGraph(Graph): ...@@ -226,7 +231,7 @@ class TFGraph(Graph):
def _remove_identity_node(self): def _remove_identity_node(self):
identity_ops = [ identity_ops = [
'Identity', 'StopGradient', 'Switch', 'Merge', 'Identity', 'StopGradient', 'Switch', 'Merge',
'PlaceholderWithDefault' 'PlaceholderWithDefault', 'IteratorGetNext'
] ]
identity_node = list() identity_node = list()
for node_name, node in self.node_map.items(): for node_name, node in self.node_map.items():
...@@ -317,7 +322,7 @@ class TFDecoder(object): ...@@ -317,7 +322,7 @@ class TFDecoder(object):
graph_def = cp.deepcopy(graph_def) graph_def = cp.deepcopy(graph_def)
input_map = dict() input_map = dict()
for layer in graph_def.node: for layer in graph_def.node:
if layer.op != "Placeholder": if layer.op != "Placeholder" and layer.op != "OneShotIterator":
continue continue
graph_node = TFGraphNode(layer) graph_node = TFGraphNode(layer)
dtype = graph_node.layer.attr['dtype'].type dtype = graph_node.layer.attr['dtype'].type
...@@ -335,6 +340,14 @@ class TFDecoder(object): ...@@ -335,6 +340,14 @@ class TFDecoder(object):
if shape.count(-1) > 1: if shape.count(-1) > 1:
need_define_shape = 2 need_define_shape = 2
if need_define_shape == 1:
try:
shape = graph_node.out_shapes[0]
if len(shape) > 0 and shape.count(-1) < 2:
need_define_shape = 0
except:
pass
if need_define_shape > 0: if need_define_shape > 0:
shape = None shape = None
if graph_node.get_attr("shape"): if graph_node.get_attr("shape"):
......
...@@ -12,7 +12,6 @@ def detectionoutput_layer(inputs, ...@@ -12,7 +12,6 @@ def detectionoutput_layer(inputs,
share_location=True, share_location=True,
keep_top_k=100, keep_top_k=100,
confidence_threshold=0.1, confidence_threshold=0.1,
num_classes=2,
input_shape=None, input_shape=None,
name=None): name=None):
nms_param_str = nms_param nms_param_str = nms_param
...@@ -37,9 +36,9 @@ def detectionoutput_layer(inputs, ...@@ -37,9 +36,9 @@ def detectionoutput_layer(inputs,
pb = fluid.layers.reshape(x=pb, shape=[-1, 4]) pb = fluid.layers.reshape(x=pb, shape=[-1, 4])
pbv = fluid.layers.reshape(x=pbv, shape=[-1, 4]) pbv = fluid.layers.reshape(x=pbv, shape=[-1, 4])
mbox_loc = inputs[0] mbox_loc = inputs[0]
mbox_loc = fluid.layers.reshape(x=mbox_loc, shape=[0, -1, 4]) mbox_loc = fluid.layers.reshape(x=mbox_loc, shape=[-1, pb.shape[0], 4])
mbox_conf_flatten = fluid.layers.reshape(x=mbox_conf_flatten, mbox_conf_flatten = fluid.layers.reshape(x=mbox_conf_flatten,
shape=[0, -1, num_classes]) shape=[0, pb.shape[0], -1])
default = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0} default = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
fields = ['eta', 'top_k', 'nms_threshold'] fields = ['eta', 'top_k', 'nms_threshold']
......
...@@ -797,21 +797,21 @@ class CaffeOpMapper(OpMapper): ...@@ -797,21 +797,21 @@ class CaffeOpMapper(OpMapper):
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
example = self.graph.get_bottom_node(node, idx=1, copy=True) example = self.graph.get_bottom_node(node, idx=1, copy=True)
params = node.layer.crop_param params = node.layer.crop_param
axis = parmas.axis axis = params.axis
input_shape = node.input_shape[0] input_shape = node.input_shape[0]
if axis < 0: if axis < 0:
axis += len(input_shape) axis += len(input_shape)
offset_real = [0] * len(input_shape) offset_real = [0] * len(input_shape)
if hasattr(params, offset): if hasattr(params, "offset") and len(params.offset) > 0:
offset = list(params.offset) offset = list(params.offset)
assert (len(input_shape) - axis) == len( assert (len(input_shape) - axis) == len(
offset), "invalid offset[%s] in crop layer" % (str(offset)) offset), "invalid offset[%s] in crop layer" % (str(offset))
offset_real = [0] * axis + offset offset_real = [0] * axis + offset
attr = {'offsets': offset_real, 'name': string(node.layer_name)} attr = {'offsets': list(offset_real), 'name': string(node.layer_name)}
node.fluid_code.add_layer("crop", node.fluid_code.add_layer("crop",
inputs={ inputs={
'x': input, 'x': input,
'y': example 'shape': node.input_shape[1]
}, },
output=node, output=node,
param_attr=attr) param_attr=attr)
......
...@@ -293,12 +293,15 @@ def shape_reshape(layer, input_shape): ...@@ -293,12 +293,15 @@ def shape_reshape(layer, input_shape):
explicit_count *= count(l) explicit_count *= count(l)
for i in range(len(copy_axes)): for i in range(len(copy_axes)):
explicit_count *= outshape[start_axis + copy_axes[i]] explicit_count *= outshape[start_axis + copy_axes[i]]
outshape[start_axis + inferred_axis] = -1 assert input_count % explicit_count == 0, "[Reshape]botom count[%d] "\
outshape[0] = 0 "must be divisible by product of the specified dimensions[%d] "\
else: % (input_count, explicit_count)
outshape[0] = -1 outshape[start_axis + inferred_axis] = int(input_count / explicit_count)
output_count = count(outshape) output_count = count(outshape)
assert output_count == input_count, "[Reshape]output count[%d] must match input count[%d]" % (
output_count, input_count)
outshape[0] = -1
return [outshape] return [outshape]
...@@ -342,10 +345,9 @@ def shape_flatten(layer, input_shape): ...@@ -342,10 +345,9 @@ def shape_flatten(layer, input_shape):
output_shape = inshape[0:start_axis] output_shape = inshape[0:start_axis]
if len(inshape[start_axis:end_axis]) != 0: if len(inshape[start_axis:end_axis]) != 0:
flat_sz = reduce(lambda a, b: a * b, inshape[start_axis:end_axis]) flat_sz = reduce(lambda a, b: a * b, inshape[start_axis:end_axis])
flat_sz = -1
output_shape[0] = 0
output_shape += [flat_sz] output_shape += [flat_sz]
output_shape += inshape[end_axis:len(inshape)] output_shape += inshape[end_axis:len(inshape)]
output_shape[0] = -1
return [output_shape] return [output_shape]
......
...@@ -32,11 +32,12 @@ default_op_mapping = { ...@@ -32,11 +32,12 @@ default_op_mapping = {
dict(), dict(),
dict( dict(
min=(_np.asarray([255, 255, 127, 255], min=(_np.asarray([255, 255, 127, 255],
dtype=_np.uint8).view(_np.float32)), dtype=_np.uint8).view(_np.float32)[0]),
max=(_np.asarray([255, 255, 127, 127], max=(_np.asarray([255, 255, 127, 127],
dtype=_np.uint8).view(_np.float32)), dtype=_np.uint8).view(_np.float32)[0]),
) )
], ],
'Erf': ['erf', ['X'], ['Out']],
'Ceil': ['ceil', ['X'], ['Out']], 'Ceil': ['ceil', ['X'], ['Out']],
'ReduceMean': [ 'ReduceMean': [
'reduce_mean', ['X'], ['Out'], 'reduce_mean', ['X'], ['Out'],
......
...@@ -373,7 +373,6 @@ class ONNXOpMapper(OpMapper): ...@@ -373,7 +373,6 @@ class ONNXOpMapper(OpMapper):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_scales = self.graph.get_input_node(node, idx=1, copy=True) val_scales = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True) val_y = self.graph.get_node(node.layer.output[0], copy=True)
out_shape = val_y.out_shapes[0] out_shape = val_y.out_shapes[0]
if out_shape is not None: if out_shape is not None:
assert len(out_shape) == 4, 'only 4-D Tensor as X and Y supported' assert len(out_shape) == 4, 'only 4-D Tensor as X and Y supported'
...@@ -383,7 +382,6 @@ class ONNXOpMapper(OpMapper): ...@@ -383,7 +382,6 @@ class ONNXOpMapper(OpMapper):
if isinstance(val_scales, ONNXGraphNode): if isinstance(val_scales, ONNXGraphNode):
scales, _, _ = self.get_dynamic_shape(val_scales.layer_name) scales, _, _ = self.get_dynamic_shape(val_scales.layer_name)
attr = {'name': string(node.layer_name)} attr = {'name': string(node.layer_name)}
use_scales = True use_scales = True
if scales is not None: if scales is not None:
...@@ -708,8 +706,8 @@ class ONNXOpMapper(OpMapper): ...@@ -708,8 +706,8 @@ class ONNXOpMapper(OpMapper):
self.omit_nodes.append(starts.layer_name) self.omit_nodes.append(starts.layer_name)
self.omit_nodes.append(ends.layer_name) self.omit_nodes.append(ends.layer_name)
starts = _const_weight_or_none(starts) starts = _const_weight_or_none(starts).copy()
ends = _const_weight_or_none(ends) ends = _const_weight_or_none(ends).copy()
else: else:
starts = node.get_attr('starts') starts = node.get_attr('starts')
ends = node.get_attr('ends') ends = node.get_attr('ends')
......
...@@ -85,7 +85,7 @@ class TFOpMapper(OpMapper): ...@@ -85,7 +85,7 @@ class TFOpMapper(OpMapper):
not_placeholder = list() not_placeholder = list()
for name in self.graph.input_nodes: for name in self.graph.input_nodes:
if self.graph.get_node(name).layer_type != "Placeholder": if self.graph.get_node(name).layer_type != "Placeholder" and self.graph.get_node(name).layer_type != "OneShotIterator":
not_placeholder.append(name) not_placeholder.append(name)
for name in not_placeholder: for name in not_placeholder:
idx = self.graph.input_nodes.index(name) idx = self.graph.input_nodes.index(name)
...@@ -287,6 +287,9 @@ class TFOpMapper(OpMapper): ...@@ -287,6 +287,9 @@ class TFOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
def OneShotIterator(self, node):
return self.Placeholder(node)
def Const(self, node): def Const(self, node):
shape = node.out_shapes[0] shape = node.out_shapes[0]
dtype = node.dtype dtype = node.dtype
...@@ -492,6 +495,9 @@ class TFOpMapper(OpMapper): ...@@ -492,6 +495,9 @@ class TFOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
def FusedBatchNormV3(self, node):
return self.FusedBatchNorm(node)
def DepthwiseConv2dNative(self, node): def DepthwiseConv2dNative(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True) kernel = self.graph.get_node(node.layer.input[1], copy=True)
...@@ -712,7 +718,7 @@ class TFOpMapper(OpMapper): ...@@ -712,7 +718,7 @@ class TFOpMapper(OpMapper):
if input.tf_data_format == "NHWC": if input.tf_data_format == "NHWC":
if len(input.out_shapes[0]) == 4: if len(input.out_shapes[0]) == 4:
expand_times = [expand_times[i] for i in [0, 3, 1, 2]] expand_times = [expand_times[i] for i in [0, 3, 1, 2]]
elif len(input.out_shape[0]) == 3: elif len(input.out_shapes[0]) == 3:
expand_times = [expand_times[i] for i in [2, 0, 1]] expand_times = [expand_times[i] for i in [2, 0, 1]]
for i in range(len(expand_times)): for i in range(len(expand_times)):
if expand_times[i] < 0: if expand_times[i] < 0:
...@@ -812,7 +818,7 @@ class TFOpMapper(OpMapper): ...@@ -812,7 +818,7 @@ class TFOpMapper(OpMapper):
node.fluid_code.add_layer("range", node.fluid_code.add_layer("range",
inputs=inputs, inputs=inputs,
output=node, output=node,
param_attr=None) param_attr=attr)
def Mean(self, node): def Mean(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
......
...@@ -744,13 +744,12 @@ class TFOpMapperNHWC(OpMapper): ...@@ -744,13 +744,12 @@ class TFOpMapperNHWC(OpMapper):
"start": start, "start": start,
"end": limit, "end": limit,
"step": delta, "step": delta,
"dtype": string(dtype)
} }
attr = {"dtype": string(node.dtype)} attr = {"dtype": string(node.dtype)}
node.fluid_code.add_layer("range", node.fluid_code.add_layer("range",
inputs=inputs, inputs=inputs,
output=node, output=node,
param_attr=None) param_attr=attr)
def Mean(self, node): def Mean(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册