未验证 提交 f0c61017 编写于 作者: T tangwei12 提交者: GitHub

fix op error, test=develop (#24451) (#24539)

上级 27dee221
......@@ -84,7 +84,9 @@ class RecvOp : public framework::OperatorBase {
}
for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_recv " << outs[i] << "from " << epmap[i];
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
PADDLE_ENFORCE_NE(
rets[i]->Wait(), 0U,
platform::errors::ExecutionTimeout("internal error in RPCClient"));
VLOG(7) << "after sync_recv " << outs[i] << "from " << epmap[i];
}
}
......
......@@ -27,14 +27,23 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("X"),
"Input(X) of RefByTrainerIdOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("TrainerId"),
"Input(TrainerId) of RefByTrainerIdOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of RefByTrainerIdOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("TrainerId").size(), 1,
"TrainerId should be a scalar.");
PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true,
platform::errors::InvalidArgument(
"Input(X) of RefByTrainerIdOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("TrainerId"), true,
platform::errors::InvalidArgument(
"Input(TrainerId) of RefByTrainerIdOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of RefByTrainerIdOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("TrainerId").size(), 1,
platform::errors::InvalidArgument("TrainerId should be a scalar."));
// Out's shape is determined at runtime.
}
......
......@@ -38,7 +38,10 @@ class RefByTrainerIdKernel : public framework::OpKernel<T> {
} else {
trainer_id = *trainer_id_data;
}
PADDLE_ENFORCE_LT((size_t)trainer_id, in_list.size());
PADDLE_ENFORCE_LT((size_t)trainer_id, in_list.size(),
platform::errors::InvalidArgument(
"X' size must >= TrainerId: [%s], but received [%s]",
trainer_id, in_list.size()));
out->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*(in_list[trainer_id]), in_list[trainer_id]->place(),
out);
......
......@@ -59,7 +59,9 @@ class SendBarrierOp : public framework::OperatorBase {
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
PADDLE_ENFORCE_NE(
rets[i]->Wait(), 0U,
platform::errors::ExecutionTimeout("internal error in RPCClient"));
}
}
};
......
......@@ -83,7 +83,9 @@ class SendOp : public framework::OperatorBase {
}
for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i];
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
PADDLE_ENFORCE_NE(
rets[i]->Wait(), 0U,
platform::errors::ExecutionTimeout("internal error in RPCClient"));
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册