未验证 提交 20fb01fb 编写于 作者: M MRXLT 提交者: GitHub

fix distributed error info (#27206)

* fix distributed error info

* bug fix; notest

* error info refine

* update error info

* update error info

* update error info

* bug fix

* bug fix

* bug fix

* bug fix
上级 295e87e4
...@@ -23,10 +23,11 @@ class CAllGatherOp : public framework::OperatorWithKernel { ...@@ -23,10 +23,11 @@ class CAllGatherOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AllGather");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Input", "Out", "AllGather");
int nranks = ctx->Attrs().Get<int>("nranks"); int nranks = ctx->Attrs().Get<int>("nranks");
PADDLE_ENFORCE_GE(nranks, 2, "nranks should be >=2"); PADDLE_ENFORCE_GE(nranks, 2, platform::errors::InvalidArgument(
"The value of nranks should be >=2."));
framework::DDim dim = ctx->GetInputDim("X"); framework::DDim dim = ctx->GetInputDim("X");
dim[0] = dim[0] * nranks; dim[0] = dim[0] * nranks;
if (dim[0] < 0) dim[0] = -1; if (dim[0] < 0) dim[0] = -1;
......
...@@ -37,7 +37,10 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -37,7 +37,10 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
int rid = ctx.Attr<int>("ring_id"); int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place); auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(nranks, comm->nranks()); PADDLE_ENFORCE_EQ(
nranks, comm->nranks(),
platform::errors::InvalidArgument("nranks: %s should equal to %s",
nranks, comm->nranks()));
framework::DDim out_dims = in->dims(); framework::DDim out_dims = in->dims();
out_dims[0] *= nranks; out_dims[0] *= nranks;
...@@ -59,7 +62,8 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -59,7 +62,8 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
send_buff, recv_buff, send_numel, static_cast<ncclDataType_t>(dtype), send_buff, recv_buff, send_numel, static_cast<ncclDataType_t>(dtype),
comm->comm(), stream)); comm->comm(), stream));
#else #else
PADDLE_THROW("PaddlePaddle should compile with GPU."); PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif #endif
} }
}; };
......
...@@ -150,13 +150,15 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> { ...@@ -150,13 +150,15 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
break; break;
default: default:
PADDLE_THROW("Invalid reduce type: %d", red_type); PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid reduce type: %d", red_type));
} }
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream)); sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream));
#else #else
PADDLE_THROW("PaddlePaddle should compile with GPU."); PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif #endif
} }
}; };
......
...@@ -69,7 +69,8 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> { ...@@ -69,7 +69,8 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
out->Resize(x->dims()); out->Resize(x->dims());
out->set_lod(x->lod()); out->set_lod(x->lod());
#else #else
PADDLE_THROW("PaddlePaddle should compile with GPU."); PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif #endif
} }
}; };
......
...@@ -50,7 +50,8 @@ class CCommInitAllOp : public framework::OperatorBase { ...@@ -50,7 +50,8 @@ class CCommInitAllOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
PADDLE_ENFORCE_EQ(is_gpu_place(place), true, PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
"CCommInitAllOp can run on gpu place only."); platform::errors::PreconditionNotMet(
"CCommInitAllOp can run on gpu place only"));
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
std::vector<int> devices = Attr<std::vector<int>>("devices"); std::vector<int> devices = Attr<std::vector<int>>("devices");
...@@ -62,7 +63,8 @@ class CCommInitAllOp : public framework::OperatorBase { ...@@ -62,7 +63,8 @@ class CCommInitAllOp : public framework::OperatorBase {
platform::NCCLCommContext::Instance().CreateAllNCCLComms(devices, rid); platform::NCCLCommContext::Instance().CreateAllNCCLComms(devices, rid);
#else #else
PADDLE_THROW("PaddlePaddle should compile with GPU."); PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif #endif
} }
}; };
......
...@@ -39,11 +39,13 @@ class CCommInitOp : public framework::OperatorBase { ...@@ -39,11 +39,13 @@ class CCommInitOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
PADDLE_ENFORCE(is_gpu_place(place), PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
"CCommInitOp can run on gpu place only."); platform::errors::PreconditionNotMet(
"CCommInitOp can run on gpu place only."));
auto var = scope.FindVar(Input("X")); auto var = scope.FindVar(Input("X"));
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input con not be empty."));
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
ncclUniqueId* nccl_id = var->GetMutable<ncclUniqueId>(); ncclUniqueId* nccl_id = var->GetMutable<ncclUniqueId>();
...@@ -57,7 +59,8 @@ class CCommInitOp : public framework::OperatorBase { ...@@ -57,7 +59,8 @@ class CCommInitOp : public framework::OperatorBase {
platform::NCCLCommContext::Instance().CreateNCCLComm( platform::NCCLCommContext::Instance().CreateNCCLComm(
nccl_id, nranks, rank_id, device_id, rid); nccl_id, nranks, rank_id, device_id, rid);
#else #else
PADDLE_THROW("PaddlePaddle should compile with GPU."); PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif #endif
} }
}; };
......
...@@ -61,9 +61,12 @@ class CGenNCCLIdOp : public framework::OperatorBase { ...@@ -61,9 +61,12 @@ class CGenNCCLIdOp : public framework::OperatorBase {
const platform::DeviceContext& dev_ctx) const { const platform::DeviceContext& dev_ctx) const {
std::string var_name = Output("Out"); std::string var_name = Output("Out");
auto var = scope->FindVar(var_name); auto var = scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Output can not be Null"));
auto id = var->GetMutable<ncclUniqueId>(); auto id = var->GetMutable<ncclUniqueId>();
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(id)); PADDLE_ENFORCE_EQ(platform::dynload::ncclGetUniqueId(id), 0,
platform::errors::InvalidArgument(
"ncclGetUniqueId failed with id %s", id));
std::vector<std::string> endpoint_list = std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints"); Attr<std::vector<std::string>>("other_endpoints");
......
...@@ -24,14 +24,15 @@ class CReduceScatterOp : public framework::OperatorWithKernel { ...@@ -24,14 +24,15 @@ class CReduceScatterOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ReduceScatter");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "X", "ReduceScatter");
int nranks = ctx->Attrs().Get<int>("nranks"); int nranks = ctx->Attrs().Get<int>("nranks");
framework::DDim dim = ctx->GetInputDim("X"); framework::DDim dim = ctx->GetInputDim("X");
if (dim[0] > 0 || dim[0] < -1) { if (dim[0] > 0 || dim[0] < -1) {
PADDLE_ENFORCE(dim[0] % nranks == 0, PADDLE_ENFORCE_EQ(
"dim[0] (%d) is not divisible by nranks(%d)", dim[0], dim[0] % nranks, 0,
nranks); platform::errors::InvalidArgument(
"dim[0] (%d) is not divisible by nranks(%d)", dim[0], nranks));
dim[0] /= nranks; dim[0] /= nranks;
} }
ctx->SetOutputDim("Out", dim); ctx->SetOutputDim("Out", dim);
......
...@@ -61,7 +61,8 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -61,7 +61,8 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel<T> {
send_buff, recv_buff, recv_numel, static_cast<ncclDataType_t>(dtype), send_buff, recv_buff, recv_numel, static_cast<ncclDataType_t>(dtype),
ncclSum, comm->comm(), stream)); ncclSum, comm->comm(), stream));
#else #else
PADDLE_THROW("PaddlePaddle should compile with GPU."); PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif #endif
} }
}; };
......
...@@ -30,7 +30,8 @@ template <typename T> ...@@ -30,7 +30,8 @@ template <typename T>
class CReduceScatterOpCPUKernel : public framework::OpKernel<T> { class CReduceScatterOpCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW("Unimplemented cpu kernel for CReduceScatterOp."); PADDLE_THROW(platform::errors::Unimplemented(
"Unimplemented cpu kernel for CReduceScatterOp."));
} }
}; };
......
...@@ -34,14 +34,16 @@ class CSyncCalcStreamOp : public framework::OperatorBase { ...@@ -34,14 +34,16 @@ class CSyncCalcStreamOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
PADDLE_ENFORCE(is_gpu_place(place), PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
"Sync stream op can run on gpu place only for now."); platform::errors::PreconditionNotMet(
"Sync stream op can run on gpu place only for now."));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto dev_ctx = static_cast<platform::CUDADeviceContext*>( auto dev_ctx = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(dev_ctx->stream())); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(dev_ctx->stream()));
#else #else
PADDLE_THROW("PaddlePaddle should compile with GPU."); PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif #endif
} }
}; };
......
...@@ -38,7 +38,8 @@ class CSyncCommStreamOp : public framework::OperatorBase { ...@@ -38,7 +38,8 @@ class CSyncCommStreamOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
PADDLE_ENFORCE_EQ(is_gpu_place(place), true, PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
"Sync stream op can run on gpu place only for now."); platform::errors::PreconditionNotMet(
"Sync stream op can run on gpu place only for now."));
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
int ring_id = Attr<int>("ring_id"); int ring_id = Attr<int>("ring_id");
...@@ -46,7 +47,8 @@ class CSyncCommStreamOp : public framework::OperatorBase { ...@@ -46,7 +47,8 @@ class CSyncCommStreamOp : public framework::OperatorBase {
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream(); platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#else #else
PADDLE_THROW("PaddlePaddle should compile with GPU."); PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif #endif
} }
}; };
......
...@@ -130,7 +130,11 @@ class AsyncSparseParamUpdateRecorder { ...@@ -130,7 +130,11 @@ class AsyncSparseParamUpdateRecorder {
std::vector<int64_t>* result) { std::vector<int64_t>* result) {
VLOG(3) << "GetAndClear param: " << param_name VLOG(3) << "GetAndClear param: " << param_name
<< " for trainer: " << trainer_id; << " for trainer: " << trainer_id;
PADDLE_ENFORCE_LT(trainer_id, trainer_num_); PADDLE_ENFORCE_LT(
trainer_id, trainer_num_,
platform::errors::InvalidArgument(
"The value of trainer_id: %s should less than trainer_num: %s.",
trainer_id, trainer_num_));
param_to_updated_rows_.at(param_name)[trainer_id] param_to_updated_rows_.at(param_name)[trainer_id]
->GetAndClear(result) ->GetAndClear(result)
.wait(); .wait();
......
...@@ -39,8 +39,8 @@ void* RdmaMemPool::Find(const std::string& varname, int64_t size) { ...@@ -39,8 +39,8 @@ void* RdmaMemPool::Find(const std::string& varname, int64_t size) {
auto info = it->second; auto info = it->second;
if (info.data_size != size) { if (info.data_size != size) {
pthread_rwlock_unlock(&access_); pthread_rwlock_unlock(&access_);
PADDLE_ENFORCE(false, "var:%s size:%ld != %ld", varname, size, PADDLE_THROW(platform::errors::InvalidArgument(
info.data_size); "var:%s size:%ld != %ld", varname, size, info.data_size));
return nullptr; return nullptr;
} }
...@@ -52,9 +52,9 @@ void RdmaMemPool::Register(const std::string& varname, void* data, ...@@ -52,9 +52,9 @@ void RdmaMemPool::Register(const std::string& varname, void* data,
int64_t data_size) { int64_t data_size) {
void* old = Find(varname, data_size); void* old = Find(varname, data_size);
if (old != nullptr) { if (old != nullptr) {
if (data != old) { PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE(false, "var:%s data:%ld != %ld", varname, data, old); data, old, platform::errors::InvalidArgument("var:%s data:%ld != %ld",
} varname, data, old));
VLOG(7) << "Find on rdma:" << varname << " data:" << data VLOG(7) << "Find on rdma:" << varname << " data:" << data
<< " data_size:" << data_size; << " data_size:" << data_size;
return; return;
......
...@@ -155,11 +155,15 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var, ...@@ -155,11 +155,15 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var,
return; return;
#endif #endif
} else { } else {
PADDLE_THROW("Serialize does not support type: %s", PADDLE_THROW(platform::errors::InvalidArgument(
typeid(var->Type()).name()); "Serialize does not support type: %s", typeid(var->Type()).name()));
} }
PADDLE_ENFORCE_NOT_NULL(payload); PADDLE_ENFORCE_NOT_NULL(
payload,
platform::errors::InvalidArgument(
"Not support type: %s, need to be LOD_TENSOR or SELECTED_ROWS.",
var->Type()));
// FIXME(gongwb): it seems that can use zero copy. // FIXME(gongwb): it seems that can use zero copy.
if (var_is_not_stable) { if (var_is_not_stable) {
...@@ -186,7 +190,10 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var, ...@@ -186,7 +190,10 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var,
if (var->IsType<framework::SelectedRows>()) { if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
PADDLE_ENFORCE(VectorElemName(slr->rows()) == typeid(int64_t).name()); PADDLE_ENFORCE_EQ(VectorElemName(slr->rows()), typeid(int64_t).name(),
platform::errors::InvalidArgument(
"Got wrong type: %s, expect type: int64_t",
VectorElemName(slr->rows())));
size_t rows_memory_size = slr->rows().size() * sizeof(int64_t); size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
IOBufWriter::Append(name, iobuf, IOBufWriter::Append(name, iobuf,
...@@ -202,7 +209,9 @@ void DeserializeFromIOBuf(const ::sendrecv::VariableMessage& meta, ...@@ -202,7 +209,9 @@ void DeserializeFromIOBuf(const ::sendrecv::VariableMessage& meta,
const framework::Scope* scope, const framework::Scope* scope,
framework::Variable** var, int* trainer_id) { framework::Variable** var, int* trainer_id) {
operators::distributed::BRPCVariableResponse resp(scope, &ctx); operators::distributed::BRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(iobuf, meta) == 0, "parse iobuf to tensor error!"); PADDLE_ENFORCE_EQ(
resp.Parse(iobuf, meta), 0,
platform::errors::InvalidArgument("parse iobuf to tensor error!"));
*var = resp.GetVar(); *var = resp.GetVar();
*trainer_id = resp.GetTrainerId(); *trainer_id = resp.GetTrainerId();
} }
......
...@@ -90,8 +90,9 @@ class BRPCServiceImpl : public SendRecvService { ...@@ -90,8 +90,9 @@ class BRPCServiceImpl : public SendRecvService {
void _SendVariable(google::protobuf::RpcController* cntl_butil, void _SendVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response, const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) { google::protobuf::Closure* done) {
PADDLE_ENFORCE(request_send_h_ != nullptr, PADDLE_ENFORCE_NOT_NULL(
"RequestSend handler should be registed first!"); request_send_h_, platform::errors::PreconditionNotMet(
"RequestSend handler should be registed first!"));
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil); brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
...@@ -103,8 +104,9 @@ class BRPCServiceImpl : public SendRecvService { ...@@ -103,8 +104,9 @@ class BRPCServiceImpl : public SendRecvService {
distributed::BRPCVariableResponse resp(request_send_h_->scope(), distributed::BRPCVariableResponse resp(request_send_h_->scope(),
request_send_h_->dev_ctx(), request_send_h_->dev_ctx(),
request_send_h_->distributed_mode()); request_send_h_->distributed_mode());
PADDLE_ENFORCE(resp.Parse(cntl->request_attachment(), *request) == 0, PADDLE_ENFORCE_EQ(
"parse iobuf to tensor error!"); resp.Parse(cntl->request_attachment(), *request), 0,
platform::errors::InvalidArgument("parse iobuf to tensor error!"));
auto scope = resp.GetMutableLocalScope(); auto scope = resp.GetMutableLocalScope();
auto invar = resp.GetVar(); auto invar = resp.GetVar();
...@@ -132,8 +134,9 @@ class BRPCServiceImpl : public SendRecvService { ...@@ -132,8 +134,9 @@ class BRPCServiceImpl : public SendRecvService {
void _GetVariable(google::protobuf::RpcController* cntl_butil, void _GetVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VariableMessage* response, const VariableMessage* request, VariableMessage* response,
google::protobuf::Closure* done) { google::protobuf::Closure* done) {
PADDLE_ENFORCE(request_get_h_ != nullptr, PADDLE_ENFORCE_NOT_NULL(
"RequestGet handler should be registed first!"); request_get_h_, platform::errors::PreconditionNotMet(
"RequestGet handler should be registed first!"));
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil); brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
...@@ -164,8 +167,10 @@ class BRPCServiceImpl : public SendRecvService { ...@@ -164,8 +167,10 @@ class BRPCServiceImpl : public SendRecvService {
const VariableMessage* request, const VariableMessage* request,
VariableMessage* response, VariableMessage* response,
google::protobuf::Closure* done) { google::protobuf::Closure* done) {
PADDLE_ENFORCE(request_getnobarrier_h_ != nullptr, PADDLE_ENFORCE_NOT_NULL(
"RequestGetNoBarrier handler should be registed first!"); request_getnobarrier_h_,
platform::errors::PreconditionNotMet(
"RequestGetNoBarrier handler should be registed first!"));
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil); brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
...@@ -204,8 +209,9 @@ class BRPCServiceImpl : public SendRecvService { ...@@ -204,8 +209,9 @@ class BRPCServiceImpl : public SendRecvService {
const VariableMessage* request, const VariableMessage* request,
VariableMessage* response, VariableMessage* response,
google::protobuf::Closure* done) { google::protobuf::Closure* done) {
PADDLE_ENFORCE(request_prefetch_h_ != nullptr, PADDLE_ENFORCE_NOT_NULL(request_prefetch_h_,
"kRequestPrefetch handler should be registed first!"); platform::errors::PreconditionNotMet(
"kRequestPrefetch handler should be registed first!");
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil); brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
...@@ -221,8 +227,9 @@ class BRPCServiceImpl : public SendRecvService { ...@@ -221,8 +227,9 @@ class BRPCServiceImpl : public SendRecvService {
distributed::BRPCVariableResponse resp( distributed::BRPCVariableResponse resp(
request_prefetch_h_->scope(), request_prefetch_h_->dev_ctx(), true); request_prefetch_h_->scope(), request_prefetch_h_->dev_ctx(), true);
PADDLE_ENFORCE(resp.Parse(cntl->request_attachment(), *request) == 0, PADDLE_ENFORCE_EQ(resp.Parse(cntl->request_attachment(), *request), 0,
"parse iobuf to tensor error!"); platform::errors::InvalidArgument(
"parse iobuf to tensor error!"));
auto scope = resp.GetMutableLocalScope(); auto scope = resp.GetMutableLocalScope();
auto invar = scope->FindVar(in_var_name); auto invar = scope->FindVar(in_var_name);
...@@ -248,9 +255,10 @@ class BRPCServiceImpl : public SendRecvService { ...@@ -248,9 +255,10 @@ class BRPCServiceImpl : public SendRecvService {
void _CheckpointNotify(google::protobuf::RpcController* cntl_butil, void _CheckpointNotify(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response, const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) { google::protobuf::Closure* done) {
PADDLE_ENFORCE( PADDLE_ENFORCE_NOT_NULL(
request_checkpoint_h_ != nullptr, request_checkpoint_h_,
"kRequestCheckpointNotify handler should be registed first!"); platform::errors::PreconditionNotMet(
"kRequestCheckpointNotify handler should be registed first!"));
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil); brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
...@@ -277,9 +285,10 @@ class BRPCServiceImpl : public SendRecvService { ...@@ -277,9 +285,10 @@ class BRPCServiceImpl : public SendRecvService {
const VariableMessage* request, const VariableMessage* request,
VariableMessage* response, VariableMessage* response,
google::protobuf::Closure* done) override { google::protobuf::Closure* done) override {
PADDLE_ENFORCE( PADDLE_ENFORCE_NOT_NULL(
request_get_monomer_handler_h_ != nullptr, request_get_monomer_handler_h_,
"kRequestGetMonomerVariable handler should be registed first!"); platform::errors::PreconditionNotMet(
"kRequestGetMonomerVariable handler should be registed first!"));
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil); brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
...@@ -309,9 +318,10 @@ class BRPCServiceImpl : public SendRecvService { ...@@ -309,9 +318,10 @@ class BRPCServiceImpl : public SendRecvService {
void GetMonomerBarrier(google::protobuf::RpcController* cntl_butil, void GetMonomerBarrier(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response, const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) override { google::protobuf::Closure* done) override {
PADDLE_ENFORCE( PADDLE_ENFORCE_NOT_NULL(
request_get_monomer_barrier_handler_h_ != nullptr, request_get_monomer_barrier_handler_h_,
"RequestGetMonomerBarrier handler should be registed first!"); platform::errors::PreconditionNotMet(
"RequestGetMonomerBarrier handler should be registed first!"));
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil); brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
......
...@@ -52,7 +52,8 @@ int BRPCVariableResponse::Parse(Source* source) { ...@@ -52,7 +52,8 @@ int BRPCVariableResponse::Parse(Source* source) {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) && meta_.type() == sendrecv::LOD_TENSOR) &&
meta_.varname() != "", meta_.varname() != "",
"meta info should be got first!"); platform::errors::PreconditionNotMet(
"meta info should be got first!"));
if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) { if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) {
return ret; return ret;
...@@ -60,7 +61,8 @@ int BRPCVariableResponse::Parse(Source* source) { ...@@ -60,7 +61,8 @@ int BRPCVariableResponse::Parse(Source* source) {
break; break;
} }
default: { default: {
PADDLE_ENFORCE(false, "not surpported %u fieldnumber", field); PADDLE_THROW(platform::errors::Unavailable(
"not surpported %u fieldnumber", field));
return ret; return ret;
} }
} }
......
...@@ -21,7 +21,6 @@ limitations under the License. */ ...@@ -21,7 +21,6 @@ limitations under the License. */
#include <thread> // NOLINT #include <thread> // NOLINT
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
...@@ -51,7 +50,8 @@ class GetMonomerHandler final : public RequestHandler { ...@@ -51,7 +50,8 @@ class GetMonomerHandler final : public RequestHandler {
VLOG(50) << "GetMonomerHandler recv " << var_name; VLOG(50) << "GetMonomerHandler recv " << var_name;
*outvar = scope->FindVar(var_name); *outvar = scope->FindVar(var_name);
PADDLE_ENFORCE(outvar != nullptr, "%s not found", var_name); PADDLE_ENFORCE_NOT_NULL(
outvar, platform::errors::NotFound("var: %s is not found.", var_name));
return true; return true;
} }
......
...@@ -58,14 +58,19 @@ template <typename T> ...@@ -58,14 +58,19 @@ template <typename T>
class BlockingQueue { class BlockingQueue {
public: public:
explicit BlockingQueue(size_t capacity) : capacity_(capacity) { explicit BlockingQueue(size_t capacity) : capacity_(capacity) {
PADDLE_ENFORCE_GT(capacity_, 0, "The capacity must be greater than 0."); PADDLE_ENFORCE_GT(capacity_, 0,
platform::errors::InvalidArgument(
"The capacity must be greater than 0."));
} }
bool Push(const T &elem) { bool Push(const T &elem) {
{ {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return queue_.size() < capacity_; }); cv_.wait(lock, [&] { return queue_.size() < capacity_; });
PADDLE_ENFORCE_LT(queue_.size(), capacity_); PADDLE_ENFORCE_LT(
queue_.size(), capacity_,
platform::errors::OutOfRange("The queue size: %s out of capacity:%s",
queue_.size(), capacity_));
queue_.push_back(elem); queue_.push_back(elem);
} }
cv_.notify_one(); cv_.notify_one();
...@@ -76,7 +81,10 @@ class BlockingQueue { ...@@ -76,7 +81,10 @@ class BlockingQueue {
{ {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return queue_.size() < capacity_; }); cv_.wait(lock, [&] { return queue_.size() < capacity_; });
PADDLE_ENFORCE_LT(queue_.size(), capacity_); PADDLE_ENFORCE_LT(
queue_.size(), capacity_,
platform::errors::OutOfRange("The queue size: %s out of capacity:%s",
queue_.size(), capacity_));
queue_.emplace_back(std::move(elem)); queue_.emplace_back(std::move(elem));
} }
cv_.notify_one(); cv_.notify_one();
...@@ -118,7 +126,8 @@ template <typename T> ...@@ -118,7 +126,8 @@ template <typename T>
inline void MergeVars(const std::string &var_name, inline void MergeVars(const std::string &var_name,
const std::vector<std::shared_ptr<Variable>> &vars, const std::vector<std::shared_ptr<Variable>> &vars,
Scope *scope, bool merge_add = true) { Scope *scope, bool merge_add = true) {
PADDLE_ENFORCE(!vars.empty(), "should have value to merge!"); PADDLE_ENFORCE_NE(vars.empty(), true, platform::errors::InvalidArgument(
"vector vars are empty."));
auto cpu_place = platform::CPUPlace(); auto cpu_place = platform::CPUPlace();
auto &var0 = vars[0]; auto &var0 = vars[0];
auto *out_var = scope->Var(var_name); auto *out_var = scope->Var(var_name);
...@@ -132,7 +141,9 @@ inline void MergeVars(const std::string &var_name, ...@@ -132,7 +141,9 @@ inline void MergeVars(const std::string &var_name,
// check the input dims // check the input dims
for (auto &var : vars) { for (auto &var : vars) {
auto &var_t = var->Get<framework::LoDTensor>(); auto &var_t = var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(var_t.dims(), dims, "should have the same dims"); PADDLE_ENFORCE_EQ(
var_t.dims(), dims,
platform::errors::InvalidArgument("vars should have the same dims"));
} }
// set output tensor to 0. // set output tensor to 0.
...@@ -173,7 +184,8 @@ inline void MergeVars(const std::string &var_name, ...@@ -173,7 +184,8 @@ inline void MergeVars(const std::string &var_name,
VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height() VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
<< " dims: " << slr0.value().dims() << "; merge add: " << merge_add; << " dims: " << slr0.value().dims() << "; merge add: " << merge_add;
} else { } else {
PADDLE_THROW("unsupported var type!"); PADDLE_THROW(platform::errors::InvalidArgument("unsupported var type: %s!",
var0->Type()));
} }
} }
......
...@@ -32,8 +32,9 @@ namespace distributed { ...@@ -32,8 +32,9 @@ namespace distributed {
void GRPCClient::InitImpl() { void GRPCClient::InitImpl() {
// start the client process thread // start the client process thread
// TODO(wuyi): can make this in a threadpool // TODO(wuyi): can make this in a threadpool
PADDLE_ENFORCE(client_thread_ == nullptr, PADDLE_ENFORCE_EQ(client_thread_ == nullptr, true,
"please not re init proceed thread"); platform::errors::PreconditionNotMet(
"please not re init proceed thread"));
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this))); client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
} }
...@@ -44,7 +45,8 @@ void GRPCClient::SendComplete() { ...@@ -44,7 +45,8 @@ void GRPCClient::SendComplete() {
VLOG(3) << "send complete message to " << it.first; VLOG(3) << "send complete message to " << it.first;
this->AsyncSendComplete(it.first); this->AsyncSendComplete(it.first);
} }
PADDLE_ENFORCE(this->Wait(), "internal grpc error"); PADDLE_ENFORCE_EQ(this->Wait(), true, platform::errors::PreconditionNotMet(
"internal grpc service error."));
completed_ = true; completed_ = true;
} }
} }
...@@ -590,7 +592,8 @@ void GRPCClient::Proceed() { ...@@ -590,7 +592,8 @@ void GRPCClient::Proceed() {
while (!stopped_ && cq_.Next(&tag, &ok)) { while (!stopped_ && cq_.Next(&tag, &ok)) {
BaseProcessor* c = static_cast<BaseProcessor*>(tag); BaseProcessor* c = static_cast<BaseProcessor*>(tag);
GPR_ASSERT(ok); GPR_ASSERT(ok);
PADDLE_ENFORCE(c); PADDLE_ENFORCE_NOT_NULL(
c, platform::errors::PreconditionNotMet("Make BaseProcessor failed."));
if (c->status_.ok()) { if (c->status_.ok()) {
VLOG(3) << c->GetVarHandlePtr()->String() << " process"; VLOG(3) << c->GetVarHandlePtr()->String() << " process";
......
...@@ -80,8 +80,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -80,8 +80,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
request.set_type(::sendrecv::NCCL_ID); request.set_type(::sendrecv::NCCL_ID);
#endif #endif
} else { } else {
PADDLE_THROW("Serialize does not support type: %s", PADDLE_THROW(platform::errors::InvalidArgument(
typeid(var->Type()).name()); "Serialize does not support type: %s", typeid(var->Type()).name()));
} }
std::string header; std::string header;
request.AppendToString(&header); request.AppendToString(&header);
...@@ -106,7 +106,11 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -106,7 +106,11 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
return; return;
} }
#endif #endif
PADDLE_ENFORCE_NOT_NULL(payload); PADDLE_ENFORCE_NOT_NULL(
payload,
platform::errors::InvalidArgument(
"Not support type: %s, need to be LOD_TENSOR or SELECTED_ROWS",
var->Type()));
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
payload->memory_size()); payload->memory_size());
if (payload->memory_size() >= std::numeric_limits<int>::max()) { if (payload->memory_size() >= std::numeric_limits<int>::max()) {
...@@ -128,7 +132,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -128,7 +132,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128); ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
PADDLE_ENFORCE(VectorElemName(slr->rows()) == typeid(int64_t).name()); PADDLE_ENFORCE_EQ(VectorElemName(slr->rows()), typeid(int64_t).name(),
platform::errors::InvalidArgument(
"Got wrong type %s, expect type: int64_t",
VectorElemName(slr->rows())));
size_t rows_memory_size = slr->rows().size() * sizeof(int64_t); size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size); e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
...@@ -155,7 +162,9 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, ...@@ -155,7 +162,9 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
framework::Variable** var, int* trainer_id) { framework::Variable** var, int* trainer_id) {
platform::RecordRPCEvent record_event("deserial"); platform::RecordRPCEvent record_event("deserial");
operators::distributed::GRPCVariableResponse resp(scope, &ctx); operators::distributed::GRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); PADDLE_ENFORCE_EQ(
resp.Parse(msg), 0,
platform::errors::InvalidArgument("parse bytebuffer to tensor error!"));
*var = resp.GetVar(); *var = resp.GetVar();
*trainer_id = resp.GetTrainerId(); *trainer_id = resp.GetTrainerId();
} }
......
...@@ -57,7 +57,8 @@ class RequestBase { ...@@ -57,7 +57,8 @@ class RequestBase {
status_(PROCESS), status_(PROCESS),
request_handler_(request_handler), request_handler_(request_handler),
req_id_(req_id) { req_id_(req_id) {
PADDLE_ENFORCE(cq_); PADDLE_ENFORCE_NOT_NULL(cq_, platform::errors::InvalidArgument(
"ServerCompletionQueue cq are empty"));
} }
virtual ~RequestBase() {} virtual ~RequestBase() {}
virtual void Process() = 0; virtual void Process() = 0;
...@@ -550,8 +551,9 @@ void AsyncGRPCServer::StartServer() { ...@@ -550,8 +551,9 @@ void AsyncGRPCServer::StartServer() {
sleep(3); sleep(3);
} }
PADDLE_ENFORCE_NE(selected_port_, 0, "can't bind to address:%s", PADDLE_ENFORCE_NE(
bind_address_); selected_port_, 0,
platform::errors::Unavailable("can't bind to address:%s", bind_address_));
std::function<void(const std::string&, int)> f = std::function<void(const std::string&, int)> f =
std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this, std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this,
...@@ -649,7 +651,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, ...@@ -649,7 +651,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
} else if (rpc_name == kRequestSendAndRecv) { } else if (rpc_name == kRequestSendAndRecv) {
b = new RequestSendAndRecv(service_.get(), cq.get(), handler, req_id); b = new RequestSendAndRecv(service_.get(), cq.get(), handler, req_id);
} else { } else {
PADDLE_ENFORCE(false, "not supported rpc"); PADDLE_THROW(
platform::errors::InvalidArgument("not supported rpc: %s", rpc_name));
} }
reqs[req_id] = b; reqs[req_id] = b;
...@@ -677,7 +680,10 @@ void AsyncGRPCServer::HandleRequest( ...@@ -677,7 +680,10 @@ void AsyncGRPCServer::HandleRequest(
auto& reqs = rpc_reqs_[rpc_name]; auto& reqs = rpc_reqs_[rpc_name];
RequestBase* base = nullptr; RequestBase* base = nullptr;
{ {
PADDLE_ENFORCE(req_id >= 0 && req_id < kRequestBufSize); PADDLE_ENFORCE_EQ(
(req_id >= 0 && req_id < kRequestBufSize), true,
platform::errors::OutOfRange("request id: %s out of bounds: [0, %s)",
req_id, kRequestBufSize));
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
base = reqs[req_id]; base = reqs[req_id];
} }
......
...@@ -47,7 +47,8 @@ class SerializationTraits< ...@@ -47,7 +47,8 @@ class SerializationTraits<
static Status Serialize( static Status Serialize(
const paddle::operators::distributed::GRPCVariableResponse& msg, const paddle::operators::distributed::GRPCVariableResponse& msg,
grpc_byte_buffer** bp, bool* own_buffer) { grpc_byte_buffer** bp, bool* own_buffer) {
PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!"); PADDLE_THROW(paddle::platform::errors::Unimplemented(
"SerializationTraits::Serialize not implemented!"));
return Status(); return Status();
} }
static Status Deserialize( static Status Deserialize(
...@@ -115,7 +116,8 @@ inline const char* GrpcMethodName(GrpcMethod id) { ...@@ -115,7 +116,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
} }
// Shouldn't be reached. // Shouldn't be reached.
PADDLE_ENFORCE(false, "Invalid id: not found valid method name"); PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid id: not found valid method name"));
return nullptr; return nullptr;
} }
......
...@@ -257,7 +257,8 @@ int GRPCVariableResponse::Parse(Source* source) { ...@@ -257,7 +257,8 @@ int GRPCVariableResponse::Parse(Source* source) {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) && meta_.type() == sendrecv::LOD_TENSOR) &&
meta_.varname() != "", meta_.varname() != "",
"meta info should be got first!"); platform::errors::PreconditionNotMet(
"meta info should be got first!"));
int num_bytes = 0; int num_bytes = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED || if (wt != WIRETYPE_LENGTH_DELIMITED ||
......
...@@ -76,11 +76,11 @@ void HeartBeatMonitor::LostWorkerMonitor() { ...@@ -76,11 +76,11 @@ void HeartBeatMonitor::LostWorkerMonitor() {
<< timestamp - worker.timestamp; << timestamp - worker.timestamp;
if (timestamp - worker.timestamp >= FLAGS_worker_update_interval_secs) { if (timestamp - worker.timestamp >= FLAGS_worker_update_interval_secs) {
PADDLE_THROW( PADDLE_THROW(platform::errors::ExecutionTimeout(
"the latest update of worker %d is %d secs ago, we doubt the " "the latest update of worker %d is %d secs ago, we doubt the "
"the worker is not alive and this may have a bad effect on the " "the worker is not alive and this may have a bad effect on the "
"fitting result, please check", "fitting result, please check",
worker.id, FLAGS_worker_update_interval_secs); worker.id, FLAGS_worker_update_interval_secs));
} }
} }
......
...@@ -56,7 +56,8 @@ class HeartBeatMonitor { ...@@ -56,7 +56,8 @@ class HeartBeatMonitor {
is_chief_(is_chief), is_chief_(is_chief),
be_monitored_var_(be_monitored_var), be_monitored_var_(be_monitored_var),
running_(true) { running_(true) {
PADDLE_ENFORCE_GT(workers, 0, "trainers must have one or more"); PADDLE_ENFORCE_GT(workers, 0, platform::errors::InvalidArgument(
"workers must greater than 0."));
for (auto worker_id = 0; worker_id < workers; worker_id++) { for (auto worker_id = 0; worker_id < workers; worker_id++) {
UnderMonitoredWorker worker(worker_id); UnderMonitoredWorker worker(worker_id);
......
...@@ -156,8 +156,14 @@ void prefetch_core( ...@@ -156,8 +156,14 @@ void prefetch_core(
const auto *out_var_data = prefetch_out_var.data<float>(); const auto *out_var_data = prefetch_out_var.data<float>();
auto &dims = prefetch_out_var.dims(); auto &dims = prefetch_out_var.dims();
PADDLE_ENFORCE_EQ(dims.size(), 2, ""); PADDLE_ENFORCE_EQ(dims.size(), 2,
PADDLE_ENFORCE_EQ(ids_in_this_section.size(), dims[0]); platform::errors::InvalidArgument(
"The size of Tensor dims must be 2."));
PADDLE_ENFORCE_EQ(ids_in_this_section.size(), dims[0],
platform::errors::InvalidArgument(
"The size of ids in this section must equal to "
"dims[0]: %s, but got %s",
dims[0], ids_in_this_section.size()));
auto row_numel = dims[1]; auto row_numel = dims[1];
......
...@@ -127,9 +127,10 @@ void ParameterSend<T>::operator()(const CommContext &rpc_ctx, ...@@ -127,9 +127,10 @@ void ParameterSend<T>::operator()(const CommContext &rpc_ctx,
outs_dims.reserve(out_num); outs_dims.reserve(out_num);
// infer output shape // infer output shape
PADDLE_ENFORCE_EQ(rpc_ctx.height_sections.size(), out_num, PADDLE_ENFORCE_EQ(
"tensor split sections size" rpc_ctx.height_sections.size(), out_num,
"should be equal to output size."); platform::errors::InvalidArgument("tensor split sections size"
"should be equal to output size."));
for (size_t i = 0; i < out_num; ++i) { for (size_t i = 0; i < out_num; ++i) {
auto dim = send_tensor_dims; auto dim = send_tensor_dims;
dim[0] = rpc_ctx.height_sections[i]; dim[0] = rpc_ctx.height_sections[i];
...@@ -309,7 +310,8 @@ void ParameterSend<T>::operator()(const CommContext &rpc_ctx, ...@@ -309,7 +310,8 @@ void ParameterSend<T>::operator()(const CommContext &rpc_ctx,
} }
} }
} else { } else {
PADDLE_THROW("unsupported var type to send!"); PADDLE_THROW(platform::errors::InvalidArgument(
"unsupported var type: %s to send!", send_var->Type()));
} }
VLOG(4) << "Prepare to send var " << rpc_ctx.var_name; VLOG(4) << "Prepare to send var " << rpc_ctx.var_name;
......
...@@ -65,9 +65,9 @@ bool RequestSendHandler::Handle(const std::string &varname, ...@@ -65,9 +65,9 @@ bool RequestSendHandler::Handle(const std::string &varname,
if (distributed_mode_ != DistributedMode::kSync) { if (distributed_mode_ != DistributedMode::kSync) {
VLOG(3) << "async process var: " << varname; VLOG(3) << "async process var: " << varname;
if (varname == BATCH_BARRIER_MESSAGE) { if (varname == BATCH_BARRIER_MESSAGE) {
PADDLE_THROW( PADDLE_THROW(platform::errors::InvalidArgument(
"async mode should not recv BATCH_BARRIER_MESSAGE or " "async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE"); "COMPLETE_MESSAGE"));
} }
HeartBeatMonitor::GetInstance()->Update(trainer_id, varname, RUNNING); HeartBeatMonitor::GetInstance()->Update(trainer_id, varname, RUNNING);
...@@ -78,7 +78,10 @@ bool RequestSendHandler::Handle(const std::string &varname, ...@@ -78,7 +78,10 @@ bool RequestSendHandler::Handle(const std::string &varname,
if (string::Contains(var_name_piece, part_piece)) { if (string::Contains(var_name_piece, part_piece)) {
auto varname_splits = paddle::string::Split(varname, '@'); auto varname_splits = paddle::string::Split(varname, '@');
PADDLE_ENFORCE_EQ(varname_splits.size(), 3); PADDLE_ENFORCE_EQ(
varname_splits.size(), 3,
platform::errors::InvalidArgument(
"varname: %s should be separated into 3 parts by @", varname));
run_varname = varname_splits[0]; run_varname = varname_splits[0];
scope->Rename(varname, run_varname); scope->Rename(varname, run_varname);
} }
...@@ -192,7 +195,11 @@ bool RequestGetHandler::Handle(const std::string &varname, ...@@ -192,7 +195,11 @@ bool RequestGetHandler::Handle(const std::string &varname,
out_dims, origin_tensor.place()); out_dims, origin_tensor.place());
auto width = dims[1]; auto width = dims[1];
for (size_t i = 0; i < updated_rows.size(); ++i) { for (size_t i = 0; i < updated_rows.size(); ++i) {
PADDLE_ENFORCE_LT(updated_rows[i], dims[0]); PADDLE_ENFORCE_LT(
updated_rows[i], dims[0],
platform::errors::OutOfRange(
"The value of updated_rows: %s out of Tensor %s dims[0]: %s",
updated_rows[i], varname, dims[0]));
memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width, memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width,
sizeof(float) * width); sizeof(float) * width);
} }
...@@ -225,7 +232,8 @@ bool RequestGetNoBarrierHandler::Handle(const std::string &varname, ...@@ -225,7 +232,8 @@ bool RequestGetNoBarrierHandler::Handle(const std::string &varname,
*outvar = scope_->FindVar(var_name_piece.ToString()); *outvar = scope_->FindVar(var_name_piece.ToString());
return true; return true;
} else { } else {
PADDLE_THROW("GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE); PADDLE_THROW(platform::errors::InvalidArgument(
"GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE));
} }
return true; return true;
} }
......
...@@ -159,9 +159,9 @@ void RPCServer::RegisterVar(const std::string& var_name, ...@@ -159,9 +159,9 @@ void RPCServer::RegisterVar(const std::string& var_name,
{ {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
if (var_map_.find(var_name) != var_map_.end()) { PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE(false, "%s alreay in var_map", var_name); var_map_.find(var_name), var_map_.end(),
} platform::errors::AlreadyExists("%s already in var_map.", var_name));
var_map_[var_name] = h; var_map_[var_name] = h;
} }
......
...@@ -172,7 +172,9 @@ TEST(COMPLETE, CPU) { ...@@ -172,7 +172,9 @@ TEST(COMPLETE, CPU) {
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2));
distributed::RPCClient* client = distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0); distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
PADDLE_ENFORCE(client != nullptr); PADDLE_ENFORCE_NE(client, nullptr,
platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env"));
std::thread server_thread(StartServer, distributed::kRequestSend); std::thread server_thread(StartServer, distributed::kRequestSend);
g_rpc_service->WaitServerReady(); g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort(); int port = g_rpc_service->GetSelectedPort();
......
...@@ -40,7 +40,9 @@ static TensorPayload GetCommunicationAllocationFromTensor( ...@@ -40,7 +40,9 @@ static TensorPayload GetCommunicationAllocationFromTensor(
const platform::DeviceContext& ctx, const framework::Tensor& tensor) { const platform::DeviceContext& ctx, const framework::Tensor& tensor) {
if (is_gpu_place(ctx.GetPlace())) { if (is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(is_gpu_place(tensor.place())); PADDLE_ENFORCE_EQ(
is_gpu_place(tensor.place()), true,
platform::errors::PreconditionNotMet("Please run in gpu place."));
auto& gpu_dev_ctx = auto& gpu_dev_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx); reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type()); auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
...@@ -53,7 +55,8 @@ static TensorPayload GetCommunicationAllocationFromTensor( ...@@ -53,7 +55,8 @@ static TensorPayload GetCommunicationAllocationFromTensor(
ctx.Wait(); ctx.Wait();
return TensorPayload(result); return TensorPayload(result);
#else #else
PADDLE_THROW("This situation should not be happened"); PADDLE_THROW(
platform::errors::Unavailable("This situation should not be happened"));
#endif #endif
} else { } else {
return TensorPayload(tensor); return TensorPayload(tensor);
......
...@@ -95,7 +95,8 @@ inline framework::proto::VarType::Type ToVarType( ...@@ -95,7 +95,8 @@ inline framework::proto::VarType::Type ToVarType(
case sendrecv::VariableMessage::BOOL: case sendrecv::VariableMessage::BOOL:
return framework::proto::VarType::BOOL; // NOLINT return framework::proto::VarType::BOOL; // NOLINT
default: default:
PADDLE_THROW("Not support type %d", type); PADDLE_THROW(
platform::errors::InvalidArgument("Not support type id: %d.", type));
} }
} }
......
...@@ -61,7 +61,8 @@ bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input, ...@@ -61,7 +61,8 @@ bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input,
} }
gpu_dev_ctx.Wait(); gpu_dev_ctx.Wait();
#else #else
PADDLE_THROW("Unexpected branch"); PADDLE_THROW(platform::errors::PreconditionNotMet(
"Unexpected branch, please compile with PADDLE_WITH_CUDA"));
#endif #endif
return true; return true;
} else if (platform::is_xpu_place(place)) { } else if (platform::is_xpu_place(place)) {
...@@ -147,7 +148,11 @@ bool VariableResponse::CopyLodTensorData( ...@@ -147,7 +148,11 @@ bool VariableResponse::CopyLodTensorData(
VLOG(6) << "Tensor.memory_size = " << tensor->memory_size() VLOG(6) << "Tensor.memory_size = " << tensor->memory_size()
<< ", Buffer Size = " << length << ", dims:" << dims << ", Buffer Size = " << length << ", dims:" << dims
<< ", numel:" << tensor->numel(); << ", numel:" << tensor->numel();
PADDLE_ENFORCE_GE(tensor->memory_size(), static_cast<unsigned int>(length)); PADDLE_ENFORCE_GE(
tensor->memory_size(), static_cast<unsigned int>(length),
platform::errors::InvalidArgument(
"The memory size of tensor: %s should greater than length: %s",
tensor->memory_size(), length));
return ReadRaw(input, ctx, tensor->place(), tensor_data, length); return ReadRaw(input, ctx, tensor->place(), tensor_data, length);
} }
...@@ -171,7 +176,12 @@ bool VariableResponse::CopySelectRowsTensorData( ...@@ -171,7 +176,12 @@ bool VariableResponse::CopySelectRowsTensorData(
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
static_cast<size_t>(tensor->numel()), static_cast<size_t>(tensor->numel()),
length / framework::SizeOfType(paddle::operators::distributed::ToVarType( length / framework::SizeOfType(paddle::operators::distributed::ToVarType(
meta_.data_type()))); meta_.data_type())),
platform::errors::InvalidArgument(
"length: %s should equal to memory size of tensor: %s", length,
tensor->numel() *
framework::SizeOfType(paddle::operators::distributed::ToVarType(
meta_.data_type()))));
void* tensor_data = tensor->mutable_data( void* tensor_data = tensor->mutable_data(
ctx.GetPlace(), ctx.GetPlace(),
paddle::operators::distributed::ToVarType(meta_.data_type())); paddle::operators::distributed::ToVarType(meta_.data_type()));
...@@ -203,11 +213,12 @@ bool VariableResponse::CopySelectRowsData( ...@@ -203,11 +213,12 @@ bool VariableResponse::CopySelectRowsData(
bool VariableResponse::ProcSerializedField( bool VariableResponse::ProcSerializedField(
int tag, ::google::protobuf::io::CodedInputStream* input, int tag, ::google::protobuf::io::CodedInputStream* input,
int64_t num_bytes) { int64_t num_bytes) {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || PADDLE_ENFORCE(
meta_.type() == sendrecv::LOD_TENSOR || (meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::NCCL_ID) && meta_.type() == sendrecv::LOD_TENSOR ||
meta_.varname() != "", meta_.type() == sendrecv::NCCL_ID) &&
"meta info should be got first!"); meta_.varname() != "",
platform::errors::PreconditionNotMet("meta info should be got first!"));
if (meta_.type() == sendrecv::NCCL_ID) { if (meta_.type() == sendrecv::NCCL_ID) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -221,7 +232,8 @@ bool VariableResponse::ProcSerializedField( ...@@ -221,7 +232,8 @@ bool VariableResponse::ProcSerializedField(
} }
return true; return true;
#else #else
PADDLE_THROW("Not compiled with CUDA!"); PADDLE_THROW(
platform::errors::PreconditionNotMet("Please compiled with CUDA!"));
return false; return false;
#endif #endif
} }
...@@ -230,7 +242,9 @@ bool VariableResponse::ProcSerializedField( ...@@ -230,7 +242,9 @@ bool VariableResponse::ProcSerializedField(
<< ", type:" << meta_.type() << std::endl; << ", type:" << meta_.type() << std::endl;
framework::DDim dims = GetDims(meta_.dims()); framework::DDim dims = GetDims(meta_.dims());
if (meta_.type() == sendrecv::LOD_TENSOR) { if (meta_.type() == sendrecv::LOD_TENSOR) {
PADDLE_ENFORCE(meta_.lod_size() >= 0, "lod info should be got first!"); PADDLE_ENFORCE_GE(
meta_.lod_size(), 0,
platform::errors::PreconditionNotMet("lod info should be got first!"));
if (!CopyLodTensorData(input, *dev_ctx_, dims, num_bytes)) { if (!CopyLodTensorData(input, *dev_ctx_, dims, num_bytes)) {
return false; return false;
} }
...@@ -245,7 +259,9 @@ bool VariableResponse::ProcSerializedField( ...@@ -245,7 +259,9 @@ bool VariableResponse::ProcSerializedField(
return true; return true;
} }
PADDLE_ENFORCE("not supported var types:", meta_.varname(), meta_.type()); PADDLE_THROW(platform::errors::InvalidArgument(
"The type: %s of var: %s is not supported", meta_.type(),
meta_.varname()));
return false; return false;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册