提交 c4714f3f 编写于 作者: M mamingjie-China

add Tile and Range

上级 e5f64f3e
......@@ -308,6 +308,10 @@ class TFOpMapperNHWC(OpMapper):
dilations = node.get_attr("dilations")
data_format = node.get_attr("data_format").decode()
pad_mode = node.get_attr("padding").decode()
if data_format == "NHWC":
n, c, h, w = input.out_shapes[0]
else:
n, h, w, c = input.out_shapes[0]
if kernel.layer_type == 'Const':
kernel_value = kernel.value
......@@ -334,6 +338,16 @@ class TFOpMapperNHWC(OpMapper):
perm=[0, 3, 1, 2])
input_name = transpose_name
if c == -1:
attr = {"shape": [0, k_size[2], 0, 0]}
node.fluid_code.add_layer(
"reshape", inputs=input, output=input, param_attr=attr)
program.add_layer(
kernel="fluid.layers.reshape",
inputs={"x": input_name},
outputs=[input_name],
shape=[0, k_size[2], 0, 0])
program.add_layer(
kernel="fluid.layers.conv2d",
inputs={"input": input_name},
......@@ -701,7 +715,7 @@ class TFOpMapperNHWC(OpMapper):
if len(new_axes) > 0:
program.add_layer(
kernel="fluid.layers.unsqueeze",
inputs={"x": node.name},
inputs={"input": node.name},
outputs=[node.name],
axes=new_axes)
if len(shrink_axes) > 0:
......@@ -710,7 +724,7 @@ class TFOpMapperNHWC(OpMapper):
else:
program.add_layer(
kernel="fluid.layers.unsqueeze",
inputs={"x": node.name},
inputs={"input": node.name},
outputs=[node.name],
axes=new_axes)
......@@ -741,14 +755,16 @@ class TFOpMapperNHWC(OpMapper):
begin = begin.value.tolist()
attrs['offsets'] = begin
else:
shape = begin.out_shapes[0]
reshape_name = gen_name("slice", "reshape")
program.add_layer(
kernel="fluid.layers.reshape",
inputs={"x": begin.name},
outputs=[reshape_name],
shape=shape)
inputs['offsets'] = reshape_name
# shape = begin.out_shapes[0]
# reshape_name = gen_name("slice", "reshape")
# program.add_layer(
# kernel="fluid.layers.reshape",
# inputs={"x": begin.name},
# outputs=[reshape_name],
# shape=shape)
# inputs['offsets'] = reshape_name
begin = self.decoder.infer_tensor(begin).tolist()
attrs['offsets'] = begin
if size.layer_type == "Const":
size = size.value.tolist()
attrs['shape'] = size
......@@ -966,3 +982,47 @@ class TFOpMapperNHWC(OpMapper):
inputs={"x": node.name},
outputs=[node.name],
perm=[0, 2, 3, 1])
def Tile(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
expand_times = self.graph.get_node(node.layer.input[1], copy=True)
inputs = {"x": input.name}
attr = dict()
if expand_times.layer_type == "Const":
expand_times = expand_times.value.tolist()
attr["expand_times"] = expand_times
else:
inputs["expand_times"] = expand_times.name
program.add_layer(
kernel="fluid.layers.expand",
inputs=inputs,
outputs=[node.name],
**attr)
def Range(self, node):
start = self.graph.get_node(node.layer.input[0], copy=True)
limit = self.graph.get_node(node.layer.input[1], copy=True)
delta = self.graph.get_node(node.layer.input[2], copy=True)
inputs = dict()
attr = dict()
if start.layer_type == "Const":
attr["start"] = start.value
else:
inputs["start"] = start.name
if limit.layer_type == "Const":
attr["end"] = limit.value
else:
inputs["end"] = limit.name
if delta.layer_type == "Const":
attr["step"] = delta.value
else:
inputs["step"] = delta.name
attr["dtype"] = string(node.dtype)
program.add_layer(
kernel="fluid.layers.range",
inputs=inputs,
outputs=[node.name],
**attr)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册