未验证 提交 a4d07bb9 编写于 作者: Z zhangbo9674 提交者: GitHub

[AMP] Add multi_precision for sgd (#38231)

上级 08941eda
...@@ -126,13 +126,24 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -126,13 +126,24 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Param", "(Tensor or SelectedRows) Input parameter"); AddInput("Param", "(Tensor or SelectedRows) Input parameter");
AddInput("LearningRate", "(Tensor) Learning rate of SGD"); AddInput("LearningRate", "(Tensor) Learning rate of SGD");
AddInput("Grad", "(Tensor or SelectedRows) Input gradient"); AddInput("Grad", "(Tensor or SelectedRows) Input gradient");
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
AddOutput("ParamOut", AddOutput("ParamOut",
"(Tensor or SelectedRows, same with Param) " "(Tensor or SelectedRows, same with Param) "
"Output parameter, should share the same memory with Param"); "Output parameter, should share the same memory with Param");
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable();
AddAttr<bool>( AddAttr<bool>(
"use_mkldnn", "use_mkldnn",
"(bool, default false) Indicates if MKL-DNN kernel will be used") "(bool, default false) Indicates if MKL-DNN kernel will be used")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
SGD operator SGD operator
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/sgd_op.h" #include "paddle/fluid/operators/optimizers/sgd_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
...@@ -21,14 +22,19 @@ namespace operators { ...@@ -21,14 +22,19 @@ namespace operators {
namespace { namespace {
template <typename T> template <typename T, typename MT>
__global__ void SGDKernel(const T* g, const T* p, const T* learning_rate, __global__ void SGDKernelMT(const T* param, const T* grad,
const int num, T* p_out) { const T* learning_rate, const int num, T* param_out,
T lr = learning_rate[0]; const MT* master_param, MT* master_param_out) {
MT lr = static_cast<MT>(learning_rate[0]);
CUDA_KERNEL_LOOP(i, num) { CUDA_KERNEL_LOOP(i, num) {
T g_data = g[i]; MT p_data = master_param ? master_param[i] : static_cast<MT>(param[i]);
T p_data = p[i]; MT g_data = static_cast<MT>(grad[i]);
p_out[i] = p_data - lr * g_data; p_data = p_data - lr * g_data;
param_out[i] = static_cast<T>(p_data);
if (master_param_out) {
master_param_out[i] = p_data;
}
} }
} }
...@@ -63,30 +69,48 @@ class SGDOpKernel<platform::CUDADeviceContext, T> ...@@ -63,30 +69,48 @@ class SGDOpKernel<platform::CUDADeviceContext, T>
"but the received is %s", "but the received is %s",
ctx.InputNames("Param").front(), ctx.InputNames("Param").front(),
paddle::framework::ToTypeName(param_var->Type()))); paddle::framework::ToTypeName(param_var->Type())));
using paddle::framework::Tensor;
using MPDType = typename details::MPTypeTrait<T>::Type;
auto* param = ctx.Input<framework::Tensor>("Param"); auto* param = ctx.Input<framework::Tensor>("Param");
auto* param_out = ctx.Output<framework::Tensor>("ParamOut"); auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate"); auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto* grad_var = ctx.InputVar("Grad"); auto* grad_var = ctx.InputVar("Grad");
const bool multi_precision = ctx.Attr<bool>("multi_precision");
const Tensor* master_param = nullptr;
Tensor* master_param_out = nullptr;
if (multi_precision) {
bool has_master =
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
PADDLE_ENFORCE_EQ(has_master, true,
platform::errors::InvalidArgument(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"));
master_param = ctx.Input<framework::Tensor>("MasterParam");
master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
}
const MPDType* master_in_data =
multi_precision ? master_param->data<MPDType>() : nullptr;
MPDType* master_out_data =
multi_precision
? master_param_out->mutable_data<MPDType>(ctx.GetPlace())
: nullptr;
// Actually, all tensors are LoDTensor except SelectedRows. // Actually, all tensors are LoDTensor except SelectedRows.
if (grad_var->IsType<framework::LoDTensor>()) { if (grad_var->IsType<framework::LoDTensor>()) {
param_out->mutable_data<T>(ctx.GetPlace());
auto* grad = ctx.Input<framework::Tensor>("Grad"); auto* grad = ctx.Input<framework::Tensor>("Grad");
// LOG(ERROR) << "grad";
// LOG(ERROR) << ctx.op().Input("Grad");
auto* grad_data = grad->data<T>();
// LOG(ERROR) << "param";
auto* param_data = param->data<T>();
// LOG(ERROR) << "fin";
auto* param_out_data = param_out->data<T>();
int block = 512; int block = 512;
int grid = (param->numel() + block - 1) / block; int grid = (param->numel() + block - 1) / block;
SGDKernel<T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>( SGDKernelMT<
grad_data, param_data, learning_rate->data<T>(), param->numel(), T, MPDType><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
param_out_data); param->data<T>(), grad->data<T>(), learning_rate->data<T>(),
param->numel(), param_out->mutable_data<T>(ctx.GetPlace()),
master_in_data, master_out_data);
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
// TODO(qijun): In Sparse SGD operator, in-place update is enforced. // TODO(qijun): In Sparse SGD operator, in-place update is enforced.
......
...@@ -79,6 +79,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -79,6 +79,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"Beta2Pow", "MasterParam"}}, "Beta2Pow", "MasterParam"}},
{"sparse_attention", {"sparse_attention",
{"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}}, {"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}},
{"sgd", {"Param", "LearningRate", "Grad", "MasterParam"}},
}; };
// NOTE(zhiqiu): Like op_ins_map. // NOTE(zhiqiu): Like op_ins_map.
...@@ -125,6 +126,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = { ...@@ -125,6 +126,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"adamw", {"adamw",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}}, "MasterParamOut"}},
{"sgd", {"ParamOut", "MasterParamOut"}},
{"lamb", {"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}}, "MasterParamOut"}},
...@@ -142,7 +144,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = { ...@@ -142,7 +144,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
// especially in declarative mode. // especially in declarative mode.
// For those OPs, we need to manually specify the outs need to pass in this map. // For those OPs, we need to manually specify the outs need to pass in this map.
std::map<std::string, std::set<std::string>> op_passing_outs_map = { std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"sgd", {"ParamOut"}}, {"sgd", {"ParamOut", "MasterParamOut"}},
{"adam", {"adam",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}}, "MasterParamOut"}},
......
...@@ -1296,6 +1296,7 @@ class SGDOptimizer(Optimizer): ...@@ -1296,6 +1296,7 @@ class SGDOptimizer(Optimizer):
parameter_list=None, parameter_list=None,
regularization=None, regularization=None,
grad_clip=None, grad_clip=None,
multi_precision=False,
name=None): name=None):
assert learning_rate is not None assert learning_rate is not None
super(SGDOptimizer, self).__init__( super(SGDOptimizer, self).__init__(
...@@ -1306,26 +1307,86 @@ class SGDOptimizer(Optimizer): ...@@ -1306,26 +1307,86 @@ class SGDOptimizer(Optimizer):
name=name) name=name)
self.type = "sgd" self.type = "sgd"
self._use_mkldnn = False self._use_mkldnn = False
self._multi_precision = multi_precision
self._master_weights = {}
def _create_master_weight(self, param):
if param.name in self._master_weights:
var = self._master_weights[param.name]
else:
assert isinstance(self.helper, LayerHelper)
var_name = param.name + "_fp32_master"
var_name = unique_name.generate(var_name)
var = layers.create_global_var(
name=var_name,
shape=param.shape,
value=0,
dtype='float32',
persistable=True)
block = self.helper.startup_program.global_block()
block.append_op(
type="cast",
inputs={"X": [param]},
outputs={"Out": [var]},
attrs={
"in_dtype": param.dtype,
"out_dtype": core.VarDesc.VarType.FP32
})
self._master_weights[param.name] = var
return var
def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
if isinstance(parameters, dict):
parameters = self._update_param_group(parameters)
# Create accumulator tensors for first and second moments
for p in parameters:
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
master_p = self._create_master_weight(p)
continue
if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision:
warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Adam optimizer."
)
@no_grad @no_grad
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
master_weight = (self._master_weights[param_and_grad[0].name]
if find_master else None)
lr = self._create_param_lr(param_and_grad) lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
_C_ops.sgd(param_and_grad[0], lr, param_and_grad[1], _C_ops.sgd(param_and_grad[0], lr, param_and_grad[1], master_weight,
param_and_grad[0]) param_and_grad[0], master_weight)
return None return None
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
# create the optimize op # create the optimize op
inputs = {
"Param": param_and_grad[0],
"Grad": param_and_grad[1],
"LearningRate": lr
}
outputs = {"ParamOut": param_and_grad[0]}
attrs = {"multi_precision": find_master}
if find_master:
inputs["MasterParam"] = master_weight
outputs["MasterParamOut"] = master_weight
sgd_op = block.append_op( sgd_op = block.append_op(
type=self.type, type=self.type,
inputs={ inputs=inputs,
"Param": param_and_grad[0], outputs=outputs,
"Grad": param_and_grad[1], attrs=attrs,
"LearningRate": lr
},
attrs={"use_mkldnn": self._use_mkldnn},
outputs={"ParamOut": param_and_grad[0]},
stop_gradient=True) stop_gradient=True)
return sgd_op return sgd_op
......
...@@ -192,6 +192,7 @@ class TestSGDOpOptimizeSelectedRows(unittest.TestCase): ...@@ -192,6 +192,7 @@ class TestSGDOpOptimizeSelectedRows(unittest.TestCase):
class TestSGDOpWithLargeInput(unittest.TestCase): class TestSGDOpWithLargeInput(unittest.TestCase):
def runTest(self): def runTest(self):
paddle.enable_static()
data = fluid.layers.fill_constant(shape=[1], value=128, dtype='int64') data = fluid.layers.fill_constant(shape=[1], value=128, dtype='int64')
label = fluid.layers.fill_constant( label = fluid.layers.fill_constant(
shape=[1, 150], value=0.5, dtype='float32') shape=[1, 150], value=0.5, dtype='float32')
...@@ -291,5 +292,212 @@ class TestSGDV2(unittest.TestCase): ...@@ -291,5 +292,212 @@ class TestSGDV2(unittest.TestCase):
adam.clear_gradients() adam.clear_gradients()
class TestSGDMultiPrecision2_0(unittest.TestCase):
def dygraph_sgd_mp(self, mp):
paddle.disable_static()
paddle.seed(10)
paddle.set_device('gpu')
input = paddle.randn((2, 2))
model = paddle.nn.Linear(2, 2)
optimizer = paddle.optimizer.SGD(parameters=model.parameters(),
multi_precision=mp)
if mp == True:
model = paddle.amp.decorate(models=model, level='O2')
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
for idx in range(5):
if mp == True:
with paddle.amp.auto_cast(level='O2'):
output = model(input)
loss = paddle.mean(output)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
optimizer.clear_grad()
else:
output = model(input)
loss = paddle.mean(output)
optimizer.step()
optimizer.clear_grad()
return output, model.parameters()
def static_sgd_mp(self, mp):
paddle.enable_static()
paddle.seed(10)
np.random.seed(10)
exe = paddle.static.Executor('gpu')
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
optimizer = paddle.optimizer.SGD(multi_precision=mp)
if mp:
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
use_pure_fp16=True,
use_fp16_guard=False)
with paddle.static.program_guard(train_program, startup_program):
if mp:
data = paddle.static.data(
shape=[2, 2], name='X', dtype='float16')
else:
data = paddle.static.data(
shape=[2, 2], name='X', dtype='float32')
hidden = paddle.static.nn.fc(x=data, size=10)
loss = paddle.fluid.layers.mean(hidden)
optimizer.minimize(loss)
exe.run(startup_program)
if mp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
out = []
for idx in range(5):
loss_data, = exe.run(train_program,
feed={"X": x},
fetch_list=[loss.name])
out.append(loss_data)
return out
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
"Test dygraph mode"
output1_dy, params1_dy = self.dygraph_sgd_mp(mp=True)
output2_dy, params2_dy = self.dygraph_sgd_mp(mp=False)
self.assertEqual(
np.allclose(
output1_dy.astype('float32').numpy(),
output2_dy.astype('float32').numpy(),
atol=1e-01),
True)
for idx in range(len(params1_dy)):
self.assertEqual(
np.allclose(
params1_dy[idx].astype('float32').numpy(),
params2_dy[idx].astype('float32').numpy(),
atol=1e-01),
True)
"Test static mode"
output1_st = self.static_sgd_mp(mp=True)
output2_st = self.static_sgd_mp(mp=False)
for idx in range(len(output1_st)):
self.assertEqual(
np.allclose(
output1_st[idx].astype('float32'),
output2_st[idx].astype('float32'),
atol=1e-01),
True)
class TestSGDMultiPrecision1_0(unittest.TestCase):
def dygraph_sgd_mp(self, mp):
paddle.disable_static()
paddle.seed(10)
paddle.set_device('gpu')
input = paddle.randn((2, 2))
model = paddle.nn.Linear(2, 2)
optimizer = paddle.fluid.optimizer.SGD(
learning_rate=0.001,
parameter_list=model.parameters(),
multi_precision=mp)
if mp == True:
model = paddle.amp.decorate(models=model, level='O2')
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
for idx in range(5):
if mp == True:
with paddle.amp.auto_cast(level='O2'):
output = model(input)
loss = paddle.mean(output)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
optimizer.clear_gradients()
else:
output = model(input)
loss = paddle.mean(output)
optimizer.minimize(loss)
optimizer.clear_gradients()
return output, model.parameters()
def static_sgd_mp(self, mp):
paddle.enable_static()
paddle.seed(10)
np.random.seed(10)
exe = paddle.static.Executor('gpu')
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001,
multi_precision=mp)
if mp:
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
use_pure_fp16=True,
use_fp16_guard=False)
with paddle.static.program_guard(train_program, startup_program):
if mp:
data = paddle.static.data(
shape=[2, 2], name='X', dtype='float16')
else:
data = paddle.static.data(
shape=[2, 2], name='X', dtype='float32')
hidden = paddle.static.nn.fc(x=data, size=10)
loss = paddle.fluid.layers.mean(hidden)
optimizer.minimize(loss)
exe.run(startup_program)
if mp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
out = []
for idx in range(5):
loss_data, = exe.run(train_program,
feed={"X": x},
fetch_list=[loss.name])
out.append(loss_data)
return out
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
"Test dygraph mode"
output1_dy, params1_dy = self.dygraph_sgd_mp(mp=True)
output2_dy, params2_dy = self.dygraph_sgd_mp(mp=False)
self.assertEqual(
np.allclose(
output1_dy.astype('float32').numpy(),
output2_dy.astype('float32').numpy(),
atol=1e-01),
True)
for idx in range(len(params1_dy)):
self.assertEqual(
np.allclose(
params1_dy[idx].astype('float32').numpy(),
params2_dy[idx].astype('float32').numpy(),
atol=1e-01),
True)
"Test static mode"
output1_st = self.static_sgd_mp(mp=True)
output2_st = self.static_sgd_mp(mp=False)
for idx in range(len(output1_st)):
self.assertEqual(
np.allclose(
output1_st[idx].astype('float32'),
output2_st[idx].astype('float32'),
atol=1e-01),
True)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -18,6 +18,10 @@ from ..fluid import framework ...@@ -18,6 +18,10 @@ from ..fluid import framework
from ..fluid.framework import Variable, name_scope from ..fluid.framework import Variable, name_scope
from ..fluid.dygraph import no_grad from ..fluid.dygraph import no_grad
from paddle import _C_ops from paddle import _C_ops
import warnings
from ..fluid.layer_helper import LayerHelper
from ..fluid import unique_name
from ..fluid import layers
__all__ = [] __all__ = []
...@@ -75,6 +79,7 @@ class SGD(Optimizer): ...@@ -75,6 +79,7 @@ class SGD(Optimizer):
parameters=None, parameters=None,
weight_decay=None, weight_decay=None,
grad_clip=None, grad_clip=None,
multi_precision=False,
name=None): name=None):
if learning_rate is None: if learning_rate is None:
raise ValueError("learning_rate is not set") raise ValueError("learning_rate is not set")
...@@ -85,27 +90,88 @@ class SGD(Optimizer): ...@@ -85,27 +90,88 @@ class SGD(Optimizer):
grad_clip=grad_clip, grad_clip=grad_clip,
name=name) name=name)
self.type = "sgd" self.type = "sgd"
self._multi_precision = multi_precision
self._master_weights = {}
def _create_master_weight(self, param):
if param.name in self._master_weights:
var = self._master_weights[param.name]
else:
assert isinstance(self.helper, LayerHelper)
var_name = param.name + "_fp32_master"
var_name = unique_name.generate(var_name)
var = layers.create_global_var(
name=var_name,
shape=param.shape,
value=0,
dtype='float32',
persistable=True)
block = self.helper.startup_program.global_block()
block.append_op(
type="cast",
inputs={"X": [param]},
outputs={"Out": [var]},
attrs={
"in_dtype": param.dtype,
"out_dtype": core.VarDesc.VarType.FP32
})
self._master_weights[param.name] = var
return var
def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
if isinstance(parameters, dict):
parameters = self._update_param_group(parameters)
# Create accumulator tensors for first and second moments
for p in parameters:
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
master_p = self._create_master_weight(p)
continue
if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision:
warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Adam optimizer."
)
@no_grad @no_grad
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
if isinstance(param_and_grad, dict): if isinstance(param_and_grad, dict):
param_and_grad = self._update_param_group(param_and_grad) param_and_grad = self._update_param_group(param_and_grad)
find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
master_weight = (self._master_weights[param_and_grad[0].name]
if find_master else None)
lr = self._create_param_lr(param_and_grad) lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
_C_ops.sgd(param_and_grad[0], lr, param_and_grad[1], _C_ops.sgd(param_and_grad[0], lr, param_and_grad[1], master_weight,
param_and_grad[0]) param_and_grad[0], master_weight)
return None return None
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
# create the optimize op # create the optimize op
inputs = {
"Param": param_and_grad[0],
"Grad": param_and_grad[1],
"LearningRate": lr
}
outputs = {"ParamOut": param_and_grad[0]}
attrs = {"multi_precision": find_master}
if find_master:
inputs["MasterParam"] = master_weight
outputs["MasterParamOut"] = master_weight
sgd_op = block.append_op( sgd_op = block.append_op(
type=self.type, type=self.type,
inputs={ inputs=inputs,
"Param": param_and_grad[0], outputs=outputs,
"Grad": param_and_grad[1], attrs=attrs,
"LearningRate": lr
},
outputs={"ParamOut": param_and_grad[0]},
stop_gradient=True) stop_gradient=True)
return sgd_op return sgd_op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册