提交 30a85204 编写于 作者: K kavyasrinet 提交者: Yi Wang

Adding the doc format for AdaDelta, AdaMax, Adam, AdaGrad, BatchNorm, Clip, Cast and AUC (#5317)

* Adding the doc format for AdaDelta

* Updating the documentation for Adagrad, Adam and Adamax

* Updating the auc op

* Fix review comments

* Updating doc for Batch Norm

* Updating the cast op

* Updating the clip op

* Fixing review comment

* Fixing review comment:

* Small change to restart PR_CI
上级 2ac5d7d0
...@@ -64,16 +64,15 @@ class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -64,16 +64,15 @@ class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter"); AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient"); AddInput("Grad", "(Tensor) Input gradient");
AddInput("AvgSquaredGrad", AddInput("AvgSquaredGrad", "(Tensor) Input average of squared gradient");
"(Tensor) Input expectation of squared gradient");
AddInput("AvgSquaredUpdate", AddInput("AvgSquaredUpdate",
"(Tensor) Input expectation of squared parameter updates"); "(Tensor) Input average of squared parameter updates");
AddOutput("ParamOut", "(Tensor) Output parameter"); AddOutput("ParamOut", "(Tensor) Output parameter");
AddOutput("AvgSquaredGradOut", AddOutput("AvgSquaredGradOut",
"(Tensor) Output expectation of squared gradient"); "(Tensor) Output average of squared gradient");
AddOutput("AvgSquaredUpdateOut", AddOutput("AvgSquaredUpdateOut",
"(Tensor) Output expectation of squared parameter updates"); "(Tensor) Output average of squared parameter updates");
AddAttr<float>("rho", AddAttr<float>("rho",
"(float, default 0.95) Exponential decay rate " "(float, default 0.95) Exponential decay rate "
...@@ -84,22 +83,21 @@ class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -84,22 +83,21 @@ class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {
"numerical stability") "numerical stability")
.SetDefault(1.0e-6f); .SetDefault(1.0e-6f);
AddComment(R"DOC( AddComment(R"DOC(
Adadelta Updates Operator. Adadelta Optimizer.
This implements the Adadelta optimizer[1]. Adadelta is a per-dimension Adadelta optimizer is implemented as explained in:
adaptive learning rate method for gradient descent. https://arxiv.org/abs/1212.5701
Adadelta is a per-dimension adaptive learning rate method used
for gradient descent.
Adadelta updates: Adadelta updates are as follows:
avg_squared_grad_out = rho * avg_squared_grad + (1 - rho) * grad * grad $$avgSquaredGradOut = \rho * avgSquaredGrad + (1 - \rho) * grad * grad \break
param_update = - sqrt((avg_squared_update + epsilon) / paramUpdate = - $\sqrt{((avgSquaredUpdate + \epsilon) /
(avg_squared_grad_out + epsilon)) * grad (avgSquaredGrad_out + \epsilon))}$ * grad \break
avg_squared_update_out = rho * avg_squared_update + (1 - rho) * param_update**2 avgSquaredUpdateOut = \rho * avgSquaredUpdate + (1 - \rho) *
param_out = param + param_update {(paramUpdate)}^2 \break
paramOut = param + paramUpdate$$
References:
[1] ADADELTA: An Adaptive Learning Rate Method
https://arxiv.org/abs/1212.5701
)DOC"); )DOC");
} }
......
...@@ -73,12 +73,16 @@ class AdagradOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -73,12 +73,16 @@ class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
Adaptive Gradient Algorithm (Adagrad). Adaptive Gradient Algorithm (Adagrad).
moment_out = moment + grad * grad The update is done as follows:
param_out = param - learning_rate * grad / (sqrt(moment_out) + epsilon)
$$momentOut = moment + grad * grad \break
paramOut = param - learningRate * grad / ($\sqrt{momentOut}$ + \epsilon) \break
$$
The original paper(http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) The original paper(http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
does not have the epsilon attribute. It is added here for numerical stability does not have the epsilon attribute. It is added here in our implementation
by avoiding division by zero. as also proposed here: http://cs231n.github.io/neural-networks-3/#ada
for numerical stability to avoid the division by zero error.
)DOC"); )DOC");
} }
......
...@@ -51,8 +51,8 @@ class AdamOp : public framework::OperatorWithKernel { ...@@ -51,8 +51,8 @@ class AdamOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1, PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1,
"Beta1 power accumulator should have 1 dimension"); "Beta1 power accumulator should have 1 dimension");
auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow"); auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow");
PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1, PADDLE_ENFORCE_EQ(framework::product(beta2_pow_dims), 1,
"Beta1 power accumulator should have 1 dimension"); "Beta2 power accumulator should have 1 dimension");
auto param_dims = ctx->GetInputDim("Param"); auto param_dims = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -60,10 +60,10 @@ class AdamOp : public framework::OperatorWithKernel { ...@@ -60,10 +60,10 @@ class AdamOp : public framework::OperatorWithKernel {
"Param and Grad input of AdamOp should have same dimension"); "Param and Grad input of AdamOp should have same dimension");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment1"), param_dims, ctx->GetInputDim("Moment1"),
"Param and Moment input of AdamOp should have same dimension"); "Param and Moment1 input of AdamOp should have same dimension");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment2"), param_dims, ctx->GetInputDim("Moment2"),
"Param and InfNorm input of AdamOp should have same dimension"); "Param and Moment2 input of AdamOp should have same dimension");
ctx->SetOutputDim("ParamOut", param_dims); ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("Moment1Out", param_dims); ctx->SetOutputDim("Moment1Out", param_dims);
...@@ -103,23 +103,20 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -103,23 +103,20 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(1.0e-8f); .SetDefault(1.0e-8f);
AddComment(R"DOC( AddComment(R"DOC(
Adam Updates Operator. Adam Optimizer.
This implements the Adam optimizer from Section 2 of the Adam This implements the Adam optimizer from Section 2 of the Adam
paper[1]. Adam is a first-order gradient-based optimization paper : https://arxiv.org/abs/1412.6980.
method based on adaptive estimates of lower-order moments. Adam is a first-order gradient-based optimization method based on
adaptive estimates of lower-order moments.
Adam updates: Adam updates:
moment1_out = beta1 * moment1 + (1 − beta1) * grad $$moment_1_{out} = \beta_1 * moment_1 + (1 - \beta_1) * grad \break
moment2_out = beta2 * moment2 + (1 − beta2) * grad * grad moment_2_{out} = \beta_2 * moment_2 + (1 - \beta_2) * grad * grad \break
learning_rate_t = learning_rate_t * learningRate = learningRate *
sqrt(1 - beta2_pow) / (1 - beta1_pow) $\sqrt{(1 - \beta_2_{pow})}$ / (1 - \beta_1_{pow}) \break
param_out = param - learning_rate_t * moment1/ (sqrt(moment2) + epsilon) paramOut = param - learningRate * moment_1/ ($\sqrt{(moment_2)} + \epsilon)$$
References:
[1] Adam: A Method for Stochastic Optimization
(https://arxiv.org/abs/1412.6980)
)DOC"); )DOC");
} }
......
...@@ -99,26 +99,22 @@ class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -99,26 +99,22 @@ class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {
"Constant for numerical stability") "Constant for numerical stability")
.SetDefault(1.0e-8f); .SetDefault(1.0e-8f);
AddComment(R"DOC( AddComment(R"DOC(
Adamax Updates Operator. Adamax Optimizer.
This implements the Adamax optimizer from Section 7 of the Adam We implement the Adamax optimizer from Section 7 of the Adam
paper[1]. Adamax is a variant of the paper: https://arxiv.org/abs/1412.6980. Adamax is a variant of the
Adam algorithm based on the infinity norm. Adam algorithm based on the infinity norm.
Adamax updates: Adamax updates:
moment_out = beta1 * moment + (1 - beta1) * grad $$momentOut = \beta_1 * moment + (1 - \beta_1) * grad \break
inf_norm_out = max(beta2 * inf_norm + epsilon, abs(grad)) infNormOut = max(\beta_2 * infNorm + \epsilon, |grad|) \break
learning_rate_t = learning_rate/(1 - beta1_pow) learningRate = learningRate /(1 - \beta_1_{pow}) \break
param_out = param - learning_rate_t * moment_out/inf_norm_out paramOut = param - learningRate * momentPut / infNormOut$$
The original paper does not have an epsilon attribute. The original paper does not have an epsilon attribute.
However, it is added here for numerical stability However, it is added here for numerical stability to prevent the
by preventing divide by 0. division by 0 error.
References:
[1] Adam: A Method for Stochastic Optimization
(https://arxiv.org/abs/1412.6980)
)DOC"); )DOC");
} }
......
...@@ -23,11 +23,11 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -23,11 +23,11 @@ class AucOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input of Out must be initialized."); PADDLE_ENFORCE(ctx->HasInput("Out"), "Input of Out should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Indices"), PADDLE_ENFORCE(ctx->HasInput("Indices"),
"Input of Indices must be initialized."); "Input of Indices should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input of Label must be initialized."); "Input of Label should not be null.");
auto inference_height = ctx->GetInputDim("Out")[0]; auto inference_height = ctx->GetInputDim("Out")[0];
auto label_height = ctx->GetInputDim("Label")[0]; auto label_height = ctx->GetInputDim("Label")[0];
...@@ -52,20 +52,20 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -52,20 +52,20 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Out", AddInput("Out",
"A floating point 2D tensor, values are in the range [0, 1]." "A floating point 2D tensor, values are in the range [0, 1]."
"Each row is descend sorted. This input should be the" "Each row is sorted in descending order. This input should be the"
"output of topk." "output of topk."
"Typically, this tensor indicates the probability of each label"); "Typically, this tensor indicates the probability of each label");
AddInput("Indices", AddInput("Indices",
"An int 2D tensor, indicating the indices of original" "An int 2D tensor, indicating the indices of original"
"tensor before sort. Typically, this tensor indicates which label" "tensor before sorting. Typically, this tensor indicates which "
"the probability stands for."); "label the probability stands for.");
AddInput("Label", AddInput("Label",
"A 2D int tensor indicating the label of the training data." "A 2D int tensor indicating the label of the training data."
"The height is batch size and width is always 1."); "The height is batch size and width is always 1.");
// TODO(typhoonzero): support weight input // TODO(typhoonzero): support weight input
AddOutput("AUC", AddOutput("AUC",
"A scalar representing the " "A scalar representing the "
"current area-under-curve."); "current area-under-the-curve.");
AddAttr<std::string>("curve", "Curve type, can be 'ROC' or 'PR'.") AddAttr<std::string>("curve", "Curve type, can be 'ROC' or 'PR'.")
.SetDefault("ROC"); .SetDefault("ROC");
...@@ -74,19 +74,18 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -74,19 +74,18 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
" roc curve.") " roc curve.")
.SetDefault(200); .SetDefault(200);
AddComment( AddComment(R"DOC(
R"DOC(Computes the AUC according forward output and label. Area Under The Curve (AUC) Operator.
Best to use for binary classification evaluations.
This implementation computes the AUC according to forward output and label.
It is used very widely in binary classification evaluation. As a note:
If input label contains values other than 0 and 1, it will be cast If input label contains values other than 0 and 1, it will be cast
to bool. to bool. You can find the relevant definitions here:
You can find the definations here:
https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve
Possible curves are: There are two types of possible curves:
- ROC: Receiver operating characteristic 1. ROC: Receiver operating characteristic
- PR: Precision Recall 2. PR: Precision Recall
)DOC"); )DOC");
} }
}; };
......
...@@ -70,7 +70,7 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -70,7 +70,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5, PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
"Input x must have 3 to 5 dimensions."); "Input X must have 3 to 5 dimensions.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C);
...@@ -97,16 +97,16 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -97,16 +97,16 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The input tensor"); AddInput("X", "The input tensor");
AddInput("Scale", AddInput("Scale",
"Scale is a 1-dimensional tensor of size C " "Scale is a 1-dimensional tensor of size C "
"to be applied to the output"); "that is applied to the output");
AddInput("Bias", AddInput("Bias",
"Bias is a 1-dimensional tensor of size C " "Bias is a 1-dimensional tensor of size C "
"to be applied to the output"); "that is applied to the output");
AddInput("Mean", AddInput("Mean",
"The global mean (for training) or the " "The global mean (for training) or "
"estimated mean (for testing)"); "estimated mean (for testing)");
AddInput("Variance", AddInput("Variance",
"The global variance (for training) " "The global variance (for training) "
"or the estimated Variance (for testing)"); "or estimated Variance (for testing)");
AddOutput("Y", "result after normalization"); AddOutput("Y", "result after normalization");
AddOutput("MeanOut", AddOutput("MeanOut",
"Share memory with Mean. " "Share memory with Mean. "
...@@ -123,10 +123,14 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -123,10 +123,14 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
"will apply to output when training") "will apply to output when training")
.AsIntermediate(); .AsIntermediate();
AddComment(R"DOC( AddComment(R"DOC(
https://arxiv.org/pdf/1502.03167.pdf Batch Normalization.
NHWC `[batch, in_height, in_width, in_channels]` Batch Norm has been implemented as discussed in the paper:
NCHW `[batch, in_channels, in_height, in_width]` https://arxiv.org/pdf/1502.03167.pdf
Can be used as a normalizer function for conv2d and fully_connected operations.
The required data format for this layer is one of the following:
1. NHWC `[batch, in_height, in_width, in_channels]`
2. NCHW `[batch, in_channels, in_height, in_width]`
)DOC"); )DOC");
} }
......
...@@ -23,13 +23,17 @@ class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -23,13 +23,17 @@ class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker {
CastOpProtoMaker(framework::OpProto *proto, CastOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input tensor of cast op"); AddInput("X", "The input tensor of cast op");
AddOutput("Out", "the output tensor of cast op"); AddOutput("Out", "The output tensor of cast op");
AddComment(R"DOC(Cast operator.
cast the input tensor to other data type.
)DOC");
AddAttr<int>("out_data_type", "output data type"); AddAttr<int>("out_data_type", "output data type");
AddAttr<int>("in_data_type", "input data type"); AddAttr<int>("in_data_type", "input data type");
AddComment(R"DOC(
Cast Operator.
This Operator casts the input tensor to another data type and
returns tha Output Tensor.
)DOC");
} }
}; };
......
...@@ -49,8 +49,11 @@ class ClipOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -49,8 +49,11 @@ class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<AttrType>( AddAttr<AttrType>(
"max", "(float)Maximum value, above which element is replaced by max"); "max", "(float)Maximum value, above which element is replaced by max");
AddComment(R"DOC( AddComment(R"DOC(
Clip operator limits the given input within an interval. The interval is Clip Operator.
The clip operator limits the value of given input within an interval. The interval is
specified with arguments 'min' and 'max'. specified with arguments 'min' and 'max'.
)DOC"); )DOC");
} }
}; };
......
...@@ -4,10 +4,10 @@ To make the operator document itself more clear, we recommend operator names obe ...@@ -4,10 +4,10 @@ To make the operator document itself more clear, we recommend operator names obe
### OpProtoMaker names ### OpProtoMaker names
When defining an operator in Paddle, a corresponding [OpProtoMaker](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h#L170) (TODO: OpProtoMaker Doc)need to be defined. All the Input/Output and Attributes will write into the [OpProto](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L61) , and will be used in client language to create operator. When defining an operator in Paddle, a corresponding [OpProtoMaker](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h#L170) (TODO: OpProtoMaker Doc)need to be defined. All the Input/Output and Attributes will write into the [OpProto](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L61) , and will be used in client language to create operator.
- Input/Output. - Input/Output.
- Input/Output names follow the **CamelCase**. e.g. `X`, `Y`, `Matrix`, `LastAxisInMatrix`. Input/Output much more like Variables, we prefer to meaningful English words. - Input/Output names follow the **CamelCase**. e.g. `X`, `Y`, `Matrix`, `LastAxisInMatrix`. Input/Output much more like Variables, we prefer to meaningful English words.
- If an operator's Input/Output are tensors in math, not match to any meaningful words, input name should starts from `X`. e.g. `X`, `Y`, and output name should starts from `Out`. e.g. `Out`. This rule intends making operators which have few inputs/outputs unified. - If an operator's Input/Output are tensors in math, not match to any meaningful words, input name should starts from `X`. e.g. `X`, `Y`, and output name should starts from `Out`. e.g. `Out`. This rule intends making operators which have few inputs/outputs unified.
- Attribute. - Attribute.
...@@ -15,7 +15,7 @@ When defining an operator in Paddle, a corresponding [OpProtoMaker](https://gith ...@@ -15,7 +15,7 @@ When defining an operator in Paddle, a corresponding [OpProtoMaker](https://gith
- Comments. - Comments.
- Input/Output/Attr comment follow the format of **(type,default value) usage**, corresponding to which type it can be and how it will be used in the operator. e.g. Attribute in Accumulator`"gamma" `,`(float, default 1.0) Accumulation multiplier`. - Input/Output/Attr comment follow the format of **(type,default value) usage**, corresponding to which type it can be and how it will be used in the operator. e.g. Attribute in Accumulator`"gamma" `,`(float, default 1.0) Accumulation multiplier`.
- Operator comment format of` R"DOC(your comment here)DOC"`. You should explain the input/output of the operator first. If there is math calculation in this operator, you should write the equation in the comment. e.g. `Out = X + Y`. - Operator comment format of` R"DOC(your comment here)DOC"`. You should explain the input/output of the operator first. If there is math calculation in this operator, you should write the equation in the comment. e.g. `Out = X + Y`.
- Order. - Order.
- Follow the order of Input/Output, then Attribute, then Comments. See the example in best practice. - Follow the order of Input/Output, then Attribute, then Comments. See the example in best practice.
...@@ -24,7 +24,7 @@ When defining an operator in Paddle, a corresponding [OpProtoMaker](https://gith ...@@ -24,7 +24,7 @@ When defining an operator in Paddle, a corresponding [OpProtoMaker](https://gith
Here we give some examples to show how these rules will be used. Here we give some examples to show how these rules will be used.
- The operator has one input, one output. e.g.`relu`, inputs: `X`, outputs: `Out`. - The operator has one input, one output. e.g.`relu`, inputs: `X`, outputs: `Out`.
- The operator has two input, one output. e.g. `rowwise_add`, inputs : `X`, `Y`, outputs : `Out`. - The operator has two input, one output. e.g. `rowwise_add`, inputs : `X`, `Y`, outputs : `Out`.
...@@ -38,8 +38,8 @@ public: ...@@ -38,8 +38,8 @@ public:
AccumulateOpMaker(framework::OpProto *proto, AccumulateOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input tensor that has to be accumulated to the output tensor. AddInput("X", "(Tensor) The input tensor that has to be accumulated to the output tensor.
If the output size is not the same as input size, If the output size is not the same as input size,
the output tensor is first reshaped and initialized to zero, and only then, accumulation is done."); the output tensor is first reshaped and initialized to zero, and only then, accumulation is done.");
AddOutput("Out", "(Tensor) Accumulated output tensor"); AddOutput("Out", "(Tensor) Accumulated output tensor");
AddAttr<float>("gamma", "(float, default 1.0) Accumulation multiplier").SetDefault(1.0f); AddAttr<float>("gamma", "(float, default 1.0) Accumulation multiplier").SetDefault(1.0f);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册