From 27cb52a4cda29184851a53d63ea45d436c632e59 Mon Sep 17 00:00:00 2001 From: ziyoujiyi <73728031+ziyoujiyi@users.noreply.github.com> Date: Tue, 26 Apr 2022 15:24:45 +0800 Subject: [PATCH] fix heter_client&heter_server (#42188) * back fl * delete ssl cert * . * make warning * . * unittest paral degree * solve unittest * heter & multi cloud commm ready * . * . * arm_brpc compile * . * . * . * . * . * . * . * . * . * . * . * . * . * . * only output is ok * base is ok * . * . * . * . * . * . * . * . * add switch server bin * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * adapt brpc ssl * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * fix heter_server & heter_client * . * . * int->int64_t * . --- .../collective/ProcessGroupHeter.cc | 4 +- .../distributed/ps/service/brpc_ps_client.cc | 2 - .../distributed/ps/service/heter_client.cc | 15 +- .../distributed/ps/service/heter_client.h | 3 +- .../distributed/ps/service/heter_server.cc | 83 +++---- .../distributed/ps/service/heter_server.h | 57 +++-- .../distributed/ps/service/sendrecv.proto | 2 +- .../pscore/heter_cloud_comm_cpu_test.cc | 234 ++++++++++-------- 8 files changed, 219 insertions(+), 181 deletions(-) mode change 100644 => 100755 paddle/fluid/distributed/collective/ProcessGroupHeter.cc mode change 100644 => 100755 paddle/fluid/distributed/ps/service/brpc_ps_client.cc mode change 100644 => 100755 paddle/fluid/distributed/ps/service/heter_client.cc mode change 100755 => 100644 paddle/fluid/distributed/ps/service/heter_client.h diff --git a/paddle/fluid/distributed/collective/ProcessGroupHeter.cc b/paddle/fluid/distributed/collective/ProcessGroupHeter.cc old mode 100644 new mode 100755 index ef57bb5ba2..ba57342081 --- a/paddle/fluid/distributed/collective/ProcessGroupHeter.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupHeter.cc @@ -116,7 +116,7 @@ std::shared_ptr ProcessGroupHeter::AllReduce( HeterClient* client_ = HeterClient::GetInstance({switch_endpoint_}, {}, 0).get(); auto dense_cpu_tensor = cpu_tensors[0]; - std::vector send_size; + std::vector send_size; send_size.push_back(dense_cpu_tensor.numel()); int ret = client_->Send( gid_, {dense_cpu_tensor.name()}, send_size, dense_cpu_tensor.data(), @@ -212,7 +212,7 @@ std::shared_ptr ProcessGroupHeter::Broadcast( HeterClient::GetInstance({switch_endpoint_}, {}, 0).get(); auto dense_cpu_tensor = cpu_tensors[0]; if (gloo_rank_ == 0) { - std::vector send_size; + std::vector send_size; send_size.push_back(dense_cpu_tensor.numel()); int ret = client_->Send( gid_, {dense_cpu_tensor.name()}, send_size, diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc old mode 100644 new mode 100755 index 921a110984..78673184eb --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc @@ -55,8 +55,6 @@ DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num"); DEFINE_int32(pserver_sparse_table_shard_num, 1000, "sparse table shard for save & load"); -DEFINE_int32(heter_world_size, 100, "group size"); // 可配置 - namespace paddle { namespace framework { class Scope; diff --git a/paddle/fluid/distributed/ps/service/heter_client.cc b/paddle/fluid/distributed/ps/service/heter_client.cc old mode 100644 new mode 100755 index 16c1ff764d..8085ef68e1 --- a/paddle/fluid/distributed/ps/service/heter_client.cc +++ b/paddle/fluid/distributed/ps/service/heter_client.cc @@ -17,9 +17,11 @@ #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/platform/profiler.h" +DEFINE_int32(heter_world_size, 100, "group size"); // group max size +DEFINE_int32(switch_send_recv_timeout_s, 600, "switch_send_recv_timeout_s"); + namespace paddle { namespace distributed { - std::shared_ptr HeterClient::s_instance_ = nullptr; int GetMicroId(const platform::DeviceContext& ctx, @@ -222,6 +224,7 @@ int HeterClient::Send(const platform::DeviceContext& ctx, distributed::MultiVarMsg request; // 1. set req message_name(string) request.set_message_name(message_name); + request.set_group_id(0); // 2. set req send_var_names() for (auto& send_var_name : send_var_names) { @@ -263,7 +266,7 @@ int HeterClient::Send(const platform::DeviceContext& ctx, } int HeterClient::Send(int group_id, const std::vector& var_names, - const std::vector& vars_len, void* data_ptr, + const std::vector& vars_size, void* data_ptr, int64_t data_size) { OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) { auto* closure = reinterpret_cast(done); @@ -282,7 +285,7 @@ int HeterClient::Send(int group_id, const std::vector& var_names, for (auto& send_var_name : var_names) { request.add_send_var_names(send_var_name); } - for (auto var_len : vars_len) { + for (auto var_len : vars_size) { request.add_vars_len(var_len); } auto& request_buffer = closure->cntl.request_attachment(); @@ -301,6 +304,7 @@ int HeterClient::Send(int group_id, const std::vector& var_names, ::paddle::distributed::PsService_Stub stub(channel); stub.SendToSwitch(&closure->cntl, &request, &closure->ps_response, closure); fut.wait(); + delete closure; return 0; } @@ -325,6 +329,7 @@ int HeterClient::Recv(const platform::DeviceContext& ctx, distributed::MultiVarMsg request; // 1. set req message_name(string) request.set_message_name(message_name); + request.set_group_id(0); // 2. set req recv_var_names() for (auto& recv_var_name : recv_var_names) { @@ -396,8 +401,8 @@ int HeterClient::Recv(int group_id, const std::vector& var_names, // save in worker auto& res_io_buffer = closure->cntl.response_attachment(); butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); - io_buffer_itr.copy_and_forward(reinterpret_cast(data_ptr), - data_size * sizeof(float)); + io_buffer_itr.copy_and_forward(reinterpret_cast(data_ptr), data_size); + delete closure; VLOG(4) << "Recv done"; return 0; } diff --git a/paddle/fluid/distributed/ps/service/heter_client.h b/paddle/fluid/distributed/ps/service/heter_client.h old mode 100755 new mode 100644 index d1e0f21c7d..b9d6561339 --- a/paddle/fluid/distributed/ps/service/heter_client.h +++ b/paddle/fluid/distributed/ps/service/heter_client.h @@ -138,7 +138,8 @@ class HeterClient { const std::string& mode = "forward"); int Send(int group_id, const std::vector& var_names, - const std::vector& vars_len, void* data_ptr, int64_t data_size); + const std::vector& vars_len, void* data_ptr, + int64_t data_size); int Send(const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& message_name, diff --git a/paddle/fluid/distributed/ps/service/heter_server.cc b/paddle/fluid/distributed/ps/service/heter_server.cc index 292b12611c..0753a6799c 100755 --- a/paddle/fluid/distributed/ps/service/heter_server.cc +++ b/paddle/fluid/distributed/ps/service/heter_server.cc @@ -20,8 +20,8 @@ namespace paddle { namespace distributed { // DEFINE_string(cert_path, "./cert.pem", "cert.pem path"); // DEFINE_string(key_path, "./key.pem", "key.pem path"); - std::shared_ptr HeterServer::s_instance_ = nullptr; +std::mutex HeterServer::mtx_; void HeterServer::RegisterServiceHandler(std::string message_name, HeterServiceHandler func) { @@ -130,21 +130,15 @@ int SendAndRecvVariableHandler::SaveInSwitchWithShard( butil::IOBufBytesIterator io_buffer_itr(request_io_buffer); for (int idx = 0; idx < request->send_var_names_size(); idx++) { const auto& var_name = request->send_var_names(idx); - const auto& var_len = request->vars_len(idx); - auto itr = local_shard.find(var_name); - if (itr != local_shard.end()) { - LOG(INFO) << "var: " << var_name << "has not been consumed!" - << "check again"; - WaitForVarsConsumed(group_id, var_name); - } + const auto& var_size = request->vars_len(idx); + WaitForVarsConsumed(group_id, var_name); auto& value = local_shard[var_name]; - value.resize(var_len); + value.resize(var_size); io_buffer_itr.copy_and_forward(reinterpret_cast(value.data()), - var_len * sizeof(float)); - VLOG(4) << "saved data in shards: "; - for (uint32_t i = 0; i < local_shard[var_name].size(); i++) { - VLOG(4) << *(local_shard[var_name].data() + i); - } + var_size); + std::unique_lock lk(scope_mutex_); + vars_ready_flag[group_id][var_name] = 1; + VLOG(4) << "saved var_name: " << var_name << "is saved ready!"; } VLOG(4) << "SaveInSwitchWithShard success"; return 0; @@ -164,20 +158,17 @@ int SendAndRecvVariableHandler::QueryInSwitchWithShard( } auto msg_name = request->message_name(); response->set_message_name(msg_name); - for (auto& req_var_name : req_var_names) { VLOG(4) << "req var name: " << req_var_name; response->add_send_var_names(req_var_name); + WaitForVarsProduced(group_id, req_var_name); auto itr = local_shard.find(req_var_name); - if (itr == local_shard.end()) { - LOG(INFO) << "var: " << req_var_name << " not found in shards"; - WaitForVarsProduced(group_id, req_var_name); - } - LOG(INFO) << "var: " << req_var_name << " found in shards"; - itr = local_shard.find(req_var_name); auto& value = itr.value(); - response_io_buffer.append(value.data(), value.size() * sizeof(float)); - value.resize(0); // 标记位 + response_io_buffer.append(value.data(), value.size()); + value.resize(0); // 清空内存 + std::unique_lock lk(scope_mutex_); + vars_ready_flag[group_id][req_var_name] = 0; + VLOG(4) << "query var_name: " << req_var_name << "is consumed ready!"; } VLOG(4) << "heter server QueryInSwitchWithShard done"; return 0; @@ -192,37 +183,31 @@ int SendAndRecvVariableHandler::SaveInSwitchWithScope( auto& cpu_dev_ctx = *pool.Get(cpu_place); auto message_name = request->message_name(); VLOG(4) << "message_name in heter server: " << message_name; + + auto send_var_nums = request->send_var_names_size(); + std::vector send_var_names(send_var_nums); + for (int idx = 0; idx < send_var_nums; idx++) { + send_var_names[idx] = request->var_messages(idx).varname(); + } std::unique_lock lk(scope_mutex_); auto local_scope = local_scope_ptr.get(); if (!local_scope) { LOG(ERROR) << "local_scope_ptr is null in SaveInSwitchWithScope"; } - for (int idx = 0; idx < request->send_var_names_size(); idx++) { - const auto& msg = request->var_messages(idx); - std::string var_name = msg.varname(); + for (auto var_name : send_var_names) { auto* var_exist_ptr = local_scope->FindVar(var_name); if (!var_exist_ptr) { VLOG(4) << "not find var: " << var_name << " in local_scope"; } - vars_table[var_name] += 1; - VLOG(4) << "saved var_name: " << var_name - << ", cnt = " << vars_table[var_name]; + WaitForVarsConsumed(0, var_name); } auto& request_io_buffer = cntl->request_attachment(); distributed::DeserializeFromMultiVarMsgAndIOBuf(*request, &request_io_buffer, cpu_dev_ctx, local_scope); lk.unlock(); - while (true) { - int ret = 0; - for (int idx = 0; idx < request->send_var_names_size(); idx++) { - ret |= vars_table[request->var_messages(idx).varname()]; - } - if (!ret) { - VLOG(4) << "all saved vars consumed"; - break; - } - VLOG(4) << "waiting consume result......"; - sleep(1); + for (auto var_name : send_var_names) { + std::unique_lock lk(scope_mutex_); + vars_ready_flag[0][var_name] = 1; } VLOG(4) << "SaveInSwitchWithScope success"; return 0; @@ -258,19 +243,14 @@ int SendAndRecvVariableHandler::QueryInSwitchWithScope( // 3. fill var_messages(VarMessage) for (auto& req_var_name : req_var_names) { - LOG(INFO) << "query var_name: " << req_var_name; + WaitForVarsProduced(0, req_var_name); auto* send_var_msg = response->add_var_messages(); send_var_msg->set_varname(req_var_name); framework::Variable* var_ptr; - while (true) { - var_ptr = local_scope->FindVar(req_var_name); - if (!var_ptr) { - LOG(INFO) << "local_scope not find var: " << req_var_name; - } else { - break; - } - sleep(1); + var_ptr = local_scope->FindVar(req_var_name); + if (!var_ptr) { + LOG(INFO) << "local_scope not find var: " << req_var_name; } butil::IOBuf temp_iobuf; if (var_ptr->IsType()) { @@ -282,10 +262,7 @@ int SendAndRecvVariableHandler::QueryInSwitchWithScope( } for (auto& req_var_name : req_var_names) { std::unique_lock lk(scope_mutex_); - vars_table[req_var_name] -= 1; - VLOG(4) << "remained var: " << req_var_name - << ", cnt = " << vars_table[req_var_name]; - lk.unlock(); + vars_ready_flag[0][req_var_name] = 0; } VLOG(4) << "heter server QueryInSwitchWithScope done"; return 0; diff --git a/paddle/fluid/distributed/ps/service/heter_server.h b/paddle/fluid/distributed/ps/service/heter_server.h index 624e76112c..a65470cdba 100644 --- a/paddle/fluid/distributed/ps/service/heter_server.h +++ b/paddle/fluid/distributed/ps/service/heter_server.h @@ -56,9 +56,10 @@ class Scope; DECLARE_double(eager_delete_tensor_gb); DECLARE_int32(pserver_timeout_ms); DECLARE_int32(heter_world_size); +DECLARE_int32(switch_send_recv_timeout_s); + namespace paddle { namespace distributed { - using MultiVarMsg = MultiVariableMessage; using VarMsg = VariableMessage; @@ -95,6 +96,19 @@ using SharedTaskQueue = std::shared_ptr< std::unordered_map>>>>; +class ValueInSwitch { + public: + ValueInSwitch() {} + ~ValueInSwitch() {} + char* data() { return _data.data(); } + size_t size() { return _data.size(); } + void resize(size_t size) { _data.resize(size); } + void shrink_to_fit() { _data.shrink_to_fit(); } + + private: + std::vector _data; +}; + class SendAndRecvVariableHandler final : public ServiceHandlerBase { public: SendAndRecvVariableHandler() { @@ -130,22 +144,31 @@ class SendAndRecvVariableHandler final : public ServiceHandlerBase { brpc::Controller* cntl); void WaitForVarsConsumed(int32_t group_id, const std::string& var_name) { - auto& local_shard = _local_shards[group_id]; - while (local_shard.find(var_name) != local_shard.end()) { - if (local_shard[var_name].size() == 0) { + timeline_.Start(); + while (true) { + if (vars_ready_flag[group_id][var_name] == 0) { + break; + } + timeline_.Pause(); + if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) { + VLOG(0) << "vars not consumed exceed 10 miniutes"; break; } - VLOG(4) << "waiting consume result......"; - sleep(1); } return; } void WaitForVarsProduced(int32_t group_id, const std::string& var_name) { - auto& local_shard = _local_shards[group_id]; - while (local_shard.find(var_name) == local_shard.end()) { - VLOG(4) << "waiting produce result......"; - sleep(1); + timeline_.Start(); + while (true) { + if (vars_ready_flag[group_id][var_name] == 1) { + break; + } + timeline_.Pause(); + if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) { + VLOG(0) << "vars not produced exceed 10 miniutes"; + break; + } } return; } @@ -245,10 +268,12 @@ class SendAndRecvVariableHandler final : public ServiceHandlerBase { } public: - using shard_type = SparseTableShard; + using shard_type = SparseTableShard; std::shared_ptr local_scope_ptr; // for switch - std::unordered_map vars_table; + std::unordered_map> + vars_ready_flag; std::unique_ptr _local_shards; + platform::Timer timeline_; private: // share with HeterPipelineTrainer @@ -576,8 +601,11 @@ class HeterServer { // HeterWrapper singleton static std::shared_ptr GetInstance() { - if (NULL == s_instance_) { - s_instance_.reset(new HeterServer()); + if (s_instance_ == nullptr) { + std::unique_lock lock(mtx_); + if (NULL == s_instance_) { + s_instance_.reset(new HeterServer()); + } } return s_instance_; } @@ -587,6 +615,7 @@ class HeterServer { private: static std::shared_ptr s_instance_; mutable std::mutex mutex_; + static std::mutex mtx_; std::condition_variable cv_; std::condition_variable condition_ready_; bool stoped_ = true; diff --git a/paddle/fluid/distributed/ps/service/sendrecv.proto b/paddle/fluid/distributed/ps/service/sendrecv.proto index 46dcc2058f..ae6364dd83 100755 --- a/paddle/fluid/distributed/ps/service/sendrecv.proto +++ b/paddle/fluid/distributed/ps/service/sendrecv.proto @@ -126,7 +126,7 @@ message MultiVariableMessage { repeated string recv_var_names = 3; repeated VariableMessage var_messages = 4; optional bytes data = 5; - repeated int32 vars_len = 6; + repeated int64 vars_len = 6; optional int32 group_id = 7; }; diff --git a/paddle/fluid/operators/pscore/heter_cloud_comm_cpu_test.cc b/paddle/fluid/operators/pscore/heter_cloud_comm_cpu_test.cc index 2340f443c4..cf6369eecd 100644 --- a/paddle/fluid/operators/pscore/heter_cloud_comm_cpu_test.cc +++ b/paddle/fluid/operators/pscore/heter_cloud_comm_cpu_test.cc @@ -15,6 +15,9 @@ limitations under the License. */ #if defined PADDLE_WITH_PSCORE #include +#include +#include +#include #include #include #include @@ -69,44 +72,6 @@ void StartSwitchServer( std::vector peer_endpoints) { switch_server_ptr->SetPeerEndPoints(peer_endpoints); switch_server_ptr->SetEndPoint(endpoints[0]); - /* - std::shared_ptr b_req_handler; - b_req_handler.reset(new distributed::SendAndRecvVariableHandler()); - switch_server_ptr->SetServiceHandler(b_req_handler); - - switch_server_ptr->SetLocalScope(); - - switch_server_ptr->RegisterServiceHandler( - std::to_string(distributed::PS_SAVE_WITH_SCOPE), - [&](const MultiVarMsg* request, MultiVarMsg* response, - brpc::Controller* cntl) -> int { - return b_req_handler->SaveInSwitchWithScope(request, response, cntl); - }); - - switch_server_ptr->RegisterServiceHandler(std::to_string(distributed::PS_SAVE_WITH_SHARD), - [&](const MultiVarMsg* request, MultiVarMsg* - response, - brpc::Controller* cntl) -> int { - return b_req_handler->SaveInSwitchWithShard( - request, response, cntl); - }); - - switch_server_ptr->RegisterServiceHandler(std::to_string(distributed::PS_QUERY_WITH_SCOPE), - [&](const MultiVarMsg* request, MultiVarMsg* - response, - brpc::Controller* cntl) -> int { - return b_req_handler->QueryInSwitchWithScope( - request, response, cntl); - }); - - switch_server_ptr->RegisterServiceHandler(std::to_string(distributed::PS_QUERY_WITH_SHARD), - [&](const MultiVarMsg* request, MultiVarMsg* - response, - brpc::Controller* cntl) -> int { - return b_req_handler->QueryInSwitchWithShard( - request, response, cntl); - }); - */ switch_server_ptr->StartHeterService(false); } @@ -119,6 +84,129 @@ void StartSwitchInterServer( switch_server_ptr->StartHeterInterService(false); } +void TestShardSendRecv( + std::shared_ptr heter_client_ptr_) { + auto send_async = [&]() -> void { + std::vector vars_len{2 * sizeof(float), + 4 * sizeof(float)}; // 字节数 + std::vector values{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + int64_t data_size = 6 * sizeof(float); + std::vector send_var_names{"w", "x"}; + int group_id = 0; + int ret = heter_client_ptr_->Send(group_id, send_var_names, vars_len, + values.data(), data_size); + if (!ret) { + LOG(INFO) << ">>>> TestShardSendRecv: worker send success"; + } + }; + std::thread t(send_async); + + int group_id = 0; + std::vector recv_var_names{"w", "x"}; + int data_size = 6 * sizeof(float); + float* value_ptr = new float[6]; + int ret = + heter_client_ptr_->Recv(group_id, recv_var_names, value_ptr, data_size); + if (!ret) { + VLOG(4) << "queried data is: "; + for (int i = 0; i < 6; i++) { + VLOG(4) << value_ptr[i] << " "; + } + delete[] value_ptr; + LOG(INFO) << "<<<< TestShardSendRecv: worker recv success"; + } + + t.join(); +} + +void PressTestSendRecv( + std::shared_ptr heter_client_ptr_) { + // long l = 0, m = 0; + std::ifstream file("/send_20_34", std::ios::in | std::ios::binary); + // l = file.tellg(); + // file.seekg(0, std::ios::end); + // m = file.tellg(); + // file.close(); + // VLOG(0) << "size of file " << "20_34" << " is " << (m - l) << " bytes.\n"; + int64_t vars_len = 2359296 * sizeof(float); + int64_t data_size = vars_len * sizeof(float); + VLOG(0) << "float num: " << data_size; + float* data_ptr = new float[data_size]; + file.read((char*)data_ptr, 9437184); + VLOG(0) << "send data is: " << data_ptr[0] << ", " << data_ptr[1]; + std::vector var_names{"34"}; + int loopCnt = 600; + auto send_async = [&]() -> void { + int i = 0; + while (i++ < loopCnt) { + heter_client_ptr_->Send(20, var_names, {vars_len}, data_ptr, data_size); + } + }; + std::thread t(send_async); + float* values = new float[2359296]; + int i = 0; + while (i++ < loopCnt) { + int ret = heter_client_ptr_->Recv(20, var_names, values, data_size); + if (!ret) { + VLOG(0) << "diff: " << abs(values[0] - 0.159544) << ", " + << abs(values[1] + 2.3484); + VLOG(0) << "loop id: " << i; + for (int j = 0; j < 2359296; j++) { + if (abs(values[j] - data_ptr[j]) > 4e-6) { + VLOG(0) << "error data idx: " << j; + VLOG(0) << "diff detail: " << values[j] << ", " << data_ptr[j]; + LOG(INFO) << ">>>> worker recv ERROR"; + break; + } + } + for (uint32_t i = 0; i < 2359296; i++) { + values[i] = -1; // reset + } + } + } + delete[] values; + + std::ofstream recv("/recv_20_34", std::ios::out | std::ios::binary); + recv.write((char*)values, data_size); + recv.close(); + t.join(); +} + +void TestScopeSendRecv( + std::shared_ptr heter_client_ptr_) { + platform::CPUPlace place; + platform::CPUDeviceContext ctx(place); + framework::Executor exe(place); + std::shared_ptr send_scope_ptr = + std::make_shared(); + int64_t rows_numel = 10; + InitTensorsOnClient(send_scope_ptr.get(), &place, rows_numel); + LOG(INFO) << "InitTensorsOnClient done"; + auto send_async = [&]() -> void { + std::string message_name = std::to_string(distributed::PS_SAVE_WITH_SCOPE); + std::vector send_var_names{"w", "x"}; + int ret = heter_client_ptr_->Send(ctx, *send_scope_ptr, message_name, + send_var_names); + if (!ret) { + LOG(ERROR) << ">>>> TestScopeSendRecv: worker send success"; + } + }; + std::thread t(send_async); + + std::string message_name = std::to_string(distributed::PS_QUERY_WITH_SCOPE); + std::vector recv_var_names{"w", "x"}; + std::shared_ptr recv_scope_ptr = + std::make_shared(); + int ret = heter_client_ptr_->Recv(ctx, *recv_scope_ptr, message_name, + recv_var_names); + if (!ret && recv_scope_ptr->FindVar("w") && recv_scope_ptr->FindVar("x")) { + LOG(INFO) << "<<<< TestScopeSendRecv: worker recv success"; + } else { + LOG(INFO) << "<<<< TestScopeSendRecv: worker recv failed"; + } + t.join(); +} + TEST(HETERSENDANDRECV, CPU) { setenv("http_proxy", "", 1); setenv("https_proxy", "", 1); @@ -155,79 +243,19 @@ TEST(HETERSENDANDRECV, CPU) { switch_server_ptr_b->WaitServerReady(); // 获取 client 实例 + // 开启单测时,请重新设置 HeterClient 端的 recv_switch_channels_ std::shared_ptr heter_client_ptr_ = distributed::HeterClient::GetInstance( {switch_a_endpoint, switch_b_endpoint}, {}, 0); + framework::ProgramDesc program; platform::CPUPlace place; - platform::CPUDeviceContext ctx(place); framework::Executor exe(place); - - framework::ProgramDesc program; exe.Prepare(program, 0); // solve undefined symbol: tensor_table.cc - std::shared_ptr send_scope_ptr = - std::make_shared(); - int64_t rows_numel = 10; - InitTensorsOnClient(send_scope_ptr.get(), &place, rows_numel); - LOG(INFO) << "InitTensorsOnClient done"; - - auto send_async = [&]() -> void { - /* - //std::string message_name = - std::to_string(distributed::PS_SAVE_WITH_SCOPE); - std::string message_name = "send and save"; - std::vector send_var_names{"w", "x"}; - int ret = heter_client_ptr_->Send(ctx, *send_scope_ptr, message_name, - send_var_names); - if (!ret) { - LOG(ERROR) << ">>>> worker send success"; - } - */ - ///* - std::vector vars_len{2, 4}; - std::vector values{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - int64_t data_size = 6; - std::vector send_var_names{"w", "x"}; - int group_id = 0; - int ret = heter_client_ptr_->Send(group_id, send_var_names, vars_len, - values.data(), data_size); - if (!ret) { - LOG(INFO) << ">>>> worker send success"; - } - //*/ - }; - std::thread send_thread(send_async); - /* - std::string message_name = std::to_string(distributed::PS_QUERY_WITH_SCOPE); - std::vector recv_var_names{"w", "x"}; - std::shared_ptr recv_scope_ptr = - std::make_shared(); - int ret = heter_client_ptr_->Recv(ctx, *recv_scope_ptr, message_name, - recv_var_names); - if (!ret && recv_scope_ptr->FindVar("w") && recv_scope_ptr->FindVar("x")) { - LOG(INFO) << ">>>> worker recv success"; - } else { - LOG(INFO) << "worker recv failed"; - } - */ - ///* - int group_id = 0; - std::vector recv_var_names{"w", "x"}; - std::vector values; - int data_size = 6; - values.resize(data_size); - int ret = heter_client_ptr_->Recv(group_id, recv_var_names, values.data(), - data_size); - if (!ret) { - VLOG(4) << "queried data is: "; - for (auto f : values) { - VLOG(4) << f << " "; - } - LOG(INFO) << ">>>> worker recv success"; - } - //*/ - send_thread.join(); + // TestScopeSendRecv(heter_client_ptr_); + TestShardSendRecv(heter_client_ptr_); + // PressTestSendRecv(heter_client_ptr_); switch_server_ptr_a->Stop(); LOG(INFO) << "switch server A stopped"; -- GitLab