提交 91879f50 编写于 作者: S SunAhong1993

fix fro pre-commit

上级 2fc9ffd0
...@@ -37,20 +37,24 @@ def prim_Constant(mapper, graph, node): ...@@ -37,20 +37,24 @@ def prim_Constant(mapper, graph, node):
tensor_value = value tensor_value = value
value = "{}".format(value) value = "{}".format(value)
if "tensor" in value: if "tensor" in value:
if isinstance(tensor_value, list) or isinstance(tensor_value, tuple): if isinstance(tensor_value, list) or isinstance(tensor_value,
tuple):
name_dict = dict() name_dict = dict()
for i, tv in enumerate(tensor_value): for i, tv in enumerate(tensor_value):
output_name_i = "{}_p{}".format(output_name,i) output_name_i = "{}_p{}".format(output_name, i)
key_i = "input{}".format(i) key_i = "input{}".format(i)
mapper.paddle_params[output_name_i] = tv.cpu().detach().numpy() mapper.paddle_params[output_name_i] = tv.cpu().detach(
).numpy()
graph.add_layer( graph.add_layer(
"self.create_parameter", "self.create_parameter",
inputs={}, inputs={},
outputs=[output_name_i], outputs=[output_name_i],
scope_name=scope_name, scope_name=scope_name,
dtype=string(str(mapper.paddle_params[output_name_i].dtype)), dtype=string(
shape = mapper.paddle_params[output_name_i].shape, str(mapper.paddle_params[output_name_i].dtype)),
default_initializer="paddle.nn.initializer.Constant(value=0.0)") shape=mapper.paddle_params[output_name_i].shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)"
)
name_dict[key_i] = output_name_i name_dict[key_i] = output_name_i
graph.add_layer( graph.add_layer(
"prim.list", "prim.list",
...@@ -59,16 +63,18 @@ def prim_Constant(mapper, graph, node): ...@@ -59,16 +63,18 @@ def prim_Constant(mapper, graph, node):
scope_name=scope_name) scope_name=scope_name)
return [], [output_name] return [], [output_name]
else: else:
# mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy() # mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy()
mapper.paddle_params[output_name] = tensor_value.cpu().detach().numpy() mapper.paddle_params[output_name] = tensor_value.cpu().detach(
).numpy()
graph.add_layer( graph.add_layer(
"self.create_parameter", "self.create_parameter",
inputs={}, inputs={},
outputs=[output_name], outputs=[output_name],
scope_name=scope_name, scope_name=scope_name,
dtype=string(str(mapper.paddle_params[output_name].dtype)), dtype=string(str(mapper.paddle_params[output_name].dtype)),
shape = mapper.paddle_params[output_name].shape, shape=mapper.paddle_params[output_name].shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)") default_initializer="paddle.nn.initializer.Constant(value=0.0)"
)
return [], [output_name] return [], [output_name]
if "inf" in str(value): if "inf" in str(value):
t = str(type(value)).split("'")[1] t = str(type(value)).split("'")[1]
...@@ -81,7 +87,11 @@ def prim_Constant(mapper, graph, node): ...@@ -81,7 +87,11 @@ def prim_Constant(mapper, graph, node):
value = int(math.pow(2, 31) - 1) value = int(math.pow(2, 31) - 1)
mapper.attrs[output_name] = value mapper.attrs[output_name] = value
graph.add_layer( graph.add_layer(
"prim.constant", inputs={}, outputs=[output_name], scope_name=scope_name, value=value) "prim.constant",
inputs={},
outputs=[output_name],
scope_name=scope_name,
value=value)
return [], [output_name] return [], [output_name]
...@@ -105,18 +115,23 @@ def prim_data(mapper, graph, node): ...@@ -105,18 +115,23 @@ def prim_data(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%4336 # 处理输入0,即%4336
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.equal",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
def prim_DictConstruct(mapper, graph, node): def prim_DictConstruct(mapper, graph, node):
""" 构建dict。 """ 构建dict。
TorchScript示例: TorchScript示例:
%32 : Dict(str, Tensor) = prim::DictConstruct(%30, %23, %31, %29) %32 : Dict(str, Tensor) = prim::DictConstruct(%30, %23, %31, %29)
参数含义: 参数含义:
...@@ -136,22 +151,22 @@ def prim_DictConstruct(mapper, graph, node): ...@@ -136,22 +151,22 @@ def prim_DictConstruct(mapper, graph, node):
current_outputs = [output_name] current_outputs = [output_name]
# 处理每个输入 # 处理每个输入
for i, input_name in enumerate(inputs_name): for i, input_name in enumerate(inputs_name):
if i%2 == 0: if i % 2 == 0:
layer_attrs["key{}".format(int(i/2))] = mapper.attrs[input_name] layer_attrs["key{}".format(int(i / 2))] = mapper.attrs[input_name]
else: else:
layer_inputs["value{}".format(int(i/2))] = input_name layer_inputs["value{}".format(int(i / 2))] = input_name
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.dict_construct", graph.add_layer(
inputs=layer_inputs, "prim.dict_construct",
outputs=layer_outputs, inputs=layer_inputs,
scope_name=scope_name, outputs=layer_outputs,
**layer_attrs) scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
def prim_GetAttr(mapper, graph, node): def prim_GetAttr(mapper, graph, node):
""" 获取attribute信息。 """ 获取attribute信息。
...@@ -212,8 +227,13 @@ def prim_If(mapper, graph, node): ...@@ -212,8 +227,13 @@ def prim_If(mapper, graph, node):
input_node = list(node.inputs())[0].node() input_node = list(node.inputs())[0].node()
script_input_unique_id = list(node.inputs())[0].unique() script_input_unique_id = list(node.inputs())[0].unique()
input_node_name = mapper.outputs_info[script_input_unique_id] input_node_name = mapper.outputs_info[script_input_unique_id]
mapper._check_input(graph, input_node, input_node_name, current_outputs, scope_name) mapper._check_input(graph, input_node, input_node_name, current_outputs,
graph.add_layer("prim.if", inputs={'input': input_node_name}, outputs=node_outputs, scope_name=scope_name) scope_name)
graph.add_layer(
"prim.if",
inputs={'input': input_node_name},
outputs=node_outputs,
scope_name=scope_name)
current_layer = list(graph.layers.values())[-1] current_layer = list(graph.layers.values())[-1]
block0 = list(node.blocks())[0] block0 = list(node.blocks())[0]
block0_graph, graph_inputs0 = mapper.traverse(block0, current_layer) block0_graph, graph_inputs0 = mapper.traverse(block0, current_layer)
...@@ -249,12 +269,17 @@ def prim_ListConstruct(mapper, graph, node): ...@@ -249,12 +269,17 @@ def prim_ListConstruct(mapper, graph, node):
current_outputs = [output_name] current_outputs = [output_name]
# 处理每个输入 # 处理每个输入
for i, input_name in enumerate(inputs_name): for i, input_name in enumerate(inputs_name):
mapper._check_input(graph, inputs_node[i], input_name, current_outputs, scope_name) mapper._check_input(graph, inputs_node[i], input_name, current_outputs,
scope_name)
layer_inputs["input{}".format(i)] = input_name layer_inputs["input{}".format(i)] = input_name
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
layer_id = graph.add_layer("prim.list", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) layer_id = graph.add_layer(
"prim.list",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
mapper.output2id[output_name] = layer_id mapper.output2id[output_name] = layer_id
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -277,13 +302,17 @@ def prim_ListUnpack(mapper, graph, node): ...@@ -277,13 +302,17 @@ def prim_ListUnpack(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = layer_outputs.copy() current_outputs = layer_outputs.copy()
# 处理输入0,即%4354 # 处理输入0,即%4354
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.list_unpack", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim.list_unpack",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
mapper.split_len[list(layer_inputs.values())[0]] = len(layer_outputs) mapper.split_len[list(layer_inputs.values())[0]] = len(layer_outputs)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -342,7 +371,11 @@ def prim_Loop(mapper, graph, node): ...@@ -342,7 +371,11 @@ def prim_Loop(mapper, graph, node):
scope_name=scope_name) scope_name=scope_name)
node_outputs.append(block_input_node_name) node_outputs.append(block_input_node_name)
graph.add_layer("prim.loop", inputs=loop_inputs, outputs=loop_outputs, scope_name=scope_name) graph.add_layer(
"prim.loop",
inputs=loop_inputs,
outputs=loop_outputs,
scope_name=scope_name)
current_layer = list(graph.layers.values())[-1] current_layer = list(graph.layers.values())[-1]
block_graph, graph_inputs = mapper.traverse(block, current_layer) block_graph, graph_inputs = mapper.traverse(block, current_layer)
for i, input_name in enumerate(graph_inputs): for i, input_name in enumerate(graph_inputs):
...@@ -370,12 +403,17 @@ def prim_min(mapper, graph, node): ...@@ -370,12 +403,17 @@ def prim_min(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%86 # 处理输入0,即%86
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.min", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.min",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -397,14 +435,19 @@ def prim_NumToTensor(mapper, graph, node): ...@@ -397,14 +435,19 @@ def prim_NumToTensor(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%86 # 处理输入0,即%86
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
inputs_inputs_name, inputs_inputs_node = mapper._get_inputs_name(inputs_node[0]) scope_name)
inputs_inputs_name, inputs_inputs_node = mapper._get_inputs_name(
inputs_node[0])
if inputs_node[0].kind() == "aten::size" and len(inputs_inputs_name) > 1: if inputs_node[0].kind() == "aten::size" and len(inputs_inputs_name) > 1:
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim_equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim_equal",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
else: else:
layer_inputs["fill_value"] = inputs_name[0] layer_inputs["fill_value"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
...@@ -437,13 +480,17 @@ def prim_RaiseException(mapper, graph, node): ...@@ -437,13 +480,17 @@ def prim_RaiseException(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%76 # 处理输入0,即%76
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.exception", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim.exception",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -464,13 +511,17 @@ def prim_requires_grad(mapper, graph, node): ...@@ -464,13 +511,17 @@ def prim_requires_grad(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%86 # 处理输入0,即%86
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.requires_grad", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim.requires_grad",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -527,13 +578,17 @@ def prim_shape(mapper, graph, node): ...@@ -527,13 +578,17 @@ def prim_shape(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input.8 # 处理输入0,即%input.8
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.shape", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "paddle.shape",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -560,7 +615,11 @@ def prim_TupleConstruct(mapper, graph, node): ...@@ -560,7 +615,11 @@ def prim_TupleConstruct(mapper, graph, node):
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.tuple", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.tuple",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -590,7 +649,11 @@ def prim_TupleUnpack(mapper, graph, node): ...@@ -590,7 +649,11 @@ def prim_TupleUnpack(mapper, graph, node):
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.tuple_unpack", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name, **layer_attrs) "prim.tuple_unpack",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -614,12 +677,17 @@ def prim_unchecked_cast(mapper, graph, node): ...@@ -614,12 +677,17 @@ def prim_unchecked_cast(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%size.63 # 处理输入0,即%size.63
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.equal",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -636,5 +704,9 @@ def prim_Uninitialized(mapper, graph, node): ...@@ -636,5 +704,9 @@ def prim_Uninitialized(mapper, graph, node):
output = list(node.outputs())[0] output = list(node.outputs())[0]
mapper.attrs[output_name] = None mapper.attrs[output_name] = None
graph.add_layer( graph.add_layer(
"prim.constant", inputs={}, outputs=[output_name], scope_name=scope_name, value=None) "prim.constant",
inputs={},
outputs=[output_name],
scope_name=scope_name,
value=None)
return [], [output_name] return [], [output_name]
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
# limitations under the License. # limitations under the License.
NO_OUTPUT_COUNT = 0 NO_OUTPUT_COUNT = 0
def gen_codes(code_list, indent=0): def gen_codes(code_list, indent=0):
indent_blank = " " * indent indent_blank = " " * indent
codes = [] codes = []
...@@ -53,13 +54,24 @@ def get_value(layer, key, layer_id=None, different_attrs=None): ...@@ -53,13 +54,24 @@ def get_value(layer, key, layer_id=None, different_attrs=None):
return str(layer.attrs[key]) return str(layer.attrs[key])
def prim_add(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_add(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {} + {}".format(layer.outputs[0], line = "{} = {} + {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_add_(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_add_(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {} + {} * {}".format(layer.outputs[0], line = "{} = {} + {} * {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
layer.attrs["alpha"], layer.attrs["alpha"],
...@@ -67,22 +79,39 @@ def prim_add_(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif ...@@ -67,22 +79,39 @@ def prim_add_(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_and(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_and(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} and {}".format(layer.outputs[0], line = "{} = {} and {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_append(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_append(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{}.append({})".format( line = "{}.append({})".format(
get_value(layer, "list", layer_id, different_attrs), get_value(layer, "list", layer_id, different_attrs),
get_value(layer, "element", layer_id, different_attrs)) get_value(layer, "element", layer_id, different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_assert(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
if layer.attrs["type"] == "eq": if layer.attrs["type"] == "eq":
values = get_value(layer, "key") values = get_value(layer, "key")
if "value" in layer.attrs: if "value" in layer.attrs:
...@@ -93,15 +122,15 @@ def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, d ...@@ -93,15 +122,15 @@ def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, d
s += "{} == {} or ".format(get_value(layer, "key"), v) s += "{} == {} or ".format(get_value(layer, "key"), v)
if len(s) > 0: if len(s) > 0:
s = s[:-4] s = s[:-4]
lc=locals() lc = locals()
exec("assert_result = {}".format(s)) exec("assert_result = {}".format(s))
assert_result = lc['assert_result'] assert_result = lc['assert_result']
line = "assert {}, \'The {} must be {}!\'".format( line = "assert {}, \'The {} must be {}!\'".format(
s, get_value(layer, "key"), get_value(layer, "value")) s, get_value(layer, "key"), get_value(layer, "value"))
else: else:
s = "{} == {}".format(get_value(layer, "key"), s = "{} == {}".format(
get_value(layer, "value")) get_value(layer, "key"), get_value(layer, "value"))
lc=locals() lc = locals()
exec("assert_result = {}".format(s)) exec("assert_result = {}".format(s))
assert_result = lc['assert_result'] assert_result = lc['assert_result']
line = "assert {}, \'The {} must be {}!\'".format( line = "assert {}, \'The {} must be {}!\'".format(
...@@ -112,7 +141,12 @@ def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, d ...@@ -112,7 +141,12 @@ def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, d
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_check_dim(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_check_dim(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
lines = [] lines = []
dim = get_value(layer, "dim", different_attrs) dim = get_value(layer, "dim", different_attrs)
lines.append("if {} < 0:".format(dim)) lines.append("if {} < 0:".format(dim))
...@@ -123,12 +157,23 @@ def prim_check_dim(layer, indent=1, init_func=[], forward_func=[], layer_id=None ...@@ -123,12 +157,23 @@ def prim_check_dim(layer, indent=1, init_func=[], forward_func=[], layer_id=None
forward_func.extend(gen_codes(lines, indent=indent)) forward_func.extend(gen_codes(lines, indent=indent))
def prim_constant(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_constant(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}".format(layer.outputs[0], layer.attrs["value"]) line = "{} = {}".format(layer.outputs[0], layer.attrs["value"])
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_contain(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_contain(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} in {}".format(layer.outputs[0], line = "{} = {} in {}".format(layer.outputs[0],
get_value(layer, "element", different_attrs), get_value(layer, "element", different_attrs),
get_value(layer, "input", different_attrs)) get_value(layer, "input", different_attrs))
...@@ -137,108 +182,182 @@ def prim_contain(layer, indent=1, init_func=[], forward_func=[], layer_id=None, ...@@ -137,108 +182,182 @@ def prim_contain(layer, indent=1, init_func=[], forward_func=[], layer_id=None,
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_dict(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_dict(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = dict()".format(layer.outputs[0]) line = "{} = dict()".format(layer.outputs[0])
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_dict_construct(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_dict_construct(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
lines = list() lines = list()
line = "{} = dict()".format(layer.outputs[0]) line = "{} = dict()".format(layer.outputs[0])
lines.append(line) lines.append(line)
for i in range(len(layer.inputs)): for i in range(len(layer.inputs)):
line = "{}[{}] = {}".format(layer.outputs[0], line = "{}[{}] = {}".format(
get_value(layer, "key{}".format(i), different_attrs), layer.outputs[0],
get_value(layer, "value{}".format(i), different_attrs)) get_value(layer, "key{}".format(i), different_attrs),
get_value(layer, "value{}".format(i), different_attrs))
lines.append(line) lines.append(line)
forward_func.extend(gen_codes(lines, indent=indent)) forward_func.extend(gen_codes(lines, indent=indent))
def prim_dict2values(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_dict2values(layer,
line = "{} = list({}.values())".format(layer.outputs[0], indent=1,
get_value(layer, "x", different_attrs)) init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = list({}.values())".format(
layer.outputs[0], get_value(layer, "x", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_div(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_div(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {} / {}".format(layer.outputs[0], line = "{} = {} / {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_eq(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None,is_return_line=False): def prim_eq(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} == {}".format(layer.outputs[0], line = "{} = {} == {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_equal(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_equal(layer,
line = "{} = {}".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_exception(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_exception(layer,
line = "raise Exception({})".format(get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "raise Exception({})".format(
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_float(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_float(layer,
line = "{} = float({})".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = float({})".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_floor(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_floor(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = math.floor({})".format(layer.outputs[0], line = "{} = math.floor({})".format(layer.outputs[0],
get_value(layer, "x", different_attrs)) get_value(layer, "x", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_floordiv(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_floordiv(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {} // {}".format(layer.outputs[0], line = "{} = {} // {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_getitem(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_getitem(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}[{}]".format(layer.outputs[0], line = "{} = {}[{}]".format(layer.outputs[0],
get_value(layer, "list", different_attrs), get_value(layer, "list", different_attrs),
get_value(layer, "index", different_attrs)) get_value(layer, "index", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_gt(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_gt(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} > {}".format(layer.outputs[0], line = "{} = {} > {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_if(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
try: try:
exec_s = None exec_s = None
for line in forward_func: for line in forward_func:
s = line.replace(" ", "") s = line.replace(" ", "")
if s.startswith("{} = ".format(get_value(layer, "input", different_attrs))): if s.startswith("{} = ".format(
get_value(layer, "input", different_attrs))):
exec_s = s.split(" = ")[1] exec_s = s.split(" = ")[1]
lc=locals() lc = locals()
if exec_s is not None: if exec_s is not None:
exec("if_result = {}".format(exec_s)) exec("if_result = {}".format(exec_s))
else: else:
exec("if_result = {}".format(get_value(layer, "input", different_attrs))) exec("if_result = {}".format(
get_value(layer, "input", different_attrs)))
if_result = lc['if_result'] if_result = lc['if_result']
if if_result: if if_result:
block = layer.blocks[0] block = layer.blocks[0]
else: else:
block = layer.blocks[1] block = layer.blocks[1]
if len(block.layers) > 0: if len(block.layers) > 0:
b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent) b_init_lines, b_forward_lines = block.gen_dygraph_code(
indent=indent)
init_func.extend(b_init_lines) init_func.extend(b_init_lines)
forward_func.extend(b_forward_lines) forward_func.extend(b_forward_lines)
except: except:
...@@ -249,7 +368,8 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diffe ...@@ -249,7 +368,8 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diffe
line = "pass" line = "pass"
forward_func.extend(gen_codes([line], indent=indent + 1)) forward_func.extend(gen_codes([line], indent=indent + 1))
else: else:
b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1) b_init_lines, b_forward_lines = block.gen_dygraph_code(
indent=indent + 1)
init_func.extend(b_init_lines) init_func.extend(b_init_lines)
forward_func.extend(b_forward_lines) forward_func.extend(b_forward_lines)
block = layer.blocks[1] block = layer.blocks[1]
...@@ -263,30 +383,54 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diffe ...@@ -263,30 +383,54 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diffe
forward_func.extend(b_forward_lines) forward_func.extend(b_forward_lines)
def prim_int(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_int(layer,
line = "{} = int({})".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = int({})".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_is(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_is(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} is {}".format(layer.outputs[0], line = "{} = {} is {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_isinstance(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_isinstance(layer,
line = "{} = isinstance({}, {})".format(layer.outputs[0], indent=1,
get_value(layer, "input", different_attrs), init_func=[],
layer.attrs["cls"]) forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = isinstance({}, {})".format(
layer.outputs[0],
get_value(layer, "input", different_attrs), layer.attrs["cls"])
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_isnot(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_isnot(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} is not {}".format(layer.outputs[0], line = "{} = {} is not {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
...@@ -295,53 +439,94 @@ def prim_isnot(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di ...@@ -295,53 +439,94 @@ def prim_isnot(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_le(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_le(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} <= {}".format(layer.outputs[0], line = "{} = {} <= {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_len(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_len(layer,
line = "{} = len({})".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = len({})".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_len2list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_len2list(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
lines = [] lines = []
lines.append("{} = []".format(layer.outputs[0])) lines.append("{} = []".format(layer.outputs[0]))
lines.append("for i in range({}):".format(get_value(layer, "len", different_attrs))) lines.append("for i in range({}):".format(
get_value(layer, "len", different_attrs)))
lines.append(" {}.append(i)".format(layer.outputs[0])) lines.append(" {}.append(i)".format(layer.outputs[0]))
forward_func.extend(gen_codes(lines, indent=indent)) forward_func.extend(gen_codes(lines, indent=indent))
def prim_lt(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_lt(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} < {}".format(layer.outputs[0], line = "{} = {} < {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_list(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
input_len = len(layer.inputs) + len(layer.attrs) input_len = len(layer.inputs) + len(layer.attrs)
inputs_list = list() inputs_list = list()
for i in range(input_len): for i in range(input_len):
inputs_list.append(get_value(layer, "input{}".format(i), different_attrs)) inputs_list.append(
get_value(layer, "input{}".format(i), different_attrs))
inputs_str = ', '.join(inputs_list) inputs_str = ', '.join(inputs_list)
line = "{} = [{}]".format(layer.outputs[0], inputs_str) line = "{} = [{}]".format(layer.outputs[0], inputs_str)
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_list_unpack(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_list_unpack(layer,
line = "{} = {}".format(", ".join(layer.outputs), get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}".format(", ".join(layer.outputs),
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_loop(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_loop(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
loop_range = get_value(layer, "input", different_attrs) loop_range = get_value(layer, "input", different_attrs)
line = "for {} in range({}):".format(layer.outputs[1], loop_range) line = "for {} in range({}):".format(layer.outputs[1], loop_range)
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
...@@ -351,171 +536,303 @@ def prim_loop(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif ...@@ -351,171 +536,303 @@ def prim_loop(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif
forward_func.extend(b_forward_lines) forward_func.extend(b_forward_lines)
def prim_min(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_min(layer,
line = "{} = min({})".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = min({})".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_mul(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_mul(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {} * {}".format(layer.outputs[0], line = "{} = {} * {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_ne(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_ne(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} != {}".format(layer.outputs[0], line = "{} = {} != {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_neg(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_neg(layer,
line = "{} = -{}".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = -{}".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_not(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_not(layer,
line = "{} = not {}".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = not {}".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_or(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_or(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} or {}".format(layer.outputs[0], line = "{} = {} or {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_replaceitem(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_replaceitem(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{}[{}] = {}".format( line = "{}[{}] = {}".format(
get_value(layer, "list", layer_id, different_attrs), get_value(layer, "list", layer_id, different_attrs),
get_value(layer, "index", layer_id, different_attrs), get_value(layer, "index", layer_id, different_attrs),
get_value(layer, "item", layer_id, different_attrs)) get_value(layer, "item", layer_id, different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_requires_grad(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_requires_grad(layer,
line = "{} = not {}.stop_gradient".format(layer.outputs[0], indent=1,
get_value(layer, "input", different_attrs)) init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = not {}.stop_gradient".format(
layer.outputs[0], get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_rsub(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_rsub(layer,
line = "{} = {} - {} * {}".format(layer.outputs[0], indent=1,
get_value(layer, "y", different_attrs), init_func=[],
get_value(layer, "x", different_attrs), forward_func=[],
get_value(layer, "alpha", different_attrs)) layer_id=None,
different_attrs=None):
line = "{} = {} - {} * {}".format(
layer.outputs[0],
get_value(layer, "y", different_attrs),
get_value(layer, "x", different_attrs),
get_value(layer, "alpha", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_select(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_select(layer,
line = "{} = {}[".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}[".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
for dim in range(layer.attrs["dim"]): for dim in range(layer.attrs["dim"]):
line += ":, " line += ":, "
line += (get_value(layer, "index", different_attrs) + "]") line += (get_value(layer, "index", different_attrs) + "]")
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_set_attr(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_set_attr(layer,
line = "{} = {}".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_set_item(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_set_item(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{}[{}] = {}".format( line = "{}[{}] = {}".format(
get_value(layer, "dict", different_attrs), get_value(layer, "dict", different_attrs),
get_value(layer, "key", different_attrs), get_value(layer, "value", different_attrs)) get_value(layer, "key", different_attrs),
get_value(layer, "value", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_shape(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
line = "{} = {}.shape".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent))
def prim_shape_dim(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_shape(layer,
line = "{} = {}.shape[{}]".format(layer.outputs[0], indent=1,
get_value(layer, "input", different_attrs), init_func=[],
get_value(layer, "dim", different_attrs)) forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}.shape".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_slice(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_shape_dim(layer,
line = "{} = {}[{}: {}: {}]".format(layer.outputs[0], indent=1,
get_value(layer, "input", different_attrs), init_func=[],
get_value(layer, "start", different_attrs), forward_func=[],
get_value(layer, "end", different_attrs), layer_id=None,
get_value(layer, "step", different_attrs)) different_attrs=None):
forward_func.extend(gen_codes([line], indent=indent)) line = "{} = {}.shape[{}]".format(
layer.outputs[0],
get_value(layer, "input", different_attrs),
def prim_startswith(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): get_value(layer, "dim", different_attrs))
line = "{} = {}.startswith({})".format(layer.outputs[0], forward_func.extend(gen_codes([line], indent=indent))
get_value(layer, "input", different_attrs),
get_value(layer, "start_str", different_attrs))
def prim_slice(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}[{}: {}: {}]".format(
layer.outputs[0],
get_value(layer, "input", different_attrs),
get_value(layer, "start", different_attrs),
get_value(layer, "end", different_attrs),
get_value(layer, "step", different_attrs))
forward_func.extend(gen_codes([line], indent=indent))
def prim_startswith(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {}.startswith({})".format(
layer.outputs[0],
get_value(layer, "input", different_attrs),
get_value(layer, "start_str", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_str(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_str(layer,
line = "{} = str({})".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = str({})".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_sub(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_sub(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
if int(float(get_value(layer, "alpha", different_attrs))) == 1: if int(float(get_value(layer, "alpha", different_attrs))) == 1:
line = "{} = {} - {}".format(layer.outputs[0], line = "{} = {} - {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
else: else:
line = "{} = {} - {} * {}".format(layer.outputs[0], line = "{} = {} - {} * {}".format(
get_value(layer, "x", different_attrs), layer.outputs[0],
get_value(layer, "alpha", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "alpha", different_attrs),
get_value(layer, "y", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_tuple(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_tuple(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
input_len = len(layer.inputs) + len(layer.attrs) input_len = len(layer.inputs) + len(layer.attrs)
inputs_list = list() inputs_list = list()
for i in range(input_len): for i in range(input_len):
inputs_list.append(get_value(layer, "input{}".format(i), different_attrs)) inputs_list.append(
get_value(layer, "input{}".format(i), different_attrs))
inputs_str = ', '.join(inputs_list) inputs_str = ', '.join(inputs_list)
line = "{} = ({})".format(layer.outputs[0], inputs_str) line = "{} = ({})".format(layer.outputs[0], inputs_str)
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_tuple_unpack(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_tuple_unpack(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
outputs_str = ', '.join(layer.outputs) outputs_str = ', '.join(layer.outputs)
line = "{} = {}".format(outputs_str, get_value(layer, "input", different_attrs)) line = "{} = {}".format(outputs_str,
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_type(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_type(layer,
line = "{} = {}.dtype".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}.dtype".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_var2list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_var2list(layer,
line = "{} = {}.numpy().tolist()".format(layer.outputs[0], indent=1,
get_value(layer, "input", different_attrs)) init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}.numpy().tolist()".format(
layer.outputs[0], get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_warnings(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_warnings(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
lines = ["import warnings"] lines = ["import warnings"]
line = "warnings.warn({}, stacklevel={})".format( line = "warnings.warn({}, stacklevel={})".format(
get_value(layer, "input", different_attrs), layer.attrs["stacklevel"]) get_value(layer, "input", different_attrs), layer.attrs["stacklevel"])
lines.append(line) lines.append(line)
forward_func.extend(gen_codes(lines, indent=indent)) forward_func.extend(gen_codes(lines, indent=indent))
...@@ -37,7 +37,7 @@ class PyTorchOpMapper(OpMapper): ...@@ -37,7 +37,7 @@ class PyTorchOpMapper(OpMapper):
self.scope_name_list = list() self.scope_name_list = list()
self.scope_name2id = dict() self.scope_name2id = dict()
self.inputs_info = dict() self.inputs_info = dict()
self.output2id = dict() # output名字和layer_id的关系,用于lstm去除前面的node self.output2id = dict() # output名字和layer_id的关系,用于lstm去除前面的node
# 转换 # 转换
if not self.op_checker(decoder.graph): if not self.op_checker(decoder.graph):
raise Exception("Model is not supported yet.") raise Exception("Model is not supported yet.")
...@@ -50,6 +50,7 @@ class PyTorchOpMapper(OpMapper): ...@@ -50,6 +50,7 @@ class PyTorchOpMapper(OpMapper):
op_list.append(node.kind()) op_list.append(node.kind())
for block in node.blocks(): for block in node.blocks():
_update_op_list(block) _update_op_list(block)
op_list = list() op_list = list()
_update_op_list(script_graph) _update_op_list(script_graph)
op_list = list(set(op_list)) op_list = list(set(op_list))
...@@ -62,11 +63,11 @@ class PyTorchOpMapper(OpMapper): ...@@ -62,11 +63,11 @@ class PyTorchOpMapper(OpMapper):
return True return True
else: else:
if len(unsupported_ops) > 0: if len(unsupported_ops) > 0:
print("\n========= {} OPs are not supported yet ===========".format( print("\n========= {} OPs are not supported yet ===========".
len(unsupported_ops))) format(len(unsupported_ops)))
for op in unsupported_ops: for op in unsupported_ops:
print("========== {} ============".format(op)) print("========== {} ============".format(op))
return False return False
def traverse(self, script_graph, parent_layer=None): def traverse(self, script_graph, parent_layer=None):
# 用于获取graph的输入 # 用于获取graph的输入
...@@ -85,20 +86,24 @@ class PyTorchOpMapper(OpMapper): ...@@ -85,20 +86,24 @@ class PyTorchOpMapper(OpMapper):
current_node_outputs.extend(outputs) current_node_outputs.extend(outputs)
# 初始化 # 初始化
graph = PaddleGraph(source_type="pytorch", parent_layer=parent_layer, graph_type="dygraph") graph = PaddleGraph(
source_type="pytorch",
parent_layer=parent_layer,
graph_type="dygraph")
if "TopLevelTracedModule" in str(type(self.script)): if "TopLevelTracedModule" in str(type(self.script)):
graph.set_script(self.script) graph.set_script(self.script)
current_node_outputs = [] current_node_outputs = []
graph_inputs = [] graph_inputs = []
# 转换输入节点 # 转换输入节点
if isinstance(script_graph, torch._C.Graph): if isinstance(script_graph, torch._C.Graph):
input_ct = 0 input_ct = 0
for i, ivalue in enumerate(script_graph.inputs()): for i, ivalue in enumerate(script_graph.inputs()):
node = ivalue.node() node = ivalue.node()
if str(ivalue.type()) not in ["Tensor", "Dict[str, Tensor]"]: if str(ivalue.type()) not in ["Tensor", "Dict[str, Tensor]"]:
graph.set_name(str(ivalue.type()).split(".")[-1]) graph.set_name(str(ivalue.type()).split(".")[-1])
continue continue
inputs, outputs = self.data(graph, node, ivalue.unique(), input_ct) inputs, outputs = self.data(graph, node,
ivalue.unique(), input_ct)
input_ct += 1 input_ct += 1
# 转换中间节点 # 转换中间节点
for node in script_graph.nodes(): for node in script_graph.nodes():
...@@ -183,8 +188,9 @@ class PyTorchOpMapper(OpMapper): ...@@ -183,8 +188,9 @@ class PyTorchOpMapper(OpMapper):
outputs=[output_name], outputs=[output_name],
scope_name=scope_name, scope_name=scope_name,
dtype=string(str(param.dtype)), dtype=string(str(param.dtype)),
shape = param.shape, shape=param.shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)") default_initializer="paddle.nn.initializer.Constant(value=0.0)"
)
self.output2id[output_name] = layer_id self.output2id[output_name] = layer_id
else: else:
if isinstance(param, dict) and "Tensor" in param and \ if isinstance(param, dict) and "Tensor" in param and \
...@@ -211,8 +217,9 @@ class PyTorchOpMapper(OpMapper): ...@@ -211,8 +217,9 @@ class PyTorchOpMapper(OpMapper):
outputs=[output_name], outputs=[output_name],
scope_name=scope_name, scope_name=scope_name,
dtype=string(str(param.dtype)), dtype=string(str(param.dtype)),
shape = param.shape, shape=param.shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)") default_initializer="paddle.nn.initializer.Constant(value=0.0)"
)
node_outputs.append(output_name) node_outputs.append(output_name)
self.output2id[output_name] = layer_id self.output2id[output_name] = layer_id
return return
...@@ -232,7 +239,8 @@ class PyTorchOpMapper(OpMapper): ...@@ -232,7 +239,8 @@ class PyTorchOpMapper(OpMapper):
value=string(param) value=string(param)
if isinstance(param, str) else param) if isinstance(param, str) else param)
node_outputs.append(output_name) node_outputs.append(output_name)
elif node.kind() == "prim::Constant" and output_name in self.pytorch_params: elif node.kind(
) == "prim::Constant" and output_name in self.pytorch_params:
param = self.pytorch_params[output_name] param = self.pytorch_params[output_name]
self.paddle_params[output_name] = param self.paddle_params[output_name] = param
layer_id = graph.add_layer( layer_id = graph.add_layer(
...@@ -241,11 +249,10 @@ class PyTorchOpMapper(OpMapper): ...@@ -241,11 +249,10 @@ class PyTorchOpMapper(OpMapper):
outputs=[output_name], outputs=[output_name],
scope_name=scope_name, scope_name=scope_name,
dtype=string(str(param.dtype)), dtype=string(str(param.dtype)),
shape = param.shape, shape=param.shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)") default_initializer="paddle.nn.initializer.Constant(value=0.0)")
self.output2id[output_name] = layer_id self.output2id[output_name] = layer_id
def _get_inputs_name(self, node): def _get_inputs_name(self, node):
inputs_name = [] inputs_name = []
inputs_node = [] inputs_node = []
...@@ -256,7 +263,6 @@ class PyTorchOpMapper(OpMapper): ...@@ -256,7 +263,6 @@ class PyTorchOpMapper(OpMapper):
inputs_node.append(script_input_node) inputs_node.append(script_input_node)
inputs_name.append(input_name) inputs_name.append(input_name)
return inputs_name, inputs_node return inputs_name, inputs_node
def data(self, graph, node, uid, input_ct): def data(self, graph, node, uid, input_ct):
scope_name = self.normalize_scope_name(node) scope_name = self.normalize_scope_name(node)
...@@ -276,7 +282,8 @@ class PyTorchOpMapper(OpMapper): ...@@ -276,7 +282,8 @@ class PyTorchOpMapper(OpMapper):
data=output_name) data=output_name)
if self.input_examples is not None: if self.input_examples is not None:
input_np = self.input_examples[input_ct].detach().numpy() input_np = self.input_examples[input_ct].detach().numpy()
self.inputs_info[output_name] = [list(input_np.shape), str(input_np.dtype)] self.inputs_info[
output_name] = [list(input_np.shape), str(input_np.dtype)]
return [], [output_name] return [], [output_name]
def equal(self, graph, node, uid=None, parent_layer=None, index=None): def equal(self, graph, node, uid=None, parent_layer=None, index=None):
...@@ -289,7 +296,8 @@ class PyTorchOpMapper(OpMapper): ...@@ -289,7 +296,8 @@ class PyTorchOpMapper(OpMapper):
control_output_id = index - 1 control_output_id = index - 1
output_node_name = parent_layer.outputs[control_output_id] output_node_name = parent_layer.outputs[control_output_id]
current_outputs = [output_node_name] current_outputs = [output_node_name]
self._check_input(graph, node, input_node_name, current_outputs, scope_name) self._check_input(graph, node, input_node_name, current_outputs,
scope_name)
graph.add_layer( graph.add_layer(
"prim.equal", "prim.equal",
inputs={'input': input_node_name}, inputs={'input': input_node_name},
...@@ -321,20 +329,20 @@ class PyTorchOpMapper(OpMapper): ...@@ -321,20 +329,20 @@ class PyTorchOpMapper(OpMapper):
self.scope_name2id[i][ns] = 0 self.scope_name2id[i][ns] = 0
real_scope_name = "/".join(name_segments[1:]) real_scope_name = "/".join(name_segments[1:])
real_father_scope_name = "/".join(name_segments[1:-1]) real_father_scope_name = "/".join(name_segments[1:-1])
for i, ns in enumerate(name_segments): for i, ns in enumerate(name_segments):
if i == 0: if i == 0:
continue continue
if self.scope_name2id[i][ns] != 0: if self.scope_name2id[i][ns] != 0:
name_segments[i] = name_segments[i] + \ name_segments[i] = name_segments[i] + \
"__{}".format(self.scope_name2id[i][ns]) "__{}".format(self.scope_name2id[i][ns])
prefix_scope_name = "/".join(name_segments[1 :i + 1]) prefix_scope_name = "/".join(name_segments[1:i + 1])
is_found = False is_found = False
for j in range(len(self.scope_name_list)): for j in range(len(self.scope_name_list)):
last_scope_name = self.scope_name_list[-1-j] last_scope_name = self.scope_name_list[-1 - j]
if last_scope_name.startswith(prefix_scope_name + "/") \ if last_scope_name.startswith(prefix_scope_name + "/") \
or last_scope_name == prefix_scope_name: or last_scope_name == prefix_scope_name:
if j != 0: # and i != len(name_segments) - 1: if j != 0: # and i != len(name_segments) - 1:
is_found = True is_found = True
origin_name_segment_i = name_segments[i].split("__")[0] origin_name_segment_i = name_segments[i].split("__")[0]
self.scope_name2id[i][origin_name_segment_i] += 1 self.scope_name2id[i][origin_name_segment_i] += 1
...@@ -346,4 +354,3 @@ class PyTorchOpMapper(OpMapper): ...@@ -346,4 +354,3 @@ class PyTorchOpMapper(OpMapper):
real_scope_name = "/".join(name_segments[1:]) real_scope_name = "/".join(name_segments[1:])
self.scope_name_list.append(real_scope_name) self.scope_name_list.append(real_scope_name)
return real_scope_name return real_scope_name
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册