提交 6469514a 编写于 作者: M mamingjie-China

update

上级 a9bb2791
...@@ -197,8 +197,8 @@ class TFOpMapperNHWC(OpMapper): ...@@ -197,8 +197,8 @@ class TFOpMapperNHWC(OpMapper):
perm=perm) perm=perm)
def Fill(self, node): def Fill(self, node):
dims = self.graph.get_node(node.layer.input[0], copy=True) dims = self.graph.get_node(node.layer.input[0])
input_value = self.graph.get_node(node.layer.input[1], copy=True) input_value = self.graph.get_node(node.layer.input[1])
assert input_value.layer_type == "Const", "Value of fill OP should be Const" assert input_value.layer_type == "Const", "Value of fill OP should be Const"
input_value = input_value.value input_value = input_value.value
...@@ -212,7 +212,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -212,7 +212,7 @@ class TFOpMapperNHWC(OpMapper):
value=input_value) value=input_value)
def DepthToSpace(self, node): def DepthToSpace(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0])
block_size = node.get_attr("block_size") block_size = node.get_attr("block_size")
data_format = node.get_attr("data_format").decode() data_format = node.get_attr("data_format").decode()
...@@ -264,7 +264,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -264,7 +264,7 @@ class TFOpMapperNHWC(OpMapper):
perm=[0, 2, 3, 1]) perm=[0, 2, 3, 1])
def MaxPool(self, node): def MaxPool(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0])
k_size = node.get_attr("ksize") k_size = node.get_attr("ksize")
strides = node.get_attr("strides") strides = node.get_attr("strides")
...@@ -592,7 +592,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -592,7 +592,7 @@ class TFOpMapperNHWC(OpMapper):
perm=[0, 2, 3, 1]) perm=[0, 2, 3, 1])
def AvgPool(self, node): def AvgPool(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0])
k_size = node.get_attr("ksize") k_size = node.get_attr("ksize")
strides = node.get_attr("strides") strides = node.get_attr("strides")
...@@ -904,7 +904,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -904,7 +904,7 @@ class TFOpMapperNHWC(OpMapper):
keep_dim=keep_dims) keep_dim=keep_dims)
def RandomUniform(self, node): def RandomUniform(self, node):
shape = self.graph.get_node(node.layer.input[0], copy=True) shape = self.graph.get_node(node.layer.input[0])
if shape.layer_type == "Const": if shape.layer_type == "Const":
shape = shape.value.tolist() shape = shape.value.tolist()
program.add_layer( program.add_layer(
...@@ -984,8 +984,8 @@ class TFOpMapperNHWC(OpMapper): ...@@ -984,8 +984,8 @@ class TFOpMapperNHWC(OpMapper):
perm=[0, 2, 3, 1]) perm=[0, 2, 3, 1])
def Tile(self, node): def Tile(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0])
expand_times = self.graph.get_node(node.layer.input[1], copy=True) expand_times = self.graph.get_node(node.layer.input[1])
inputs = {"x": input.name} inputs = {"x": input.name}
attr = dict() attr = dict()
if expand_times.layer_type == "Const": if expand_times.layer_type == "Const":
...@@ -1001,9 +1001,9 @@ class TFOpMapperNHWC(OpMapper): ...@@ -1001,9 +1001,9 @@ class TFOpMapperNHWC(OpMapper):
**attr) **attr)
def Range(self, node): def Range(self, node):
start = self.graph.get_node(node.layer.input[0], copy=True) start = self.graph.get_node(node.layer.input[0])
limit = self.graph.get_node(node.layer.input[1], copy=True) limit = self.graph.get_node(node.layer.input[1])
delta = self.graph.get_node(node.layer.input[2], copy=True) delta = self.graph.get_node(node.layer.input[2])
inputs = dict() inputs = dict()
attr = dict() attr = dict()
...@@ -1028,8 +1028,8 @@ class TFOpMapperNHWC(OpMapper): ...@@ -1028,8 +1028,8 @@ class TFOpMapperNHWC(OpMapper):
**attr) **attr)
def SquaredDifference(self, node): def SquaredDifference(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True) x = self.graph.get_node(node.layer.input[0])
y = self.graph.get_node(node.layer.input[1], copy=True) y = self.graph.get_node(node.layer.input[1])
inputs = {"x": x.name, "y": y.name} inputs = {"x": x.name, "y": y.name}
program.add_layer( program.add_layer(
"fluid.layers.elementwise_sub", inputs=inputs, outputs=[node.name]) "fluid.layers.elementwise_sub", inputs=inputs, outputs=[node.name])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册