提交 82c61dbd 编写于 作者: T typhoonzero

fix testing

上级 0598a4b3
......@@ -52,7 +52,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
// stub context
SendProcessor* s = new SendProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = NULL;
s->response_call_back_ = nullptr;
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
......
......@@ -57,7 +57,9 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class BaseProcessor {
public:
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) { context_ = NULL; }
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) {
context_ = nullptr;
}
virtual ~BaseProcessor() {}
......@@ -105,7 +107,7 @@ class SendProcessor : public BaseProcessor {
::grpc::GenericStub stub_g_;
::grpc::ByteBuffer reply_;
RequestSendCallBack response_call_back_ = NULL;
RequestSendCallBack response_call_back_ = nullptr;
};
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
......
......@@ -261,8 +261,8 @@ void AsyncGRPCServer::ShutdownQueue() {
// This URL explains why shutdown is complicate:
void AsyncGRPCServer::ShutDown() {
is_shut_down_ = true;
ShutdownQueue();
server_->Shutdown();
ShutdownQueue();
}
void AsyncGRPCServer::TryToRegisterNewSendOne() {
......
......@@ -47,6 +47,8 @@ class AsyncGRPCServer final {
explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
: address_(address), sync_mode_(sync_mode) {}
~AsyncGRPCServer() {}
void RunSyncUpdate();
// functions to sync server barrier status.
......
......@@ -53,14 +53,15 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
} else if (var->IsType<ncclUniqueId>()) {
// NOTE: sendrecv only support RAW type for NCCL_ID
VLOG(3) << "serilizing: setting var type nccl id";
e.WriteUint64(VarMsg::kTypeFieldNumber, 2);
}
if (!out_name.empty()) {
e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name);
}
switch (framework::ToVarType(var->Type())) {
case framework::proto::VarType_Type_LOD_TENSOR: {
if (var->IsType<framework::LoDTensor>()) {
// ===========================Tensor==================================
auto tensor = var->Get<framework::LoDTensor>();
e.WriteUint64(VarMsg::kDataTypeFieldNumber,
framework::ToDataType(tensor.type()));
......@@ -86,8 +87,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
platform::CPUPlace cpu;
auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(ctx);
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
payload = memory::Alloc(cpu, copy_size);
......@@ -107,8 +107,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
}
payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
} break;
case framework::proto::VarType_Type_SELECTED_ROWS: {
} else if (var->IsType<framework::SelectedRows>()) {
// ===========================SELECTED
// ROWS==================================
// TODO(typhoonzero): selectedrows implement should not use unique_ptr
auto* slr = var->GetMutable<framework::SelectedRows>();
e.WriteUint64(VarMsg::kDataTypeFieldNumber,
......@@ -122,10 +123,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
platform::CPUPlace cpu;
auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size =
tensor->numel() * framework::SizeOfType(tensor->type());
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor->numel() * framework::SizeOfType(tensor->type());
payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload,
boost::get<platform::CUDAPlace>(tensor->place()),
......@@ -142,20 +141,18 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
}
payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
} break;
case framework::proto::VarType_Type_RAW: {
} else if (var->IsType<ncclUniqueId>()) {
// ===========================NCCL ID==================================
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
NCCL_UNIQUE_ID_BYTES);
ncclUniqueId* uid = var->GetMutable<ncclUniqueId>();
e.WriteRawBytes(std::string(uid->internal, NCCL_UNIQUE_ID_BYTES));
} break;
default:
} else {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
break;
}
if (framework::ToVarType(var->Type()) == framework::proto::VarType_Type_RAW) {
if (var->IsType<ncclUniqueId>()) {
// for serialize NCCL_ID
::grpc::Slice slices(e.size());
memcpy(const_cast<uint8_t*>(slices.begin()), e.data(), e.size());
......
......@@ -371,19 +371,26 @@ int VariableResponse::Parse(Source* source) {
meta_.type() == sendrecv::NCCL_ID) &&
meta_.varname() != "",
"meta info should be got first!");
int length = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
return tag;
}
if (meta_.type() == sendrecv::NCCL_ID) {
VLOG(3) << "parse nccl id request";
auto* var = scope_->FindVar(meta_.varname());
if (var != nullptr) {
VLOG(3) << "parse nccl id: length " << length;
ncclUniqueId* id = var->GetMutable<ncclUniqueId>();
memcpy(id->internal, meta_.serialized().c_str(),
meta_.serialized().size());
if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal,
length)) {
return tag;
}
// memcpy(id->internal, meta_.serialized().c_str(),
// meta_.serialized().size());
}
int length = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
return tag;
break;
}
framework::DDim dims = GetDims(meta_.dims());
......
......@@ -37,7 +37,8 @@ class GenNCCLIdOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(dev_place);
// put nccl id in CPUPlace
auto& dev_ctx = *pool.Get(platform::CPUPlace());
int trainer_id = Attr<int>("trainer_id");
framework::Scope& local_scope = scope.NewScope();
......@@ -60,9 +61,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
Attr<std::vector<std::string>>("endpoint_list");
detail::RPCClient client;
for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep;
client.AsyncSendVariable(ep, dev_ctx, *scope, "NCCLID");
}
client.Wait();
VLOG(3) << "sending completed...";
}
void GetIdByServer(framework::Scope* scope,
......@@ -78,9 +81,14 @@ class GenNCCLIdOp : public framework::OperatorBase {
server_thread_.reset(new std::thread(std::bind(
&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service_.get())));
rpc_service_->SetCond(0);
VLOG(3) << "start getting nccl id from trainer 0...";
auto recv = rpc_service_->Get();
rpc_service_->ShutDown();
VLOG(3) << "got nccl id and stop server...";
// rpc_service_->SetCond(1);
// rpc_service_->ShutDown();
rpc_service->Push(LISTEN_TERMINATE_MESSAGE);
VLOG(3) << "rpc server stopped";
// TODO(wuyi): reinit nccl communicators
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册