未验证 提交 b6032933 编写于 作者: J Jason 提交者: GitHub

Merge pull request #166 from PaddlePaddle/develop

pull
...@@ -312,6 +312,10 @@ class TFDecoder(object): ...@@ -312,6 +312,10 @@ class TFDecoder(object):
right_shape_been_input = False right_shape_been_input = False
while not right_shape_been_input: while not right_shape_been_input:
try:
shape = raw_input(
"Shape of Input(e.g. None,224,224,3): ")
except:
shape = input("Shape of Input(e.g. None,224,224,3): ") shape = input("Shape of Input(e.g. None,224,224,3): ")
if shape.count("None") > 1: if shape.count("None") > 1:
print("Only 1 dimension can be None, type again:)") print("Only 1 dimension can be None, type again:)")
......
...@@ -1010,7 +1010,7 @@ class TFOpMapper(OpMapper): ...@@ -1010,7 +1010,7 @@ class TFOpMapper(OpMapper):
attr = { attr = {
"bias_attr": False, "bias_attr": False,
"param_attr": string(kernel.layer_name), "param_attr": string(kernel.layer_name),
"num_filters": k_size[3], "num_filters": k_size[2],
"filter_size": k_size[0:2], "filter_size": k_size[0:2],
"stride": strides[2:4], "stride": strides[2:4],
"dilation": dilations[2:4], "dilation": dilations[2:4],
......
...@@ -1007,7 +1007,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -1007,7 +1007,7 @@ class TFOpMapperNHWC(OpMapper):
attr = { attr = {
"bias_attr": False, "bias_attr": False,
"param_attr": string(kernel.layer_name), "param_attr": string(kernel.layer_name),
"num_filters": k_size[3], "num_filters": k_size[2],
"filter_size": k_size[0:2], "filter_size": k_size[0:2],
"stride": strides[2:4], "stride": strides[2:4],
"dilation": dilations[2:4], "dilation": dilations[2:4],
......
...@@ -517,7 +517,7 @@ class TFOptimizer(object): ...@@ -517,7 +517,7 @@ class TFOptimizer(object):
l.op = 'transpose' l.op = 'transpose'
l.inputs = true_node.fluid_code.layers[3].output l.inputs = true_node.fluid_code.layers[3].output
l.param_attr = {'perm': [0, 3, 1, 2]} l.param_attr = {'perm': [0, 3, 1, 2]}
if type(l.inputs) == str: if isinstance(l.inputs, six.string_types):
l.output = l.inputs l.output = l.inputs
else: else:
l.output = l.inputs.layer_name l.output = l.inputs.layer_name
...@@ -550,7 +550,7 @@ class TFOptimizer(object): ...@@ -550,7 +550,7 @@ class TFOptimizer(object):
node = self.graph.get_node(name) node = self.graph.get_node(name)
if len(node.out_shapes[0]) == 4 and node.tf_data_format == "NHWC": if len(node.out_shapes[0]) == 4 and node.tf_data_format == "NHWC":
shape = node.fluid_code.layers[0].param_attr["shape"] shape = node.fluid_code.layers[0].param_attr["shape"]
shape = [shape[i] for i in [0, 3, 1, 2]] shape = [shape[j] for j in [0, 3, 1, 2]]
node.fluid_code.layers[0].param_attr["shape"] = shape node.fluid_code.layers[0].param_attr["shape"] = shape
node.fluid_code.layers[0].output = "nhwc_" + name node.fluid_code.layers[0].output = "nhwc_" + name
attr = {"perm": [0, 2, 3, 1]} attr = {"perm": [0, 2, 3, 1]}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册