提交 6469514a 编写于 作者: M mamingjie-China

update

上级 a9bb2791
......@@ -197,8 +197,8 @@ class TFOpMapperNHWC(OpMapper):
perm=perm)
def Fill(self, node):
dims = self.graph.get_node(node.layer.input[0], copy=True)
input_value = self.graph.get_node(node.layer.input[1], copy=True)
dims = self.graph.get_node(node.layer.input[0])
input_value = self.graph.get_node(node.layer.input[1])
assert input_value.layer_type == "Const", "Value of fill OP should be Const"
input_value = input_value.value
......@@ -212,7 +212,7 @@ class TFOpMapperNHWC(OpMapper):
value=input_value)
def DepthToSpace(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
input = self.graph.get_node(node.layer.input[0])
block_size = node.get_attr("block_size")
data_format = node.get_attr("data_format").decode()
......@@ -264,7 +264,7 @@ class TFOpMapperNHWC(OpMapper):
perm=[0, 2, 3, 1])
def MaxPool(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
input = self.graph.get_node(node.layer.input[0])
k_size = node.get_attr("ksize")
strides = node.get_attr("strides")
......@@ -592,7 +592,7 @@ class TFOpMapperNHWC(OpMapper):
perm=[0, 2, 3, 1])
def AvgPool(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
input = self.graph.get_node(node.layer.input[0])
k_size = node.get_attr("ksize")
strides = node.get_attr("strides")
......@@ -904,7 +904,7 @@ class TFOpMapperNHWC(OpMapper):
keep_dim=keep_dims)
def RandomUniform(self, node):
shape = self.graph.get_node(node.layer.input[0], copy=True)
shape = self.graph.get_node(node.layer.input[0])
if shape.layer_type == "Const":
shape = shape.value.tolist()
program.add_layer(
......@@ -984,8 +984,8 @@ class TFOpMapperNHWC(OpMapper):
perm=[0, 2, 3, 1])
def Tile(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
expand_times = self.graph.get_node(node.layer.input[1], copy=True)
input = self.graph.get_node(node.layer.input[0])
expand_times = self.graph.get_node(node.layer.input[1])
inputs = {"x": input.name}
attr = dict()
if expand_times.layer_type == "Const":
......@@ -1001,9 +1001,9 @@ class TFOpMapperNHWC(OpMapper):
**attr)
def Range(self, node):
start = self.graph.get_node(node.layer.input[0], copy=True)
limit = self.graph.get_node(node.layer.input[1], copy=True)
delta = self.graph.get_node(node.layer.input[2], copy=True)
start = self.graph.get_node(node.layer.input[0])
limit = self.graph.get_node(node.layer.input[1])
delta = self.graph.get_node(node.layer.input[2])
inputs = dict()
attr = dict()
......@@ -1028,8 +1028,8 @@ class TFOpMapperNHWC(OpMapper):
**attr)
def SquaredDifference(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True)
y = self.graph.get_node(node.layer.input[1], copy=True)
x = self.graph.get_node(node.layer.input[0])
y = self.graph.get_node(node.layer.input[1])
inputs = {"x": x.name, "y": y.name}
program.add_layer(
"fluid.layers.elementwise_sub", inputs=inputs, outputs=[node.name])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册