Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
841efcd4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
841efcd4
编写于
5月 05, 2023
作者:
C
co63oc
提交者:
GitHub
5月 05, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【Hackathon No.59】addmm 算子FP16/BF16单测完善 (#53111)
* Add addmm tests * Fix code
上级
74074a8d
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
313 addition
and
23 deletion
+313
-23
paddle/phi/kernels/funcs/blas/blas_impl.cu.h
paddle/phi/kernels/funcs/blas/blas_impl.cu.h
+68
-0
paddle/phi/kernels/funcs/blas/blas_impl.hip.h
paddle/phi/kernels/funcs/blas/blas_impl.hip.h
+65
-1
paddle/phi/kernels/gpu/addmm_grad_kernel.cu
paddle/phi/kernels/gpu/addmm_grad_kernel.cu
+8
-2
paddle/phi/kernels/gpu/addmm_kernel.cu
paddle/phi/kernels/gpu/addmm_kernel.cu
+8
-1
paddle/phi/kernels/impl/addmm_grad_kernel_impl.h
paddle/phi/kernels/impl/addmm_grad_kernel_impl.h
+94
-10
paddle/phi/kernels/impl/addmm_kernel_impl.h
paddle/phi/kernels/impl/addmm_kernel_impl.h
+4
-2
python/paddle/fluid/tests/unittests/test_addmm_op.py
python/paddle/fluid/tests/unittests/test_addmm_op.py
+59
-4
python/paddle/tensor/math.py
python/paddle/tensor/math.py
+7
-3
未找到文件。
paddle/phi/kernels/funcs/blas/blas_impl.cu.h
浏览文件 @
841efcd4
...
...
@@ -1316,6 +1316,74 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA,
});
}
template
<
>
template
<
>
inline
void
Blas
<
phi
::
GPUContext
>::
GEMM
(
bool
transA
,
bool
transB
,
int
M
,
int
N
,
int
K
,
phi
::
dtype
::
bfloat16
alpha
,
const
phi
::
dtype
::
bfloat16
*
A
,
int
lda
,
const
phi
::
dtype
::
bfloat16
*
B
,
int
ldb
,
phi
::
dtype
::
bfloat16
beta
,
phi
::
dtype
::
bfloat16
*
C
,
int
ldc
)
const
{
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t
cuTransA
=
transA
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cuTransB
=
transB
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
PADDLE_ENFORCE_GE
(
context_
.
GetComputeCapability
(),
80
,
phi
::
errors
::
InvalidArgument
(
"cublas bf16 gemm requires GPU compute capability >= 80,"
"but received %d"
,
context_
.
GetComputeCapability
()));
float
h_alpha
=
static_cast
<
float
>
(
alpha
);
float
h_beta
=
static_cast
<
float
>
(
beta
);
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DEFAULT
;
bool
use_tensor_op_math
=
context_
.
tensor_core_available
();
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
context_
.
TensorCoreCublasCallIfAvailable
([
&
](
cublasHandle_t
handle
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
phi
::
dynload
::
cublasGemmEx
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
B
,
CUDA_R_16BF
,
ldb
,
A
,
CUDA_R_16BF
,
lda
,
&
h_beta
,
C
,
CUDA_R_16BF
,
ldc
,
CUDA_R_32F
,
algo
));
});
#else
// raise error
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"cublasGemmEx with bfloat16 is not supported on cuda <= 11"
));
#endif // CUDA_VERSION >= 11000
}
template
<
>
template
<
typename
T
>
void
Blas
<
phi
::
GPUContext
>::
AXPY
(
int
n
,
T
alpha
,
const
T
*
x
,
T
*
y
)
const
{
...
...
paddle/phi/kernels/funcs/blas/blas_impl.hip.h
浏览文件 @
841efcd4
...
...
@@ -751,7 +751,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
context_
.
GetComputeCapability
(),
80
,
phi
::
errors
::
InvalidArgument
(
"rocblas
fp
16 gemm requires GPU compute capability >= 80,"
"rocblas
bf
16 gemm requires GPU compute capability >= 80,"
"but received %d"
,
context_
.
GetComputeCapability
()));
...
...
@@ -982,6 +982,70 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA,
});
}
template
<
>
template
<
>
inline
void
Blas
<
phi
::
GPUContext
>::
GEMM
(
bool
transA
,
bool
transB
,
int
M
,
int
N
,
int
K
,
phi
::
dtype
::
bfloat16
alpha
,
const
phi
::
dtype
::
bfloat16
*
A
,
int
lda
,
const
phi
::
dtype
::
bfloat16
*
B
,
int
ldb
,
phi
::
dtype
::
bfloat16
beta
,
phi
::
dtype
::
bfloat16
*
C
,
int
ldc
)
const
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
rocblas_operation
cuTransA
=
(
transA
==
CblasNoTrans
)
?
rocblas_operation_none
:
rocblas_operation_transpose
;
rocblas_operation
cuTransB
=
(
transB
==
CblasNoTrans
)
?
rocblas_operation_none
:
rocblas_operation_transpose
;
PADDLE_ENFORCE_GE
(
context_
.
GetComputeCapability
(),
80
,
phi
::
errors
::
InvalidArgument
(
"rocblas bf16 gemm requires GPU compute capability >= 80,"
"but received %d"
,
context_
.
GetComputeCapability
()));
float
h_alpha
=
static_cast
<
float
>
(
alpha
);
float
h_beta
=
static_cast
<
float
>
(
beta
);
rocblas_gemm_algo
algo
=
rocblas_gemm_algo_standard
;
context_
.
TensorCoreCublasCallIfAvailable
([
&
](
rocblas_handle
handle
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
phi
::
dynload
::
rocblas_gemm_ex
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
B
,
rocblas_datatype_bf16_r
,
ldb
,
A
,
rocblas_datatype_bf16_r
,
lda
,
&
h_beta
,
C
,
rocblas_datatype_bf16_r
,
ldc
,
C
,
rocblas_datatype_bf16_r
,
ldc
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
));
});
}
template
<
>
template
<
typename
T
>
void
Blas
<
phi
::
GPUContext
>::
AXPY
(
int
n
,
T
alpha
,
const
T
*
x
,
T
*
y
)
const
{
...
...
paddle/phi/kernels/gpu/addmm_grad_kernel.cu
浏览文件 @
841efcd4
...
...
@@ -18,5 +18,11 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/addmm_grad_kernel_impl.h"
PD_REGISTER_KERNEL
(
addmm_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
AddmmGradKernel
,
float
,
double
)
{}
PD_REGISTER_KERNEL
(
addmm_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
AddmmGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/addmm_kernel.cu
浏览文件 @
841efcd4
...
...
@@ -18,4 +18,11 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/addmm_kernel_impl.h"
PD_REGISTER_KERNEL
(
addmm
,
GPU
,
ALL_LAYOUT
,
phi
::
AddmmKernel
,
float
,
double
)
{}
PD_REGISTER_KERNEL
(
addmm
,
GPU
,
ALL_LAYOUT
,
phi
::
AddmmKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/impl/addmm_grad_kernel_impl.h
浏览文件 @
841efcd4
...
...
@@ -18,13 +18,34 @@ limitations under the License. */
#include "glog/logging.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/addmm_grad_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/for_range.h"
namespace
phi
{
template
<
typename
T
>
struct
CopyOrScaleFunctor
{
CopyOrScaleFunctor
(
const
float
scale
,
const
T
*
x
,
T
*
output
,
int64_t
numel
)
:
scale_
(
scale
),
x_
(
x
),
output_
(
output
),
numel_
(
numel
)
{}
HOSTDEVICE
void
operator
()(
int64_t
idx
)
const
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
const
MPType
mp_scale
=
static_cast
<
MPType
>
(
scale_
);
const
MPType
mp_x
=
static_cast
<
MPType
>
(
x_
[
idx
]);
output_
[
idx
]
=
static_cast
<
T
>
(
mp_scale
*
mp_x
);
}
private:
const
float
scale_
;
const
T
*
x_
;
T
*
output_
;
int64_t
numel_
;
};
template
<
typename
T
,
size_t
D
,
int
MajorType
=
Eigen
::
RowMajor
,
...
...
@@ -45,6 +66,13 @@ void AddmmGradKernel(const Context& dev_ctx,
DenseTensor
*
input_grad
,
DenseTensor
*
x_grad
,
DenseTensor
*
y_grad
)
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
bool
is_float16_or_bfloat16
=
false
;
if
(
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
||
std
::
is_same
<
T
,
phi
::
dtype
::
bfloat16
>::
value
)
{
is_float16_or_bfloat16
=
true
;
}
auto
in_dims
=
input
.
dims
();
if
(
input
.
dims
().
size
()
==
1
)
{
in_dims
=
{
1
,
input
.
dims
()[
0
]};
...
...
@@ -65,6 +93,7 @@ void AddmmGradKernel(const Context& dev_ctx,
}
auto
blas
=
funcs
::
GetBlas
<
Context
,
T
>
(
dev_ctx
);
auto
mt_blas
=
funcs
::
GetBlas
<
Context
,
MPType
>
(
dev_ctx
);
if
(
input_grad
)
{
dev_ctx
.
template
Alloc
<
T
>(
input_grad
);
total_elems
=
in_dims
[
0
]
*
in_dims
[
1
];
...
...
@@ -78,19 +107,60 @@ void AddmmGradKernel(const Context& dev_ctx,
Array2
(
input_grad
->
dims
()[
0
],
input_grad
->
dims
()[
1
]);
if
(
row_compress
&&
col_compress
)
{
eigen_dinput
.
device
(
place
)
=
eigen_dout
.
sum
().
eval
().
reshape
(
eigen_dinput_shape
);
if
(
!
is_float16_or_bfloat16
)
{
eigen_dinput
.
device
(
place
)
=
eigen_dout
.
sum
().
eval
().
reshape
(
eigen_dinput_shape
);
}
else
{
eigen_dinput
.
device
(
place
)
=
eigen_dout
.
template
cast
<
MPType
>()
.
sum
()
.
eval
()
.
reshape
(
eigen_dinput_shape
)
.
template
cast
<
T
>();
}
}
else
if
(
row_compress
)
{
eigen_dinput
.
device
(
place
)
=
eigen_dout
.
sum
(
Array1
(
0
)).
eval
().
reshape
(
eigen_dinput_shape
);
if
(
!
is_float16_or_bfloat16
)
{
eigen_dinput
.
device
(
place
)
=
eigen_dout
.
sum
(
Array1
(
0
)).
eval
().
reshape
(
eigen_dinput_shape
);
}
else
{
eigen_dinput
.
device
(
place
)
=
eigen_dout
.
template
cast
<
MPType
>()
.
sum
(
Array1
(
0
))
.
eval
()
.
reshape
(
eigen_dinput_shape
)
.
template
cast
<
T
>();
}
}
else
if
(
col_compress
)
{
eigen_dinput
.
device
(
place
)
=
eigen_dout
.
sum
(
Array1
(
1
)).
eval
().
reshape
(
eigen_dinput_shape
);
if
(
!
is_float16_or_bfloat16
)
{
eigen_dinput
.
device
(
place
)
=
eigen_dout
.
sum
(
Array1
(
1
)).
eval
().
reshape
(
eigen_dinput_shape
);
}
else
{
eigen_dinput
.
device
(
place
)
=
eigen_dout
.
template
cast
<
MPType
>()
.
sum
(
Array1
(
1
))
.
eval
()
.
reshape
(
eigen_dinput_shape
)
.
template
cast
<
T
>();
}
}
else
{
blas
.
VCOPY
(
total_elems
,
out_grad
.
data
<
T
>
(),
input_grad
->
data
<
T
>
());
// The VCOPY does not support the float16, bfloat16
if
(
!
is_float16_or_bfloat16
)
{
mt_blas
.
VCOPY
(
total_elems
,
out_grad
.
data
<
MPType
>
(),
input_grad
->
data
<
MPType
>
());
}
else
{
phi
::
funcs
::
ForRange
<
Context
>
for_range
(
dev_ctx
,
total_elems
);
CopyOrScaleFunctor
<
T
>
functor
(
1
,
out_grad
.
data
<
T
>
(),
input_grad
->
data
<
T
>
(),
total_elems
);
for_range
(
functor
);
}
}
blas
.
SCAL
(
total_elems
,
beta
,
input_grad
->
data
<
T
>
());
// The SCAL does not support the float16, bfloat16
if
(
!
is_float16_or_bfloat16
)
{
mt_blas
.
SCAL
(
total_elems
,
beta
,
input_grad
->
data
<
MPType
>
());
}
else
{
phi
::
funcs
::
ForRange
<
Context
>
for_range
(
dev_ctx
,
total_elems
);
CopyOrScaleFunctor
<
T
>
functor
(
beta
,
input_grad
->
data
<
T
>
(),
input_grad
->
data
<
T
>
(),
total_elems
);
for_range
(
functor
);
}
if
(
input
.
dims
().
size
()
==
1
)
{
input_grad
->
Resize
(
input
.
dims
());
...
...
@@ -101,14 +171,28 @@ void AddmmGradKernel(const Context& dev_ctx,
total_elems
=
x
.
dims
()[
0
]
*
x
.
dims
()[
1
];
// x_grad = out_grad * y'. x_grad: M x K, out_grad : M x N, y : K x N
blas
.
MatMul
(
out_grad
,
false
,
y
,
true
,
x_grad
);
blas
.
SCAL
(
total_elems
,
alpha
,
x_grad
->
data
<
T
>
());
if
(
!
is_float16_or_bfloat16
)
{
mt_blas
.
SCAL
(
total_elems
,
alpha
,
x_grad
->
data
<
MPType
>
());
}
else
{
phi
::
funcs
::
ForRange
<
Context
>
for_range
(
dev_ctx
,
total_elems
);
CopyOrScaleFunctor
<
T
>
functor
(
alpha
,
x_grad
->
data
<
T
>
(),
x_grad
->
data
<
T
>
(),
total_elems
);
for_range
(
functor
);
}
}
if
(
y_grad
)
{
dev_ctx
.
template
Alloc
<
T
>(
y_grad
);
total_elems
=
x
.
dims
()[
1
]
*
y
.
dims
()[
1
];
// y_grad = x' * out_grad. y_grad K x N, out_grad : M x N, x : M x K
blas
.
MatMul
(
x
,
true
,
out_grad
,
false
,
y_grad
);
blas
.
SCAL
(
total_elems
,
alpha
,
y_grad
->
data
<
T
>
());
if
(
!
is_float16_or_bfloat16
)
{
mt_blas
.
SCAL
(
total_elems
,
alpha
,
y_grad
->
data
<
MPType
>
());
}
else
{
phi
::
funcs
::
ForRange
<
Context
>
for_range
(
dev_ctx
,
total_elems
);
CopyOrScaleFunctor
<
T
>
functor
(
alpha
,
y_grad
->
data
<
T
>
(),
y_grad
->
data
<
T
>
(),
total_elems
);
for_range
(
functor
);
}
}
}
...
...
paddle/phi/kernels/impl/addmm_kernel_impl.h
浏览文件 @
841efcd4
...
...
@@ -112,17 +112,19 @@ void AddmmKernel(const Context& dev_ctx,
funcs
::
EigenBroadcast
<
std
::
decay_t
<
decltype
(
place
)
>
,
T
,
2
>::
Eval
(
place
,
eigen_out
,
eigen_input
,
bcast_dims
);
T
t_alpha
=
static_cast
<
T
>
(
alpha
);
T
t_beta
=
static_cast
<
T
>
(
beta
);
blas
.
GEMM
(
false
,
false
,
x_dims
[
0
],
y_dims
[
1
],
x_dims
[
1
],
alpha
,
t_
alpha
,
x
.
data
<
T
>
(),
x_dims
[
1
],
y
.
data
<
T
>
(),
y_dims
[
1
],
beta
,
t_
beta
,
out
->
data
<
T
>
(),
y_dims
[
1
]);
}
...
...
python/paddle/fluid/tests/unittests/test_addmm_op.py
浏览文件 @
841efcd4
...
...
@@ -15,11 +15,11 @@
import
unittest
import
numpy
as
np
from
eager_op_test
import
OpTest
from
eager_op_test
import
OpTest
,
convert_float_to_uint16
import
paddle
from
paddle
import
fluid
from
paddle.fluid
import
Program
,
program_guard
from
paddle.fluid
import
Program
,
core
,
program_guard
class
TestAddMMOp
(
OpTest
):
...
...
@@ -27,7 +27,6 @@ class TestAddMMOp(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"addmm"
self
.
python_api
=
paddle
.
addmm
self
.
dtype
=
np
.
float64
self
.
init_dtype_type
()
self
.
inputs
=
{
'Input'
:
np
.
random
.
random
((
100
,
1
)).
astype
(
self
.
dtype
),
...
...
@@ -40,7 +39,7 @@ class TestAddMMOp(OpTest):
}
def
init_dtype_type
(
self
):
pass
self
.
dtype
=
np
.
float64
def
test_check_output
(
self
):
self
.
check_output
()
...
...
@@ -58,6 +57,62 @@ class TestAddMMOp(OpTest):
self
.
check_grad
([
'Input'
],
'Out'
,
no_grad_set
=
None
)
class
TestAddMMFP16Op
(
TestAddMMOp
):
def
init_dtype_type
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
self
.
check_output
(
atol
=
1e-2
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not compiled with CUDA or not support bfloat16"
,
)
class
TestAddMMBF16Op
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"addmm"
self
.
python_api
=
paddle
.
addmm
self
.
init_dtype_type
()
self
.
inputs
=
{
'Input'
:
np
.
random
.
random
((
100
,
1
)).
astype
(
self
.
np_dtype
),
'X'
:
np
.
random
.
random
((
100
,
10
)).
astype
(
self
.
np_dtype
),
'Y'
:
np
.
random
.
random
((
10
,
20
)).
astype
(
self
.
np_dtype
),
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'Input'
]
+
np
.
dot
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])
}
self
.
inputs
[
'Input'
]
=
convert_float_to_uint16
(
self
.
inputs
[
'Input'
])
self
.
inputs
[
'X'
]
=
convert_float_to_uint16
(
self
.
inputs
[
'X'
])
self
.
inputs
[
'Y'
]
=
convert_float_to_uint16
(
self
.
inputs
[
'Y'
])
self
.
outputs
[
'Out'
]
=
convert_float_to_uint16
(
self
.
outputs
[
'Out'
])
self
.
place
=
core
.
CUDAPlace
(
0
)
def
init_dtype_type
(
self
):
self
.
dtype
=
np
.
uint16
self
.
np_dtype
=
np
.
float32
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
def
test_check_grad_normal
(
self
):
self
.
check_grad_with_place
(
self
.
place
,
[
'Input'
,
'X'
,
'Y'
],
'Out'
)
def
test_check_grad_x
(
self
):
self
.
check_grad_with_place
(
self
.
place
,
[
'X'
],
'Out'
,
no_grad_set
=
None
)
def
test_check_grad_y
(
self
):
self
.
check_grad_with_place
(
self
.
place
,
[
'Y'
],
'Out'
,
no_grad_set
=
None
)
def
test_check_grad_input
(
self
):
self
.
check_grad_with_place
(
self
.
place
,
[
'Input'
],
'Out'
,
no_grad_set
=
None
)
class
TestAddMMOpError
(
unittest
.
TestCase
):
# test error
def
test_errors
(
self
):
...
...
python/paddle/tensor/math.py
浏览文件 @
841efcd4
...
...
@@ -1959,10 +1959,14 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None):
helper
=
LayerHelper
(
"addmm"
,
**
locals
())
check_variable_and_dtype
(
input
,
'Input'
,
[
'float32'
,
'float64'
],
'addmm'
input
,
'Input'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'addmm'
)
check_variable_and_dtype
(
x
,
'X'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'addmm'
)
check_variable_and_dtype
(
y
,
'Y'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'addmm'
)
check_variable_and_dtype
(
x
,
'X'
,
[
'float32'
,
'float64'
],
'addmm'
)
check_variable_and_dtype
(
y
,
'Y'
,
[
'float32'
,
'float64'
],
'addmm'
)
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录