提交 3c80edf9 编写于 作者: S SunAhong1993

fix the gather

上级 e8df9aec
......@@ -76,6 +76,7 @@ class PaddleGraph(object):
self.source_type = source_type
self.custom_code = None
self.inputs_info = None
self.can_dygraph2static = True
def set_name(self, name):
......@@ -166,6 +167,8 @@ class PaddleGraph(object):
self.clear_edges()
outputs_from_nodes = dict()
for layer_id, layer in self.layers.items():
if layer.kernel == "custom_layer:Gather":
self.can_dygraph2static = False
for input_key, input_var in layer.inputs.items():
vs = input_var
if not isinstance(vs, list):
......@@ -283,7 +286,7 @@ class PaddleGraph(object):
self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir)
# 动转静
if len(self.inputs_info) > 0:
if len(self.inputs_info) > 0 and self.can_dygraph2static:
input_shapes = list()
input_types = list()
for input_name in self.inputs:
......
......@@ -21,14 +21,14 @@ import numpy as np
class Decoder(object):
def _optimize_graph(self, graph):
torch._C._jit_pass_constant_propagation(graph)
torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_peephole(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_canonicalize(graph)
torch._C._jit_pass_lint(graph)
# torch._C._jit_pass_dce(graph)
# torch._C._jit_pass_lint(graph)
# torch._C._jit_pass_peephole(graph)
# torch._C._jit_pass_lint(graph)
# torch._C._jit_pass_dce(graph)
# torch._C._jit_pass_lint(graph)
# torch._C._jit_pass_canonicalize(graph)
# torch._C._jit_pass_lint(graph)
torch._C._jit_pass_constant_propagation(graph)
return graph
......
......@@ -402,7 +402,7 @@ class TFDecoder(object):
right_shape_been_input = False
while not right_shape_been_input:
try:
shape = raw_input(
shape = input(
"Shape of Input(e.g. None,224,224,3): ")
except:
shape = input("Shape of Input(e.g. None,224,224,3): ")
......
......@@ -20,22 +20,41 @@ import numpy as np
class Gather(object):
def __init__(self, dim):
self.dim = dim
self.dtype_mapping = {"VarType.INT32": "int32",
"VarType.INT64": "int64"}
def __call__(self, x, index):
out_list = list()
dims = list()
index_shape = index.shape
x_type = x.numpy().dtype
for s in index_shape:
dims.append(list(range(s)))
for id in product(*dims):
id = list(id)
id_tensor = paddle.to_tensor(np.array(id).astype('int32'))
dim_id = paddle.gather_nd(index, id_tensor).numpy()
id[self.dim] = dim_id
id_tensor = paddle.to_tensor(np.array(id).astype('int32'))
data = paddle.gather_nd(x, id_tensor).numpy()
out_list.append(data)
out = paddle.to_tensor(np.array(out_list).astype(x_type))
out = paddle.reshape(out, index_shape)
if self.dim < 0:
self.dim += len(x.shape)
x_range = list(range(len(x.shape)))
x_range[0] = self.dim
x_range[self.dim] = 0
x_swaped = paddle.transpose(x, perm=x_range)
index_range = list(range(len(index.shape)))
index_range[0] = self.dim
index_range[self.dim] = 0
index_swaped = paddle.transpose(index, perm=index_range)
dtype = self.dtype_mapping[str(index.dtype)]
x_shape = paddle.shape(x_swaped)
index_shape = paddle.shape(index_swaped)
prod = paddle.prod(x_shape, dtype=dtype) / x_shape[0]
x_swaped_flattend = paddle.flatten(x_swaped)
index_swaped_flattend = paddle.flatten(index_swaped)
index_swaped_flattend *= prod
bias = paddle.arange(start=0, end=prod, dtype=dtype)
bias = paddle.reshape(bias, x_shape[1:])
bias = paddle.crop(bias, index_shape[1:])
bias = paddle.flatten(bias)
bias = paddle.tile(bias, [index_shape[0]])
index_swaped_flattend += bias
gathered = paddle.index_select(x_swaped_flattend, index_swaped_flattend)
gathered = paddle.reshape(gathered, index_swaped.shape)
out = paddle.transpose(gathered, perm=x_range)
return out
......@@ -372,7 +372,7 @@ class HierarchicalTree(Tree):
"import paddle.fluid as fluid",
"from paddle.fluid.initializer import Constant",
"from paddle.fluid.param_attr import ParamAttr",
"imort math",
"import math",
"from x2paddle.op_mapper.dygraph.pytorch2paddle " + \
"import pytorch_custom_layer as x2paddle_nn"
"\n",]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册