未验证 提交 6885d156 编写于 作者: G gongweibao 提交者: GitHub

fix errorenhancement test=release/1.8 (#24471)

上级 fc43909e
...@@ -57,7 +57,13 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -57,7 +57,13 @@ class PrefetchOp : public framework::OperatorBase {
} }
} }
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE_EQ(
rets[i]->Wait(), true,
platform::errors::Fatal(
"It's a fatal error of RPCClient that RPCClient can't "
"get the wait result. It may happen when trainers or "
"parameter servers exit un normally or the network "
"issue!"));
} }
} }
}; };
......
...@@ -24,34 +24,42 @@ class ProximalAdagradOp : public framework::OperatorWithKernel { ...@@ -24,34 +24,42 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"), OP_INOUT_CHECK(ctx->HasInput("Param"), "Input", "Param",
"Input(Param) of ProximalAdagradOp should not be null."); "ProximalAdagradOp");
PADDLE_ENFORCE(ctx->HasInput("Moment"), OP_INOUT_CHECK(ctx->HasInput("Moment"), "Input", "Moment",
"Input(Moment) of ProximalAdagradOp should not be null."); "ProximalAdagradOp");
PADDLE_ENFORCE(ctx->HasInput("Grad"), OP_INOUT_CHECK(ctx->HasInput("Grad"), "Input", "Grad", "ProximalAdagradOp");
"Input(Grad) of ProximalAdagradOp should not be null."); OP_INOUT_CHECK(ctx->HasInput("LearningRate"), "Input", "LearningRate",
PADDLE_ENFORCE( "ProximalAdagradOp");
ctx->HasInput("LearningRate"),
"Input(LearningRate) of ProximalAdagradOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("ParamOut"), "Output", "ParamOut",
"ProximalAdagradOp");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), OP_INOUT_CHECK(ctx->HasOutput("MomentOut"), "Output", "MomentOut",
"Output(ParamOut) of ProximalAdagradOp should not be null."); "ProximalAdagradOp");
PADDLE_ENFORCE(
ctx->HasOutput("MomentOut"),
"Output(MomentOut) of ProximalAdagradOp should not be null.");
auto param_dim = ctx->GetInputDim("Param"); auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"),
param_dim, ctx->GetInputDim("Grad"), platform::errors::InvalidArgument(
"Param and Grad of ProximalAdagrad Op must have same dimension."); "The shape of Intput(Param) should be equal to the "
"Input(Grad) of ProximalAdagrad Op. But received "
PADDLE_ENFORCE_EQ( "Input(Param).dimensions=[%s], "
param_dim, ctx->GetInputDim("Moment"), "Input(Grad).dimensions=[%s]",
"Param and Moment of ProximalAdagrad Op must have same dimension."); param_dim, ctx->GetInputDim("Grad")));
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Moment"),
platform::errors::InvalidArgument(
"The shape of Intput(Param) should be equal to the "
"Input(Moment) of ProximalAdagrad Op. But received "
"Input(Param).dimensions=[%s], "
"Input(Moment).dimensions=[%s]",
param_dim, ctx->GetInputDim("Moment")));
auto lr_dim = ctx->GetInputDim("LearningRate"); auto lr_dim = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1, PADDLE_ENFORCE_EQ(
"Learning Rate should be a scalar."); framework::product(lr_dim), 1,
platform::errors::InvalidArgument(
"Learning Rate should be a scalar. But received dimension[%s]",
lr_dim));
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("MomentOut", param_dim); ctx->SetOutputDim("MomentOut", param_dim);
......
...@@ -24,23 +24,29 @@ class ProximalGDOp : public framework::OperatorWithKernel { ...@@ -24,23 +24,29 @@ class ProximalGDOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"), OP_INOUT_CHECK(ctx->HasInput("Param"), "Input", "Param", "ProximalGDOp");
"Input(Param) of ProximalGDOp should not be null."); OP_INOUT_CHECK(ctx->HasInput("Grad"), "Input", "Grad", "ProximalGDOp");
PADDLE_ENFORCE(ctx->HasInput("Grad"), OP_INOUT_CHECK(ctx->HasInput("LearningRate"), "Input", "LearningRate",
"Input(Grad) of ProximalGDOp should not be null."); "ProximalGDOp");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of ProximalGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), OP_INOUT_CHECK(ctx->HasOutput("ParamOut"), "Output", "Paramout",
"Output(ParamOut) of ProximalGDOp should not be null."); "ProximalGDOp");
auto param_dim = ctx->GetInputDim("Param"); auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"), PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"),
"Two input of ProximalGD Op's dimension must be same."); platform::errors::InvalidArgument(
"The shape of Intput(Param) should be equal to the "
"Input(Grad) of ProximalGD Op. But received "
"Input(Param).dimensions=[%s], "
"Input(Grad).dimensions=[%s]",
param_dim, ctx->GetInputDim("Grad")));
auto lr_dim = ctx->GetInputDim("LearningRate"); auto lr_dim = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1, PADDLE_ENFORCE_EQ(
"Learning Rate should be a scalar."); framework::product(lr_dim), 1,
platform::errors::InvalidArgument(
"Learning Rate should be a scalar. But received dimmensions:[%s]",
lr_dim));
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册