“fc374821ddb9d40daaaf443c3d78ac2d3643ce03”上不存在“paddle/fluid/lite/api/paddle_lite_factory_helper.h”
提交 aa98edf9 编写于 作者: M mamingjie-China

update

上级 6469514a
...@@ -199,17 +199,21 @@ class TFOpMapperNHWC(OpMapper): ...@@ -199,17 +199,21 @@ class TFOpMapperNHWC(OpMapper):
def Fill(self, node): def Fill(self, node):
dims = self.graph.get_node(node.layer.input[0]) dims = self.graph.get_node(node.layer.input[0])
input_value = self.graph.get_node(node.layer.input[1]) input_value = self.graph.get_node(node.layer.input[1])
inputs = dict()
attr = dict()
assert input_value.layer_type == "Const", "Value of fill OP should be Const" assert input_value.layer_type == "Const", "Value of fill OP should be Const"
if dims.layer_type == "Const":
attr["shape"] = dims.value.tolist()
else:
inputs["shape"] = dims.name
attr["dtype"] = string(input_value.dtype)
attr["value"] = input_value.value
input_value = input_value.value
input_dtype = string(input_value.dtype)
program.add_layer( program.add_layer(
"fluid.layers.fill_constant", "fluid.layers.fill_constant",
inputs={}, inputs=inputs,
outputs=[node.name], outputs=[node.name],
shape=dims, **attr)
dtype=string(input_dtype),
value=input_value)
def DepthToSpace(self, node): def DepthToSpace(self, node):
input = self.graph.get_node(node.layer.input[0]) input = self.graph.get_node(node.layer.input[0])
...@@ -251,8 +255,8 @@ class TFOpMapperNHWC(OpMapper): ...@@ -251,8 +255,8 @@ class TFOpMapperNHWC(OpMapper):
shape=[0, c, h, w]) shape=[0, c, h, w])
program.add_layer( program.add_layer(
kernel="fluid.layers.pixed_shuffle", kernel="fluid.layers.pixel_shuffle",
inputs={"input": reshape_name}, inputs={"x": reshape_name},
outputs=[node.name], outputs=[node.name],
upscale_factor=block_size) upscale_factor=block_size)
...@@ -309,9 +313,9 @@ class TFOpMapperNHWC(OpMapper): ...@@ -309,9 +313,9 @@ class TFOpMapperNHWC(OpMapper):
data_format = node.get_attr("data_format").decode() data_format = node.get_attr("data_format").decode()
pad_mode = node.get_attr("padding").decode() pad_mode = node.get_attr("padding").decode()
if data_format == "NHWC": if data_format == "NHWC":
n, c, h, w = input.out_shapes[0]
else:
n, h, w, c = input.out_shapes[0] n, h, w, c = input.out_shapes[0]
else:
n, c, h, w = input.out_shapes[0]
if kernel.layer_type == 'Const': if kernel.layer_type == 'Const':
kernel_value = kernel.value kernel_value = kernel.value
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册