未验证 提交 a97d5a61 编写于 作者: T tangwei12 提交者: GitHub

fix op error, test=develop (#24451)

上级 7c17ed57
...@@ -84,7 +84,9 @@ class RecvOp : public framework::OperatorBase { ...@@ -84,7 +84,9 @@ class RecvOp : public framework::OperatorBase {
} }
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_recv " << outs[i] << "from " << epmap[i]; VLOG(7) << "before sync_recv " << outs[i] << "from " << epmap[i];
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient"); PADDLE_ENFORCE_NE(
rets[i]->Wait(), 0U,
platform::errors::ExecutionTimeout("internal error in RPCClient"));
VLOG(7) << "after sync_recv " << outs[i] << "from " << epmap[i]; VLOG(7) << "after sync_recv " << outs[i] << "from " << epmap[i];
} }
} }
......
...@@ -27,14 +27,23 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel { ...@@ -27,14 +27,23 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("X"), PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true,
"Input(X) of RefByTrainerIdOp should not be null."); platform::errors::InvalidArgument(
PADDLE_ENFORCE(ctx->HasInput("TrainerId"), "Input(X) of RefByTrainerIdOp should not be null."));
"Input(TrainerId) of RefByTrainerIdOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE_EQ(
"Output(Out) of RefByTrainerIdOp should not be null."); ctx->HasInput("TrainerId"), true,
PADDLE_ENFORCE_EQ(ctx->GetInputDim("TrainerId").size(), 1, platform::errors::InvalidArgument(
"TrainerId should be a scalar."); "Input(TrainerId) of RefByTrainerIdOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of RefByTrainerIdOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("TrainerId").size(), 1,
platform::errors::InvalidArgument("TrainerId should be a scalar."));
// Out's shape is determined at runtime. // Out's shape is determined at runtime.
} }
......
...@@ -38,7 +38,10 @@ class RefByTrainerIdKernel : public framework::OpKernel<T> { ...@@ -38,7 +38,10 @@ class RefByTrainerIdKernel : public framework::OpKernel<T> {
} else { } else {
trainer_id = *trainer_id_data; trainer_id = *trainer_id_data;
} }
PADDLE_ENFORCE_LT((size_t)trainer_id, in_list.size()); PADDLE_ENFORCE_LT((size_t)trainer_id, in_list.size(),
platform::errors::InvalidArgument(
"X' size must >= TrainerId: [%s], but received [%s]",
trainer_id, in_list.size()));
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*(in_list[trainer_id]), in_list[trainer_id]->place(), framework::TensorCopy(*(in_list[trainer_id]), in_list[trainer_id]->place(),
out); out);
......
...@@ -59,7 +59,9 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -59,7 +59,9 @@ class SendBarrierOp : public framework::OperatorBase {
} }
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient"); PADDLE_ENFORCE_NE(
rets[i]->Wait(), 0U,
platform::errors::ExecutionTimeout("internal error in RPCClient"));
} }
} }
}; };
......
...@@ -83,7 +83,9 @@ class SendOp : public framework::OperatorBase { ...@@ -83,7 +83,9 @@ class SendOp : public framework::OperatorBase {
} }
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i]; VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i];
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient"); PADDLE_ENFORCE_NE(
rets[i]->Wait(), 0U,
platform::errors::ExecutionTimeout("internal error in RPCClient"));
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i]; VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册