未验证 提交 52a0a677 编写于 作者: G gouzil 提交者: GitHub

[Fluid] move lars_momentum_op InferShape to phi (#56749)

* move to phi

* fix

* fix type
上级 99ae88f1
......@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
......@@ -22,117 +24,6 @@ class LarsMomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("Param"), "Input", "Param", "LarsMomentum");
OP_INOUT_CHECK(ctx->HasInputs("Grad"), "Input", "Grad", "LarsMomentum");
OP_INOUT_CHECK(
ctx->HasInputs("Velocity"), "Input", "Velocity", "LarsMomentum");
OP_INOUT_CHECK(ctx->HasInputs("LearningRate"),
"Input",
"LearningRate",
"LarsMomentum");
OP_INOUT_CHECK(
ctx->HasOutputs("ParamOut"), "Output", "ParamOut", "LarsMomentum");
OP_INOUT_CHECK(ctx->HasOutputs("VelocityOut"),
"Output",
"VelocityOut",
"LarsMomentum");
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Param").front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The input var's type should be phi::DenseTensor, "
"but the received is %s",
ctx->GetInputsVarType("Param").front()));
auto lr_dims = ctx->GetInputsDim("LearningRate");
auto grad_dim = ctx->GetInputsDim("Grad");
auto param_dim = ctx->GetInputsDim("Param");
auto velocity_dim = ctx->GetInputsDim("Velocity");
auto lars_weight_decays =
ctx->Attrs().Get<std::vector<float>>("lars_weight_decay");
auto multi_precision = ctx->Attrs().Get<bool>("multi_precision");
PADDLE_ENFORCE_EQ(
param_dim.size(),
grad_dim.size(),
platform::errors::InvalidArgument(
"Input(Param) and Input(Grad) of LarsMomentumOp should have "
"same quantity. But number of Param is [%d] and Grad is [%d].",
param_dim.size(),
grad_dim.size()));
PADDLE_ENFORCE_EQ(
param_dim.size(),
velocity_dim.size(),
platform::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp should "
"have same quantity. But number of Param is [%d] and Velocity "
"is [%d].",
param_dim.size(),
velocity_dim.size()));
PADDLE_ENFORCE_EQ(
lars_weight_decays.size(),
grad_dim.size(),
platform::errors::InvalidArgument(
"Attr(Lars_weight_decay) and "
"Input(Grad) of LarsMomentumOp should have same quantity. "
"But number of Lars_weight_decay is [%d] and Grad is [%d].",
lars_weight_decays.size(),
grad_dim.size()));
if (multi_precision) {
OP_INOUT_CHECK(ctx->HasInputs("MasterParam"),
"Input",
"MasterParam",
"LarsMomentumMultiPrecision");
OP_INOUT_CHECK(ctx->HasOutputs("MasterParamOut"),
"Output",
"MasterParamOut",
"LarsMomentumMultiPrecision");
}
for (auto& lr_dim : lr_dims) {
PADDLE_ENFORCE_EQ(phi::product(lr_dim),
1,
platform::errors::InvalidArgument(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
phi::product(lr_dim)));
}
for (size_t i = 0; i < param_dim.size(); ++i) {
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Grad")[i],
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be phi::DenseTensor, "
"but the received is %s",
ctx->Inputs("Grad")[i].front(),
ctx->GetInputsVarType("Grad")[i]));
PADDLE_ENFORCE_EQ(
param_dim[i],
grad_dim[i],
platform::errors::InvalidArgument(
"Input(Param) and Input(Grad) input of LarsMomentumOp shall "
"have same dimension. But Param`s dim is [%s] and Grad's dim "
"is [%s].",
param_dim[i],
grad_dim[i]));
PADDLE_ENFORCE_EQ(
param_dim[i],
velocity_dim[i],
platform::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp shall have "
"same dimension. But Param dim [%s] differs with Velocity dim "
"[%s].",
param_dim[i],
velocity_dim[i]));
}
ctx->SetOutputsDim("ParamOut", param_dim);
ctx->SetOutputsDim("VelocityOut", param_dim);
if (ctx->HasOutputs("MasterParamOut")) {
ctx->SetOutputsDim("MasterParamOut", param_dim);
}
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -226,6 +117,10 @@ class LarsMomentumOpVarTypeInference : public framework::VarTypeInference {
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(lars_momentum,
LarsMomentumInferShapeFunctor,
PD_INFER_META(phi::LarsMomentumInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(
lars_momentum,
......@@ -233,4 +128,5 @@ REGISTER_OPERATOR(
ops::LarsMomentumOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::LarsMomentumOpVarTypeInference);
ops::LarsMomentumOpVarTypeInference,
LarsMomentumInferShapeFunctor);
......@@ -2577,6 +2577,92 @@ void LambInferMeta(const MetaTensor& param,
}
}
void LarsMomentumInferMeta(
const std::vector<const MetaTensor*>& param,
const std::vector<const MetaTensor*>& velocity,
const std::vector<const MetaTensor*>& learning_rate,
const std::vector<const MetaTensor*>& grad,
const paddle::optional<std::vector<const MetaTensor*>>& master_param,
const std::vector<float>& lars_weight_decay,
float mu,
float lars_coeff,
float epsilon,
bool multi_precision,
float rescale_grad,
std::vector<MetaTensor*> param_out,
std::vector<MetaTensor*> velocity_out,
std::vector<MetaTensor*> master_param_out) {
std::vector<DDim> lr_dims = GetMetaTensorsDim(learning_rate);
std::vector<DDim> grad_dim = GetMetaTensorsDim(grad);
std::vector<DDim> param_dim = GetMetaTensorsDim(param);
std::vector<DDim> velocity_dim = GetMetaTensorsDim(velocity);
PADDLE_ENFORCE_EQ(
param_dim.size(),
grad_dim.size(),
phi::errors::InvalidArgument(
"Input(Param) and Input(Grad) of LarsMomentumOp should have "
"same quantity. But number of Param is [%d] and Grad is [%d].",
param_dim.size(),
grad_dim.size()));
PADDLE_ENFORCE_EQ(
param_dim.size(),
velocity_dim.size(),
phi::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp should "
"have same quantity. But number of Param is [%d] and Velocity "
"is [%d].",
param_dim.size(),
velocity_dim.size()));
PADDLE_ENFORCE_EQ(
lars_weight_decay.size(),
grad_dim.size(),
phi::errors::InvalidArgument(
"Attr(Lars_weight_decay) and "
"Input(Grad) of LarsMomentumOp should have same quantity. "
"But number of Lars_weight_decay is [%d] and Grad is [%d].",
lars_weight_decay.size(),
grad_dim.size()));
for (auto& lr_dim : lr_dims) {
PADDLE_ENFORCE_EQ(phi::product(lr_dim),
1,
phi::errors::InvalidArgument(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
phi::product(lr_dim)));
}
for (size_t i = 0; i < param_dim.size(); ++i) {
PADDLE_ENFORCE_EQ(
param_dim[i],
grad_dim[i],
phi::errors::InvalidArgument(
"Input(Param) and Input(Grad) input of LarsMomentumOp shall "
"have same dimension. But Param`s dim is [%s] and Grad's dim "
"is [%s].",
param_dim[i],
grad_dim[i]));
PADDLE_ENFORCE_EQ(
param_dim[i],
velocity_dim[i],
phi::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp shall have "
"same dimension. But Param dim [%s] differs with Velocity dim "
"[%s].",
param_dim[i],
velocity_dim[i]));
}
for (size_t i = 0; i < param_out.size(); i++) {
param_out[i]->set_dims(param_dim[i]);
velocity_out[i]->set_dims(param_dim[i]);
if (master_param != nullptr) {
master_param_out[i]->set_dims(param_dim[i]);
}
}
}
void LLMInt8LinearInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
......
......@@ -459,6 +459,22 @@ void LambInferMeta(const MetaTensor& param,
MetaTensor* beta2_pow_out,
MetaTensor* master_param_outs);
void LarsMomentumInferMeta(
const std::vector<const MetaTensor*>& param,
const std::vector<const MetaTensor*>& velocity,
const std::vector<const MetaTensor*>& learning_rate,
const std::vector<const MetaTensor*>& grad,
const paddle::optional<std::vector<const MetaTensor*>>& master_param,
const std::vector<float>& lars_weight_decay,
float mu,
float lars_coeff,
float epsilon,
bool multi_precision,
float rescale_grad,
std::vector<MetaTensor*> param_out,
std::vector<MetaTensor*> velocity_out,
std::vector<MetaTensor*> master_param_out);
void LLMInt8LinearInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册