未验证 提交 0c93656b 编写于 作者: W WJJ1995 提交者: GitHub

Support PyTorch InstanceNorm2d op (#638)

* Add pytorch LeakyRelu op

* fix pytorch InstanceNorm op

* update paddle code format
上级 d5c2ceb3
...@@ -2637,7 +2637,7 @@ def aten_instance_norm(mapper, graph, node): ...@@ -2637,7 +2637,7 @@ def aten_instance_norm(mapper, graph, node):
# 处理输入1,即%88 # 处理输入1,即%88
if inputs_name[1] in mapper.pytorch_params: if inputs_name[1] in mapper.pytorch_params:
weights = mapper.pytorch_params[inputs_name[1]] weights = mapper.pytorch_params[inputs_name[1]]
mapper.paddle_params[op_name + ".weight"] = weights mapper.paddle_params[op_name + ".scale"] = weights
layer_attrs['num_features'] = weights.shape[0] layer_attrs['num_features'] = weights.shape[0]
# 处理输入2,即%85 # 处理输入2,即%85
if inputs_name[2] in mapper.pytorch_params: if inputs_name[2] in mapper.pytorch_params:
...@@ -2888,6 +2888,42 @@ def aten_leaky_relu_(mapper, graph, node): ...@@ -2888,6 +2888,42 @@ def aten_leaky_relu_(mapper, graph, node):
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_leaky_relu(mapper, graph, node):
""" 构造leaky relu激活的PaddleLayer。
TorchScript示例:
%input.117 : Tensor = aten::leaky_relu(%input.114, %1570)
参数含义:
%input.117 (Tensor): 输出,leaky relu后的结果。
%input.114 (Tensor): 需要leaky relu的Tensor。
%1570 (float): 输入中的元素小于0时的斜率。
"""
scope_name = mapper.normalize_scope_name(node)
op_name = name_generator("leakly_relu", mapper.nn_name2id)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [op_name, output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%result.5
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values())
# 处理输入1,即%1570
layer_attrs["negative_slope"] = mapper.attrs[inputs_name[1]]
graph.add_layer(
"paddle.nn.LeakyReLU",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs
def aten_len(mapper, graph, node): def aten_len(mapper, graph, node):
""" 构造获取list长度的PaddleLayer。 """ 构造获取list长度的PaddleLayer。
TorchScript示例: TorchScript示例:
......
...@@ -13,13 +13,12 @@ ...@@ -13,13 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy import copy
import os.path as osp import os.path as osp
from treelib import Tree from treelib import Tree
from queue import Queue from queue import Queue
from x2paddle.optimizer.pytorch_code_optimizer.layer_code_generator import gen_layer_code, rename_layers, NN_KERNEL_WITH_PARAMS, NN_KERNEL_NAME from x2paddle.optimizer.pytorch_code_optimizer.layer_code_generator import gen_layer_code, rename_layers, NN_KERNEL_WITH_PARAMS, NN_KERNEL_NAME
from x2paddle.optimizer.pytorch_code_optimizer.subgraphs_union import distinguish_sequential, get_inputs_outputs from x2paddle.optimizer.pytorch_code_optimizer.subgraphs_union import distinguish_sequential, get_inputs_outputs
from x2paddle.core.program import PaddleLayer from x2paddle.core.program import PaddleLayer
from x2paddle.optimizer.pytorch_code_optimizer.parameter_tree import PamareterNode, PamareterTree from x2paddle.optimizer.pytorch_code_optimizer.parameter_tree import PamareterNode, PamareterTree
...@@ -29,21 +28,22 @@ SEPARATOR_IN_SCOPE = "/" ...@@ -29,21 +28,22 @@ SEPARATOR_IN_SCOPE = "/"
class HierarchicalTree(Tree): class HierarchicalTree(Tree):
""" 定义层次树。 """ 定义层次树。
""" """
def __init__(self, pd_graph): def __init__(self, pd_graph):
super(HierarchicalTree, self).__init__() super(HierarchicalTree, self).__init__()
self.pd_graph = pd_graph self.pd_graph = pd_graph
self.script = pd_graph.script self.script = pd_graph.script
self.create_node("Module", self.pd_graph.name) # create root self.create_node("Module", self.pd_graph.name) # create root
self._hierarchical_order = dict() self._hierarchical_order = dict()
self.codes = list() self.codes = list()
self.identifier_idx = dict() self.identifier_idx = dict()
self.param_tree = PamareterTree() self.param_tree = PamareterTree()
self.module_name2count = dict() self.module_name2count = dict()
self.scope_name_list = list() self.scope_name_list = list()
def insert(self, layer): def insert(self, layer):
""" 往层次树中插入节点。 """ 往层次树中插入节点。
Args: Args:
layer (PaddleLayer): 需要插入的节点。 layer (PaddleLayer): 需要插入的节点。
""" """
...@@ -56,12 +56,13 @@ class HierarchicalTree(Tree): ...@@ -56,12 +56,13 @@ class HierarchicalTree(Tree):
for input_layer_id in self.pd_graph.edges_in[layer_id]: for input_layer_id in self.pd_graph.edges_in[layer_id]:
layer_id_list.append(int(input_layer_id)) layer_id_list.append(int(input_layer_id))
layer_id_list = list(set(layer_id_list)) layer_id_list = list(set(layer_id_list))
layer_id_list.sort(reverse=True) layer_id_list.sort(reverse=True)
if layer.kernel == "prim.tuple": if layer.kernel == "prim.tuple":
for i, input_layer_id in enumerate(layer_id_list): for i, input_layer_id in enumerate(layer_id_list):
input_layer_id_str = str(input_layer_id) input_layer_id_str = str(input_layer_id)
scope_name = self.pd_graph.layers[input_layer_id_str].scope_name scope_name = self.pd_graph.layers[
input_layer_id_str].scope_name
if i == 0: if i == 0:
min_scope_name = scope_name min_scope_name = scope_name
else: else:
...@@ -73,30 +74,34 @@ class HierarchicalTree(Tree): ...@@ -73,30 +74,34 @@ class HierarchicalTree(Tree):
if len1 > len2: if len1 > len2:
min_scope_name = scope_name min_scope_name = scope_name
if min_scope_name == "": if min_scope_name == "":
self.create_node(tag=layer.id, self.create_node(
identifier="no_scope_" + layer.id, tag=layer.id,
parent=self.pd_graph.name, identifier="no_scope_" + layer.id,
data=layer) parent=self.pd_graph.name,
return data=layer)
return
layer.scope_name = min_scope_name layer.scope_name = min_scope_name
scope_name = min_scope_name scope_name = min_scope_name
else: else:
for input_layer_id in layer_id_list: for input_layer_id in layer_id_list:
input_layer_id_str = str(input_layer_id) input_layer_id_str = str(input_layer_id)
if self.pd_graph.layers[input_layer_id_str].scope_name != "": if self.pd_graph.layers[
scope_name = self.pd_graph.layers[input_layer_id_str].scope_name input_layer_id_str].scope_name != "":
break scope_name = self.pd_graph.layers[
input_layer_id_str].scope_name
break
layer.scope_name = scope_name layer.scope_name = scope_name
else: else:
self.create_node(tag=layer.id, self.create_node(
identifier="no_scope_" + layer.id, tag=layer.id,
parent=self.pd_graph.name, identifier="no_scope_" + layer.id,
data=layer) parent=self.pd_graph.name,
return data=layer)
return
scopes = scope_name.split(SEPARATOR_IN_SCOPE) scopes = scope_name.split(SEPARATOR_IN_SCOPE)
for idx, scope in enumerate(scopes): for idx, scope in enumerate(scopes):
parent = SEPARATOR_IN_SCOPE.join(scopes[:idx])#.lower() parent = SEPARATOR_IN_SCOPE.join(scopes[:idx]) #.lower()
identifier = SEPARATOR_IN_SCOPE.join(scopes[:idx + 1])#.lower() identifier = SEPARATOR_IN_SCOPE.join(scopes[:idx + 1]) #.lower()
if self.contains(identifier): if self.contains(identifier):
if idx != len(scopes) - 1: if idx != len(scopes) - 1:
parent_node = self.parent(identifier) parent_node = self.parent(identifier)
...@@ -109,11 +114,13 @@ class HierarchicalTree(Tree): ...@@ -109,11 +114,13 @@ class HierarchicalTree(Tree):
self.identifier_idx[identifier] = 0 self.identifier_idx[identifier] = 0
else: else:
self.identifier_idx[identifier] += 1 self.identifier_idx[identifier] += 1
identifier_name = identifier + SEPARATOR_IN_SCOPE + str(self.identifier_idx[identifier]) identifier_name = identifier + SEPARATOR_IN_SCOPE + str(
self.create_node(tag=scopes[idx], self.identifier_idx[identifier])
identifier=identifier_name, self.create_node(
parent=identifier, tag=scopes[idx],
data=data) identifier=identifier_name,
parent=identifier,
data=data)
data.scope_name = identifier_name data.scope_name = identifier_name
continue continue
else: else:
...@@ -125,23 +132,28 @@ class HierarchicalTree(Tree): ...@@ -125,23 +132,28 @@ class HierarchicalTree(Tree):
self.identifier_idx[identifier] = 0 self.identifier_idx[identifier] = 0
else: else:
self.identifier_idx[identifier] += 1 self.identifier_idx[identifier] += 1
self.create_node(tag=scopes[idx], self.create_node(
identifier=identifier + SEPARATOR_IN_SCOPE + str(self.identifier_idx[identifier]), tag=scopes[idx],
parent=identifier, identifier=identifier + SEPARATOR_IN_SCOPE +
data=data) str(self.identifier_idx[identifier]),
parent=identifier,
data=data)
self.identifier_idx[identifier] += 1 self.identifier_idx[identifier] += 1
data = layer data = layer
self.create_node(tag=scopes[idx], self.create_node(
identifier=identifier + SEPARATOR_IN_SCOPE + str(self.identifier_idx[identifier]), tag=scopes[idx],
parent=identifier, identifier=identifier + SEPARATOR_IN_SCOPE +
data=data) str(self.identifier_idx[identifier]),
parent=identifier,
data=data)
continue continue
if idx == 0 and not self.contains(identifier): if idx == 0 and not self.contains(identifier):
data = layer if idx == len(scopes) - 1 else None data = layer if idx == len(scopes) - 1 else None
self.create_node(tag=scopes[idx], self.create_node(
identifier=identifier, tag=scopes[idx],
parent=self.pd_graph.name, identifier=identifier,
data=data) parent=self.pd_graph.name,
data=data)
else: else:
if idx == len(scopes) - 1: if idx == len(scopes) - 1:
if parent == "": if parent == "":
...@@ -153,7 +165,8 @@ class HierarchicalTree(Tree): ...@@ -153,7 +165,8 @@ class HierarchicalTree(Tree):
identifiers = list() identifiers = list()
for child in childs: for child in childs:
child_identifier = child.identifier child_identifier = child.identifier
if child_identifier.startswith(prefix) and child_identifier != prefix: if child_identifier.startswith(
prefix) and child_identifier != prefix:
identifiers.append(child_identifier) identifiers.append(child_identifier)
if len(identifiers) == 0: if len(identifiers) == 0:
identifier = prefix + "_0" identifier = prefix + "_0"
...@@ -162,14 +175,15 @@ class HierarchicalTree(Tree): ...@@ -162,14 +175,15 @@ class HierarchicalTree(Tree):
for id_obj in identifiers: for id_obj in identifiers:
identifier_ids.append(int(id_obj.split("_")[-1])) identifier_ids.append(int(id_obj.split("_")[-1]))
identifier_ids.sort() identifier_ids.sort()
identifier = prefix + "_{}".format(identifier_ids[-1] + 1) identifier = prefix + "_{}".format(identifier_ids[-1] +
1)
data = layer if idx == len(scopes) - 1 else None data = layer if idx == len(scopes) - 1 else None
self.create_node(tag=scopes[idx], self.create_node(
identifier=identifier, tag=scopes[idx],
parent=parent, identifier=identifier,
data=data) parent=parent,
data=data)
def update_hierarchical_order(self): def update_hierarchical_order(self):
""" 更新层次排序,使用一个字典存储该信息, """ 更新层次排序,使用一个字典存储该信息,
关键字为当前层次,值为节点名字。 关键字为当前层次,值为节点名字。
...@@ -201,32 +215,36 @@ class HierarchicalTree(Tree): ...@@ -201,32 +215,36 @@ class HierarchicalTree(Tree):
diff_attrs_column.append(column) diff_attrs_column.append(column)
break break
return diff_attrs_column return diff_attrs_column
def merge_node(self, sub_layers_list, attrs_table, node_name2sub_layers, module_name): def merge_node(self, sub_layers_list, attrs_table, node_name2sub_layers,
module_name):
""" 将一个scope的节点合成一个Module(Class),并将对应的Class代码 """ 将一个scope的节点合成一个Module(Class),并将对应的Class代码
放到code字符串中。 放到code字符串中。
""" """
def get_node_name(sub_layers): def get_node_name(sub_layers):
for k, v in node_name2sub_layers.items(): for k, v in node_name2sub_layers.items():
if v == sub_layers: if v == sub_layers:
node_name = k node_name = k
break break
return node_name return node_name
sub_layers = sub_layers_list[0] sub_layers = sub_layers_list[0]
node_name = get_node_name(sub_layers) node_name = get_node_name(sub_layers)
sub_layers, _, _ = rename_layers(sub_layers) sub_layers, _, _ = rename_layers(sub_layers)
diff_attrs_column = self.analyze_attrs_table(attrs_table) diff_attrs_column = self.analyze_attrs_table(attrs_table)
if module_name is None: if module_name is None:
module_name = node_name.replace("/", "_") #node_name.split("/")[-1] module_name = node_name.replace("/", "_") #node_name.split("/")[-1]
module_name = module_name[0].upper() + module_name[1:] module_name = module_name[0].upper() + module_name[1:]
if module_name in self.module_name2count: if module_name in self.module_name2count:
module_name = module_name + "_0" module_name = module_name + "_0"
code_str = gen_layer_code(self.pd_graph, sub_layers, module_name, code_str = gen_layer_code(
different_attrs=diff_attrs_column) self.pd_graph,
sub_layers,
module_name,
different_attrs=diff_attrs_column)
self.codes.append(code_str) self.codes.append(code_str)
for sub_layers in sub_layers_list: for sub_layers in sub_layers_list:
inputs, outputs = get_inputs_outputs(self.pd_graph, sub_layers) inputs, outputs = get_inputs_outputs(self.pd_graph, sub_layers)
...@@ -241,23 +259,26 @@ class HierarchicalTree(Tree): ...@@ -241,23 +259,26 @@ class HierarchicalTree(Tree):
mn = module_name.lower() + "__" mn = module_name.lower() + "__"
else: else:
mn = module_name.lower() mn = module_name.lower()
outputs = ["{}/{}".format(mn, self.module_name2count[module_name])] + outputs outputs = [
"{}/{}".format(mn, self.module_name2count[module_name])
] + outputs
node_name = get_node_name(sub_layers) node_name = get_node_name(sub_layers)
diff_attrs = dict() diff_attrs = dict()
for column in diff_attrs_column: for column in diff_attrs_column:
diff_attrs[column] = attrs_table.get(column).loc[node_name] diff_attrs[column] = attrs_table.get(column).loc[node_name]
node_name_seg = node_name.split(SEPARATOR_IN_SCOPE) node_name_seg = node_name.split(SEPARATOR_IN_SCOPE)
node_name_seg[-1] = module_name.lower() node_name_seg[-1] = module_name.lower()
new_node_name = SEPARATOR_IN_SCOPE.join(node_name_seg) new_node_name = SEPARATOR_IN_SCOPE.join(node_name_seg)
new_layer = PaddleLayer(id=list(sub_layers.keys())[-1], new_layer = PaddleLayer(
kernel="module", id=list(sub_layers.keys())[-1],
inputs=inputs_dict, kernel="module",
outputs=outputs, inputs=inputs_dict,
scope_name=new_node_name, outputs=outputs,
module=module_name, scope_name=new_node_name,
**diff_attrs) module=module_name,
**diff_attrs)
_, nn_param_nodes, _ = rename_layers(sub_layers, self.param_tree) _, nn_param_nodes, _ = rename_layers(sub_layers, self.param_tree)
param_node = PamareterNode(old_name=outputs[0]) param_node = PamareterNode(old_name=outputs[0])
for node in nn_param_nodes: for node in nn_param_nodes:
...@@ -272,28 +293,26 @@ class HierarchicalTree(Tree): ...@@ -272,28 +293,26 @@ class HierarchicalTree(Tree):
self.pd_graph.build() self.pd_graph.build()
self[node_name].data = new_layer self[node_name].data = new_layer
def find_subgraph_diff(self, module_name2sub_layers,
def find_subgraph_diff(self, module_name2sub_layers, module_name2sub_identifiers, node_name2sub_layers, name): module_name2sub_identifiers, node_name2sub_layers,
name):
""" 查找子图的diff,主要是输入参数的diff。 """ 查找子图的diff,主要是输入参数的diff。
""" """
sub_layers = module_name2sub_layers[name] sub_layers = module_name2sub_layers[name]
sub_identifiers = module_name2sub_identifiers[name] sub_identifiers = module_name2sub_identifiers[name]
new_sub_layers, new_sub_sequentials, sequentials2attrs_table = distinguish_sequential(self.pd_graph, new_sub_layers, new_sub_sequentials, sequentials2attrs_table = distinguish_sequential(
name, self.pd_graph, name, sub_layers, sub_identifiers,
sub_layers, node_name2sub_layers)
sub_identifiers,
node_name2sub_layers)
module_name2sub_layers.pop(name) module_name2sub_layers.pop(name)
module_name2sub_identifiers.pop(name) module_name2sub_identifiers.pop(name)
for k, v in new_sub_layers.items(): for k, v in new_sub_layers.items():
module_name2sub_layers[k] = v module_name2sub_layers[k] = v
module_name2sub_identifiers[k] = new_sub_sequentials[k] module_name2sub_identifiers[k] = new_sub_sequentials[k]
return sequentials2attrs_table return sequentials2attrs_table
def convert_subgraph_to_layer(self): def convert_subgraph_to_layer(self):
""" """
1. 根据_hierarchical_order,从最深的层次开始将 1. 根据_hierarchical_order,从最深的层次开始将
子图合并成layer(即合成节点)。 子图合并成layer(即合成节点)。
2. 根据参数名新旧对应关系,更新参数名。 2. 根据参数名新旧对应关系,更新参数名。
...@@ -314,14 +333,16 @@ class HierarchicalTree(Tree): ...@@ -314,14 +333,16 @@ class HierarchicalTree(Tree):
sub_layers = dict() sub_layers = dict()
sub_identifiers = dict() sub_identifiers = dict()
for successor_name in node_inst.successors(self.identifier): for successor_name in node_inst.successors(self.identifier):
sub_layers[self[successor_name].data.id] = self[successor_name].data sub_layers[self[successor_name].data.id] = self[
sub_identifiers[self[successor_name].data.id] = self[successor_name].data.scope_name.split("/")[-1] successor_name].data
sub_identifiers[self[successor_name].data.id] = self[
successor_name].data.scope_name.split("/")[-1]
node_name2sub_layers[node_name] = sub_layers node_name2sub_layers[node_name] = sub_layers
node_name_segs = node_name.split("/") node_name_segs = node_name.split("/")
# 获取Module的名字 # 获取Module的名字
module = self.script module = self.script
is_largest_module = False # 当前module是否是最外层的Module is_largest_module = False # 当前module是否是最外层的Module
for name_id, name in enumerate(node_name_segs): for name_id, name in enumerate(node_name_segs):
name = name.split("__")[0] name = name.split("__")[0]
if not hasattr(module, name): if not hasattr(module, name):
...@@ -345,23 +366,29 @@ class HierarchicalTree(Tree): ...@@ -345,23 +366,29 @@ class HierarchicalTree(Tree):
len(module_name2sub_layers[module_name][0][list(module_name2sub_layers[module_name][0].keys())[-1]].outputs): len(module_name2sub_layers[module_name][0][list(module_name2sub_layers[module_name][0].keys())[-1]].outputs):
break break
if module_name not in module_name2sub_layers: if module_name not in module_name2sub_layers:
module_name2sub_layers[module_name] = [sub_layers] module_name2sub_layers[
module_name2sub_identifiers[module_name] = [sub_identifiers] module_name] = [sub_layers]
module_name2sub_identifiers[
module_name] = [sub_identifiers]
else: else:
module_name2sub_layers[module_name].append(sub_layers) module_name2sub_layers[module_name].append(
module_name2sub_identifiers[module_name].append(sub_identifiers) sub_layers)
module_name2sub_identifiers[module_name].append(
sub_identifiers)
else: else:
module_name2sub_layers[module_name].append(sub_layers) module_name2sub_layers[module_name].append(
module_name2sub_identifiers[module_name].append(sub_identifiers) sub_layers)
module_name2sub_identifiers[module_name].append(
sub_identifiers)
else: else:
module_name2sub_layers[module_name] = [sub_layers] module_name2sub_layers[module_name] = [sub_layers]
module_name2sub_identifiers[module_name] = [sub_identifiers] module_name2sub_identifiers[
module_name] = [sub_identifiers]
module_names = list(module_name2sub_layers.keys()) module_names = list(module_name2sub_layers.keys())
for module_name in module_names: for module_name in module_names:
sequentials2attrs_table = self.find_subgraph_diff(module_name2sub_layers, sequentials2attrs_table = self.find_subgraph_diff(
module_name2sub_identifiers, module_name2sub_layers, module_name2sub_identifiers,
node_name2sub_layers, node_name2sub_layers, module_name)
module_name)
for name in sequentials2attrs_table.keys(): for name in sequentials2attrs_table.keys():
if name.startswith("Sequential"): if name.startswith("Sequential"):
# 若Module的名字为Sequential,则以scope_name的名字来命名,在merge_node中实现 # 若Module的名字为Sequential,则以scope_name的名字来命名,在merge_node中实现
...@@ -371,54 +398,57 @@ class HierarchicalTree(Tree): ...@@ -371,54 +398,57 @@ class HierarchicalTree(Tree):
while module_name in current_module_name_list: while module_name in current_module_name_list:
module_name += "__0" module_name += "__0"
current_module_name_list.append(module_name) current_module_name_list.append(module_name)
self.merge_node(module_name2sub_layers[name], self.merge_node(module_name2sub_layers[name],
sequentials2attrs_table[name], sequentials2attrs_table[name],
node_name2sub_layers, node_name2sub_layers, module_name)
module_name)
def update_parameters(self): def update_parameters(self):
""" 更新参数。 """ 更新参数。
""" """
self.param_tree.traverse() self.param_tree.traverse()
full_old_name_list = copy.deepcopy(list(self.pd_graph.parameters.keys())) full_old_name_list = copy.deepcopy(
list(self.pd_graph.parameters.keys()))
for old_name, new_name in self.param_tree.old2new.items(): for old_name, new_name in self.param_tree.old2new.items():
for full_old_name in full_old_name_list: for full_old_name in full_old_name_list:
if full_old_name.startswith("{}.".format(old_name)): if full_old_name.startswith("{}.".format(old_name)):
full_new_name = full_old_name.replace("{}.".format(old_name), "{}.".format(new_name)) full_new_name = full_old_name.replace(
"{}.".format(old_name), "{}.".format(new_name))
params = self.pd_graph.parameters.pop(full_old_name) params = self.pd_graph.parameters.pop(full_old_name)
self.pd_graph.parameters[full_new_name] = params self.pd_graph.parameters[full_new_name] = params
if full_old_name == old_name: if full_old_name == old_name:
full_new_name = full_old_name.replace(old_name, new_name) full_new_name = full_old_name.replace(old_name, new_name)
params = self.pd_graph.parameters.pop(full_old_name) params = self.pd_graph.parameters.pop(full_old_name)
self.pd_graph.parameters[full_new_name] = params self.pd_graph.parameters[full_new_name] = params
def save_source_files(self, save_dir): def save_source_files(self, save_dir):
def gen_main_code(): def gen_main_code():
input_data_name = ', '.join(self.pd_graph.inputs) input_data_name = ', '.join(self.pd_graph.inputs)
run_func_list = list() run_func_list = list()
run_func_list.append("def main({}):".format(input_data_name)) run_func_list.append("def main({}):".format(input_data_name))
run_func_list.append(" # There are {} inputs.".format(len(self.pd_graph.inputs_info))) run_func_list.append(" # There are {} inputs.".format(
len(self.pd_graph.inputs_info)))
for k, v in self.pd_graph.inputs_info.items(): for k, v in self.pd_graph.inputs_info.items():
run_func_list.append(" # {}: shape-{}, type-{}.".format(k, v[0], v[1])) run_func_list.append(" # {}: shape-{}, type-{}.".format(k, v[
run_func_list.extend( 0], v[1]))
[" paddle.disable_static()", run_func_list.extend([
" params = paddle.load('{}')".format(osp.join(osp.abspath(save_dir), "model.pdparams")), " paddle.disable_static()",
" model = {}()".format(self.pd_graph.name), " params = paddle.load('{}')".format(
" model.set_dict(params)", osp.join(osp.abspath(save_dir), "model.pdparams")),
" model.eval()", " model = {}()".format(self.pd_graph.name),
" out = model({})".format(input_data_name), " model.set_dict(params)", " model.eval()",
" return out"]) " out = model({})".format(input_data_name), " return out"
])
return "\n".join(run_func_list) return "\n".join(run_func_list)
self.update_hierarchical_order() self.update_hierarchical_order()
self.convert_subgraph_to_layer() self.convert_subgraph_to_layer()
self.update_parameters() self.update_parameters()
import_list = ["import paddle", import_list = ["import paddle",
"import math", "import math",
"from x2paddle.op_mapper.pytorch2paddle " + \ "from x2paddle.op_mapper.pytorch2paddle " + \
"import pytorch_custom_layer as x2paddle_nn" "import pytorch_custom_layer as x2paddle_nn",
"\n",] "",]
import_str = "\n".join(import_list) import_str = "\n".join(import_list) + "\n"
if not osp.exists(save_dir): if not osp.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
f = open(osp.join(save_dir, 'x2paddle_code.py'), 'w') f = open(osp.join(save_dir, 'x2paddle_code.py'), 'w')
......
...@@ -263,7 +263,7 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()): ...@@ -263,7 +263,7 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
layer.kernel.startswith("custom_layer"): layer.kernel.startswith("custom_layer"):
line = "self.{}".format(layer.outputs[0]) line = "self.{}".format(layer.outputs[0])
if layer.kernel.startswith("custom_layer"): if layer.kernel.startswith("custom_layer"):
line += "= x2paddle_nn.{}(".format(layer.kernel.split(":")[-1]) line += " = x2paddle_nn.{}(".format(layer.kernel.split(":")[-1])
else: else:
line += " = {}(".format(layer.kernel) line += " = {}(".format(layer.kernel)
for k, v in layer.attrs.items(): for k, v in layer.attrs.items():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册