提交 75187123 编写于 作者: S SunAhong1993

fix

上级 e3d2bca8
......@@ -191,7 +191,7 @@ class TFGraph(Graph):
return node
def get_input_node(self, node, idx=0, copy=False):
input_node_name = node.inputs[idx]
input_node_name = node.layer.input[idx]
return self.get_node(input_node_name, copy)
def remove_node(self, node_name):
......@@ -488,3 +488,96 @@ class TFDecoder(object):
return results[0].tolist()
else:
raise Exception("Couldn't infer a stable shape shape tensor value")
# def infer_tensor(self, graph_node):
# if hasattr(graph_node, "index"):
# tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
# else:
# tensor_name = graph_node.layer.name + ":0"
# feed = dict()
# for input_name, info in self.inputs_info.items():
# (shape, dtype) = cp.deepcopy(info)
# input_tensor = self.sess.graph.get_tensor_by_name(input_name + ":0")
# if shape.count(-1) > 0:
# shape[shape.index(-1)] = 2
# feed[input_tensor] = numpy.random.random_sample(shape)
# output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
# return self.sess.run([output_tensor], feed)[0]
# def infer_shape_tensor(self, graph_node, out_shape=None):
# if hasattr(graph_node, "index"):
# tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
# else:
# tensor_name = graph_node.layer.name + ":0"
# feed = dict()
# batch_size = [2, 3, 5]
# results = list()
# for b in batch_size:
# for input_name, info in self.inputs_info.items():
# (shape, dtype) = cp.deepcopy(info)
# input_tensor = self.sess.graph.get_tensor_by_name(input_name +
# ":0")
# if shape.count(-1) > 0:
# shape[shape.index(-1)] = b
# feed[input_tensor] = numpy.random.random_sample(shape)
# output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
# results.append(self.sess.run([output_tensor], feed)[0].flatten())
# compare01 = (results[0] == results[1])
# compare12 = (results[1] == results[2])
# if compare01.all() and compare12.all():
# return results[0].tolist()
# if (compare01 == compare12).all():
# index = numpy.argwhere(compare01 == False).flatten()
# if index.shape[0] != 1:
# raise Exception("There's not only one unstable dimension")
# results[0][index[0]] = -1
# index = numpy.argwhere(results[0] < 0).flatten()
# if index.shape[0] > 2:
# print("Warning: More than two dimension less than zero")
# if index.shape[0] == 2 and out_shape is not None:
# if out_shape[index[1]] > 0:
# results[0][index[1]] = out_shape[index[1]]
# else:
# results[0][index[0]] = out_shape[index[0]]
# return results[0].tolist()
# else:
# raise Exception("Couldn't infer a stable shape shape tensor value")
# def infer_tensor_shape(self, graph_node):
# if hasattr(graph_node, "index"):
# tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
# else:
# tensor_name = graph_node.layer.name + ":0"
# feed = dict()
# batch_size = [2, 3, 5]
# shapes = list()
# for b in batch_size:
# for input_name, info in self.inputs_info.items():
# (shape, dtype) = cp.deepcopy(info)
# input_tensor = self.sess.graph.get_tensor_by_name(input_name +
# ":0")
# if shape.count(-1) > 0:
# shape[shape.index(-1)] = b
# feed[input_tensor] = numpy.random.random_sample(shape)
# output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
# shape = self.sess.run([output_tensor], feed)[0].shape
# shapes.append(numpy.array(shape))
# compare01 = (shapes[0] == shapes[1])
# compare12 = (shapes[1] == shapes[2])
# if compare01.all() and compare12.all():
# return shape[0].tolist()
# if (compare01 == compare12).all():
# index = numpy.argwhere(compare01 == False).flatten()
# if index.shape[0] != 1:
# raise Exception("There's not only one unstable dimension")
# if index[0] != 0:
# raise Exception("Batch size not in the first dimension")
# shapes[0][0] = -1
# return shapes[0].tolist()
......@@ -953,7 +953,6 @@ class OpSet9():
@print_mapping_info
def Split(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
paddle_op = 'split'
split = node.get_attr('split')
axis = node.get_attr('axis', 0)
......@@ -963,11 +962,10 @@ class OpSet9():
}
outputs_list = list()
if isinstance(split, list) or isinstance(split, tuple):
for i, s in enumerate(split):
outputs_list.append("{}_p{}".format(node.name, i))
for i in range(len(split)):
outputs_list.append("{}_p{}".format(node.layer_name, i))
else:
outputs_list.append(node.name)
self.paddle_graph.add_layer(
'paddle.split',
inputs={"x": val_x.name},
......
......@@ -153,7 +153,7 @@ class TFOpMapper(OpMapper):
layer_attrs = dict()
if len(op_info) > 1:
attrs_name_map_dict = op_info[1]
for tf_attr_name, pd_attr_name in attrs_name_map_dict:
for tf_attr_name, pd_attr_name in attrs_name_map_dict.items():
layer_attrs[pd_attr_name] = node.get_attr(tf_attr_name)
if paddle_op.startswith("paddle.nn"):
op_name = paddle_op[10:].lower()
......@@ -767,11 +767,12 @@ class TFOpMapper(OpMapper):
inputs_list = list()
for i in range(len(node.inputs) - 1):
inputs_list.append(self.graph.get_input_node(node, i))
# inputs_list = [self.graph.get_node(name) for name in node.layer.input[:-1]]
axis = self.graph.get_input_node(node, -1)
assert axis.layer_type == "Const", "axis for ConcatV2 must be type Const"
axis = axis.value
if axis < 0:
axis += len(inputs[0].out_shapes[0])
axis += len(inputs_list[0].out_shapes[0])
input_names = [i.name for i in inputs_list]
self.paddle_graph.add_layer(
......@@ -846,6 +847,13 @@ class TFOpMapper(OpMapper):
new_end.append(999999)
else:
new_end.append(end[i])
if input.dtype == "bool":
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": input.name},
outputs=[input.name],
dtype=string("int32"))
self.paddle_graph.add_layer(
kernel="paddle.slice",
......@@ -854,6 +862,14 @@ class TFOpMapper(OpMapper):
axes=[i for i in range(len(new_begin))],
starts=new_begin,
ends=new_end)
if input.dtype == "bool":
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": node.name},
outputs=[node.name],
dtype=string("bool"))
if len(new_axes) > 0:
self.paddle_graph.add_layer(
kernel="paddle.unsqueeze",
......@@ -905,7 +921,7 @@ class TFOpMapper(OpMapper):
# outputs=[reshape_name],
# shape=shape)
# inputs['offsets'] = reshape_name
begin = self.decoder.infer_tensor(begin).tolist()
begin = self.decoder.infer_tensor(begin, use_diff_inputs=False).to_list()
attrs['offsets'] = begin
if size.layer_type == "Const":
size = size.value.tolist()
......
......@@ -897,7 +897,7 @@ class TFOpMapper(OpMapper):
# outputs=[reshape_name],
# shape=shape)
# inputs['offsets'] = reshape_name
begin = self.decoder.infer_tensor(begin).tolist()
begin = self.decoder.infer_tensor(begin, use_diff_inputs=False).to_list()
attrs['offsets'] = begin
if size.layer_type == "Const":
size = size.value.tolist()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册