提交 5f55fea8 编写于 作者: J jiangjiajun

add support for more ops

上级 90d2eb65
......@@ -176,7 +176,7 @@ class TFGraph(Graph):
def _remove_identity_node(self):
identity_node = list()
for node_name, node in self.node_map.items():
if node.layer_type == "Identity":
if node.layer_type == "Identity" or node.layer_type == "StopGradient":
identity_node.append(node_name)
for node_name in identity_node:
......@@ -374,3 +374,38 @@ class TFDecoder(object):
return results[0].tolist()
else:
raise Exception("Couldn't infer a stable shape shape tensor value")
def infer_tensor_shape(self, graph_node):
if hasattr(graph_node, "index"):
tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
else:
tensor_name = graph_node.layer.name + ":0"
feed = dict()
batch_size = [2, 3, 5]
shapes = list()
for b in batch_size:
for input_name, info in self.input_info.items():
(shape, dtype) = cp.deepcopy(info)
input_tensor = self.sess.graph.get_tensor_by_name(input_name +
":0")
if shape.count(-1) > 0:
shape[shape.index(-1)] = b
feed[input_tensor] = numpy.random.random_sample(shape)
output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
shape = self.sess.run([output_tensor], feed)[0].shape
shapes.append(numpy.array(shape))
compare01 = (shapes[0] == shapes[1])
compare12 = (shapes[1] == shapes[2])
if compare01.all() and compare12.all():
return shape[0].tolist()
if (compare01 == compare12).all():
index = numpy.argwhere(compare01 == False).flatten()
if index.shape[0] != 1:
raise Exception("There's not only one unstable dimension")
if index[0] != 0:
raise Exception("Batch size not in the first dimension")
shapes[0][0] = -1
return shapes[0].tolist()
......@@ -58,6 +58,7 @@ class TFOpMapper(OpMapper):
'Exp': ['exp'],
'Rsqrt': ['rsqrt'],
'swish_f32': ['swish'],
'Tanh': ['tanh'],
'LeakyRelu': ['leaky_relu', {
'alpha': 'alpha'
}]
......@@ -188,6 +189,10 @@ class TFOpMapper(OpMapper):
if y_shape[index] != x_shape[index]:
is_sub_seq = False
if not is_sub_seq:
if x_shape.count(-1) > 2:
x_shape = self.decoder.infer_tensor_shape(x_input)
if y_shape.count(-1) > 2:
y_shape = self.decoder.infer_tensor_shape(y_input)
x_expand_times = [1] * len(x_shape)
y_expand_times = [1] * len(y_shape)
x_need_expand = False
......@@ -913,6 +918,12 @@ class TFOpMapper(OpMapper):
self.add_omit_nodes(kernel.layer_name, node.layer_name)
self.add_omit_nodes(out_shape.layer_name, node.layer_name)
if out_shape.layer_type == "Const":
out_shape = out_shape.value.tolist()
else:
out_shape = self.decoder.infer_shape_tensor(out_shape,
node.out_shapes[0])
in_shape = input.out_shapes[0]
if in_shape.count(-1) > 2:
in_shape = self.decoder.infer_tensor(input).shape
......@@ -920,7 +931,7 @@ class TFOpMapper(OpMapper):
if k_size.count(-1) > 2:
k_size = self.decoder.infer_tensor(kernel).shape
pad_mode = node.get_attr("padding")
pad_mode = node.get_attr("padding").decode()
strides = node.get_attr("strides")
dilations = node.get_attr("dilations")
data_format = node.get_attr("data_format").decode()
......@@ -963,6 +974,22 @@ class TFOpMapper(OpMapper):
output=node,
param_attr=attr)
if pad_mode == "SAME":
if node.tf_data_format == "NHWC":
out_shape = [out_shape[i] for i in [0, 3, 1, 2]]
for i in range(4):
if out_shape[i] < 0:
out_shape[i] = 999999
attr = {
"axes": [0, 1, 2, 3],
"starts": [0, 0, 0, 0],
"ends": out_shape
}
node.fluid_code.add_layer("slice",
inputs=node,
output=node,
param_attr=attr)
def Max(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
......@@ -1142,3 +1169,17 @@ class TFOpMapper(OpMapper):
inputs=None,
output=node,
param_attr=attr)
def SquaredDifference(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True)
y = self.graph.get_node(node.layer.input[1], copy=True)
inputs = {"x": x, "y": y}
node.fluid_code.add_layer("elementwise_sub",
inputs=inputs,
output=node,
param_attr=None)
inputs = {"x": node, "y": node}
node.fluid_code.add_layer("elementwise_mul",
inputs=inputs,
output=node,
param_attr=None)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册