提交 3967c640 编写于 作者: Y yeliang2258

add GatherElements, suport list input

上级 2f01932d
......@@ -27,12 +27,13 @@ from x2paddle.core.util import *
class PaddleLayer(object):
def __init__(self, id, kernel, inputs, outputs, scope_name="", **kwargs):
assert isinstance(
inputs,
dict), "parameter 'inputs' for PaddleLayer should be type of dict"
assert isinstance(inputs, (
dict, list
)), "parameter 'inputs' for PaddleLayer should be type of dict or list"
assert isinstance(
outputs,
list), "parameter 'outputs' for PaddleLayer should be type of list"
if isinstance(inputs, dict):
for k, v in inputs.items():
if isinstance(v, (list, tuple)):
for i in v:
......@@ -164,6 +165,7 @@ class PaddleGraph(object):
self.clear_edges()
outputs_from_nodes = dict()
for layer_id, layer in self.layers.items():
if isinstance(layer.inputs, dict):
for input_key, input_var in layer.inputs.items():
vs = input_var
if not isinstance(vs, (list, tuple)):
......@@ -186,6 +188,26 @@ class PaddleGraph(object):
if layer_id not in self.edges_in:
self.edges_in[layer_id] = list()
self.edges_in[layer_id].append(in_layer_id)
else:
for v in layer.inputs:
assert v in outputs_from_nodes or (
inputs is not None and v in list(inputs.values())
) or (
outputs is not None and v in outputs
), "Couldn't find {} in previous layers, the layers should be make by topological sort".format(
v)
if v in outputs_from_nodes:
in_layer_id = outputs_from_nodes[v]
else:
in_layer_id = -1
if in_layer_id not in self.edges_out:
self.edges_out[in_layer_id] = list()
self.edges_out[in_layer_id].append(layer_id)
if layer_id not in self.edges_in:
self.edges_in[layer_id] = list()
self.edges_in[layer_id].append(in_layer_id)
for output in layer.outputs:
outputs_from_nodes[output] = layer_id
......@@ -359,6 +381,7 @@ class PaddleGraph(object):
"class {}(paddle.nn.Layer):".format(self.name),
],
indent=0)
print(self.inputs)
input_data_name = ', '.join(self.inputs)
self.init_func.extend(gen_codes(["def __init__(self):"], indent=1))
self.init_func.extend(
......@@ -496,6 +519,7 @@ class PaddleGraph(object):
else:
line = ','.join(layer.outputs)
line += " = {}(".format(layer.kernel)
if isinstance(layer.inputs, dict):
for k, v in layer.inputs.items():
if isinstance(v, list):
line += "{}=[{}], ".format(k, ", ".join(v))
......@@ -506,6 +530,9 @@ class PaddleGraph(object):
line += v
else:
line += "{}={}, ".format(k, v)
else:
line += "{}".format(", ".join(layer.inputs))
for k, v in layer.attrs.items():
line += "{}={}, ".format(k, v)
line = line.strip(", ")
......
......@@ -69,8 +69,6 @@ def _rename_or_remove_weight(weights,
if target_name is not None:
# rename weight
weights[target_name] = data
if "x2paddle_297" in weights.keys():
print("keep")
def _is_static_shape(shape):
......@@ -233,8 +231,6 @@ class OpSet9():
@print_mapping_info
def place_holder(self, node):
if node.name in ["297", "x2paddle_297"]:
print("!!!!!!!find! 1123")
shape = node.out_shapes[0]
for i, dim_shape in enumerate(shape):
if dim_shape == 0 and i == 0:
......@@ -264,7 +260,6 @@ class OpSet9():
shape=[1],
fill_value=node.weight)
else:
print("test point:", node.name)
self.weights[node.name] = node.weight
self.paddle_graph.add_layer(
"self.create_parameter",
......@@ -405,9 +400,6 @@ class OpSet9():
inputs['scale_factor'] = val_scales.name
else:
val_scales = node.get_attr('scales')[2:]
print(type(val_scales))
print(val_scales)
# inputs['scale_factor'] = val_scales
mode = node.get_attr('mode', 'nearest')
attrs.update({
......@@ -733,8 +725,6 @@ class OpSet9():
@print_mapping_info
def Constant(self, node):
if node.name in ["297", "x2paddle_297"]:
print("!!!!!!!find!")
val_output = self.graph.get_node(node.layer.output[0], copy=True)
value = node.get_attr('value')
......@@ -829,8 +819,6 @@ class OpSet9():
'fill_value': 1
}
else:
print("test:", type(shape_values))
print(shape_values.tolist())
attr_ones = {
'shape': shape_values.tolist(),
'dtype': string(val_x_dtype),
......@@ -855,8 +843,6 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=0, copy=True)
indices = self.graph.get_input_node(node, idx=1, copy=True)
indices_shape = indices.out_shapes[0]
print("indices_shape:", node.name, " ", indices_shape, " ",
val_x.out_shapes[0])
axis = node.get_attr('axis', 0)
#assert len(
# indices_shape) <= 2, "Gather op don't support dim of indice >2 "
......@@ -1180,7 +1166,6 @@ class OpSet9():
if axes is None:
axes = [i for i in range(len(starts))]
print("axes:", axes)
for idx in range(len(ends)):
if ends[idx] > 2**31 - 1:
ends[idx] = 2**31 - 1
......@@ -1221,7 +1206,6 @@ class OpSet9():
@print_mapping_info
def GatherND(self, node):
print(len(node.inputs), node.inputs)
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
self.paddle_graph.add_layer(
......@@ -1387,7 +1371,6 @@ class OpSet9():
@print_mapping_info
def GatherND(self, node):
print(len(node.inputs), node.inputs)
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
self.paddle_graph.add_layer(
......@@ -2037,7 +2020,6 @@ class OpSet9():
def SpaceToDepth(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
blocksize = node.get_attr('blocksize')
print(blocksize)
val_x_shape = val_x.out_shapes[0]
b, c, h, w = val_x_shape
self.paddle_graph.add_layer(
......@@ -2060,12 +2042,91 @@ class OpSet9():
def GatherElements(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
indices = self.graph.get_input_node(node, idx=1, copy=True)
dtype = np.dtype(val_x.dtype)
axis = node.get_attr('axis')
val_x_shape = val_x.out_shapes[0]
indices_shape = indices.out_shapes[0]
axis = axis if axis >= 0 else axis + len(val_x_shape)
if axis == 0:
axis_perm = [i for i in range(len(val_x_shape))]
data_swaped = val_x.name
index_swaped = indices.name
else:
axis_perm = [i for i in range(len(val_x_shape))]
axis_perm[axis] = 0
axis_perm[0] = axis
data_swaped = val_x.name + "_transpose"
self.paddle_graph.add_layer(
"paddle.gather",
inputs={'x': val_x.name,
'index': indices.name},
axis=node.get_attr('axis'),
"paddle.transpose",
inputs={'x': val_x.name},
perm=axis_perm,
outputs=[data_swaped])
index_swaped = indices.name + "_transpose"
self.paddle_graph.add_layer(
"paddle.transpose",
inputs={'x': indices.name},
perm=axis_perm,
outputs=[index_swaped])
temp = indices_shape[0]
indices_shape[0] = indices_shape[axis]
indices_shape[axis] = temp
idx_tensors_per_axis_pre = [
indices_shape[i] for i in range(len(indices_shape))
]
name_list = list()
for i in range(len(idx_tensors_per_axis_pre)):
tensor_name = val_x.name + "_meshgrid_" + str(i)
self.paddle_graph.add_layer(
kernel="paddle.linspace",
inputs={},
outputs=[tensor_name],
start=0,
stop=idx_tensors_per_axis_pre[i] - 1,
num=idx_tensors_per_axis_pre[i])
name_list.append(tensor_name)
self.paddle_graph.add_layer(
"paddle.meshgrid", inputs=name_list, outputs=name_list)
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": index_swaped},
outputs=[index_swaped],
dtype=string("float32"))
import copy
copy_name_list = copy.copy(name_list)
copy_name_list[0] = index_swaped
new_name_list = list()
for i in range(len(copy_name_list)):
unsqueeze_name = copy_name_list[i] + "_unsqueeze"
self.paddle_graph.add_layer(
"paddle.unsqueeze",
inputs={"x": copy_name_list[i]},
axis=-1,
outputs=[unsqueeze_name])
new_name_list.append(unsqueeze_name)
concat_name = val_x.name + "_concated_layer"
self.paddle_graph.add_layer(
"paddle.concat",
inputs={'x': new_name_list},
axis=-1,
outputs=[concat_name])
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": concat_name},
outputs=[concat_name],
dtype=string("int32"))
gather_nd_name = "gather_nd_layer"
self.paddle_graph.add_layer(
"paddle.gather_nd",
inputs={'x': data_swaped,
"index": concat_name},
outputs=[gather_nd_name])
self.paddle_graph.add_layer(
"paddle.transpose",
inputs={'x': gather_nd_name},
perm=axis_perm,
outputs=[node.name])
@print_mapping_info
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册