提交 5b1e074e 编写于 作者: M mamingjie-China

add tanh and squareddifference in tensorflow_nhwc

上级 20a14d29
......@@ -41,6 +41,7 @@ class TFOpMapperNHWC(OpMapper):
'Exp': ['exp'],
'Rsqrt': ['rsqrt'],
'swish_f32': ['swish'],
'Tanh': ['tanh'],
'LeakyRelu': ['leaky_relu', {
'alpha': 'alpha'
}]
......@@ -913,6 +914,12 @@ class TFOpMapperNHWC(OpMapper):
self.add_omit_nodes(kernel.layer_name, node.layer_name)
self.add_omit_nodes(out_shape.layer_name, node.layer_name)
if out_shape.layer_type == "Const":
out_shape = out_shape.value.tolist()
else:
out_shape = self.decoder.infer_shape_tensor(out_shape,
node.out_shapes[0])
in_shape = input.out_shapes[0]
if in_shape.count(-1) > 2:
in_shape = self.decoder.infer_tensor(input).shape
......@@ -920,7 +927,7 @@ class TFOpMapperNHWC(OpMapper):
if k_size.count(-1) > 2:
k_size = self.decoder.infer_tensor(kernel).shape
pad_mode = node.get_attr("padding")
pad_mode = node.get_attr("padding").decode()
strides = node.get_attr("strides")
dilations = node.get_attr("dilations")
data_format = node.get_attr("data_format").decode()
......@@ -969,6 +976,24 @@ class TFOpMapperNHWC(OpMapper):
output=node,
param_attr=attr)
if pad_mode == "SAME":
if node.tf_data_format == "NHWC":
print(out_shape)
out_shape = [out_shape[i] for i in [0, 3, 1, 2]]
print(out_shape)
for i in range(4):
if out_shape[i] < 0:
out_shape[i] = 999999
attr = {
"axes": [0, 1, 2, 3],
"starts": [0, 0, 0, 0],
"ends": out_shape
}
node.fluid_code.add_layer("slice",
inputs=node,
output=node,
param_attr=attr)
if not channel_first:
attr = {"perm": [0, 2, 3, 1]}
node.fluid_code.add_layer("transpose",
......@@ -1127,3 +1152,17 @@ class TFOpMapperNHWC(OpMapper):
inputs=None,
output=node,
param_attr=attr)
def SquaredDifference(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True)
y = self.graph.get_node(node.layer.input[1], copy=True)
inputs = {"x": x, "y": y}
node.fluid_code.add_layer("elementwise_sub",
inputs=inputs,
output=node,
param_attr=None)
inputs = {"x": node, "y": node}
node.fluid_code.add_layer("elementwise_mul",
inputs=inputs,
output=node,
param_attr=None)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册