Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f7eb03c6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f7eb03c6
编写于
6月 14, 2023
作者:
C
Charles-hit
提交者:
GitHub
6月 14, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support group_norm and cumsum prim ops bf16 dtype (#54580)
上级
3d4d995f
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
188 addition
and
54 deletion
+188
-54
paddle/fluid/prim/api/composite_backward/composite_backward_api.h
...luid/prim/api/composite_backward/composite_backward_api.h
+12
-8
paddle/phi/core/visit_type.h
paddle/phi/core/visit_type.h
+11
-2
paddle/phi/kernels/gpu/reduce.h
paddle/phi/kernels/gpu/reduce.h
+2
-1
python/paddle/incubate/autograd/composite_rules.py
python/paddle/incubate/autograd/composite_rules.py
+5
-4
test/legacy_test/CMakeLists.txt
test/legacy_test/CMakeLists.txt
+1
-1
test/legacy_test/test_cumsum_op.py
test/legacy_test/test_cumsum_op.py
+11
-9
test/legacy_test/test_group_norm_op.py
test/legacy_test/test_group_norm_op.py
+146
-29
未找到文件。
paddle/fluid/prim/api/composite_backward/composite_backward_api.h
浏览文件 @
f7eb03c6
...
...
@@ -693,11 +693,13 @@ void group_norm_grad(const Tensor& x,
Tensor
x_data
=
x
;
Tensor
out_grad_data
=
out_grad
;
if
(
x
.
dtype
()
==
phi
::
DataType
::
FLOAT16
)
{
if
(
x
.
dtype
()
==
phi
::
DataType
::
FLOAT16
||
x
.
dtype
()
==
phi
::
DataType
::
BFLOAT16
)
{
x_data
=
cast
<
T
>
(
x
,
phi
::
DataType
::
FLOAT32
);
}
if
(
out_grad
.
dtype
()
==
phi
::
DataType
::
FLOAT16
)
{
if
(
out_grad
.
dtype
()
==
phi
::
DataType
::
FLOAT16
||
out_grad
.
dtype
()
==
phi
::
DataType
::
BFLOAT16
)
{
out_grad_data
=
cast
<
T
>
(
out_grad
,
phi
::
DataType
::
FLOAT32
);
}
...
...
@@ -728,7 +730,8 @@ void group_norm_grad(const Tensor& x,
Tensor
p1
;
if
(
scale_ptr
)
{
auto
scale_data
=
scale
.
get
();
if
(
scale_data
.
dtype
()
==
phi
::
DataType
::
FLOAT16
)
{
if
(
scale_data
.
dtype
()
==
phi
::
DataType
::
FLOAT16
||
scale_data
.
dtype
()
==
phi
::
DataType
::
BFLOAT16
)
{
scale_data
=
cast
<
T
>
(
scale_data
,
phi
::
DataType
::
FLOAT32
);
}
d1
=
(
reshape
<
T
>
(
sum_y_grad_mul_x
*
scale_data
,
shape_group
))
...
...
@@ -757,7 +760,8 @@ void group_norm_grad(const Tensor& x,
auto
tmp_2
=
reshape
<
T
>
(
x_data
,
whole_group_shape
)
*
p2
+
p3
;
auto
x_grad_data
=
tmp_1
+
tmp_2
;
x_grad_data
=
reshape
<
T
>
(
x_grad_data
,
x
.
shape
());
if
(
x
.
dtype
()
==
phi
::
DataType
::
FLOAT16
)
{
if
(
x
.
dtype
()
==
phi
::
DataType
::
FLOAT16
||
x
.
dtype
()
==
phi
::
DataType
::
BFLOAT16
)
{
x_grad_data
=
cast
<
T
>
(
x_grad_data
,
x
.
dtype
());
}
...
...
@@ -770,9 +774,9 @@ void group_norm_grad(const Tensor& x,
reshape
<
T
>
(
sum_y_grad
,
shape_group
)
*
reshape
<
T
>
(
mean
,
third_shape
))
*
reshape
<
T
>
(
inv_std
,
third_shape
);
auto
scale_grad_tmp
=
reshape
<
T
>
(
tmp1
.
sum
(
std
::
vector
<
int64_t
>
({
0
}),
dtype
,
false
),
IntArray
(
std
::
vector
<
int64_t
>
({
C
})));
auto
scale_grad_tmp
=
reshape
<
T
>
(
tmp1
.
sum
(
std
::
vector
<
int64_t
>
({
0
}),
scale_ptr
->
dtype
()
,
false
),
IntArray
(
std
::
vector
<
int64_t
>
({
C
})));
set_output
<
T
>
(
scale_grad_tmp
,
scale_grad
);
}
else
{
scale_grad
=
nullptr
;
...
...
@@ -782,7 +786,7 @@ void group_norm_grad(const Tensor& x,
if
(
bias_grad
)
{
if
(
bias_ptr
)
{
auto
bias_grad_tmp
=
sum_y_grad
.
sum
(
std
::
vector
<
int64_t
>
({
0
}),
dtype
,
false
);
sum_y_grad
.
sum
(
std
::
vector
<
int64_t
>
({
0
}),
bias_ptr
->
dtype
()
,
false
);
set_output
<
T
>
(
bias_grad_tmp
,
bias_grad
);
}
else
{
bias_grad
=
nullptr
;
...
...
paddle/phi/core/visit_type.h
浏览文件 @
f7eb03c6
...
...
@@ -298,8 +298,13 @@ namespace phi {
} \
}()
#define PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES( \
SPECIFIED_TYPE1, SPECIFIED_TYPE2, SPECIFIED_TYPE3, TYPE, NAME, ...) \
#define PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_4_TYPES(SPECIFIED_TYPE1, \
SPECIFIED_TYPE2, \
SPECIFIED_TYPE3, \
SPECIFIED_TYPE4, \
TYPE, \
NAME, \
...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
...
...
@@ -328,6 +333,10 @@ namespace phi {
SPECIFIED_TYPE3, \
::phi::DataTypeToCppType<SPECIFIED_TYPE3>::type, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, \
SPECIFIED_TYPE4, \
::phi::DataTypeToCppType<SPECIFIED_TYPE4>::type, \
__VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
__dtype__, \
...
...
paddle/phi/kernels/gpu/reduce.h
浏览文件 @
f7eb03c6
...
...
@@ -47,10 +47,11 @@ void Reduce(const KPDevice& dev_ctx,
#ifndef PADDLE_WITH_XPU_KP
if
(
out_dtype
!=
phi
::
DataType
::
UNDEFINED
&&
out_dtype
!=
x
.
dtype
())
{
auto
tmp_tensor
=
phi
::
Cast
<
T
>
(
dev_ctx
,
x
,
out_dtype
);
PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_
3
_TYPES
(
PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_
4
_TYPES
(
phi
::
DataType
::
INT32
,
phi
::
DataType
::
INT64
,
phi
::
DataType
::
FLOAT16
,
phi
::
DataType
::
BFLOAT16
,
out_dtype
,
"ReduceKernel"
,
([
&
]
{
...
...
python/paddle/incubate/autograd/composite_rules.py
浏览文件 @
f7eb03c6
...
...
@@ -656,8 +656,9 @@ def group_norm_composite(x, scale, bias, epsilon, groups, data_layout):
is_amp
=
False
from
paddle.fluid.data_feeder
import
convert_dtype
# when inputs are float16, convert to float32 in computing
if
convert_dtype
(
x
.
dtype
)
==
"float16"
:
dtype
=
convert_dtype
(
x
.
dtype
)
# when inputs are float16 or bfloat16, convert to float32 in computing
if
dtype
in
[
"float16"
,
"uint16"
]:
is_amp
=
True
x
=
cast
(
x
,
"float32"
)
scale
=
cast
(
scale
,
"float32"
)
...
...
@@ -676,9 +677,9 @@ def group_norm_composite(x, scale, bias, epsilon, groups, data_layout):
out
=
out
+
reshape
(
bias
,
(
-
1
,
1
,
1
))
ret_mean_
=
reshape
(
mean_
,
(
N
,
groups
))
ret_var_
=
reshape
(
var_
,
(
N
,
groups
))
# return output in float16, mean and var in float32
# return output in float16
or bfloat16
, mean and var in float32
if
is_amp
:
out
=
cast
(
out
,
"float16"
)
out
=
cast
(
out
,
dtype
)
return
out
,
ret_mean_
,
ret_var_
...
...
test/legacy_test/CMakeLists.txt
浏览文件 @
f7eb03c6
...
...
@@ -1033,7 +1033,7 @@ set_tests_properties(
PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_conv_nn_grad PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_program_prune_backward PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_group_norm_op PROPERTIES TIMEOUT
12
0
)
set_tests_properties
(
test_group_norm_op PROPERTIES TIMEOUT
30
0
)
set_tests_properties
(
test_imperative_optimizer PROPERTIES TIMEOUT 250
)
set_tests_properties
(
test_imperative_optimizer_v2 PROPERTIES TIMEOUT 250
)
set_tests_properties
(
test_pool2d_op PROPERTIES TIMEOUT 120
)
...
...
test/legacy_test/test_cumsum_op.py
浏览文件 @
f7eb03c6
...
...
@@ -122,7 +122,7 @@ class TestSumOp1(OpTest):
self
.
prim_op_type
=
"prim"
self
.
python_api
=
cumsum_wrapper
self
.
public_python_api
=
paddle
.
cumsum
self
.
set
_enable_cinn
()
self
.
if
_enable_cinn
()
self
.
init_dtype
()
self
.
set_attrs_input_output
()
if
self
.
dtype
==
np
.
uint16
:
...
...
@@ -141,7 +141,7 @@ class TestSumOp1(OpTest):
def
init_dtype
(
self
):
self
.
dtype
=
self
.
dtype_
=
np
.
float64
def
set
_enable_cinn
(
self
):
def
if
_enable_cinn
(
self
):
pass
def
set_attrs_input_output
(
self
):
...
...
@@ -221,7 +221,7 @@ class TestSumOpExclusive1(OpTest):
self
.
prim_op_type
=
"prim"
self
.
python_api
=
cumsum_wrapper
self
.
public_python_api
=
paddle
.
cumsum
self
.
set
_enable_cinn
()
self
.
if
_enable_cinn
()
self
.
init_dtype
()
self
.
set_attrs_input_output
()
if
self
.
dtype
==
np
.
uint16
:
...
...
@@ -240,7 +240,7 @@ class TestSumOpExclusive1(OpTest):
def
init_dtype
(
self
):
self
.
dtype
=
self
.
dtype_
=
np
.
float64
def
set
_enable_cinn
(
self
):
def
if
_enable_cinn
(
self
):
pass
def
set_attrs_input_output
(
self
):
...
...
@@ -346,7 +346,7 @@ class TestSumOpReverseExclusive(OpTest):
self
.
prim_op_type
=
"prim"
self
.
python_api
=
cumsum_wrapper
self
.
public_python_api
=
paddle
.
cumsum
self
.
set
_enable_cinn
()
self
.
if
_enable_cinn
()
self
.
init_dtype
()
self
.
attrs
=
{
'axis'
:
2
,
...
...
@@ -378,7 +378,7 @@ class TestSumOpReverseExclusive(OpTest):
def
init_dtype
(
self
):
self
.
dtype
=
self
.
dtype_
=
np
.
float64
def
set
_enable_cinn
(
self
):
def
if
_enable_cinn
(
self
):
pass
...
...
@@ -387,7 +387,7 @@ def create_test_fp16_class(parent, max_relative_error=1e-2):
def
init_dtype
(
self
):
self
.
dtype
=
self
.
dtype_
=
np
.
float16
def
set
_enable_cinn
(
self
):
def
if
_enable_cinn
(
self
):
pass
def
test_check_output
(
self
):
...
...
@@ -430,7 +430,7 @@ def create_test_bf16_class(parent):
self
.
dtype
=
np
.
uint16
self
.
dtype_
=
np
.
float32
def
set
_enable_cinn
(
self
):
def
if
_enable_cinn
(
self
):
self
.
enable_cinn
=
False
def
test_check_output
(
self
):
...
...
@@ -439,7 +439,9 @@ def create_test_bf16_class(parent):
def
test_check_grad
(
self
):
place
=
paddle
.
CUDAPlace
(
0
)
self
.
check_grad_with_place
(
place
,
[
"X"
],
"Out"
,
check_prim
=
True
)
self
.
check_grad_with_place
(
place
,
[
"X"
],
"Out"
,
check_prim
=
True
,
numeric_grad_delta
=
0.05
)
cls_name
=
"{}_{}"
.
format
(
parent
.
__name__
,
"BF16"
)
TestCumsumBF16Op
.
__name__
=
cls_name
...
...
test/legacy_test/test_group_norm_op.py
浏览文件 @
f7eb03c6
...
...
@@ -19,6 +19,7 @@ import parameterized as param
from
eager_op_test
import
(
OpTest
,
convert_float_to_uint16
,
convert_uint16_to_float
,
paddle_static_guard
,
skip_check_grad_ci
,
)
...
...
@@ -810,6 +811,96 @@ def apply_to_static(net, use_cinn):
[[
1e-3
,
1e-3
,
1e-3
]],
# gpu thresholds for static, jit, jit_cinn
None
,
),
(
'test0_bfp16'
,
(
2
,
100
,
3
,
5
),
1e-5
,
2
,
'NCHW'
,
places
,
'bfloat16'
,
[
[
1e-2
,
1e-2
,
1e-2
,
],
# cpu thresholds for static, jit, jit_cinn
[
1e-2
,
1e-2
,
1e-2
],
],
# gpu thresholds for static, jit, jit_cinn
None
,
),
(
'test1_bfp16'
,
(
2
,
100
,
3
,
5
),
1e-5
,
1
,
'NCHW'
,
places
,
'bfloat16'
,
[
[
1e-2
,
1e-2
,
1e-2
,
],
# cpu thresholds for static, jit, jit_cinn
[
1e-2
,
1e-2
,
1e-2
],
],
# gpu thresholds for static, jit, jit_cinn
None
,
),
(
'test2_bfp16'
,
(
2
,
100
,
3
,
5
),
1e-5
,
4
,
'NCHW'
,
places
,
'bfloat16'
,
[
[
1e-2
,
1e-2
,
1e-2
,
],
# cpu thresholds for static, jit, jit_cinn
[
1e-2
,
1e-2
,
1e-2
],
],
# gpu thresholds for static, jit, jit_cinn
None
,
),
(
'bigeps3_bfp16'
,
(
2
,
100
,
3
,
5
),
0.5
,
2
,
'NCHW'
,
places
,
'bfloat16'
,
[
[
1e-2
,
1e-2
,
1e-2
,
],
# cpu thresholds for static, jit, jit_cinn
[
1e-2
,
1e-2
,
1e-2
],
],
# gpu thresholds for static, jit, jit_cinn
None
,
),
(
'largedata_bfp16'
,
(
2
,
32
,
64
,
64
),
1e-5
,
4
,
'NCHW'
,
places
,
'bfloat16'
,
[
[
1e-2
,
1e-2
,
1e-2
,
],
# cpu thresholds for static, jit, jit_cinn
[
1e-2
,
1e-2
,
1e-2
],
],
# gpu thresholds for static, jit, jit_cinn
None
,
),
),
)
class
TestCompositeGroupNorm
(
unittest
.
TestCase
):
...
...
@@ -825,12 +916,23 @@ class TestCompositeGroupNorm(unittest.TestCase):
np
.
random
.
seed
(
1234
)
self
.
fwd_desire
=
[]
self
.
rev_desire
=
[]
self
.
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
self
.
scale
=
np
.
random
.
random
([
self
.
shape
[
1
]]).
astype
(
self
.
dtype
)
self
.
bias
=
np
.
random
.
random
([
self
.
shape
[
1
]]).
astype
(
self
.
dtype
)
if
self
.
dtype
!=
"bfloat16"
:
self
.
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
self
.
scale
=
np
.
random
.
random
([
self
.
shape
[
1
]]).
astype
(
self
.
dtype
)
self
.
bias
=
np
.
random
.
random
([
self
.
shape
[
1
]]).
astype
(
self
.
dtype
)
else
:
self
.
x
=
convert_float_to_uint16
(
np
.
random
.
random
(
self
.
shape
).
astype
(
"float32"
)
)
self
.
scale
=
convert_float_to_uint16
(
np
.
random
.
random
([
self
.
shape
[
1
]]).
astype
(
"float32"
)
)
self
.
bias
=
convert_float_to_uint16
(
np
.
random
.
random
([
self
.
shape
[
1
]]).
astype
(
"float32"
)
)
self
.
num_channels
=
self
.
shape
[
1
]
if
self
.
dtype
==
'float16'
:
if
self
.
dtype
in
[
'float16'
,
'bfloat16'
]
:
self
.
places
=
[]
if
paddle
.
is_compiled_with_cuda
():
self
.
places
.
append
(
paddle
.
CUDAPlace
(
0
))
...
...
@@ -879,7 +981,11 @@ class TestCompositeGroupNorm(unittest.TestCase):
paddle
.
assign
(
bias_
,
group_norm
.
bias
)
output
=
group_norm
(
input_
)
grad
=
paddle
.
grad
(
output
,
input_
)
if
self
.
dtype
==
"bfloat16"
:
output
=
paddle
.
cast
(
output
,
"float32"
)
grad
=
paddle
.
utils
.
map_structure
(
lambda
x
:
paddle
.
cast
(
x
,
"float32"
),
grad
)
return
output
,
grad
[
0
]
def
get_static_desire
(
self
,
place
):
...
...
@@ -923,7 +1029,6 @@ class TestCompositeGroupNorm(unittest.TestCase):
output
=
group_norm
(
input_
)
blocks
=
mp
.
blocks
names
=
dict
(
zip
(
blocks
[
0
].
ops
[
2
].
output_names
,
...
...
@@ -964,7 +1069,11 @@ class TestCompositeGroupNorm(unittest.TestCase):
)
paddle
.
disable_static
()
core
.
_set_prim_all_enabled
(
True
)
if
self
.
dtype
==
"bfloat16"
:
out_list
[
0
]
=
convert_uint16_to_float
(
out_list
[
0
])
i
=
3
for
i
in
range
(
3
,
len
(
out_list
)):
out_list
[
i
]
=
convert_uint16_to_float
(
out_list
[
i
])
return
out_list
[:
3
],
out_list
[
3
:]
def
test_static_comp
(
self
):
...
...
@@ -1051,6 +1160,11 @@ class TestCompositeGroupNorm(unittest.TestCase):
},
fetch_list
=
vars_list
+
[
grads
],
)
if
self
.
dtype
==
"bfloat16"
:
out_list
[
0
]
=
convert_uint16_to_float
(
out_list
[
0
])
i
=
3
for
i
in
range
(
3
,
len
(
out_list
)):
out_list
[
i
]
=
convert_uint16_to_float
(
out_list
[
i
])
fwd_actual
[
-
1
].
append
(
out_list
[
0
])
fwd_actual
[
-
1
].
append
(
out_list
[
1
])
fwd_actual
[
-
1
].
append
(
out_list
[
2
])
...
...
@@ -1075,12 +1189,14 @@ class TestCompositeGroupNorm(unittest.TestCase):
atol
=
self
.
threshold_list
[
i
][
0
]
rtol
=
self
.
threshold_list
[
i
][
0
]
for
j
in
range
(
len
(
self
.
static_fwd_desire
[
i
])):
# in float16 type, Y is float16, mean and var are float
16
# in float16 type, Y is float16, mean and var are float
32
# so check mean and var with float32 gpu threshold
if
self
.
dtype
==
'float16'
and
j
>
0
:
if
self
.
dtype
==
"float16"
and
j
>
0
:
atol
=
1e-5
rtol
=
1e-5
elif
self
.
dtype
==
"bfloat16"
and
j
>
0
:
atol
=
5e-3
rtol
=
5e-3
np
.
testing
.
assert_allclose
(
self
.
static_fwd_desire
[
i
][
j
],
fwd_actual
[
i
][
j
],
...
...
@@ -1091,13 +1207,6 @@ class TestCompositeGroupNorm(unittest.TestCase):
max_abs_diff
=
np
.
max
(
np
.
abs
(
self
.
static_fwd_desire
[
i
][
j
]
-
fwd_actual
[
i
][
j
])
)
print
(
self
.
shape
,
self
.
dtype
,
self
.
places
[
i
],
vars_name
[
j
],
max_abs_diff
,
)
# compare with eager_desire
np
.
testing
.
assert_allclose
(
self
.
fwd_desire
[
i
],
...
...
@@ -1121,14 +1230,6 @@ class TestCompositeGroupNorm(unittest.TestCase):
np
.
abs
(
self
.
static_rev_desire
[
i
][
j
]
-
rev_actual
[
i
][
j
])
)
print
(
self
.
shape
,
self
.
dtype
,
self
.
places
[
i
],
vars_name
[
j
+
3
],
max_abs_diff
,
)
np
.
testing
.
assert_allclose
(
self
.
static_rev_desire
[
i
][
j
],
rev_actual
[
i
][
j
],
...
...
@@ -1183,8 +1284,16 @@ class TestCompositeGroupNorm(unittest.TestCase):
net
=
apply_to_static
(
net
,
False
)
output
=
net
(
input_
)
grad
=
paddle
.
grad
(
output
,
input_
)
fwd_actual
.
append
(
output
.
numpy
())
rev_actual
.
append
(
grad
[
0
].
numpy
())
fwd_actual
.
append
(
convert_uint16_to_float
(
output
.
numpy
())
if
self
.
dtype
==
"bfloat16"
else
output
.
numpy
()
)
rev_actual
.
append
(
convert_uint16_to_float
(
grad
[
0
].
numpy
())
if
self
.
dtype
==
"bfloat16"
else
grad
[
0
].
numpy
()
)
for
i
in
range
(
len
(
self
.
places
)):
atol
=
self
.
threshold_list
[
i
][
1
]
...
...
@@ -1244,8 +1353,16 @@ class TestCompositeGroupNorm(unittest.TestCase):
net
=
apply_to_static
(
net
,
True
)
output
=
net
(
input_
)
grad
=
paddle
.
grad
(
output
,
input_
)
fwd_actual
.
append
(
output
.
numpy
())
rev_actual
.
append
(
grad
[
0
].
numpy
())
fwd_actual
.
append
(
convert_uint16_to_float
(
output
.
numpy
())
if
self
.
dtype
==
"bfloat16"
else
output
.
numpy
()
)
rev_actual
.
append
(
convert_uint16_to_float
(
grad
[
0
].
numpy
())
if
self
.
dtype
==
"bfloat16"
else
grad
[
0
].
numpy
()
)
i
=
0
for
place
in
self
.
places
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录