Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
592894bd
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
1 年多 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
592894bd
编写于
5月 21, 2021
作者:
S
SunAhong1993
提交者:
GitHub
5月 21, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
For Paddle2.1 (#596)
* Update stargan.md * fix the type * fix * for 2.0.0 * fix
上级
4a2df7e1
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
140 addition
and
108 deletion
+140
-108
docs/pytorch_project_convertor/after_convert.md
docs/pytorch_project_convertor/after_convert.md
+8
-9
docs/pytorch_project_convertor/demo/stargan.md
docs/pytorch_project_convertor/demo/stargan.md
+9
-11
x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py
x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py
+16
-44
x2paddle/op_mapper/dygraph/pytorch2paddle/prim.py
x2paddle/op_mapper/dygraph/pytorch2paddle/prim.py
+1
-1
x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_custom_layer/gather.py
...per/dygraph/pytorch2paddle/pytorch_custom_layer/gather.py
+11
-11
x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py
...dle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py
+2
-12
x2paddle/project_convertor/pytorch/mapper.py
x2paddle/project_convertor/pytorch/mapper.py
+37
-7
x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py
x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py
+9
-8
x2paddle/project_convertor/pytorch/torch2paddle/varbase.py
x2paddle/project_convertor/pytorch/torch2paddle/varbase.py
+6
-5
x2paddle/utils.py
x2paddle/utils.py
+41
-0
未找到文件。
docs/pytorch_project_convertor/after_convert.md
浏览文件 @
592894bd
...
...
@@ -24,19 +24,18 @@ class VocDataset(paddle.io.Dataset):
...
```
3.
若存在Tensor对比操作(包含==、!=、
<
、<=、
>
、>=操作符),在对比操作符前添加对Tensor类型的判断,如果为
bool型强转为int
型,并在对比后转换回bool型。
3.
若存在Tensor对比操作(包含==、!=、
<
、<=、
>
、>=操作符),在对比操作符前添加对Tensor类型的判断,如果为
非bool型强转为bool
型,并在对比后转换回bool型。
```
# 原始代码(其中c_trg是Tensor)
# 原始代码(其中c_trg是
非bool型的
Tensor)
c_trg = c_trg == 0
# 替换后代码
is_bool = False
if str(c_trg.dtype) == "VarType.BOOL":
c_trg = c_trg.cast("int32")
is_bool = True
c_trg = c_trg == 0
if is_bool:
c_trg = c_trg.cast("bool")
c_trg = c_trg.cast("int32")
c_trg_tmp = paddle.zeros_like(c_trg)
paddle.assign(c_trg, c_trg_tmp)
c_trg_tmp = c_trg_tmp.cast("bool")
c_trg_tmp[:, i] = c_trg[:, i] == 0
c_trg = c_trg_tmp
```
4.
如若转换后的运行代码的入口为sh脚本文件,且其中有预训练模型路径,应将其中的预训练模型的路径字符串中的“.pth”、“.pt”、“.ckpt”替换为“.pdiparams”。
docs/pytorch_project_convertor/demo/stargan.md
浏览文件 @
592894bd
...
...
@@ -80,17 +80,15 @@ class Solver(object):
if
j
!=
i
:
c_trg
[:,
j
]
=
0
else
:
# 如果为bool型,需要强转为int32,
# 在17-20行实现
is_bool
=
False
if
str
(
c_trg
.
dtype
)
==
"VarType.BOOL"
:
c_trg
=
c_trg
.
cast
(
"int32"
)
is_bool
=
True
c_trg
[:,
i
]
=
(
c_trg
[:,
i
]
==
0
)
# 如果为bool类型转换为原类型
# 在23-24行实现
if
is_bool
:
c_trg
=
c_trg
.
cast
(
"bool"
)
# 如果为非int型,需要强转为int32,
# 在18-22行实现
# c_trg[:, i] = (c_trg[:, i] == 0)
c_trg
=
c_trg
.
cast
(
"int32"
)
c_trg_tmp
=
paddle
.
zeros_like
(
c_trg
)
paddle
.
assign
(
c_trg
,
c_trg_tmp
)
c_trg_tmp
=
c_trg_tmp
.
cast
(
"bool"
)
c_trg_tmp
[:,
i
]
=
c_trg
[:,
i
]
==
0
c_trg
=
c_trg_tmp
...
...
...
...
...
x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py
浏览文件 @
592894bd
...
...
@@ -15,7 +15,8 @@
import
copy
import
numpy
as
np
from
x2paddle.core.util
import
*
from
x2paddle.core.util
import
name_generator
,
string
from
x2paddle.utils
import
paddle_dtypes
from
x2paddle.core.program
import
PaddleGraph
dtype_dict
=
{
...
...
@@ -182,13 +183,8 @@ def aten_addmm(mapper, graph, node):
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%150
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
,
add_dim
=
True
)
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
,
scope_name
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 处理输入1,即%input.3
mapper
.
_check_input
(
graph
,
inputs_node
[
1
],
inputs_name
[
1
],
current_outputs
,
...
...
@@ -247,13 +243,8 @@ def aten_add(mapper, graph, node):
scope_name
)
layer_inputs
[
"x"
]
=
inputs_name
[
0
]
# 处理输入1,即%288
mapper
.
_check_input
(
graph
,
inputs_node
[
1
],
inputs_name
[
1
],
current_outputs
,
scope_name
,
add_dim
=
True
)
mapper
.
_check_input
(
graph
,
inputs_node
[
1
],
inputs_name
[
1
],
current_outputs
,
scope_name
)
layer_inputs
[
"y"
]
=
inputs_name
[
1
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
...
...
@@ -289,13 +280,8 @@ def aten_add_(mapper, graph, node):
scope_name
)
layer_inputs
[
"x"
]
=
inputs_name
[
0
]
# 处理输入1,即%150
mapper
.
_check_input
(
graph
,
inputs_node
[
1
],
inputs_name
[
1
],
current_outputs
,
scope_name
,
add_dim
=
True
)
mapper
.
_check_input
(
graph
,
inputs_node
[
1
],
inputs_name
[
1
],
current_outputs
,
scope_name
)
layer_inputs
[
"y"
]
=
inputs_name
[
1
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
...
...
@@ -745,13 +731,8 @@ def aten_bmm(mapper, graph, node):
scope_name
)
layer_inputs
[
"x"
]
=
inputs_name
[
0
]
# 处理输入1,即%288
mapper
.
_check_input
(
graph
,
inputs_node
[
1
],
inputs_name
[
1
],
current_outputs
,
scope_name
,
add_dim
=
True
)
mapper
.
_check_input
(
graph
,
inputs_node
[
1
],
inputs_name
[
1
],
current_outputs
,
scope_name
)
layer_inputs
[
"y"
]
=
inputs_name
[
1
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
...
...
@@ -1854,17 +1835,12 @@ def aten_expand_as(mapper, graph, node):
inputs
=
{
"input"
:
inputs_name
[
0
]},
outputs
=
[
inputs_name
[
0
]
+
"_type"
],
scope_name
=
scope_name
)
graph
.
add_layer
(
"prim.str"
,
inputs
=
{
"input"
:
inputs_name
[
0
]
+
"_type"
},
outputs
=
[
inputs_name
[
0
]
+
"_type"
],
scope_name
=
scope_name
)
graph
.
add_layer
(
"prim.eq"
,
inputs
=
{
"x"
:
inputs_name
[
0
]
+
"_type"
},
outputs
=
[
inputs_name
[
0
]
+
"_cond"
],
scope_name
=
scope_name
,
y
=
string
(
"VarType.BOOL"
)
)
y
=
paddle_dtypes
.
t_bool
)
graph
.
add_layer
(
"prim.if"
,
{
'input'
:
inputs_name
[
0
]
+
"_cond"
},
outputs
=
[
inputs_name
[
0
]
+
"_if1"
],
...
...
@@ -2101,10 +2077,11 @@ def aten_floor(mapper, graph, node):
outputs
=
[
inputs_name
[
0
]
+
"_type"
],
scope_name
=
scope_name
)
graph
.
add_layer
(
"prim.startswith"
,
{
'input'
:
inputs_name
[
0
]
+
"_type"
},
"prim.eq"
,
inputs
=
{
"x"
:
inputs_name
[
0
]
+
"_type"
},
outputs
=
[
inputs_name
[
0
]
+
"_cond"
],
scope_name
=
scope_name
,
start_str
=
string
(
"VarType"
)
)
y
=
paddle_dtypes
.
t_bool
)
graph
.
add_layer
(
"prim.if"
,
{
'input'
:
inputs_name
[
0
]
+
"_cond"
},
outputs
=
[
inputs_name
[
0
]
+
"_if"
],
...
...
@@ -5004,13 +4981,8 @@ def aten_sub(mapper, graph, node):
scope_name
)
layer_inputs
[
"x"
]
=
inputs_name
[
0
]
# 处理输入1,即%836
mapper
.
_check_input
(
graph
,
inputs_node
[
1
],
inputs_name
[
1
],
current_outputs
,
scope_name
,
add_dim
=
True
)
mapper
.
_check_input
(
graph
,
inputs_node
[
1
],
inputs_name
[
1
],
current_outputs
,
scope_name
)
layer_inputs
[
"y"
]
=
inputs_name
[
1
]
# 处理输入2,即%3
if
len
(
inputs_node
)
>
2
:
...
...
x2paddle/op_mapper/dygraph/pytorch2paddle/prim.py
浏览文件 @
592894bd
...
...
@@ -15,7 +15,7 @@
import
torch
import
numpy
as
np
from
x2paddle.core.util
import
*
from
x2paddle.core.util
import
string
def
prim_Constant
(
mapper
,
graph
,
node
):
...
...
x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_custom_layer/gather.py
浏览文件 @
592894bd
...
...
@@ -13,13 +13,13 @@
# limitations under the License.
import
paddle
from
x2paddle.core.util
import
*
class
Gather
(
object
):
def
__init__
(
self
,
dim
):
self
.
dim
=
dim
self
.
dtype_mapping
=
{
"VarType.INT32"
:
"int32"
,
"VarType.INT64"
:
"int64"
}
def
__call__
(
self
,
x
,
index
):
if
self
.
dim
<
0
:
self
.
dim
+=
len
(
x
.
shape
)
...
...
@@ -31,27 +31,27 @@ class Gather(object):
index_range
[
0
]
=
self
.
dim
index_range
[
self
.
dim
]
=
0
index_swaped
=
paddle
.
transpose
(
index
,
perm
=
index_range
)
dtype
=
self
.
dtype_mapping
[
str
(
index
.
dtype
)]
dtype
=
index
.
dtype
x_shape
=
paddle
.
shape
(
x_swaped
)
index_shape
=
paddle
.
shape
(
index_swaped
)
prod
=
paddle
.
prod
(
x_shape
,
dtype
=
dtype
)
/
x_shape
[
0
]
prod
=
paddle
.
cast
(
paddle
.
prod
(
x_shape
)
,
dtype
=
dtype
)
/
x_shape
[
0
]
x_swaped_flattend
=
paddle
.
flatten
(
x_swaped
)
index_swaped_flattend
=
paddle
.
flatten
(
index_swaped
)
index_swaped_flattend
*=
prod
bias
=
paddle
.
arange
(
start
=
0
,
end
=
prod
,
dtype
=
dtype
)
bias
=
paddle
.
reshape
(
bias
,
x_shape
[
1
:])
bias
=
paddle
.
crop
(
bias
,
index_shape
[
1
:])
bias
=
paddle
.
flatten
(
bias
)
bias
=
paddle
.
tile
(
bias
,
[
index_shape
[
0
]])
index_swaped_flattend
+=
bias
gathered
=
paddle
.
index_select
(
x_swaped_flattend
,
index_swaped_flattend
)
gathered
=
paddle
.
reshape
(
gathered
,
index_swaped
.
shape
)
out
=
paddle
.
transpose
(
gathered
,
perm
=
x_range
)
return
out
x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py
浏览文件 @
592894bd
...
...
@@ -16,7 +16,7 @@
import
torch
import
numpy
as
np
from
x2paddle.core.op_mapper
import
OpMapper
from
x2paddle.core.util
import
*
from
x2paddle.core.util
import
string
from
x2paddle.core.program
import
PaddleGraph
from
x2paddle.op_mapper.dygraph.pytorch2paddle
import
prim
from
x2paddle.op_mapper.dygraph.pytorch2paddle
import
aten
...
...
@@ -169,18 +169,10 @@ class PyTorchOpMapper(OpMapper):
outputs_name
.
append
(
output_name
)
return
outputs_name
def
_check_input
(
self
,
graph
,
node
,
output_name
,
node_outputs
,
scope_name
,
add_dim
=
False
):
def
_check_input
(
self
,
graph
,
node
,
output_name
,
node_outputs
,
scope_name
):
if
node
.
kind
()
==
"prim::GetAttr"
:
param
=
self
.
pytorch_params
[
output_name
]
if
isinstance
(
param
,
np
.
ndarray
):
if
add_dim
:
param
=
param
[
np
.
newaxis
,
:]
self
.
paddle_params
[
output_name
]
=
param
layer_id
=
graph
.
add_layer
(
"self.create_parameter"
,
...
...
@@ -208,8 +200,6 @@ class PyTorchOpMapper(OpMapper):
else
:
if
id1_part
[
i
]
==
"0"
and
id2_part
[
i
]
==
"1"
:
if
add_dim
:
param
=
param
[
np
.
newaxis
,
:]
self
.
paddle_params
[
output_name
]
=
param
layer_id
=
graph
.
add_layer
(
"self.create_parameter"
,
...
...
x2paddle/project_convertor/pytorch/mapper.py
浏览文件 @
592894bd
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
from
x2paddle.project_convertor.pytorch.api_mapper
import
*
from
x2paddle.utils
import
*
from
x2paddle.utils
import
is_new_version
OPTIMIZER_MAPPER
=
{
"torch.optim"
:
[
"paddle.optimizer"
,
None
],
...
...
@@ -25,7 +25,6 @@ OPTIMIZER_MAPPER = {
[
"paddle.optimizer.lr.MultiStepDecay"
,
LRScheculerMapper
],
"torch.optim.Adam"
:
[
"x2paddle.torch2paddle.Adam"
,
None
],
"torch.optim.SGD"
:
[
"x2paddle.torch2paddle.Momentum"
,
None
]
}
NN_MAPPER
=
{
...
...
@@ -169,11 +168,42 @@ DIST_MAPPER = {
[
"x2paddle.torch2paddle.init_process_group"
,
None
]
}
DTYPE_MAPPER
=
{
"torch.float32"
:
[
string
(
"float32"
),
None
],
"torch.long"
:
[
string
(
"int64"
),
None
],
"torch.bool"
:
[
string
(
"bool"
),
None
]
}
if
is_new_version
:
DTYPE_MAPPER
=
{
"torch.float16"
:
[
"paddle.float16"
,
None
],
"torch.half"
:
[
"paddle.float16"
,
None
],
"torch.float32"
:
[
"paddle.float32"
,
None
],
"torch.float"
:
[
"paddle.float32"
,
None
],
"torch.float64"
:
[
"paddle.float64"
,
None
],
"torch.double"
:
[
"paddle.float64"
,
None
],
"torch.uint8"
:
[
"paddle.uint8"
,
None
],
"torch.int8"
:
[
"paddle.int8"
,
None
],
"torch.int16"
:
[
"paddle.int16"
,
None
],
"torch.short"
:
[
"paddle.int16"
,
None
],
"torch.int32"
:
[
"paddle.int32"
,
None
],
"torch.int"
:
[
"paddle.int32"
,
None
],
"torch.int64"
:
[
"paddle.int64"
,
None
],
"torch.long"
:
[
"paddle.int64"
,
None
],
"torch.bool"
:
[
"paddle.bool"
,
None
],
}
else
:
DTYPE_MAPPER
=
{
"torch.float16"
:
[
string
(
"float16"
),
None
],
"torch.half"
:
[
string
(
"float16"
),
None
],
"torch.float32"
:
[
string
(
"float32"
),
None
],
"torch.float"
:
[
string
(
"float32"
),
None
],
"torch.float64"
:
[
string
(
"float64"
),
None
],
"torch.double"
:
[
string
(
"float64"
),
None
],
"torch.uint8"
:
[
string
(
"uint8"
),
None
],
"torch.int8"
:
[
string
(
"int8"
),
None
],
"torch.int16"
:
[
string
(
"int16"
),
None
],
"torch.short"
:
[
string
(
"int16"
),
None
],
"torch.int32"
:
[
string
(
"int32"
),
None
],
"torch.int"
:
[
string
(
"int32"
),
None
],
"torch.int64"
:
[
string
(
"int64"
),
None
],
"torch.long"
:
[
string
(
"int64"
),
None
],
"torch.bool"
:
[
string
(
"bool"
),
None
],
}
TORCHVISION_MAPPER
=
{
"torchvision"
:
[
"paddle.vision"
,
None
],
...
...
x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py
浏览文件 @
592894bd
...
...
@@ -19,6 +19,7 @@ from paddle.fluid import framework
from
paddle.fluid.core
import
VarDesc
from
paddle.fluid.initializer
import
XavierInitializer
,
MSRAInitializer
from
paddle.fluid.data_feeder
import
check_variable_and_dtype
from
x2paddle.utils
import
paddle_dtypes
def
_calculate_fan_in_and_fan_out
(
var
):
...
...
@@ -101,8 +102,8 @@ class KaimingNormal(MSRAInitializer):
self
.
_seed
=
block
.
program
.
random_seed
# to be compatible of fp16 initalizers
if
var
.
dtype
==
VarDesc
.
VarType
.
FP
16
:
out_dtype
=
VarDesc
.
VarType
.
FP
32
if
var
.
dtype
==
paddle_dtypes
.
t_float
16
:
out_dtype
=
paddle_dtypes
.
t_float
32
out_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
"."
.
join
(
[
'masra_init'
,
var
.
name
,
'tmp'
])),
...
...
@@ -169,8 +170,8 @@ class XavierNormal(XavierInitializer):
self
.
_seed
=
block
.
program
.
random_seed
# to be compatible of fp16 initalizers
if
var
.
dtype
==
VarDesc
.
VarType
.
FP
16
:
out_dtype
=
VarDesc
.
VarType
.
FP
32
if
var
.
dtype
==
paddle_dtypes
.
t_float
16
:
out_dtype
=
paddle_dtypes
.
t_float
32
out_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
"."
.
join
(
[
'xavier_init'
,
var
.
name
,
'tmp'
])),
...
...
@@ -195,7 +196,7 @@ class XavierNormal(XavierInitializer):
"seed"
:
self
.
_seed
},
stop_gradient
=
True
)
if
var
.
dtype
==
VarDesc
.
VarType
.
FP
16
:
if
var
.
dtype
==
paddle_dtypes
.
t_float
16
:
block
.
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
out_var
},
...
...
@@ -233,8 +234,8 @@ class XavierUniform(XavierInitializer):
self
.
_seed
=
block
.
program
.
random_seed
# to be compatible of fp16 initalizers
if
var
.
dtype
==
VarDesc
.
VarType
.
FP
16
:
out_dtype
=
VarDesc
.
VarType
.
FP
32
if
var
.
dtype
==
paddle_dtypes
.
t_float
16
:
out_dtype
=
paddle_dtypes
.
t_float
32
out_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
"."
.
join
(
[
'xavier_init'
,
var
.
name
,
'tmp'
])),
...
...
@@ -260,7 +261,7 @@ class XavierUniform(XavierInitializer):
"seed"
:
self
.
_seed
},
stop_gradient
=
True
)
if
var
.
dtype
==
VarDesc
.
VarType
.
FP
16
:
if
var
.
dtype
==
paddle_dtypes
.
t_float
16
:
block
.
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
out_var
},
...
...
x2paddle/project_convertor/pytorch/torch2paddle/varbase.py
浏览文件 @
592894bd
...
...
@@ -14,6 +14,7 @@
import
paddle
from
paddle.fluid.core
import
VarBase
from
x2paddle.utils
import
paddle_dtypes
def
is_condition_one
(
idx
):
...
...
@@ -23,8 +24,8 @@ def is_condition_one(idx):
a[mask, :]
a[mask, ...]
"""
if
not
(
isinstance
(
idx
[
0
],
paddle
.
Tensor
)
and
str
(
idx
[
0
].
dtype
)
==
"VarType.BOOL"
):
if
not
(
isinstance
(
idx
[
0
],
paddle
.
Tensor
)
and
\
idx
[
0
].
dtype
==
paddle_dtypes
.
t_bool
):
return
False
if
len
(
idx
)
==
1
:
return
False
...
...
@@ -57,13 +58,13 @@ VarBase.tmp = VarBase.__getitem__
def
__getitem__
(
self
,
idx
):
is_bool
=
False
if
s
tr
(
self
.
dtype
)
==
"VarType.BOOL"
:
if
s
elf
.
dtype
==
paddle_dtypes
.
t_bool
:
self
=
self
.
cast
(
"int32"
)
is_bool
=
True
if
isinstance
(
idx
,
paddle
.
Tensor
)
and
len
(
idx
.
shape
)
==
1
:
out
=
paddle
.
gather
(
self
,
idx
)
return
out
.
cast
(
"bool"
)
if
is_bool
else
out
elif
isinstance
(
idx
,
paddle
.
Tensor
)
and
str
(
idx
.
dtype
)
==
"VarType.BOOL"
:
elif
isinstance
(
idx
,
paddle
.
Tensor
)
and
idx
.
dtype
==
paddle_dtypes
.
t_bool
:
idx
=
paddle
.
cast
(
idx
,
"int32"
)
idx
=
paddle
.
nonzero
(
idx
)
out
=
paddle
.
gather_nd
(
self
,
idx
)
...
...
@@ -100,7 +101,7 @@ VarBase.setitem_tmp = VarBase.__setitem__
def
__setitem__
(
self
,
idx
,
value
):
if
isinstance
(
idx
,
paddle
.
Tensor
)
and
str
(
idx
.
dtype
)
==
"VarType.BOOL"
:
if
isinstance
(
idx
,
paddle
.
Tensor
)
and
idx
.
dtype
==
paddle_dtypes
.
t_bool
:
"""
a = paddle.to_tensor(np.array([1,2,3]).astype("float32"))
mask = paddle.to_tensor(np.array([True, False, True]).astype("bool"))
...
...
x2paddle/utils.py
浏览文件 @
592894bd
...
...
@@ -13,8 +13,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
def
string
(
param
):
""" 生成字符串。
"""
return
"
\'
{}
\'
"
.
format
(
param
)
def
check_version
():
version
=
paddle
.
__version__
v0
,
v1
,
v2
=
version
.
split
(
'.'
)
if
not
((
v0
==
'0'
and
v1
==
'0'
and
v2
==
'0'
)
or
(
int
(
v0
)
>=
2
and
int
(
v1
)
>=
1
)):
return
False
else
:
return
True
class
PaddleDtypes
():
def
__init__
(
self
,
is_new_version
=
True
):
if
is_new_version
:
self
.
t_float16
=
paddle
.
float16
self
.
t_float32
=
paddle
.
float32
self
.
t_float64
=
paddle
.
float64
self
.
t_uint8
=
paddle
.
uint8
self
.
t_int8
=
paddle
.
int8
self
.
t_int16
=
paddle
.
int16
self
.
t_int32
=
paddle
.
int32
self
.
t_int64
=
paddle
.
int64
self
.
t_bool
=
paddle
.
bool
else
:
from
paddle.fluid.core
import
VarDesc
self
.
t_float16
=
VarDesc
.
VarType
.
FP16
self
.
t_float32
=
VarDesc
.
VarType
.
FP32
self
.
t_float64
=
VarDesc
.
VarType
.
FP64
self
.
t_uint8
=
VarDesc
.
VarType
.
UINT8
self
.
t_int8
=
VarDesc
.
VarType
.
INT8
self
.
t_int16
=
VarDesc
.
VarType
.
INT16
self
.
t_int32
=
VarDesc
.
VarType
.
INT32
self
.
t_int64
=
VarDesc
.
VarType
.
INT64
self
.
t_bool
=
VarDesc
.
VarType
.
BOOL
is_new_version
=
check_version
()
paddle_dtypes
=
PaddleDtypes
(
is_new_version
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录