Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
cb138726
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
337
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cb138726
编写于
7月 02, 2020
作者:
W
Wilber
提交者:
GitHub
7月 02, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CUDA] [Kernel] Add fc cuda kernel. (#3873)
上级
0edd6cf1
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
419 addition
and
2 deletion
+419
-2
lite/core/mir/fusion/fc_fuse_pass.cc
lite/core/mir/fusion/fc_fuse_pass.cc
+0
-1
lite/kernels/cuda/CMakeLists.txt
lite/kernels/cuda/CMakeLists.txt
+3
-1
lite/kernels/cuda/fc_compute.cu
lite/kernels/cuda/fc_compute.cu
+176
-0
lite/kernels/cuda/fc_compute.h
lite/kernels/cuda/fc_compute.h
+45
-0
lite/kernels/cuda/fc_compute_test.cc
lite/kernels/cuda/fc_compute_test.cc
+195
-0
未找到文件。
lite/core/mir/fusion/fc_fuse_pass.cc
浏览文件 @
cb138726
...
...
@@ -40,5 +40,4 @@ REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass)
.
BindTargets
({
TARGET
(
kAny
)})
.
ExcludeTargets
({
TARGET
(
kXPU
),
TARGET
(
kX86
)})
.
ExcludeTargets
({
TARGET
(
kBM
)})
.
ExcludeTargets
({
TARGET
(
kCUDA
)})
.
BindKernel
(
"fc"
);
lite/kernels/cuda/CMakeLists.txt
浏览文件 @
cb138726
...
...
@@ -6,6 +6,7 @@ message(STATUS "compile with lite CUDA kernels")
# basic kernels
add_kernel
(
mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS
${
lite_kernel_deps
}
${
math_cuda
}
)
add_kernel
(
fc_compute_cuda CUDA basic SRCS fc_compute.cu DEPS
${
lite_kernel_deps
}
${
math_cuda
}
)
add_kernel
(
search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS
${
lite_kernel_deps
}
)
add_kernel
(
leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS
${
lite_kernel_deps
}
)
...
...
@@ -65,7 +66,8 @@ nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute
nv_test
(
elementwise_compute_cuda_test SRCS elementwise_compute_test.cc DEPS elementwise_compute_cuda
)
nv_test
(
softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_compute_cuda
)
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_test
(
mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda
)
nv_test
(
mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda
)
nv_test
(
fc_compute_cuda_test SRCS fc_compute_test.cc DEPS fc_compute_cuda
)
nv_test
(
dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda
)
nv_test
(
bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda
)
#nv_test(pool_compute_cuda_test SRCS pool_compute_test.cc DEPS pool_compute_cuda)
...
...
lite/kernels/cuda/fc_compute.cu
0 → 100644
浏览文件 @
cb138726
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/fc_compute.h"
#include <string>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
template
<
typename
T
>
struct
FcTypeTraits
;
template
<
>
struct
FcTypeTraits
<
float
>
{
typedef
float4
Type
;
};
template
<
typename
T
>
__global__
void
bias_v4
(
const
int
num
,
const
T
*
bias
,
T
*
data
,
int
K
)
{
CUDA_KERNEL_LOOP
(
index
,
num
)
{
int
bias_idx
=
index
%
K
;
const
T
bias_ptr
=
bias
[
bias_idx
];
const
T
in_ptr
=
data
[
index
];
T
packed_val
;
packed_val
.
x
=
in_ptr
.
x
+
bias_ptr
.
x
;
packed_val
.
y
=
in_ptr
.
y
+
bias_ptr
.
y
;
packed_val
.
z
=
in_ptr
.
z
+
bias_ptr
.
z
;
packed_val
.
w
=
in_ptr
.
w
+
bias_ptr
.
w
;
data
[
index
]
=
packed_val
;
}
}
template
<
typename
T
>
__global__
void
bias_relu_v4
(
const
int
num
,
const
T
*
bias
,
T
*
data
,
int
K
)
{
CUDA_KERNEL_LOOP
(
index
,
num
)
{
int
bias_idx
=
index
%
K
;
const
T
bias_ptr
=
bias
[
bias_idx
];
const
T
in_ptr
=
data
[
index
];
T
packed_val
;
packed_val
.
x
=
fmaxf
(
0.
f
,
in_ptr
.
x
+
bias_ptr
.
x
);
packed_val
.
y
=
fmaxf
(
0.
f
,
in_ptr
.
y
+
bias_ptr
.
y
);
packed_val
.
z
=
fmaxf
(
0.
f
,
in_ptr
.
z
+
bias_ptr
.
z
);
packed_val
.
w
=
fmaxf
(
0.
f
,
in_ptr
.
w
+
bias_ptr
.
w
);
data
[
index
]
=
packed_val
;
}
}
template
<
typename
T
>
__global__
void
general_bias
(
const
int
num
,
const
T
*
bias
,
T
*
data
)
{
int
offset
=
blockIdx
.
x
*
num
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num
;
i
+=
blockDim
.
x
)
{
T
temp
;
#if __CUDA_ARCH__ >= 350
temp
=
__ldg
(
data
+
offset
+
i
)
+
__ldg
(
bias
+
i
);
#else
temp
=
data
[
offset
+
i
]
+
bias
[
i
];
#endif
data
[
offset
+
i
]
=
temp
;
}
}
template
<
typename
T
>
__global__
void
general_relu_bias
(
const
int
num
,
const
T
*
bias
,
T
*
data
)
{
int
offset
=
blockIdx
.
x
*
num
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num
;
i
+=
blockDim
.
x
)
{
T
temp
;
#if __CUDA_ARCH__ >= 350
temp
=
__ldg
(
data
+
offset
+
i
)
+
__ldg
(
bias
+
i
);
#else
temp
=
data
[
offset
+
i
]
+
bias
[
i
];
#endif
data
[
offset
+
i
]
=
static_cast
<
int
>
(
temp
>
0
)
*
temp
;
}
}
template
<
typename
T
,
PrecisionType
PType
>
void
FcCompute
<
T
,
PType
>::
PrepareForRun
()
{
gemm_impl_
.
reset
(
new
lite
::
cuda
::
math
::
Gemm
<
T
,
T
>
);
}
template
<
typename
T
,
PrecisionType
PType
>
void
FcCompute
<
T
,
PType
>::
Run
()
{
auto
&
context
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
context
.
exec_stream
();
auto
&
param
=
this
->
template
Param
<
param_t
>();
const
auto
*
x_data
=
param
.
input
->
template
data
<
T
>();
const
auto
*
w_data
=
param
.
w
->
template
data
<
T
>();
const
auto
*
b_data
=
param
.
bias
?
param
.
bias
->
template
data
<
T
>()
:
nullptr
;
auto
out_vec
=
param
.
output
->
dims
().
Vectorize
();
out_vec
.
back
()
=
param
.
w
->
dims
()[
1
];
param
.
output
->
Resize
(
out_vec
);
auto
*
out_data
=
param
.
output
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
int
in_num_col_dims
=
param
.
in_num_col_dims
;
int
M
=
static_cast
<
int
>
(
param
.
input
->
dims
().
Slice
(
0
,
param
.
in_num_col_dims
).
production
());
int
K
=
static_cast
<
int
>
(
param
.
input
->
dims
()
.
Slice
(
param
.
in_num_col_dims
,
param
.
input
->
dims
().
size
())
.
production
());
int
K2
=
static_cast
<
int
>
(
param
.
w
->
dims
()[
0
]);
int
N
=
static_cast
<
int
>
(
param
.
w
->
dims
()[
1
]);
CHECK_EQ
(
K
,
K2
)
<<
"x_w must be equal with y_h"
;
CHECK
(
gemm_impl_
->
init
(
false
,
false
,
M
,
N
,
K
,
&
context
));
gemm_impl_
->
run
(
1.0
f
,
0.0
f
,
x_data
,
w_data
,
out_data
,
&
context
);
if
(
b_data
==
nullptr
)
{
return
;
}
std
::
string
activation_type
=
param
.
activation_type
;
if
(
N
%
4
==
0
)
{
const
int
threads
=
256
;
const
int
num
=
M
*
N
/
4
;
const
int
blocks
=
(
num
+
threads
-
1
)
/
threads
;
typedef
typename
FcTypeTraits
<
T
>::
Type
trans_type
;
const
auto
*
bias_ptr_v4
=
reinterpret_cast
<
const
trans_type
*>
(
b_data
);
auto
*
data_ptr_v4
=
reinterpret_cast
<
trans_type
*>
(
out_data
);
if
(
activation_type
==
"relu"
)
{
bias_relu_v4
<
trans_type
><<<
blocks
,
threads
,
0
,
stream
>>>
(
num
,
bias_ptr_v4
,
data_ptr_v4
,
N
/
4
);
}
else
if
(
activation_type
==
""
)
{
bias_v4
<
trans_type
><<<
blocks
,
threads
,
0
,
stream
>>>
(
num
,
bias_ptr_v4
,
data_ptr_v4
,
N
/
4
);
}
else
{
LOG
(
FATAL
)
<<
"not supported activation type: "
<<
activation_type
;
}
}
else
{
const
int
threads
=
256
;
const
int
blocks
=
M
;
if
(
activation_type
==
"relu"
)
{
general_relu_bias
<
T
><<<
blocks
,
threads
,
0
,
stream
>>>
(
N
,
b_data
,
out_data
);
}
else
if
(
activation_type
==
""
)
{
general_bias
<
T
><<<
blocks
,
threads
,
0
,
stream
>>>
(
N
,
b_data
,
out_data
);
}
else
{
LOG
(
FATAL
)
<<
"not supported activation type: "
<<
activation_type
;
}
}
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
using
FcFp32
=
paddle
::
lite
::
kernels
::
cuda
::
FcCompute
<
float
,
PRECISION
(
kFloat
)
>
;
REGISTER_LITE_KERNEL
(
fc
,
kCUDA
,
kFloat
,
kNCHW
,
FcFp32
,
def
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindInput
(
"Bias"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindInput
(
"W"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
Finalize
();
lite/kernels/cuda/fc_compute.h
0 → 100644
浏览文件 @
cb138726
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/cuda/math/gemm.h"
#include "lite/core/kernel.h"
#include "lite/operators/op_params.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
template
<
typename
T
,
PrecisionType
PType
>
class
FcCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
PType
>
{
public:
using
param_t
=
operators
::
FcParam
;
void
PrepareForRun
()
override
;
void
Run
()
override
;
virtual
~
FcCompute
()
=
default
;
private:
std
::
unique_ptr
<
lite
::
cuda
::
math
::
Gemm
<
T
,
T
>>
gemm_impl_
{
nullptr
};
};
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
lite/kernels/cuda/fc_compute_test.cc
0 → 100644
浏览文件 @
cb138726
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/fc_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/utils/float16.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
class
FcTest
:
public
::
testing
::
Test
{
protected:
FcTest
()
:
m
(
128
),
k
(
512
),
n
(
64
),
in_num_col_dims
(
1
),
act_type
(
"relu"
),
x_shape
({
m
,
k
}),
w_shape
({
k
,
n
}),
b_shape
({
n
}),
out_shape
({
m
,
n
})
{
X_gpu
.
Resize
(
lite
::
DDim
(
x_shape
));
X_ref
.
Resize
(
lite
::
DDim
(
x_shape
));
W_gpu
.
Resize
(
lite
::
DDim
(
w_shape
));
W_ref
.
Resize
(
lite
::
DDim
(
w_shape
));
b_gpu
.
Resize
(
lite
::
DDim
(
b_shape
));
b_ref
.
Resize
(
lite
::
DDim
(
b_shape
));
auto
x_ref_data
=
X_ref
.
mutable_data
<
float
>
();
auto
w_ref_data
=
W_ref
.
mutable_data
<
float
>
();
auto
b_ref_data
=
b_ref
.
mutable_data
<
float
>
();
// prepare input
for
(
int64_t
i
=
0
;
i
<
X_ref
.
numel
();
i
++
)
{
x_ref_data
[
i
]
=
static_cast
<
float
>
(
i
%
10
*
0.2
);
}
for
(
int64_t
i
=
0
;
i
<
W_ref
.
numel
();
i
++
)
{
w_ref_data
[
i
]
=
static_cast
<
float
>
(
i
%
10
*
0.2
);
}
for
(
int64_t
i
=
0
;
i
<
b_ref
.
numel
();
i
++
)
{
b_ref_data
[
i
]
=
static_cast
<
float
>
(
i
%
10
*
0.2
);
}
Out_ref
.
Resize
(
lite
::
DDim
(
out_shape
));
Out_cpu
.
Resize
(
Out_ref
.
dims
());
Out_gpu
.
Resize
(
Out_ref
.
dims
());
fc_cpu_base
(
&
X_ref
,
&
W_ref
,
&
b_ref
,
&
Out_ref
);
device_init
();
}
void
device_init
()
{
ctx
.
reset
(
new
KernelContext
);
cudaStreamCreate
(
&
stream
);
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
context
.
SetExecStream
(
stream
);
param
.
input
=
&
X_gpu
;
param
.
w
=
&
W_gpu
;
param
.
bias
=
&
b_gpu
;
param
.
in_num_col_dims
=
in_num_col_dims
;
param
.
activation_type
=
act_type
;
param
.
output
=
&
Out_gpu
;
}
void
float_data_init
()
{
X_gpu
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
X_ref
.
data
<
float
>
(),
X_gpu
.
dims
());
W_gpu
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
W_ref
.
data
<
float
>
(),
W_gpu
.
dims
());
b_gpu
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
b_ref
.
data
<
float
>
(),
b_gpu
.
dims
());
}
void
half_data_init
()
{
X_half
.
Resize
(
lite
::
DDim
(
x_shape
));
auto
x_half_data
=
X_half
.
mutable_data
<
half
>
();
for
(
int64_t
i
=
0
;
i
<
X_half
.
numel
();
i
++
)
{
x_half_data
[
i
]
=
half
(
lite
::
float16
(
X_ref
.
data
<
float
>
()[
i
]));
}
X_gpu
.
Assign
<
half
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_half_data
,
X_gpu
.
dims
());
W_half
.
Resize
(
W_ref
.
dims
());
auto
w_half_data
=
W_half
.
mutable_data
<
half
>
();
for
(
int64_t
i
=
0
;
i
<
W_half
.
numel
();
i
++
)
{
w_half_data
[
i
]
=
half
(
lite
::
float16
(
W_ref
.
data
<
float
>
()[
i
]));
}
W_gpu
.
Assign
<
half
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
w_half_data
,
W_gpu
.
dims
());
b_half
.
Resize
(
b_ref
.
dims
());
auto
b_half_data
=
b_half
.
mutable_data
<
half
>
();
for
(
int64_t
i
=
0
;
i
<
b_half
.
numel
();
i
++
)
{
b_half_data
[
i
]
=
half
(
lite
::
float16
(
b_ref
.
data
<
float
>
()[
i
]));
}
b_gpu
.
Assign
<
half
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
b_half_data
,
b_gpu
.
dims
());
}
void
fc_cpu_base
(
const
lite
::
Tensor
*
X
,
const
lite
::
Tensor
*
W
,
const
lite
::
Tensor
*
b
,
lite
::
Tensor
*
Out
)
{
const
float
*
data_in
=
X
->
data
<
float
>
();
const
float
*
bias
=
b
->
data
<
float
>
();
const
float
*
weights
=
W
->
data
<
float
>
();
float
*
data_out
=
Out
->
mutable_data
<
float
>
();
int
out_rows
=
X
->
dims
()[
0
];
int
in_cols
=
X
->
numel
()
/
out_rows
;
int
out_cols
=
W
->
numel
()
/
in_cols
;
int
index_out
;
for
(
int
i
=
0
;
i
<
out_rows
;
i
++
)
{
for
(
int
j
=
0
;
j
<
out_cols
;
j
++
)
{
index_out
=
i
*
out_cols
+
j
;
data_out
[
index_out
]
=
bias
?
bias
[
j
]
:
0
;
for
(
int
k
=
0
;
k
<
in_cols
;
k
++
)
{
data_out
[
index_out
]
+=
data_in
[
i
*
in_cols
+
k
]
*
weights
[
k
*
out_cols
+
j
];
}
if
(
act_type
==
"relu"
)
{
data_out
[
index_out
]
*=
static_cast
<
int
>
(
data_out
[
index_out
]
>
0
);
}
}
}
}
int
m
,
k
,
n
,
in_num_col_dims
;
std
::
string
act_type
;
std
::
vector
<
int64_t
>
x_shape
,
w_shape
,
b_shape
,
out_shape
;
lite
::
Tensor
X_ref
,
W_ref
,
b_ref
,
Out_ref
;
lite
::
Tensor
X_gpu
,
W_gpu
,
b_gpu
;
lite
::
Tensor
X_half
,
W_half
,
b_half
;
lite
::
Tensor
Out_cpu
,
Out_gpu
;
operators
::
FcParam
param
;
std
::
unique_ptr
<
KernelContext
>
ctx
;
cudaStream_t
stream
;
};
TEST_F
(
FcTest
,
TestFP32
)
{
float_data_init
();
FcCompute
<
float
,
PRECISION
(
kFloat
)
>
kernel
;
kernel
.
SetParam
(
param
);
kernel
.
SetContext
(
std
::
move
(
ctx
));
for
(
int
i
=
0
;
i
<
FLAGS_warmup
;
++
i
)
{
kernel
.
Launch
();
cudaDeviceSynchronize
();
}
auto
start
=
GetCurrentUS
();
kernel
.
PrepareForRun
();
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
kernel
.
Run
();
}
cudaDeviceSynchronize
();
auto
duration
=
(
GetCurrentUS
()
-
start
)
/
1000.0
;
LOG
(
INFO
)
<<
"fp32, warmup: "
<<
FLAGS_warmup
<<
", repeats: "
<<
FLAGS_repeats
<<
", spend "
<<
duration
/
FLAGS_repeats
<<
" ms in average."
;
CopySync
<
TARGET
(
kCUDA
)
>
(
Out_cpu
.
mutable_data
<
float
>
(),
Out_gpu
.
data
<
float
>
(),
sizeof
(
float
)
*
Out_gpu
.
numel
(),
IoDirection
::
DtoH
);
for
(
int
i
=
0
;
i
<
Out_gpu
.
numel
();
++
i
)
{
float
res
=
Out_cpu
.
data
<
float
>
()[
i
];
float
ref
=
Out_ref
.
data
<
float
>
()[
i
];
EXPECT_NEAR
(
fabs
(
res
-
ref
)
/
ref
,
0.
f
,
1e-5
);
}
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录