提交 4ed14de2 编写于 作者: S SunAhong1993

add tf SplitV

上级 a7496317
......@@ -1066,6 +1066,24 @@ class TFOpMapper(OpMapper):
num_or_sections=num_split,
axis=dim)
def SplitV(self, node):
input = self.graph.get_input_node(node, 0)
size_splits = self.graph.get_input_node(node, 1)
assert size_splits.layer_type == "Const", "size_splits of SplitV OP should be Const"
size_splits = size_splits.value.tolist()
dim = self.graph.get_input_node(node, 2)
assert dim.layer_type == "Const", "dim of SplitV OP should be Const"
dim = dim.value
self.paddle_graph.add_layer(
kernel="paddle.split",
inputs={"x": input.name},
outputs=[
"{}_p{}".format(node.layer_name, i) for i in range(len(size_splits))
],
num_or_sections=size_splits,
axis=dim)
def Slice(self, node):
input = self.graph.get_input_node(node, 0)
begin = self.graph.get_input_node(node, 1)
......
......@@ -1043,6 +1043,24 @@ class TFOpMapper(OpMapper):
num_or_sections=num_split,
axis=dim)
def SplitV(self, node):
input = self.graph.get_input_node(node, 0)
size_splits = self.graph.get_input_node(node, 1)
assert size_splits.layer_type == "Const", "size_splits of SplitV OP should be Const"
size_splits = size_splits.value.tolist()
dim = self.graph.get_input_node(node, 2)
assert dim.layer_type == "Const", "dim of SplitV OP should be Const"
dim = dim.value
self.paddle_graph.add_layer(
kernel="paddle.split",
inputs={"x": input.name},
outputs=[
"{}_p{}".format(node.layer_name, i) for i in range(len(size_splits))
],
num_or_sections=size_splits,
axis=dim)
def Slice(self, node):
input = self.graph.get_input_node(node, 0)
begin = self.graph.get_input_node(node, 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册