提交 32289c74 编写于 作者: J jiangjiajun

add resize op

上级 405a2f18
...@@ -17,6 +17,7 @@ from x2paddle.core.op_mapper import OpMapper ...@@ -17,6 +17,7 @@ from x2paddle.core.op_mapper import OpMapper
from x2paddle.core.util import * from x2paddle.core.util import *
import inspect import inspect
import numpy import numpy
import sys
# compute padding size for SAME mode # compute padding size for SAME mode
...@@ -83,18 +84,31 @@ class TFOpMapper(OpMapper): ...@@ -83,18 +84,31 @@ class TFOpMapper(OpMapper):
del self.graph.input_nodes[idx] del self.graph.input_nodes[idx]
print("Total nodes: {}".format(len(self.graph.topo_sort))) print("Total nodes: {}".format(len(self.graph.topo_sort)))
unsupported_ops = set()
for node_name in self.graph.topo_sort: for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
op = node.layer_type op = node.layer_type
if op in self.directly_map_ops: if op in self.directly_map_ops:
if len(unsupported_ops) > 0:
continue
self.directly_map(node) self.directly_map(node)
elif op in self.elementwise_ops: elif op in self.elementwise_ops:
if len(unsupported_ops) > 0:
continue
self.elementwise_map(node) self.elementwise_map(node)
elif hasattr(self, op): elif hasattr(self, op):
if len(unsupported_ops) > 0:
continue
func = getattr(self, op) func = getattr(self, op)
func(node) func(node)
else: else:
raise Exception("OP: [{}] not support yet".format(op)) unsupported_ops.add(op)
if len(unsupported_ops) > 0:
print("=========={} Ops are not supported yet======".format(
len(unsupported_ops)))
for op in unsupported_ops:
print("========== {} ==========".format(op))
sys.exit(-1)
def directly_map(self, node): def directly_map(self, node):
assert node.layer_type in self.directly_map_ops assert node.layer_type in self.directly_map_ops
...@@ -773,7 +787,15 @@ class TFOpMapper(OpMapper): ...@@ -773,7 +787,15 @@ class TFOpMapper(OpMapper):
begin = [begin[i] for i in [0, 3, 1, 2]] begin = [begin[i] for i in [0, 3, 1, 2]]
end = [end[i] for i in [0, 3, 1, 2]] end = [end[i] for i in [0, 3, 1, 2]]
attr = {"axes": range(len(strides)), "starts": begin, "ends": end} for i in range(len(end)):
if end[i] == 0:
end[i] = 999999
attr = {
"axes": [i for i in range(len(strides))],
"starts": begin,
"ends": end
}
node.fluid_code.add_layer("slice", node.fluid_code.add_layer("slice",
inputs=input, inputs=input,
output=node, output=node,
...@@ -955,3 +977,37 @@ class TFOpMapper(OpMapper): ...@@ -955,3 +977,37 @@ class TFOpMapper(OpMapper):
inputs=input, inputs=input,
output=node, output=node,
param_attr=attr) param_attr=attr)
def ResizeNearestNeighbor(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
self.omit_nodes.append(resize_shape.layer_name)
if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist()
else:
resize_shape = self.decoder.infer_shape_tensor(resize_shape)
align_corners = node.get_attr("align_corners")
attr = {"align_corners": align_corners, "out_shape": resize_shape}
node.fluid_code.add_layer("resize_nearest",
inputs=input,
output=node,
param_attr=attr)
def ResizeBilinear(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
self.omit_nodes.append(resize_shape.layer_name)
if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist()
else:
resize_shape = self.decoder.infer_shape_tensor(resize_shape)
align_corners = node.get_attr("align_corners")
attr = {
"align_corners": align_corners,
"out_shape": resize_shape,
"align_mode": 1
}
node.fluid_code.add_layer("resize_bilinear",
inputs=input,
output=node,
param_attr=attr)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册