未验证 提交 27cb52a4 编写于 作者: Z ziyoujiyi 提交者: GitHub

fix heter_client&heter_server (#42188)

* back fl

* delete ssl cert

* .

* make warning

* .

* unittest paral degree

* solve unittest

* heter & multi cloud commm ready

* .

* .

* arm_brpc compile

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* only output is ok

* base is ok

* .

* .

* .

* .

* .

* .

* .

* .

* add switch server bin

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* adapt brpc ssl

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* fix heter_server & heter_client

* .

* .

* int->int64_t

* .
上级 fccb0819
......@@ -116,7 +116,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce(
HeterClient* client_ =
HeterClient::GetInstance({switch_endpoint_}, {}, 0).get();
auto dense_cpu_tensor = cpu_tensors[0];
std::vector<int> send_size;
std::vector<int64_t> send_size;
send_size.push_back(dense_cpu_tensor.numel());
int ret = client_->Send(
gid_, {dense_cpu_tensor.name()}, send_size, dense_cpu_tensor.data(),
......@@ -212,7 +212,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast(
HeterClient::GetInstance({switch_endpoint_}, {}, 0).get();
auto dense_cpu_tensor = cpu_tensors[0];
if (gloo_rank_ == 0) {
std::vector<int> send_size;
std::vector<int64_t> send_size;
send_size.push_back(dense_cpu_tensor.numel());
int ret = client_->Send(
gid_, {dense_cpu_tensor.name()}, send_size,
......
......@@ -55,8 +55,6 @@ DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num");
DEFINE_int32(pserver_sparse_table_shard_num, 1000,
"sparse table shard for save & load");
DEFINE_int32(heter_world_size, 100, "group size"); // 可配置
namespace paddle {
namespace framework {
class Scope;
......
......@@ -17,9 +17,11 @@
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32(heter_world_size, 100, "group size"); // group max size
DEFINE_int32(switch_send_recv_timeout_s, 600, "switch_send_recv_timeout_s");
namespace paddle {
namespace distributed {
std::shared_ptr<HeterClient> HeterClient::s_instance_ = nullptr;
int GetMicroId(const platform::DeviceContext& ctx,
......@@ -222,6 +224,7 @@ int HeterClient::Send(const platform::DeviceContext& ctx,
distributed::MultiVarMsg request;
// 1. set req message_name(string)
request.set_message_name(message_name);
request.set_group_id(0);
// 2. set req send_var_names(<string>)
for (auto& send_var_name : send_var_names) {
......@@ -263,7 +266,7 @@ int HeterClient::Send(const platform::DeviceContext& ctx,
}
int HeterClient::Send(int group_id, const std::vector<std::string>& var_names,
const std::vector<int>& vars_len, void* data_ptr,
const std::vector<int64_t>& vars_size, void* data_ptr,
int64_t data_size) {
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
......@@ -282,7 +285,7 @@ int HeterClient::Send(int group_id, const std::vector<std::string>& var_names,
for (auto& send_var_name : var_names) {
request.add_send_var_names(send_var_name);
}
for (auto var_len : vars_len) {
for (auto var_len : vars_size) {
request.add_vars_len(var_len);
}
auto& request_buffer = closure->cntl.request_attachment();
......@@ -301,6 +304,7 @@ int HeterClient::Send(int group_id, const std::vector<std::string>& var_names,
::paddle::distributed::PsService_Stub stub(channel);
stub.SendToSwitch(&closure->cntl, &request, &closure->ps_response, closure);
fut.wait();
delete closure;
return 0;
}
......@@ -325,6 +329,7 @@ int HeterClient::Recv(const platform::DeviceContext& ctx,
distributed::MultiVarMsg request;
// 1. set req message_name(string)
request.set_message_name(message_name);
request.set_group_id(0);
// 2. set req recv_var_names(<string>)
for (auto& recv_var_name : recv_var_names) {
......@@ -396,8 +401,8 @@ int HeterClient::Recv(int group_id, const std::vector<std::string>& var_names,
// save in worker
auto& res_io_buffer = closure->cntl.response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
io_buffer_itr.copy_and_forward(reinterpret_cast<void*>(data_ptr),
data_size * sizeof(float));
io_buffer_itr.copy_and_forward(reinterpret_cast<void*>(data_ptr), data_size);
delete closure;
VLOG(4) << "Recv done";
return 0;
}
......
......@@ -138,7 +138,8 @@ class HeterClient {
const std::string& mode = "forward");
int Send(int group_id, const std::vector<std::string>& var_names,
const std::vector<int>& vars_len, void* data_ptr, int64_t data_size);
const std::vector<int64_t>& vars_len, void* data_ptr,
int64_t data_size);
int Send(const platform::DeviceContext& ctx, const framework::Scope& scope,
const std::string& message_name,
......
......@@ -20,8 +20,8 @@ namespace paddle {
namespace distributed {
// DEFINE_string(cert_path, "./cert.pem", "cert.pem path");
// DEFINE_string(key_path, "./key.pem", "key.pem path");
std::shared_ptr<HeterServer> HeterServer::s_instance_ = nullptr;
std::mutex HeterServer::mtx_;
void HeterServer::RegisterServiceHandler(std::string message_name,
HeterServiceHandler func) {
......@@ -130,21 +130,15 @@ int SendAndRecvVariableHandler::SaveInSwitchWithShard(
butil::IOBufBytesIterator io_buffer_itr(request_io_buffer);
for (int idx = 0; idx < request->send_var_names_size(); idx++) {
const auto& var_name = request->send_var_names(idx);
const auto& var_len = request->vars_len(idx);
auto itr = local_shard.find(var_name);
if (itr != local_shard.end()) {
LOG(INFO) << "var: " << var_name << "has not been consumed!"
<< "check again";
WaitForVarsConsumed(group_id, var_name);
}
const auto& var_size = request->vars_len(idx);
WaitForVarsConsumed(group_id, var_name);
auto& value = local_shard[var_name];
value.resize(var_len);
value.resize(var_size);
io_buffer_itr.copy_and_forward(reinterpret_cast<void*>(value.data()),
var_len * sizeof(float));
VLOG(4) << "saved data in shards: ";
for (uint32_t i = 0; i < local_shard[var_name].size(); i++) {
VLOG(4) << *(local_shard[var_name].data() + i);
}
var_size);
std::unique_lock<std::mutex> lk(scope_mutex_);
vars_ready_flag[group_id][var_name] = 1;
VLOG(4) << "saved var_name: " << var_name << "is saved ready!";
}
VLOG(4) << "SaveInSwitchWithShard success";
return 0;
......@@ -164,20 +158,17 @@ int SendAndRecvVariableHandler::QueryInSwitchWithShard(
}
auto msg_name = request->message_name();
response->set_message_name(msg_name);
for (auto& req_var_name : req_var_names) {
VLOG(4) << "req var name: " << req_var_name;
response->add_send_var_names(req_var_name);
WaitForVarsProduced(group_id, req_var_name);
auto itr = local_shard.find(req_var_name);
if (itr == local_shard.end()) {
LOG(INFO) << "var: " << req_var_name << " not found in shards";
WaitForVarsProduced(group_id, req_var_name);
}
LOG(INFO) << "var: " << req_var_name << " found in shards";
itr = local_shard.find(req_var_name);
auto& value = itr.value();
response_io_buffer.append(value.data(), value.size() * sizeof(float));
value.resize(0); // 标记位
response_io_buffer.append(value.data(), value.size());
value.resize(0); // 清空内存
std::unique_lock<std::mutex> lk(scope_mutex_);
vars_ready_flag[group_id][req_var_name] = 0;
VLOG(4) << "query var_name: " << req_var_name << "is consumed ready!";
}
VLOG(4) << "heter server QueryInSwitchWithShard done";
return 0;
......@@ -192,37 +183,31 @@ int SendAndRecvVariableHandler::SaveInSwitchWithScope(
auto& cpu_dev_ctx = *pool.Get(cpu_place);
auto message_name = request->message_name();
VLOG(4) << "message_name in heter server: " << message_name;
auto send_var_nums = request->send_var_names_size();
std::vector<std::string> send_var_names(send_var_nums);
for (int idx = 0; idx < send_var_nums; idx++) {
send_var_names[idx] = request->var_messages(idx).varname();
}
std::unique_lock<std::mutex> lk(scope_mutex_);
auto local_scope = local_scope_ptr.get();
if (!local_scope) {
LOG(ERROR) << "local_scope_ptr is null in SaveInSwitchWithScope";
}
for (int idx = 0; idx < request->send_var_names_size(); idx++) {
const auto& msg = request->var_messages(idx);
std::string var_name = msg.varname();
for (auto var_name : send_var_names) {
auto* var_exist_ptr = local_scope->FindVar(var_name);
if (!var_exist_ptr) {
VLOG(4) << "not find var: " << var_name << " in local_scope";
}
vars_table[var_name] += 1;
VLOG(4) << "saved var_name: " << var_name
<< ", cnt = " << vars_table[var_name];
WaitForVarsConsumed(0, var_name);
}
auto& request_io_buffer = cntl->request_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(*request, &request_io_buffer,
cpu_dev_ctx, local_scope);
lk.unlock();
while (true) {
int ret = 0;
for (int idx = 0; idx < request->send_var_names_size(); idx++) {
ret |= vars_table[request->var_messages(idx).varname()];
}
if (!ret) {
VLOG(4) << "all saved vars consumed";
break;
}
VLOG(4) << "waiting consume result......";
sleep(1);
for (auto var_name : send_var_names) {
std::unique_lock<std::mutex> lk(scope_mutex_);
vars_ready_flag[0][var_name] = 1;
}
VLOG(4) << "SaveInSwitchWithScope success";
return 0;
......@@ -258,19 +243,14 @@ int SendAndRecvVariableHandler::QueryInSwitchWithScope(
// 3. fill var_messages(VarMessage)
for (auto& req_var_name : req_var_names) {
LOG(INFO) << "query var_name: " << req_var_name;
WaitForVarsProduced(0, req_var_name);
auto* send_var_msg = response->add_var_messages();
send_var_msg->set_varname(req_var_name);
framework::Variable* var_ptr;
while (true) {
var_ptr = local_scope->FindVar(req_var_name);
if (!var_ptr) {
LOG(INFO) << "local_scope not find var: " << req_var_name;
} else {
break;
}
sleep(1);
var_ptr = local_scope->FindVar(req_var_name);
if (!var_ptr) {
LOG(INFO) << "local_scope not find var: " << req_var_name;
}
butil::IOBuf temp_iobuf;
if (var_ptr->IsType<framework::LoDTensor>()) {
......@@ -282,10 +262,7 @@ int SendAndRecvVariableHandler::QueryInSwitchWithScope(
}
for (auto& req_var_name : req_var_names) {
std::unique_lock<std::mutex> lk(scope_mutex_);
vars_table[req_var_name] -= 1;
VLOG(4) << "remained var: " << req_var_name
<< ", cnt = " << vars_table[req_var_name];
lk.unlock();
vars_ready_flag[0][req_var_name] = 0;
}
VLOG(4) << "heter server QueryInSwitchWithScope done";
return 0;
......
......@@ -56,9 +56,10 @@ class Scope;
DECLARE_double(eager_delete_tensor_gb);
DECLARE_int32(pserver_timeout_ms);
DECLARE_int32(heter_world_size);
DECLARE_int32(switch_send_recv_timeout_s);
namespace paddle {
namespace distributed {
using MultiVarMsg = MultiVariableMessage;
using VarMsg = VariableMessage;
......@@ -95,6 +96,19 @@ using SharedTaskQueue = std::shared_ptr<
std::unordered_map<int, std::shared_ptr<::paddle::framework::BlockingQueue<
std::pair<std::string, int>>>>>;
class ValueInSwitch {
public:
ValueInSwitch() {}
~ValueInSwitch() {}
char* data() { return _data.data(); }
size_t size() { return _data.size(); }
void resize(size_t size) { _data.resize(size); }
void shrink_to_fit() { _data.shrink_to_fit(); }
private:
std::vector<char> _data;
};
class SendAndRecvVariableHandler final : public ServiceHandlerBase {
public:
SendAndRecvVariableHandler() {
......@@ -130,22 +144,31 @@ class SendAndRecvVariableHandler final : public ServiceHandlerBase {
brpc::Controller* cntl);
void WaitForVarsConsumed(int32_t group_id, const std::string& var_name) {
auto& local_shard = _local_shards[group_id];
while (local_shard.find(var_name) != local_shard.end()) {
if (local_shard[var_name].size() == 0) {
timeline_.Start();
while (true) {
if (vars_ready_flag[group_id][var_name] == 0) {
break;
}
timeline_.Pause();
if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) {
VLOG(0) << "vars not consumed exceed 10 miniutes";
break;
}
VLOG(4) << "waiting consume result......";
sleep(1);
}
return;
}
void WaitForVarsProduced(int32_t group_id, const std::string& var_name) {
auto& local_shard = _local_shards[group_id];
while (local_shard.find(var_name) == local_shard.end()) {
VLOG(4) << "waiting produce result......";
sleep(1);
timeline_.Start();
while (true) {
if (vars_ready_flag[group_id][var_name] == 1) {
break;
}
timeline_.Pause();
if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) {
VLOG(0) << "vars not produced exceed 10 miniutes";
break;
}
}
return;
}
......@@ -245,10 +268,12 @@ class SendAndRecvVariableHandler final : public ServiceHandlerBase {
}
public:
using shard_type = SparseTableShard<std::string, FixedFeatureValue>;
using shard_type = SparseTableShard<std::string, ValueInSwitch>;
std::shared_ptr<paddle::framework::Scope> local_scope_ptr; // for switch
std::unordered_map<std::string, uint32_t> vars_table;
std::unordered_map<uint32_t, std::unordered_map<std::string, uint32_t>>
vars_ready_flag;
std::unique_ptr<shard_type[]> _local_shards;
platform::Timer timeline_;
private:
// share with HeterPipelineTrainer
......@@ -576,8 +601,11 @@ class HeterServer {
// HeterWrapper singleton
static std::shared_ptr<HeterServer> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new HeterServer());
if (s_instance_ == nullptr) {
std::unique_lock<std::mutex> lock(mtx_);
if (NULL == s_instance_) {
s_instance_.reset(new HeterServer());
}
}
return s_instance_;
}
......@@ -587,6 +615,7 @@ class HeterServer {
private:
static std::shared_ptr<HeterServer> s_instance_;
mutable std::mutex mutex_;
static std::mutex mtx_;
std::condition_variable cv_;
std::condition_variable condition_ready_;
bool stoped_ = true;
......
......@@ -126,7 +126,7 @@ message MultiVariableMessage {
repeated string recv_var_names = 3;
repeated VariableMessage var_messages = 4;
optional bytes data = 5;
repeated int32 vars_len = 6;
repeated int64 vars_len = 6;
optional int32 group_id = 7;
};
......
......@@ -15,6 +15,9 @@ limitations under the License. */
#if defined PADDLE_WITH_PSCORE
#include <stdlib.h>
#include <cmath>
#include <fstream>
#include <iostream>
#include <memory>
#include <random>
#include <sstream>
......@@ -69,44 +72,6 @@ void StartSwitchServer(
std::vector<std::string> peer_endpoints) {
switch_server_ptr->SetPeerEndPoints(peer_endpoints);
switch_server_ptr->SetEndPoint(endpoints[0]);
/*
std::shared_ptr<distributed::SendAndRecvVariableHandler> b_req_handler;
b_req_handler.reset(new distributed::SendAndRecvVariableHandler());
switch_server_ptr->SetServiceHandler(b_req_handler);
switch_server_ptr->SetLocalScope();
switch_server_ptr->RegisterServiceHandler(
std::to_string(distributed::PS_SAVE_WITH_SCOPE),
[&](const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) -> int {
return b_req_handler->SaveInSwitchWithScope(request, response, cntl);
});
switch_server_ptr->RegisterServiceHandler(std::to_string(distributed::PS_SAVE_WITH_SHARD),
[&](const MultiVarMsg* request, MultiVarMsg*
response,
brpc::Controller* cntl) -> int {
return b_req_handler->SaveInSwitchWithShard(
request, response, cntl);
});
switch_server_ptr->RegisterServiceHandler(std::to_string(distributed::PS_QUERY_WITH_SCOPE),
[&](const MultiVarMsg* request, MultiVarMsg*
response,
brpc::Controller* cntl) -> int {
return b_req_handler->QueryInSwitchWithScope(
request, response, cntl);
});
switch_server_ptr->RegisterServiceHandler(std::to_string(distributed::PS_QUERY_WITH_SHARD),
[&](const MultiVarMsg* request, MultiVarMsg*
response,
brpc::Controller* cntl) -> int {
return b_req_handler->QueryInSwitchWithShard(
request, response, cntl);
});
*/
switch_server_ptr->StartHeterService(false);
}
......@@ -119,6 +84,129 @@ void StartSwitchInterServer(
switch_server_ptr->StartHeterInterService(false);
}
void TestShardSendRecv(
std::shared_ptr<distributed::HeterClient> heter_client_ptr_) {
auto send_async = [&]() -> void {
std::vector<int64_t> vars_len{2 * sizeof(float),
4 * sizeof(float)}; // 字节数
std::vector<float> values{1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
int64_t data_size = 6 * sizeof(float);
std::vector<std::string> send_var_names{"w", "x"};
int group_id = 0;
int ret = heter_client_ptr_->Send(group_id, send_var_names, vars_len,
values.data(), data_size);
if (!ret) {
LOG(INFO) << ">>>> TestShardSendRecv: worker send success";
}
};
std::thread t(send_async);
int group_id = 0;
std::vector<std::string> recv_var_names{"w", "x"};
int data_size = 6 * sizeof(float);
float* value_ptr = new float[6];
int ret =
heter_client_ptr_->Recv(group_id, recv_var_names, value_ptr, data_size);
if (!ret) {
VLOG(4) << "queried data is: ";
for (int i = 0; i < 6; i++) {
VLOG(4) << value_ptr[i] << " ";
}
delete[] value_ptr;
LOG(INFO) << "<<<< TestShardSendRecv: worker recv success";
}
t.join();
}
void PressTestSendRecv(
std::shared_ptr<distributed::HeterClient> heter_client_ptr_) {
// long l = 0, m = 0;
std::ifstream file("/send_20_34", std::ios::in | std::ios::binary);
// l = file.tellg();
// file.seekg(0, std::ios::end);
// m = file.tellg();
// file.close();
// VLOG(0) << "size of file " << "20_34" << " is " << (m - l) << " bytes.\n";
int64_t vars_len = 2359296 * sizeof(float);
int64_t data_size = vars_len * sizeof(float);
VLOG(0) << "float num: " << data_size;
float* data_ptr = new float[data_size];
file.read((char*)data_ptr, 9437184);
VLOG(0) << "send data is: " << data_ptr[0] << ", " << data_ptr[1];
std::vector<std::string> var_names{"34"};
int loopCnt = 600;
auto send_async = [&]() -> void {
int i = 0;
while (i++ < loopCnt) {
heter_client_ptr_->Send(20, var_names, {vars_len}, data_ptr, data_size);
}
};
std::thread t(send_async);
float* values = new float[2359296];
int i = 0;
while (i++ < loopCnt) {
int ret = heter_client_ptr_->Recv(20, var_names, values, data_size);
if (!ret) {
VLOG(0) << "diff: " << abs(values[0] - 0.159544) << ", "
<< abs(values[1] + 2.3484);
VLOG(0) << "loop id: " << i;
for (int j = 0; j < 2359296; j++) {
if (abs(values[j] - data_ptr[j]) > 4e-6) {
VLOG(0) << "error data idx: " << j;
VLOG(0) << "diff detail: " << values[j] << ", " << data_ptr[j];
LOG(INFO) << ">>>> worker recv ERROR";
break;
}
}
for (uint32_t i = 0; i < 2359296; i++) {
values[i] = -1; // reset
}
}
}
delete[] values;
std::ofstream recv("/recv_20_34", std::ios::out | std::ios::binary);
recv.write((char*)values, data_size);
recv.close();
t.join();
}
void TestScopeSendRecv(
std::shared_ptr<distributed::HeterClient> heter_client_ptr_) {
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
framework::Executor exe(place);
std::shared_ptr<framework::Scope> send_scope_ptr =
std::make_shared<framework::Scope>();
int64_t rows_numel = 10;
InitTensorsOnClient(send_scope_ptr.get(), &place, rows_numel);
LOG(INFO) << "InitTensorsOnClient done";
auto send_async = [&]() -> void {
std::string message_name = std::to_string(distributed::PS_SAVE_WITH_SCOPE);
std::vector<std::string> send_var_names{"w", "x"};
int ret = heter_client_ptr_->Send(ctx, *send_scope_ptr, message_name,
send_var_names);
if (!ret) {
LOG(ERROR) << ">>>> TestScopeSendRecv: worker send success";
}
};
std::thread t(send_async);
std::string message_name = std::to_string(distributed::PS_QUERY_WITH_SCOPE);
std::vector<std::string> recv_var_names{"w", "x"};
std::shared_ptr<framework::Scope> recv_scope_ptr =
std::make_shared<framework::Scope>();
int ret = heter_client_ptr_->Recv(ctx, *recv_scope_ptr, message_name,
recv_var_names);
if (!ret && recv_scope_ptr->FindVar("w") && recv_scope_ptr->FindVar("x")) {
LOG(INFO) << "<<<< TestScopeSendRecv: worker recv success";
} else {
LOG(INFO) << "<<<< TestScopeSendRecv: worker recv failed";
}
t.join();
}
TEST(HETERSENDANDRECV, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
......@@ -155,79 +243,19 @@ TEST(HETERSENDANDRECV, CPU) {
switch_server_ptr_b->WaitServerReady();
// 获取 client 实例
// 开启单测时,请重新设置 HeterClient 端的 recv_switch_channels_
std::shared_ptr<distributed::HeterClient> heter_client_ptr_ =
distributed::HeterClient::GetInstance(
{switch_a_endpoint, switch_b_endpoint}, {}, 0);
framework::ProgramDesc program;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
framework::Executor exe(place);
framework::ProgramDesc program;
exe.Prepare(program, 0); // solve undefined symbol: tensor_table.cc
std::shared_ptr<framework::Scope> send_scope_ptr =
std::make_shared<framework::Scope>();
int64_t rows_numel = 10;
InitTensorsOnClient(send_scope_ptr.get(), &place, rows_numel);
LOG(INFO) << "InitTensorsOnClient done";
auto send_async = [&]() -> void {
/*
//std::string message_name =
std::to_string(distributed::PS_SAVE_WITH_SCOPE);
std::string message_name = "send and save";
std::vector<std::string> send_var_names{"w", "x"};
int ret = heter_client_ptr_->Send(ctx, *send_scope_ptr, message_name,
send_var_names);
if (!ret) {
LOG(ERROR) << ">>>> worker send success";
}
*/
///*
std::vector<int> vars_len{2, 4};
std::vector<float> values{1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
int64_t data_size = 6;
std::vector<std::string> send_var_names{"w", "x"};
int group_id = 0;
int ret = heter_client_ptr_->Send(group_id, send_var_names, vars_len,
values.data(), data_size);
if (!ret) {
LOG(INFO) << ">>>> worker send success";
}
//*/
};
std::thread send_thread(send_async);
/*
std::string message_name = std::to_string(distributed::PS_QUERY_WITH_SCOPE);
std::vector<std::string> recv_var_names{"w", "x"};
std::shared_ptr<framework::Scope> recv_scope_ptr =
std::make_shared<framework::Scope>();
int ret = heter_client_ptr_->Recv(ctx, *recv_scope_ptr, message_name,
recv_var_names);
if (!ret && recv_scope_ptr->FindVar("w") && recv_scope_ptr->FindVar("x")) {
LOG(INFO) << ">>>> worker recv success";
} else {
LOG(INFO) << "worker recv failed";
}
*/
///*
int group_id = 0;
std::vector<std::string> recv_var_names{"w", "x"};
std::vector<float> values;
int data_size = 6;
values.resize(data_size);
int ret = heter_client_ptr_->Recv(group_id, recv_var_names, values.data(),
data_size);
if (!ret) {
VLOG(4) << "queried data is: ";
for (auto f : values) {
VLOG(4) << f << " ";
}
LOG(INFO) << ">>>> worker recv success";
}
//*/
send_thread.join();
// TestScopeSendRecv(heter_client_ptr_);
TestShardSendRecv(heter_client_ptr_);
// PressTestSendRecv(heter_client_ptr_);
switch_server_ptr_a->Stop();
LOG(INFO) << "switch server A stopped";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册