提交 3967c640 编写于 作者: Y yeliang2258

add GatherElements, suport list input

上级 2f01932d
...@@ -27,22 +27,23 @@ from x2paddle.core.util import * ...@@ -27,22 +27,23 @@ from x2paddle.core.util import *
class PaddleLayer(object): class PaddleLayer(object):
def __init__(self, id, kernel, inputs, outputs, scope_name="", **kwargs): def __init__(self, id, kernel, inputs, outputs, scope_name="", **kwargs):
assert isinstance( assert isinstance(inputs, (
inputs, dict, list
dict), "parameter 'inputs' for PaddleLayer should be type of dict" )), "parameter 'inputs' for PaddleLayer should be type of dict or list"
assert isinstance( assert isinstance(
outputs, outputs,
list), "parameter 'outputs' for PaddleLayer should be type of list" list), "parameter 'outputs' for PaddleLayer should be type of list"
for k, v in inputs.items(): if isinstance(inputs, dict):
if isinstance(v, (list, tuple)): for k, v in inputs.items():
for i in v: if isinstance(v, (list, tuple)):
assert isinstance( for i in v:
i, six.string_types assert isinstance(
i, six.string_types
), "value in inputs should be type of string or list of string"
else:
assert isinstance(v, six.string_types) or isinstance(
v, list
), "value in inputs should be type of string or list of string" ), "value in inputs should be type of string or list of string"
else:
assert isinstance(v, six.string_types) or isinstance(
v, list
), "value in inputs should be type of string or list of string"
for v in outputs: for v in outputs:
assert isinstance( assert isinstance(
v, six. v, six.
...@@ -164,11 +165,31 @@ class PaddleGraph(object): ...@@ -164,11 +165,31 @@ class PaddleGraph(object):
self.clear_edges() self.clear_edges()
outputs_from_nodes = dict() outputs_from_nodes = dict()
for layer_id, layer in self.layers.items(): for layer_id, layer in self.layers.items():
for input_key, input_var in layer.inputs.items(): if isinstance(layer.inputs, dict):
vs = input_var for input_key, input_var in layer.inputs.items():
if not isinstance(vs, (list, tuple)): vs = input_var
vs = [vs] if not isinstance(vs, (list, tuple)):
for v in vs: vs = [vs]
for v in vs:
assert v in outputs_from_nodes or (
inputs is not None and v in list(inputs.values())
) or (
outputs is not None and v in outputs
), "Couldn't find {} in previous layers, the layers should be make by topological sort".format(
v)
if v in outputs_from_nodes:
in_layer_id = outputs_from_nodes[v]
else:
in_layer_id = -1
if in_layer_id not in self.edges_out:
self.edges_out[in_layer_id] = list()
self.edges_out[in_layer_id].append(layer_id)
if layer_id not in self.edges_in:
self.edges_in[layer_id] = list()
self.edges_in[layer_id].append(in_layer_id)
else:
for v in layer.inputs:
assert v in outputs_from_nodes or ( assert v in outputs_from_nodes or (
inputs is not None and v in list(inputs.values()) inputs is not None and v in list(inputs.values())
) or ( ) or (
...@@ -186,6 +207,7 @@ class PaddleGraph(object): ...@@ -186,6 +207,7 @@ class PaddleGraph(object):
if layer_id not in self.edges_in: if layer_id not in self.edges_in:
self.edges_in[layer_id] = list() self.edges_in[layer_id] = list()
self.edges_in[layer_id].append(in_layer_id) self.edges_in[layer_id].append(in_layer_id)
for output in layer.outputs: for output in layer.outputs:
outputs_from_nodes[output] = layer_id outputs_from_nodes[output] = layer_id
...@@ -359,6 +381,7 @@ class PaddleGraph(object): ...@@ -359,6 +381,7 @@ class PaddleGraph(object):
"class {}(paddle.nn.Layer):".format(self.name), "class {}(paddle.nn.Layer):".format(self.name),
], ],
indent=0) indent=0)
print(self.inputs)
input_data_name = ', '.join(self.inputs) input_data_name = ', '.join(self.inputs)
self.init_func.extend(gen_codes(["def __init__(self):"], indent=1)) self.init_func.extend(gen_codes(["def __init__(self):"], indent=1))
self.init_func.extend( self.init_func.extend(
...@@ -496,16 +519,20 @@ class PaddleGraph(object): ...@@ -496,16 +519,20 @@ class PaddleGraph(object):
else: else:
line = ','.join(layer.outputs) line = ','.join(layer.outputs)
line += " = {}(".format(layer.kernel) line += " = {}(".format(layer.kernel)
for k, v in layer.inputs.items(): if isinstance(layer.inputs, dict):
if isinstance(v, list): for k, v in layer.inputs.items():
line += "{}=[{}], ".format(k, ", ".join(v)) if isinstance(v, list):
elif isinstance(v, tuple): line += "{}=[{}], ".format(k, ", ".join(v))
line += "{}=({}), ".format(k, ", ".join(v)) elif isinstance(v, tuple):
else: line += "{}=({}), ".format(k, ", ".join(v))
if k == "args":
line += v
else: else:
line += "{}={}, ".format(k, v) if k == "args":
line += v
else:
line += "{}={}, ".format(k, v)
else:
line += "{}".format(", ".join(layer.inputs))
for k, v in layer.attrs.items(): for k, v in layer.attrs.items():
line += "{}={}, ".format(k, v) line += "{}={}, ".format(k, v)
line = line.strip(", ") line = line.strip(", ")
......
...@@ -69,8 +69,6 @@ def _rename_or_remove_weight(weights, ...@@ -69,8 +69,6 @@ def _rename_or_remove_weight(weights,
if target_name is not None: if target_name is not None:
# rename weight # rename weight
weights[target_name] = data weights[target_name] = data
if "x2paddle_297" in weights.keys():
print("keep")
def _is_static_shape(shape): def _is_static_shape(shape):
...@@ -233,8 +231,6 @@ class OpSet9(): ...@@ -233,8 +231,6 @@ class OpSet9():
@print_mapping_info @print_mapping_info
def place_holder(self, node): def place_holder(self, node):
if node.name in ["297", "x2paddle_297"]:
print("!!!!!!!find! 1123")
shape = node.out_shapes[0] shape = node.out_shapes[0]
for i, dim_shape in enumerate(shape): for i, dim_shape in enumerate(shape):
if dim_shape == 0 and i == 0: if dim_shape == 0 and i == 0:
...@@ -264,7 +260,6 @@ class OpSet9(): ...@@ -264,7 +260,6 @@ class OpSet9():
shape=[1], shape=[1],
fill_value=node.weight) fill_value=node.weight)
else: else:
print("test point:", node.name)
self.weights[node.name] = node.weight self.weights[node.name] = node.weight
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"self.create_parameter", "self.create_parameter",
...@@ -405,9 +400,6 @@ class OpSet9(): ...@@ -405,9 +400,6 @@ class OpSet9():
inputs['scale_factor'] = val_scales.name inputs['scale_factor'] = val_scales.name
else: else:
val_scales = node.get_attr('scales')[2:] val_scales = node.get_attr('scales')[2:]
print(type(val_scales))
print(val_scales)
# inputs['scale_factor'] = val_scales
mode = node.get_attr('mode', 'nearest') mode = node.get_attr('mode', 'nearest')
attrs.update({ attrs.update({
...@@ -733,8 +725,6 @@ class OpSet9(): ...@@ -733,8 +725,6 @@ class OpSet9():
@print_mapping_info @print_mapping_info
def Constant(self, node): def Constant(self, node):
if node.name in ["297", "x2paddle_297"]:
print("!!!!!!!find!")
val_output = self.graph.get_node(node.layer.output[0], copy=True) val_output = self.graph.get_node(node.layer.output[0], copy=True)
value = node.get_attr('value') value = node.get_attr('value')
...@@ -829,8 +819,6 @@ class OpSet9(): ...@@ -829,8 +819,6 @@ class OpSet9():
'fill_value': 1 'fill_value': 1
} }
else: else:
print("test:", type(shape_values))
print(shape_values.tolist())
attr_ones = { attr_ones = {
'shape': shape_values.tolist(), 'shape': shape_values.tolist(),
'dtype': string(val_x_dtype), 'dtype': string(val_x_dtype),
...@@ -855,8 +843,6 @@ class OpSet9(): ...@@ -855,8 +843,6 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
indices = self.graph.get_input_node(node, idx=1, copy=True) indices = self.graph.get_input_node(node, idx=1, copy=True)
indices_shape = indices.out_shapes[0] indices_shape = indices.out_shapes[0]
print("indices_shape:", node.name, " ", indices_shape, " ",
val_x.out_shapes[0])
axis = node.get_attr('axis', 0) axis = node.get_attr('axis', 0)
#assert len( #assert len(
# indices_shape) <= 2, "Gather op don't support dim of indice >2 " # indices_shape) <= 2, "Gather op don't support dim of indice >2 "
...@@ -1180,7 +1166,6 @@ class OpSet9(): ...@@ -1180,7 +1166,6 @@ class OpSet9():
if axes is None: if axes is None:
axes = [i for i in range(len(starts))] axes = [i for i in range(len(starts))]
print("axes:", axes)
for idx in range(len(ends)): for idx in range(len(ends)):
if ends[idx] > 2**31 - 1: if ends[idx] > 2**31 - 1:
ends[idx] = 2**31 - 1 ends[idx] = 2**31 - 1
...@@ -1221,7 +1206,6 @@ class OpSet9(): ...@@ -1221,7 +1206,6 @@ class OpSet9():
@print_mapping_info @print_mapping_info
def GatherND(self, node): def GatherND(self, node):
print(len(node.inputs), node.inputs)
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True)
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
...@@ -1387,7 +1371,6 @@ class OpSet9(): ...@@ -1387,7 +1371,6 @@ class OpSet9():
@print_mapping_info @print_mapping_info
def GatherND(self, node): def GatherND(self, node):
print(len(node.inputs), node.inputs)
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True)
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
...@@ -2037,7 +2020,6 @@ class OpSet9(): ...@@ -2037,7 +2020,6 @@ class OpSet9():
def SpaceToDepth(self, node): def SpaceToDepth(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
blocksize = node.get_attr('blocksize') blocksize = node.get_attr('blocksize')
print(blocksize)
val_x_shape = val_x.out_shapes[0] val_x_shape = val_x.out_shapes[0]
b, c, h, w = val_x_shape b, c, h, w = val_x_shape
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
...@@ -2060,12 +2042,91 @@ class OpSet9(): ...@@ -2060,12 +2042,91 @@ class OpSet9():
def GatherElements(self, node): def GatherElements(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
indices = self.graph.get_input_node(node, idx=1, copy=True) indices = self.graph.get_input_node(node, idx=1, copy=True)
dtype = np.dtype(val_x.dtype) axis = node.get_attr('axis')
val_x_shape = val_x.out_shapes[0]
indices_shape = indices.out_shapes[0]
axis = axis if axis >= 0 else axis + len(val_x_shape)
if axis == 0:
axis_perm = [i for i in range(len(val_x_shape))]
data_swaped = val_x.name
index_swaped = indices.name
else:
axis_perm = [i for i in range(len(val_x_shape))]
axis_perm[axis] = 0
axis_perm[0] = axis
data_swaped = val_x.name + "_transpose"
self.paddle_graph.add_layer(
"paddle.transpose",
inputs={'x': val_x.name},
perm=axis_perm,
outputs=[data_swaped])
index_swaped = indices.name + "_transpose"
self.paddle_graph.add_layer(
"paddle.transpose",
inputs={'x': indices.name},
perm=axis_perm,
outputs=[index_swaped])
temp = indices_shape[0]
indices_shape[0] = indices_shape[axis]
indices_shape[axis] = temp
idx_tensors_per_axis_pre = [
indices_shape[i] for i in range(len(indices_shape))
]
name_list = list()
for i in range(len(idx_tensors_per_axis_pre)):
tensor_name = val_x.name + "_meshgrid_" + str(i)
self.paddle_graph.add_layer(
kernel="paddle.linspace",
inputs={},
outputs=[tensor_name],
start=0,
stop=idx_tensors_per_axis_pre[i] - 1,
num=idx_tensors_per_axis_pre[i])
name_list.append(tensor_name)
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.gather", "paddle.meshgrid", inputs=name_list, outputs=name_list)
inputs={'x': val_x.name,
'index': indices.name}, self.paddle_graph.add_layer(
axis=node.get_attr('axis'), "paddle.cast",
inputs={"x": index_swaped},
outputs=[index_swaped],
dtype=string("float32"))
import copy
copy_name_list = copy.copy(name_list)
copy_name_list[0] = index_swaped
new_name_list = list()
for i in range(len(copy_name_list)):
unsqueeze_name = copy_name_list[i] + "_unsqueeze"
self.paddle_graph.add_layer(
"paddle.unsqueeze",
inputs={"x": copy_name_list[i]},
axis=-1,
outputs=[unsqueeze_name])
new_name_list.append(unsqueeze_name)
concat_name = val_x.name + "_concated_layer"
self.paddle_graph.add_layer(
"paddle.concat",
inputs={'x': new_name_list},
axis=-1,
outputs=[concat_name])
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": concat_name},
outputs=[concat_name],
dtype=string("int32"))
gather_nd_name = "gather_nd_layer"
self.paddle_graph.add_layer(
"paddle.gather_nd",
inputs={'x': data_swaped,
"index": concat_name},
outputs=[gather_nd_name])
self.paddle_graph.add_layer(
"paddle.transpose",
inputs={'x': gather_nd_name},
perm=axis_perm,
outputs=[node.name]) outputs=[node.name])
@print_mapping_info @print_mapping_info
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册