提交 82c61dbd 编写于 作者: T typhoonzero

fix testing

上级 0598a4b3
...@@ -52,7 +52,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, ...@@ -52,7 +52,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
// stub context // stub context
SendProcessor* s = new SendProcessor(ch); SendProcessor* s = new SendProcessor(ch);
s->Prepare(var_h, time_out); s->Prepare(var_h, time_out);
s->response_call_back_ = NULL; s->response_call_back_ = nullptr;
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_); s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
......
...@@ -57,7 +57,9 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); ...@@ -57,7 +57,9 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class BaseProcessor { class BaseProcessor {
public: public:
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) { context_ = NULL; } explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) {
context_ = nullptr;
}
virtual ~BaseProcessor() {} virtual ~BaseProcessor() {}
...@@ -105,7 +107,7 @@ class SendProcessor : public BaseProcessor { ...@@ -105,7 +107,7 @@ class SendProcessor : public BaseProcessor {
::grpc::GenericStub stub_g_; ::grpc::GenericStub stub_g_;
::grpc::ByteBuffer reply_; ::grpc::ByteBuffer reply_;
RequestSendCallBack response_call_back_ = NULL; RequestSendCallBack response_call_back_ = nullptr;
}; };
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)> typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
......
...@@ -261,8 +261,8 @@ void AsyncGRPCServer::ShutdownQueue() { ...@@ -261,8 +261,8 @@ void AsyncGRPCServer::ShutdownQueue() {
// This URL explains why shutdown is complicate: // This URL explains why shutdown is complicate:
void AsyncGRPCServer::ShutDown() { void AsyncGRPCServer::ShutDown() {
is_shut_down_ = true; is_shut_down_ = true;
ShutdownQueue();
server_->Shutdown(); server_->Shutdown();
ShutdownQueue();
} }
void AsyncGRPCServer::TryToRegisterNewSendOne() { void AsyncGRPCServer::TryToRegisterNewSendOne() {
......
...@@ -47,6 +47,8 @@ class AsyncGRPCServer final { ...@@ -47,6 +47,8 @@ class AsyncGRPCServer final {
explicit AsyncGRPCServer(const std::string &address, bool sync_mode) explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
: address_(address), sync_mode_(sync_mode) {} : address_(address), sync_mode_(sync_mode) {}
~AsyncGRPCServer() {}
void RunSyncUpdate(); void RunSyncUpdate();
// functions to sync server barrier status. // functions to sync server barrier status.
......
...@@ -53,109 +53,106 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -53,109 +53,106 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteUint64(VarMsg::kTypeFieldNumber, 1); e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
} else if (var->IsType<ncclUniqueId>()) { } else if (var->IsType<ncclUniqueId>()) {
// NOTE: sendrecv only support RAW type for NCCL_ID // NOTE: sendrecv only support RAW type for NCCL_ID
VLOG(3) << "serilizing: setting var type nccl id";
e.WriteUint64(VarMsg::kTypeFieldNumber, 2); e.WriteUint64(VarMsg::kTypeFieldNumber, 2);
} }
if (!out_name.empty()) { if (!out_name.empty()) {
e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name); e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name);
} }
switch (framework::ToVarType(var->Type())) { if (var->IsType<framework::LoDTensor>()) {
case framework::proto::VarType_Type_LOD_TENSOR: { // ===========================Tensor==================================
auto tensor = var->Get<framework::LoDTensor>(); auto tensor = var->Get<framework::LoDTensor>();
e.WriteUint64(VarMsg::kDataTypeFieldNumber, e.WriteUint64(VarMsg::kDataTypeFieldNumber,
framework::ToDataType(tensor.type())); framework::ToDataType(tensor.type()));
for (auto& dim : framework::vectorize(tensor.dims())) { for (auto& dim : framework::vectorize(tensor.dims())) {
e.WriteUint64(VarMsg::kDimsFieldNumber, dim); e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
} }
auto lod = tensor.lod(); // std::vector<Vector<size_t>> auto lod = tensor.lod(); // std::vector<Vector<size_t>>
if (lod.size() > 0) { if (lod.size() > 0) {
e.WriteUint64(VarMsg::kLodLevelFieldNumber, lod.size()); e.WriteUint64(VarMsg::kLodLevelFieldNumber, lod.size());
for (auto& each : lod) { for (auto& each : lod) {
e.WriteVarlengthBeginning(VarMsg::kLodFieldNumber, e.WriteVarlengthBeginning(VarMsg::kLodFieldNumber,
2 + // tag + varintlength of submessage 2 + // tag + varintlength of submessage
1 + // kLodDataFieldNumber 1 + // kLodDataFieldNumber
each.size()); each.size());
// auto copied from GPU // auto copied from GPU
for (auto& d : each) { for (auto& d : each) {
e.WriteUint64(VarMsg::LodData::kLodDataFieldNumber, d); e.WriteUint64(VarMsg::LodData::kLodDataFieldNumber, d);
}
} }
} }
if (platform::is_gpu_place(ctx.GetPlace())) { }
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(platform::is_gpu_place(tensor.place())); PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
platform::CPUPlace cpu;
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload,
boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void*>(tensor.data<void>()),
copy_size, gpu_dev_ctx.stream());
ctx.Wait();
destroy_callback = [](void* backing) {
platform::CPUPlace cpu; platform::CPUPlace cpu;
auto& gpu_dev_ctx = memory::Free(cpu, backing);
static_cast<const platform::CUDADeviceContext&>(ctx); };
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload,
boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void*>(tensor.data<void>()),
copy_size, gpu_dev_ctx.stream());
ctx.Wait();
destroy_callback = [](void* backing) {
platform::CPUPlace cpu;
memory::Free(cpu, backing);
};
#endif #endif
} else { } else {
payload = tensor.data<void>(); payload = tensor.data<void>();
} }
payload_size = tensor.numel() * framework::SizeOfType(tensor.type()); payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
} break; } else if (var->IsType<framework::SelectedRows>()) {
case framework::proto::VarType_Type_SELECTED_ROWS: { // ===========================SELECTED
// TODO(typhoonzero): selectedrows implement should not use unique_ptr // ROWS==================================
auto* slr = var->GetMutable<framework::SelectedRows>(); // TODO(typhoonzero): selectedrows implement should not use unique_ptr
e.WriteUint64(VarMsg::kDataTypeFieldNumber, auto* slr = var->GetMutable<framework::SelectedRows>();
framework::ToDataType(slr->value().type())); e.WriteUint64(VarMsg::kDataTypeFieldNumber,
for (auto& dim : framework::vectorize(slr->value().dims())) { framework::ToDataType(slr->value().type()));
e.WriteUint64(VarMsg::kDimsFieldNumber, dim); for (auto& dim : framework::vectorize(slr->value().dims())) {
} e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0); }
e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height()); e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0);
auto* tensor = slr->mutable_value(); e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height());
if (platform::is_gpu_place(ctx.GetPlace())) { auto* tensor = slr->mutable_value();
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
platform::CPUPlace cpu;
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor->numel() * framework::SizeOfType(tensor->type());
payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload,
boost::get<platform::CUDAPlace>(tensor->place()),
reinterpret_cast<const void*>(tensor->data<void>()),
copy_size, gpu_dev_ctx.stream());
ctx.Wait();
destroy_callback = [](void* backing) {
platform::CPUPlace cpu; platform::CPUPlace cpu;
auto& gpu_dev_ctx = memory::Free(cpu, backing);
static_cast<const platform::CUDADeviceContext&>(ctx); };
auto copy_size =
tensor->numel() * framework::SizeOfType(tensor->type());
payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload,
boost::get<platform::CUDAPlace>(tensor->place()),
reinterpret_cast<const void*>(tensor->data<void>()),
copy_size, gpu_dev_ctx.stream());
ctx.Wait();
destroy_callback = [](void* backing) {
platform::CPUPlace cpu;
memory::Free(cpu, backing);
};
#endif #endif
} else { } else {
payload = slr->mutable_value()->data<void>(); payload = slr->mutable_value()->data<void>();
} }
payload_size = tensor->numel() * framework::SizeOfType(tensor->type()); payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
} break; } else if (var->IsType<ncclUniqueId>()) {
case framework::proto::VarType_Type_RAW: { // ===========================NCCL ID==================================
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
NCCL_UNIQUE_ID_BYTES); NCCL_UNIQUE_ID_BYTES);
ncclUniqueId* uid = var->GetMutable<ncclUniqueId>(); ncclUniqueId* uid = var->GetMutable<ncclUniqueId>();
e.WriteRawBytes(std::string(uid->internal, NCCL_UNIQUE_ID_BYTES)); e.WriteRawBytes(std::string(uid->internal, NCCL_UNIQUE_ID_BYTES));
} break; } else {
default: PADDLE_THROW("Serialize does not support type: %s",
PADDLE_THROW("Serialize does not support type: %s", typeid(var->Type()).name());
typeid(var->Type()).name());
break;
} }
if (framework::ToVarType(var->Type()) == framework::proto::VarType_Type_RAW) { if (var->IsType<ncclUniqueId>()) {
// for serialize NCCL_ID // for serialize NCCL_ID
::grpc::Slice slices(e.size()); ::grpc::Slice slices(e.size());
memcpy(const_cast<uint8_t*>(slices.begin()), e.data(), e.size()); memcpy(const_cast<uint8_t*>(slices.begin()), e.data(), e.size());
......
...@@ -371,19 +371,26 @@ int VariableResponse::Parse(Source* source) { ...@@ -371,19 +371,26 @@ int VariableResponse::Parse(Source* source) {
meta_.type() == sendrecv::NCCL_ID) && meta_.type() == sendrecv::NCCL_ID) &&
meta_.varname() != "", meta_.varname() != "",
"meta info should be got first!"); "meta info should be got first!");
int length = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
return tag;
}
if (meta_.type() == sendrecv::NCCL_ID) { if (meta_.type() == sendrecv::NCCL_ID) {
VLOG(3) << "parse nccl id request";
auto* var = scope_->FindVar(meta_.varname()); auto* var = scope_->FindVar(meta_.varname());
if (var != nullptr) { if (var != nullptr) {
VLOG(3) << "parse nccl id: length " << length;
ncclUniqueId* id = var->GetMutable<ncclUniqueId>(); ncclUniqueId* id = var->GetMutable<ncclUniqueId>();
memcpy(id->internal, meta_.serialized().c_str(), if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal,
meta_.serialized().size()); length)) {
return tag;
}
// memcpy(id->internal, meta_.serialized().c_str(),
// meta_.serialized().size());
} }
} break;
int length = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
return tag;
} }
framework::DDim dims = GetDims(meta_.dims()); framework::DDim dims = GetDims(meta_.dims());
......
...@@ -37,7 +37,8 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -37,7 +37,8 @@ class GenNCCLIdOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(dev_place); // put nccl id in CPUPlace
auto& dev_ctx = *pool.Get(platform::CPUPlace());
int trainer_id = Attr<int>("trainer_id"); int trainer_id = Attr<int>("trainer_id");
framework::Scope& local_scope = scope.NewScope(); framework::Scope& local_scope = scope.NewScope();
...@@ -60,9 +61,11 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -60,9 +61,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
Attr<std::vector<std::string>>("endpoint_list"); Attr<std::vector<std::string>>("endpoint_list");
detail::RPCClient client; detail::RPCClient client;
for (auto& ep : endpoint_list) { for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep;
client.AsyncSendVariable(ep, dev_ctx, *scope, "NCCLID"); client.AsyncSendVariable(ep, dev_ctx, *scope, "NCCLID");
} }
client.Wait(); client.Wait();
VLOG(3) << "sending completed...";
} }
void GetIdByServer(framework::Scope* scope, void GetIdByServer(framework::Scope* scope,
...@@ -78,9 +81,14 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -78,9 +81,14 @@ class GenNCCLIdOp : public framework::OperatorBase {
server_thread_.reset(new std::thread(std::bind( server_thread_.reset(new std::thread(std::bind(
&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service_.get()))); &detail::AsyncGRPCServer::RunSyncUpdate, rpc_service_.get())));
rpc_service_->SetCond(0);
VLOG(3) << "start getting nccl id from trainer 0...";
auto recv = rpc_service_->Get(); auto recv = rpc_service_->Get();
rpc_service_->ShutDown(); VLOG(3) << "got nccl id and stop server...";
// rpc_service_->SetCond(1);
// rpc_service_->ShutDown();
rpc_service->Push(LISTEN_TERMINATE_MESSAGE);
VLOG(3) << "rpc server stopped";
// TODO(wuyi): reinit nccl communicators // TODO(wuyi): reinit nccl communicators
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册