未验证 提交 d014e29f 编写于 作者: C Chengmo 提交者: GitHub

fix error message (#27318)

* fix sgd/momentum/dpsgd/rmsprop error message
上级 35074963
......@@ -24,32 +24,45 @@ class DpsgdOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true,
"Input(Param) of DpsgdOp should not be null.");
platform::errors::NotFound(
"Input(Param) of DpsgdOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true,
"Input(Grad) of DpsgdOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("LearningRate"), true,
"Input(LearningRate) of DpsgdOp should not be null.");
platform::errors::NotFound(
"Input(Grad) of DpsgdOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("LearningRate"), true,
platform::errors::NotFound(
"Input(LearningRate) of DpsgdOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->GetInputsVarType("Param").front(),
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
platform::errors::InvalidArgument(
"The input var's type should be LoDTensor, but the received is %s",
ctx->GetInputsVarType("Param").front()));
PADDLE_ENFORCE_EQ(
ctx->GetInputsVarType("Grad").front(),
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
platform::errors::InvalidArgument(
"The input var's type should be LoDTensor, but the received is %s",
ctx->GetInputsVarType("Grad").front()));
PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), true,
"Output(ParamOut) of DpsgdOp should not be null.");
platform::errors::NotFound(
"Output(ParamOut) of DpsgdOp should not be null."));
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 dimension");
platform::errors::InvalidArgument(
"Learning rate should have 1 dimension. But Received "
"LearningRate's dims [%s].",
framework::product(lr_dims)));
auto param_dims = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Grad"),
"Param and Grad input of DpsgdOp should have same dimension");
platform::errors::InvalidArgument(
"Param and Grad input of DpsgdOp should have same dimension. But "
"received Para's dim [%s] and Grad's dim [%s].",
param_dims, ctx->GetInputDim("Grad")));
ctx->SetOutputDim("ParamOut", param_dims);
}
......
......@@ -28,17 +28,19 @@ class DpsgdOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type()));
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())));
const auto *grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true,
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Grad").front(),
framework::ToTypeName(grad_var->Type()));
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Grad").front(),
framework::ToTypeName(grad_var->Type())));
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
......
......@@ -40,43 +40,62 @@ class MomentumOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(param) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(grad) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Velocity"),
"Input(velocity) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of Momentum should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("VelocityOut"),
"Output(VelocityOut) of Momentum should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true,
platform::errors::NotFound(
"Input(param) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true,
platform::errors::NotFound(
"Input(grad) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Velocity"), true,
platform::errors::NotFound(
"Input(velocity) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("LearningRate"), true,
platform::errors::NotFound(
"Input(LearningRate) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(
ctx->GetInputsVarType("Param").front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The input var's type should be LoDTensor, but the received is %s",
ctx->GetInputsVarType("Param").front()));
PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), true,
platform::errors::NotFound(
"Output(ParamOut) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("VelocityOut"), true,
platform::errors::NotFound(
"Output(VelocityOut) of Momentum should not be null."));
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.");
platform::errors::InvalidArgument(
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning_rate should be a scalar");
platform::errors::InvalidArgument(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
framework::product(lr_dims)));
auto param_dim = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
"Param and Grad input of MomentumOp should have the same dimension.");
platform::errors::InvalidArgument(
"Param and Grad input of MomentumOp should have the same "
"dimension. But received Param's dim [%s] and Grad's dim [%s].",
param_dim, ctx->GetInputDim("Grad")));
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Velocity"),
"Param and Velocity of MomentumOp should have the same dimension.");
platform::errors::InvalidArgument(
"Param and Velocity of MomentumOp should have the same "
"dimension. But received Param's dim [%s] and Velocity [%s].",
param_dim, ctx->GetInputDim("Velocity")));
}
ctx->SetOutputDim("ParamOut", param_dim);
......@@ -398,10 +417,12 @@ class MomentumOpKernel : public framework::OpKernel<T> {
for_range(functor);
}
} else {
PADDLE_THROW(
string::Sprintf("MomentumOp only supports LoDTensor or SelectedRows "
"gradient, but the received Variable Type is %s",
framework::ToTypeName(grad_var->Type())));
PADDLE_ENFORCE_EQ(false, true,
platform::errors::PermissionDenied(
"Unsupported Variable Type of Grad "
"in MomentumOp. Excepted LodTensor "
"or SelectedRows, But received [%s]",
paddle::framework::ToTypeName(grad_var->Type())));
}
}
};
......
......@@ -22,47 +22,75 @@ class RmspropOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("MeanSquare"),
"Input(MeanSquare) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment"),
"Input(Moment) of RmspropOp should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(param_out) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MomentOut"),
"Output(MomentOut) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MeanSquareOut"),
"Output(MeanSquareOut) of RmspropOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true,
platform::errors::NotFound(
"Input(Param) of RmspropOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("MeanSquare"), true,
platform::errors::NotFound(
"Input(MeanSquare) of RmspropOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("LearningRate"), true,
platform::errors::NotFound(
"Input(LearningRate) of RmspropOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true,
platform::errors::NotFound(
"Input(Grad) of RmspropOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Moment"), true,
platform::errors::NotFound(
"Input(Moment) of RmspropOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Param").front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The input var's type in RmspropOp should be "
"LoDTensor, but the received is %s",
ctx->GetInputsVarType("Param").front()));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("ParamOut"), true,
platform::errors::NotFound(
"Output(param_out) of RmspropOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("MomentOut"), true,
platform::errors::NotFound(
"Output(MomentOut) of RmspropOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("MeanSquareOut"), true,
platform::errors::NotFound(
"Output(MeanSquareOut) of RmspropOp should not be null."));
if (ctx->Attrs().Get<bool>("centered")) {
PADDLE_ENFORCE(ctx->HasOutput("MeanGradOut"),
"Output(MeanGradOut) of RmspropOp should not be null.");
PADDLE_ENFORCE_EQ(
ctx->HasOutput("MeanGradOut"), true,
platform::errors::NotFound(
"Output(MeanGradOut) of RmspropOp should not be null."));
}
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
"Param and grad input of RmspropOp should have the same dimension.");
platform::errors::InvalidArgument(
"Param and grad input of RmspropOp should have the same dimension. "
"But received Param's dim [%s] and Grad's dim [%s].",
param_dim, ctx->GetInputDim("Grad")));
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Moment"),
"Param and Momentum input of RmspropOp "
"should have the same dimension.");
platform::errors::InvalidArgument(
"Param and Momentum input of RmspropOp "
"should have the same dimension. But received "
"Param's dim [%s] and Moment [%s]",
param_dim, ctx->GetInputDim("Moment")));
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("MeanSquare"),
"Param and Momentum input of RmspropOp "
"should have the same dimension.");
platform::errors::InvalidArgument(
"Param and Momentum input of RmspropOp "
"should have the same dimension. But received "
"Param's dim [%s] and MeanSquare [%s]",
param_dim, ctx->GetInputDim("MeanSquare")));
auto lr_dim = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1,
"Learning Rate should be a scalar.");
platform::errors::InvalidArgument(
"Learning Rate of RmspropOp should be a scalar. But "
"received LearningRate's dim [%s]",
framework::product(lr_dim)));
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("MomentOut", param_dim);
......
......@@ -148,11 +148,15 @@ class RmspropOpKernel : public framework::OpKernel<T> {
auto &mom_tensor = *ctx.Input<LoDTensor>("Moment");
PADDLE_ENFORCE_EQ(&p_tensor, param_out,
"Param and ParamOut must be the same Tensor");
platform::errors::InvalidArgument(
"Param and ParamOut must be the same Tensor"));
PADDLE_ENFORCE_EQ(&mom_tensor, moment_out,
"Moment and MomentOut must be the same Tensor");
PADDLE_ENFORCE_EQ(&ms_tensor, mean_square_out,
"MeanSquare and MeanSquareOut must be the same Tensor");
platform::errors::InvalidArgument(
"Moment and MomentOut must be the same Tensor"));
PADDLE_ENFORCE_EQ(
&ms_tensor, mean_square_out,
platform::errors::InvalidArgument(
"MeanSquare and MeanSquareOut must be the same Tensor"));
auto &dev_ctx = ctx.template device_context<DeviceContext>();
size_t limit = static_cast<size_t>(ms_tensor.numel());
......@@ -179,8 +183,10 @@ class RmspropOpKernel : public framework::OpKernel<T> {
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto mg = EigenVector<T>::Flatten(mg_tensor);
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
PADDLE_ENFORCE_EQ(
&mg_tensor, mean_grad_out,
platform::errors::InvalidArgument(
"MeanGrad and MeanGradOut must be the same Tensor"));
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
mg_out.device(place) = rho * mg + (1 - rho) * g;
......@@ -198,8 +204,10 @@ class RmspropOpKernel : public framework::OpKernel<T> {
if (centered) {
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
PADDLE_ENFORCE_EQ(
&mg_tensor, mean_grad_out,
platform::errors::InvalidArgument(
"MeanGrad and MeanGradOut must be the same Tensor"));
for_range(CenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()),
mean_square_out->mutable_data<T>(ctx.GetPlace()),
......@@ -233,8 +241,10 @@ class RmspropOpKernel : public framework::OpKernel<T> {
if (centered) {
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
PADDLE_ENFORCE_EQ(
&mg_tensor, mean_grad_out,
platform::errors::InvalidArgument(
"MeanGrad and MeanGradOut must be the same Tensor"));
for_range(CenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()),
mean_square_out->mutable_data<T>(ctx.GetPlace()),
......@@ -249,7 +259,12 @@ class RmspropOpKernel : public framework::OpKernel<T> {
rho, epsilon, momentum, grad_func));
}
} else {
PADDLE_THROW("RMSProp only supports LoDTensor or SelectedRows gradient");
PADDLE_ENFORCE_EQ(false, true,
platform::errors::PermissionDenied(
"Unsupported Variable Type of Grad "
"in RmspropOp. Excepted LodTensor "
"or SelectedRows, But received [%s]",
paddle::framework::ToTypeName(grad_var->Type())));
}
}
};
......
......@@ -22,23 +22,31 @@ class SGDOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of SGDOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true,
platform::errors::NotFound(
"Input(Param) of SGDOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Grad"), true,
platform::errors::NotFound("Input(Grad) of SGDOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("LearningRate"), true,
platform::errors::NotFound(
"Input(LearningRate) of SGDOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), true,
platform::errors::NotFound(
"Output(ParamOut) of SGDOp should not be null."));
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.");
platform::errors::NotFound(
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element");
platform::errors::InvalidArgument(
"Learning rate should have 1 element. But received "
"LearningRate dims [%s]",
framework::product(lr_dims)));
auto param_dim = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
......
......@@ -57,11 +57,12 @@ class SGDOpKernel<platform::CUDADeviceContext, T>
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type()));
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
paddle::framework::ToTypeName(param_var->Type())));
auto* param = ctx.Input<framework::Tensor>("Param");
auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
......@@ -91,18 +92,30 @@ class SGDOpKernel<platform::CUDADeviceContext, T>
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
// This manual optimization brings difficulty to track data dependency.
// It's better to find a more elegant solution.
PADDLE_ENFORCE_EQ(param, param_out);
PADDLE_ENFORCE_EQ(
param, param_out,
platform::errors::InvalidArgument(
"The input tensor Param of SgdOp should be equal with ParamOut "
"if variable's type is SelectedRows."));
auto* grad = ctx.Input<framework::SelectedRows>("Grad");
auto in_height = grad->height();
auto out_dims = param_out->dims();
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
PADDLE_ENFORCE_EQ(in_height, out_dims[0],
platform::errors::InvalidArgument(
"The input tensor Grad's height of SgdOp should be "
"equal with ParamOut's dims. But received Grad's "
"height [%s] and ParamOut's dims [%s]",
in_height, out_dims[0]));
auto& in_value = grad->value();
auto& in_rows = grad->rows();
int64_t in_row_numel = in_value.numel() / in_rows.size();
PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height);
PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height,
platform::errors::InvalidArgument(
"The in_row_numel of SgdOp should be equal with "
"param_out's numel / in_height."));
auto* in_data = in_value.data<T>();
auto* out_data = param_out->data<T>();
......@@ -118,7 +131,12 @@ class SGDOpKernel<platform::CUDADeviceContext, T>
out_data, in_row_numel, in_rows.size());
} else {
PADDLE_THROW("Unsupported Variable Type of Grad");
PADDLE_ENFORCE_EQ(false, true,
platform::errors::PermissionDenied(
"Unsupported Variable Type of Grad "
"in SgdOp. Excepted LodTensor or "
"SelectedRows, But received [%s]",
paddle::framework::ToTypeName(grad_var->Type())));
}
}
};
......
......@@ -44,8 +44,20 @@ class SGDOpKernel<platform::CPUDeviceContext, T>
if (grad_var->IsType<framework::LoDTensor>()) {
const auto *grad = ctx.Input<framework::Tensor>("Grad");
auto sz = param_out->numel();
PADDLE_ENFORCE_EQ(param->numel(), sz);
PADDLE_ENFORCE_EQ(grad->numel(), sz);
PADDLE_ENFORCE_EQ(param->numel(), sz,
platform::errors::InvalidArgument(
"The input tensor Param's numel of SgdOp "
"should be equal with ParamOut's numel. "
"But received Param's "
"numel = [%s], ParamOut's numel = [%s]",
param->numel(), sz));
PADDLE_ENFORCE_EQ(grad->numel(), sz,
platform::errors::InvalidArgument(
"The input tensor Grad's numel of SgdOp "
"should be equal with ParamOut's numel. "
"But received Grad's "
"numel = [%s], ParamOut's numel = [%s]",
grad->numel(), sz));
jit::sgd_attr_t attr(1, sz, 1, sz, 1);
const T *lr = learning_rate->data<T>();
......@@ -62,7 +74,11 @@ class SGDOpKernel<platform::CPUDeviceContext, T>
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
// This manual optimization brings difficulty to track data dependency.
// It's better to find a more elegant solution.
PADDLE_ENFORCE_EQ(param, param_out);
PADDLE_ENFORCE_EQ(param, param_out,
platform::errors::InvalidArgument(
"The input tensor Param of SgdOp "
"should be equal with ParamOut if variable's "
"type is SelectedRows. "));
const auto *grad = ctx.Input<framework::SelectedRows>("Grad");
auto &grad_rows = grad->rows();
......@@ -73,7 +89,13 @@ class SGDOpKernel<platform::CPUDeviceContext, T>
}
auto out_dims = param_out->dims();
PADDLE_ENFORCE_EQ(grad->height(), out_dims[0]);
PADDLE_ENFORCE_EQ(
grad->height(), out_dims[0],
platform::errors::InvalidArgument(
"The input tensor Grad's height of SgdOp "
"should be equal with ParamOut's dims. But received Grad's "
"height [%s] and ParamOut's dims [%s]",
grad->height(), out_dims[0]));
auto &grad_value = grad->value();
const T *param_data = param->data<T>();
const T *grad_data = grad_value.data<T>();
......@@ -87,19 +109,31 @@ class SGDOpKernel<platform::CPUDeviceContext, T>
attr.grad_height = grad_rows.size(); // note: it is not grad->height()
attr.grad_width = grad_value.numel() / attr.grad_height;
attr.selected_rows_size = grad_rows.size();
PADDLE_ENFORCE_EQ(attr.grad_width, attr.param_width);
PADDLE_ENFORCE_EQ(
attr.grad_width, attr.param_width,
platform::errors::InvalidArgument(
"The grad_value's numel of SgdOp "
"should be equal with param_out's numel. But received "
"grad_value's numel [%s] and param_out's numel [%s]",
attr.grad_width, attr.param_width));
auto sgd =
jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At(
attr);
sgd(lr, param_data, grad_data, rows_data, out_data, &attr);
} else {
PADDLE_THROW("Unsupported Variable Type of Grad");
PADDLE_ENFORCE_EQ(
false, true,
platform::errors::PermissionDenied(
"Unsupported Variable Type of Grad in SgdOp. Excepted "
"LodTensor or SelectedRows, But received [%s]",
paddle::framework::ToTypeName(grad_var->Type())));
}
} else if (param_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE(grad_var->IsType<framework::SelectedRows>(),
"when param "
"is SelectedRows, gradient should also be SelectedRows");
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::SelectedRows>(), true,
platform::errors::InvalidArgument(
"when param is SelectedRows, "
"gradient should also be SelectedRows"));
const auto &param = param_var->Get<framework::SelectedRows>();
auto *param_out = ctx.Output<framework::SelectedRows>("ParamOut");
const auto &grad = grad_var->Get<framework::SelectedRows>();
......@@ -112,27 +146,36 @@ class SGDOpKernel<platform::CPUDeviceContext, T>
auto param_row_width = param.value().dims()[1];
auto grad_row_width = grad.value().dims()[1];
VLOG(4) << " param rows: " << param.rows().size()
<< " param memory rows: " << param.value().dims()[0]
<< " grad rows: " << grad.rows().size()
<< " grad memory rows: " << grad.value().dims()[0];
PADDLE_ENFORCE_EQ(param_row_width, grad_row_width,
"param_row should have the same size with grad_row");
PADDLE_ENFORCE_EQ(
param_row_width, grad_row_width,
platform::errors::InvalidArgument(
"The param_row in SgdOP should have the same size with grad_row. "
"But received param_row's width is [%s], and grad_row's width is "
"[%s]",
param_row_width, grad_row_width));
const auto *lr = learning_rate->data<T>();
const auto *grad_data = grad.value().data<T>();
auto *out_data = param_out->mutable_value()->data<T>();
for (size_t i = 0; i < grad.rows().size(); i++) {
int64_t id_index = param_out->AutoGrownIndex(grad.rows()[i], false);
PADDLE_ENFORCE_GE(id_index, static_cast<int64_t>(0),
"id should be in the table");
PADDLE_ENFORCE_GE(
id_index, static_cast<int64_t>(0),
platform::errors::InvalidArgument(
"The id in SgdOp should be >= 0. But recevied id_index is [%s]",
id_index));
for (int64_t j = 0; j < grad_row_width; j++) {
out_data[id_index * grad_row_width + j] -=
lr[0] * grad_data[i * grad_row_width + j];
}
}
} else {
PADDLE_THROW("Unsupported Variable Type of Parameter");
PADDLE_ENFORCE_EQ(
false, true,
platform::errors::PermissionDenied(
"Unsupported Variable Type of Parameter in SgdOp. Excepted "
"LodTensor or SelectedRows, But received [%s]",
paddle::framework::ToTypeName(param_var->Type())));
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册