diff --git a/docs/pytorch_project_convertor/after_convert.md b/docs/pytorch_project_convertor/after_convert.md index c0c5ecbbbd3f02d4e82313bceac3e80aca9d92a9..d2766c566f8da125ee69115ad45af860fea341f9 100644 --- a/docs/pytorch_project_convertor/after_convert.md +++ b/docs/pytorch_project_convertor/after_convert.md @@ -24,19 +24,18 @@ class VocDataset(paddle.io.Dataset): ... ``` -3. 若存在Tensor对比操作(包含==、!=、<、<=、>、>=操作符),在对比操作符前添加对Tensor类型的判断,如果为bool型强转为int型,并在对比后转换回bool型。 +3. 若存在Tensor对比操作(包含==、!=、<、<=、>、>=操作符),在对比操作符前添加对Tensor类型的判断,如果为非bool型强转为bool型,并在对比后转换回bool型。 ``` -# 原始代码(其中c_trg是Tensor) +# 原始代码(其中c_trg是非bool型的Tensor) c_trg = c_trg == 0 # 替换后代码 -is_bool = False -if str(c_trg.dtype) == "VarType.BOOL": - c_trg = c_trg.cast("int32") - is_bool = True -c_trg = c_trg == 0 -if is_bool: - c_trg = c_trg.cast("bool") +c_trg = c_trg.cast("int32") +c_trg_tmp = paddle.zeros_like(c_trg) +paddle.assign(c_trg, c_trg_tmp) +c_trg_tmp = c_trg_tmp.cast("bool") +c_trg_tmp[:, i] = c_trg[:, i] == 0 +c_trg = c_trg_tmp ``` 4. 如若转换后的运行代码的入口为sh脚本文件,且其中有预训练模型路径,应将其中的预训练模型的路径字符串中的“.pth”、“.pt”、“.ckpt”替换为“.pdiparams”。 diff --git a/docs/pytorch_project_convertor/demo/stargan.md b/docs/pytorch_project_convertor/demo/stargan.md index 17a1b0000f2f829423d03ccb875c1b65b8d0836e..ceb2d2227cf227aed536730d681999ca84140b0e 100644 --- a/docs/pytorch_project_convertor/demo/stargan.md +++ b/docs/pytorch_project_convertor/demo/stargan.md @@ -80,17 +80,15 @@ class Solver(object): if j != i: c_trg[:, j] = 0 else: - # 如果为bool型,需要强转为int32, - # 在17-20行实现 - is_bool = False - if str(c_trg.dtype) == "VarType.BOOL": - c_trg = c_trg.cast("int32") - is_bool = True - c_trg[:, i] = (c_trg[:, i] == 0) - # 如果为bool类型转换为原类型 - # 在23-24行实现 - if is_bool: - c_trg = c_trg.cast("bool") + # 如果为非int型,需要强转为int32, + # 在18-22行实现 + # c_trg[:, i] = (c_trg[:, i] == 0) + c_trg = c_trg.cast("int32") + c_trg_tmp = paddle.zeros_like(c_trg) + paddle.assign(c_trg, c_trg_tmp) + c_trg_tmp = c_trg_tmp.cast("bool") + c_trg_tmp[:, i] = c_trg[:, i] == 0 + c_trg = c_trg_tmp ... ... ... diff --git a/x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py b/x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py index 669465cf218cf23a4d506e102c0012256503ee4a..6f46a3aeb274b05ba1414ba5e79793edd37b1225 100644 --- a/x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py +++ b/x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py @@ -15,7 +15,8 @@ import copy import numpy as np -from x2paddle.core.util import * +from x2paddle.core.util import name_generator, string +from x2paddle.utils import paddle_dtypes from x2paddle.core.program import PaddleGraph dtype_dict = { @@ -182,13 +183,8 @@ def aten_addmm(mapper, graph, node): # 获取当前节点输出的list current_outputs = [output_name] # 处理输入0,即%150 - mapper._check_input( - graph, - inputs_node[0], - inputs_name[0], - current_outputs, - scope_name, - add_dim=True) + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, + scope_name) layer_inputs["input"] = inputs_name[0] # 处理输入1,即%input.3 mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, @@ -247,13 +243,8 @@ def aten_add(mapper, graph, node): scope_name) layer_inputs["x"] = inputs_name[0] # 处理输入1,即%288 - mapper._check_input( - graph, - inputs_node[1], - inputs_name[1], - current_outputs, - scope_name, - add_dim=True) + mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, + scope_name) layer_inputs["y"] = inputs_name[1] # 获取当前节点输入的list current_inputs = list(layer_inputs.values()) @@ -289,13 +280,8 @@ def aten_add_(mapper, graph, node): scope_name) layer_inputs["x"] = inputs_name[0] # 处理输入1,即%150 - mapper._check_input( - graph, - inputs_node[1], - inputs_name[1], - current_outputs, - scope_name, - add_dim=True) + mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, + scope_name) layer_inputs["y"] = inputs_name[1] # 获取当前节点输入的list current_inputs = list(layer_inputs.values()) @@ -745,13 +731,8 @@ def aten_bmm(mapper, graph, node): scope_name) layer_inputs["x"] = inputs_name[0] # 处理输入1,即%288 - mapper._check_input( - graph, - inputs_node[1], - inputs_name[1], - current_outputs, - scope_name, - add_dim=True) + mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, + scope_name) layer_inputs["y"] = inputs_name[1] # 获取当前节点输入的list current_inputs = list(layer_inputs.values()) @@ -1854,17 +1835,12 @@ def aten_expand_as(mapper, graph, node): inputs={"input": inputs_name[0]}, outputs=[inputs_name[0] + "_type"], scope_name=scope_name) - graph.add_layer( - "prim.str", - inputs={"input": inputs_name[0] + "_type"}, - outputs=[inputs_name[0] + "_type"], - scope_name=scope_name) graph.add_layer( "prim.eq", inputs={"x": inputs_name[0] + "_type"}, outputs=[inputs_name[0] + "_cond"], scope_name=scope_name, - y=string("VarType.BOOL")) + y=paddle_dtypes.t_bool) graph.add_layer( "prim.if", {'input': inputs_name[0] + "_cond"}, outputs=[inputs_name[0] + "_if1"], @@ -2101,10 +2077,11 @@ def aten_floor(mapper, graph, node): outputs=[inputs_name[0] + "_type"], scope_name=scope_name) graph.add_layer( - "prim.startswith", {'input': inputs_name[0] + "_type"}, + "prim.eq", + inputs={"x": inputs_name[0] + "_type"}, outputs=[inputs_name[0] + "_cond"], scope_name=scope_name, - start_str=string("VarType")) + y=paddle_dtypes.t_bool) graph.add_layer( "prim.if", {'input': inputs_name[0] + "_cond"}, outputs=[inputs_name[0] + "_if"], @@ -5004,13 +4981,8 @@ def aten_sub(mapper, graph, node): scope_name) layer_inputs["x"] = inputs_name[0] # 处理输入1,即%836 - mapper._check_input( - graph, - inputs_node[1], - inputs_name[1], - current_outputs, - scope_name, - add_dim=True) + mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, + scope_name) layer_inputs["y"] = inputs_name[1] # 处理输入2,即%3 if len(inputs_node) > 2: diff --git a/x2paddle/op_mapper/dygraph/pytorch2paddle/prim.py b/x2paddle/op_mapper/dygraph/pytorch2paddle/prim.py index cc9c61efb110c6ba68363180680fda1947c3998b..fa8698a6fddd89836fc4517d1bbfef187d60ad24 100644 --- a/x2paddle/op_mapper/dygraph/pytorch2paddle/prim.py +++ b/x2paddle/op_mapper/dygraph/pytorch2paddle/prim.py @@ -15,7 +15,7 @@ import torch import numpy as np -from x2paddle.core.util import * +from x2paddle.core.util import string def prim_Constant(mapper, graph, node): diff --git a/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_custom_layer/gather.py b/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_custom_layer/gather.py index 6ef9d488587cdbe01dbf7ae343a008094113bc81..56cb9e3299874101c617b8f27496ff962713757e 100644 --- a/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_custom_layer/gather.py +++ b/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_custom_layer/gather.py @@ -13,13 +13,13 @@ # limitations under the License. import paddle +from x2paddle.core.util import * + class Gather(object): def __init__(self, dim): self.dim = dim - self.dtype_mapping = {"VarType.INT32": "int32", - "VarType.INT64": "int64"} - + def __call__(self, x, index): if self.dim < 0: self.dim += len(x.shape) @@ -31,27 +31,27 @@ class Gather(object): index_range[0] = self.dim index_range[self.dim] = 0 index_swaped = paddle.transpose(index, perm=index_range) - dtype = self.dtype_mapping[str(index.dtype)] - + dtype = index.dtype + x_shape = paddle.shape(x_swaped) index_shape = paddle.shape(index_swaped) - - prod = paddle.prod(x_shape, dtype=dtype) / x_shape[0] - + + prod = paddle.cast(paddle.prod(x_shape), dtype=dtype) / x_shape[0] + x_swaped_flattend = paddle.flatten(x_swaped) index_swaped_flattend = paddle.flatten(index_swaped) index_swaped_flattend *= prod - + bias = paddle.arange(start=0, end=prod, dtype=dtype) bias = paddle.reshape(bias, x_shape[1:]) bias = paddle.crop(bias, index_shape[1:]) bias = paddle.flatten(bias) bias = paddle.tile(bias, [index_shape[0]]) index_swaped_flattend += bias - + gathered = paddle.index_select(x_swaped_flattend, index_swaped_flattend) gathered = paddle.reshape(gathered, index_swaped.shape) - + out = paddle.transpose(gathered, perm=x_range) return out diff --git a/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py b/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py index 931ace6efd47ce3e11ae0e2ce9f8ab84046ab989..912c7fac9b4e3ae439e92adad9ffca3ac6ad2155 100644 --- a/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py +++ b/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py @@ -16,7 +16,7 @@ import torch import numpy as np from x2paddle.core.op_mapper import OpMapper -from x2paddle.core.util import * +from x2paddle.core.util import string from x2paddle.core.program import PaddleGraph from x2paddle.op_mapper.dygraph.pytorch2paddle import prim from x2paddle.op_mapper.dygraph.pytorch2paddle import aten @@ -169,18 +169,10 @@ class PyTorchOpMapper(OpMapper): outputs_name.append(output_name) return outputs_name - def _check_input(self, - graph, - node, - output_name, - node_outputs, - scope_name, - add_dim=False): + def _check_input(self, graph, node, output_name, node_outputs, scope_name): if node.kind() == "prim::GetAttr": param = self.pytorch_params[output_name] if isinstance(param, np.ndarray): - if add_dim: - param = param[np.newaxis, :] self.paddle_params[output_name] = param layer_id = graph.add_layer( "self.create_parameter", @@ -208,8 +200,6 @@ class PyTorchOpMapper(OpMapper): else: if id1_part[i] == "0" and id2_part[ i] == "1": - if add_dim: - param = param[np.newaxis, :] self.paddle_params[output_name] = param layer_id = graph.add_layer( "self.create_parameter", diff --git a/x2paddle/project_convertor/pytorch/mapper.py b/x2paddle/project_convertor/pytorch/mapper.py index 452e0233e34997bb6aceca1198f748573160d7ec..aa62921e6c2a6fb6ac6d83b74908c18ecffdc822 100644 --- a/x2paddle/project_convertor/pytorch/mapper.py +++ b/x2paddle/project_convertor/pytorch/mapper.py @@ -13,7 +13,7 @@ # limitations under the License. from x2paddle.project_convertor.pytorch.api_mapper import * -from x2paddle.utils import * +from x2paddle.utils import is_new_version OPTIMIZER_MAPPER = { "torch.optim": ["paddle.optimizer", None], @@ -25,7 +25,6 @@ OPTIMIZER_MAPPER = { ["paddle.optimizer.lr.MultiStepDecay", LRScheculerMapper], "torch.optim.Adam": ["x2paddle.torch2paddle.Adam", None], "torch.optim.SGD": ["x2paddle.torch2paddle.Momentum", None] - } NN_MAPPER = { @@ -169,11 +168,42 @@ DIST_MAPPER = { ["x2paddle.torch2paddle.init_process_group", None] } -DTYPE_MAPPER = { - "torch.float32": [string("float32"), None], - "torch.long": [string("int64"), None], - "torch.bool": [string("bool"), None] -} +if is_new_version: + DTYPE_MAPPER = { + "torch.float16": ["paddle.float16", None], + "torch.half": ["paddle.float16", None], + "torch.float32": ["paddle.float32", None], + "torch.float": ["paddle.float32", None], + "torch.float64": ["paddle.float64", None], + "torch.double": ["paddle.float64", None], + "torch.uint8": ["paddle.uint8", None], + "torch.int8": ["paddle.int8", None], + "torch.int16": ["paddle.int16", None], + "torch.short": ["paddle.int16", None], + "torch.int32": ["paddle.int32", None], + "torch.int": ["paddle.int32", None], + "torch.int64": ["paddle.int64", None], + "torch.long": ["paddle.int64", None], + "torch.bool": ["paddle.bool", None], + } +else: + DTYPE_MAPPER = { + "torch.float16": [string("float16"), None], + "torch.half": [string("float16"), None], + "torch.float32": [string("float32"), None], + "torch.float": [string("float32"), None], + "torch.float64": [string("float64"), None], + "torch.double": [string("float64"), None], + "torch.uint8": [string("uint8"), None], + "torch.int8": [string("int8"), None], + "torch.int16": [string("int16"), None], + "torch.short": [string("int16"), None], + "torch.int32": [string("int32"), None], + "torch.int": [string("int32"), None], + "torch.int64": [string("int64"), None], + "torch.long": [string("int64"), None], + "torch.bool": [string("bool"), None], + } TORCHVISION_MAPPER = { "torchvision": ["paddle.vision", None], diff --git a/x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py b/x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py index 01d9ccb4ff1a901ea13332895083668440ad49c2..0d5e1af6aa2d6c546369312b1a9867bab844c9f0 100644 --- a/x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py +++ b/x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py @@ -19,6 +19,7 @@ from paddle.fluid import framework from paddle.fluid.core import VarDesc from paddle.fluid.initializer import XavierInitializer, MSRAInitializer from paddle.fluid.data_feeder import check_variable_and_dtype +from x2paddle.utils import paddle_dtypes def _calculate_fan_in_and_fan_out(var): @@ -101,8 +102,8 @@ class KaimingNormal(MSRAInitializer): self._seed = block.program.random_seed # to be compatible of fp16 initalizers - if var.dtype == VarDesc.VarType.FP16: - out_dtype = VarDesc.VarType.FP32 + if var.dtype == paddle_dtypes.t_float16: + out_dtype = paddle_dtypes.t_float32 out_var = block.create_var( name=unique_name.generate(".".join( ['masra_init', var.name, 'tmp'])), @@ -169,8 +170,8 @@ class XavierNormal(XavierInitializer): self._seed = block.program.random_seed # to be compatible of fp16 initalizers - if var.dtype == VarDesc.VarType.FP16: - out_dtype = VarDesc.VarType.FP32 + if var.dtype == paddle_dtypes.t_float16: + out_dtype = paddle_dtypes.t_float32 out_var = block.create_var( name=unique_name.generate(".".join( ['xavier_init', var.name, 'tmp'])), @@ -195,7 +196,7 @@ class XavierNormal(XavierInitializer): "seed": self._seed }, stop_gradient=True) - if var.dtype == VarDesc.VarType.FP16: + if var.dtype == paddle_dtypes.t_float16: block.append_op( type="cast", inputs={"X": out_var}, @@ -233,8 +234,8 @@ class XavierUniform(XavierInitializer): self._seed = block.program.random_seed # to be compatible of fp16 initalizers - if var.dtype == VarDesc.VarType.FP16: - out_dtype = VarDesc.VarType.FP32 + if var.dtype == paddle_dtypes.t_float16: + out_dtype = paddle_dtypes.t_float32 out_var = block.create_var( name=unique_name.generate(".".join( ['xavier_init', var.name, 'tmp'])), @@ -260,7 +261,7 @@ class XavierUniform(XavierInitializer): "seed": self._seed }, stop_gradient=True) - if var.dtype == VarDesc.VarType.FP16: + if var.dtype == paddle_dtypes.t_float16: block.append_op( type="cast", inputs={"X": out_var}, diff --git a/x2paddle/project_convertor/pytorch/torch2paddle/varbase.py b/x2paddle/project_convertor/pytorch/torch2paddle/varbase.py index 7b0e932ed9569a02c6820ec449712818c622e634..c1596769bbc6fab1436db401e5e8e68e22ffd835 100644 --- a/x2paddle/project_convertor/pytorch/torch2paddle/varbase.py +++ b/x2paddle/project_convertor/pytorch/torch2paddle/varbase.py @@ -14,6 +14,7 @@ import paddle from paddle.fluid.core import VarBase +from x2paddle.utils import paddle_dtypes def is_condition_one(idx): @@ -23,8 +24,8 @@ def is_condition_one(idx): a[mask, :] a[mask, ...] """ - if not (isinstance(idx[0], paddle.Tensor) and - str(idx[0].dtype) == "VarType.BOOL"): + if not (isinstance(idx[0], paddle.Tensor) and \ + idx[0].dtype == paddle_dtypes.t_bool): return False if len(idx) == 1: return False @@ -57,13 +58,13 @@ VarBase.tmp = VarBase.__getitem__ def __getitem__(self, idx): is_bool = False - if str(self.dtype) == "VarType.BOOL": + if self.dtype == paddle_dtypes.t_bool: self = self.cast("int32") is_bool = True if isinstance(idx, paddle.Tensor) and len(idx.shape) == 1: out = paddle.gather(self, idx) return out.cast("bool") if is_bool else out - elif isinstance(idx, paddle.Tensor) and str(idx.dtype) == "VarType.BOOL": + elif isinstance(idx, paddle.Tensor) and idx.dtype == paddle_dtypes.t_bool: idx = paddle.cast(idx, "int32") idx = paddle.nonzero(idx) out = paddle.gather_nd(self, idx) @@ -100,7 +101,7 @@ VarBase.setitem_tmp = VarBase.__setitem__ def __setitem__(self, idx, value): - if isinstance(idx, paddle.Tensor) and str(idx.dtype) == "VarType.BOOL": + if isinstance(idx, paddle.Tensor) and idx.dtype == paddle_dtypes.t_bool: """ a = paddle.to_tensor(np.array([1,2,3]).astype("float32")) mask = paddle.to_tensor(np.array([True, False, True]).astype("bool")) diff --git a/x2paddle/utils.py b/x2paddle/utils.py index 7ffa1ec4ea7512f6fa5e30de5d6862302832a408..37409cdb2d187d3b5e077ef92e3edc9f6b161db0 100644 --- a/x2paddle/utils.py +++ b/x2paddle/utils.py @@ -13,8 +13,49 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle + def string(param): """ 生成字符串。 """ return "\'{}\'".format(param) + + +def check_version(): + version = paddle.__version__ + v0, v1, v2 = version.split('.') + if not ((v0 == '0' and v1 == '0' and v2 == '0') or + (int(v0) >= 2 and int(v1) >= 1)): + return False + else: + return True + + +class PaddleDtypes(): + def __init__(self, is_new_version=True): + if is_new_version: + self.t_float16 = paddle.float16 + self.t_float32 = paddle.float32 + self.t_float64 = paddle.float64 + self.t_uint8 = paddle.uint8 + self.t_int8 = paddle.int8 + self.t_int16 = paddle.int16 + self.t_int32 = paddle.int32 + self.t_int64 = paddle.int64 + self.t_bool = paddle.bool + else: + from paddle.fluid.core import VarDesc + self.t_float16 = VarDesc.VarType.FP16 + self.t_float32 = VarDesc.VarType.FP32 + self.t_float64 = VarDesc.VarType.FP64 + self.t_uint8 = VarDesc.VarType.UINT8 + self.t_int8 = VarDesc.VarType.INT8 + self.t_int16 = VarDesc.VarType.INT16 + self.t_int32 = VarDesc.VarType.INT32 + self.t_int64 = VarDesc.VarType.INT64 + self.t_bool = VarDesc.VarType.BOOL + + +is_new_version = check_version() +paddle_dtypes = PaddleDtypes(is_new_version)