未验证 提交 a2da1efa 编写于 作者: Z zmx 提交者: GitHub

[Heterps]Refactor Heter Pipeline Parameter Server (#36845)

* change username

* fix

* fix

* fix

* fix

* fix

* update

* update

* update unittests

* fix

* update

* fix

* update

* fix

* fix

* fix

* update

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update send_and_recv op. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix ut. test=develop

* fix unit. notest,test=coverage

* fix ut. notest, test=coverage

* update. notest,test=coverage

* fix ut. notest, test=coverage

* fix ut. notest, test=coverage

* fix. notest, test=coverage

* fix. notest, test=coverage

* fix ut. notest, test=coverage

* fix ut. notest, test=coverage

* fix ut. notest, test=coverage

* fix ut. notest, test=coverage

* add func. notest, test=coverage

* fix ut. notest, test=coverage

* fix. test=develop

* fix. test=develop
上级 52645667
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/distributed/service/communicator.h" #include "paddle/fluid/distributed/service/communicator.h"
#include <google/protobuf/text_format.h> #include <google/protobuf/text_format.h>
#include "gflags/gflags.h" #include "gflags/gflags.h"
...@@ -361,6 +360,8 @@ void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) { ...@@ -361,6 +360,8 @@ void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) {
<< " from 0' trainer done"; << " from 0' trainer done";
} }
} }
std::this_thread::sleep_for(
std::chrono::milliseconds(100 + trainer_id_ * 10));
BarrierWithTable(1); BarrierWithTable(1);
return; return;
} }
...@@ -518,7 +519,6 @@ void AsyncCommunicator::SendByCommunicator() { ...@@ -518,7 +519,6 @@ void AsyncCommunicator::SendByCommunicator() {
MergeVars<float>(var_name, vars[i], send_scope_.get(), 1); MergeVars<float>(var_name, vars[i], send_scope_.get(), 1);
} }
} }
if (ctx.is_tensor_table) { if (ctx.is_tensor_table) {
SendGlobalStep(ctx, merged_var_num, send_scope_.get()); SendGlobalStep(ctx, merged_var_num, send_scope_.get());
} else if (ctx.is_sparse) { } else if (ctx.is_sparse) {
......
...@@ -25,6 +25,36 @@ namespace distributed { ...@@ -25,6 +25,36 @@ namespace distributed {
std::shared_ptr<HeterClient> HeterClient::s_instance_ = NULL; std::shared_ptr<HeterClient> HeterClient::s_instance_ = NULL;
bool HeterClient::is_initialized_ = false; bool HeterClient::is_initialized_ = false;
int GetMicroId(const platform::DeviceContext& ctx,
const framework::Scope* scope) {
framework::Variable* var = scope->FindVar("microbatch_id");
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"the type of micro id shoulde be LoDTensor."));
auto micro_id = -1;
auto* tensor = var->GetMutable<framework::LoDTensor>();
if (platform::is_cpu_place(tensor->place())) {
auto data = reinterpret_cast<const float*>(tensor->data<void>());
micro_id = static_cast<int>(data[0]);
} else {
#ifdef PADDLE_WITH_CUDA
std::vector<char> temp;
temp.resize(tensor->numel() * framework::SizeOfType(tensor->type()));
char* temp_ptr = temp.data();
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(platform::CPUPlace(), temp_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()),
tensor->data<void>(),
tensor->numel() * framework::SizeOfType(tensor->type()),
stream);
float* temp_ptr_float = reinterpret_cast<float*>(temp_ptr);
micro_id = static_cast<int>(temp_ptr_float[0]);
#endif
}
return micro_id;
}
void HeterClient::MainThread() { void HeterClient::MainThread() {
while (running_) { while (running_) {
RpcProfilerControl(); RpcProfilerControl();
...@@ -99,43 +129,68 @@ void HeterClient::CreateClient2XpuConnection() { ...@@ -99,43 +129,68 @@ void HeterClient::CreateClient2XpuConnection() {
} }
} }
} }
previous_xpu_channels_.resize(previous_xpu_list_.size());
for (size_t i = 0; i < previous_xpu_list_.size(); ++i) {
previous_xpu_channels_[i].reset(new brpc::Channel());
if (previous_xpu_channels_[i]->Init(previous_xpu_list_[i].c_str(), "",
&options) != 0) {
VLOG(0) << "HeterClient channel init fail. Try Again";
auto ip_port = paddle::string::Split(previous_xpu_list_[i], ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (previous_xpu_channels_[i]->Init(int_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "BrpcPsServer start failed, ip_port= " << int_ip_port;
}
}
}
} }
void HeterClient::SendAndRecvAsync( void HeterClient::SendAndRecvAsync(
const std::vector<std::string>& ep, const platform::DeviceContext& ctx, const platform::DeviceContext& ctx, const framework::Scope& scope,
const framework::Scope& scope, const std::string& message_name, const std::string& message_name,
const std::vector<std::string>& send_var_name, const std::vector<std::string>& send_var_name,
const std::vector<std::string>& recv_var_name) { const std::vector<std::string>& recv_var_name, const std::string& mode) {
platform::RecordEvent record_event("HeterClient->SendAndRecvAsync"); platform::RecordEvent record_event("HeterClient->SendAndRecvAsync");
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const std::string message_name_val = message_name; const std::string message_name_val = message_name;
const std::vector<std::string> send_var_name_val = send_var_name; const std::vector<std::string> send_var_name_val = send_var_name;
const std::vector<std::string> recv_var_name_val = recv_var_name; const std::vector<std::string> recv_var_name_val = recv_var_name;
VLOG(3) << "BRPCClient::SendAndRecv Begin, message_name: "
VLOG(3) << "GRPCClient::SendAndRecv Begin, message_name: "
<< message_name_val; << message_name_val;
// Todo: get correct channel brpc::Channel* channel = nullptr;
int num = trainer_id_ % xpu_channels_.size(); distributed::MultiVarMsg request;
OnHeterRpcDone* closure = new OnHeterRpcDone([p_ctx, p_scope](void* done) {
brpc::Controller cntl; auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
distributed::MultiVarMsg request, response;
auto& request_io_buffer = cntl.request_attachment();
::paddle::distributed::PsService_Stub stub(xpu_channels_[num].get());
distributed::SerializeToMultiVarMsgAndIOBuf(
message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
&request, &request_io_buffer);
stub.SendAndRecvVariable(&cntl, &request, &response, NULL);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
cntl.Failed(), true, closure->cntl.Failed(), true,
platform::errors::Unimplemented( platform::errors::Unimplemented(
"HeterClient::SendAndRecv meets brpc error, error message is %s", "HeterClient::SendAndRecv meets brpc error, error message is %s",
cntl.ErrorText())); closure->cntl.ErrorText()));
VLOG(4) << "call heter_worker success"; VLOG(4) << "call heter_worker success";
auto& response_io_buffer = cntl.response_attachment(); });
distributed::DeserializeFromMultiVarMsgAndIOBuf(response, &response_io_buffer, closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
ctx, p_scope); auto& request_io_buffer = closure->cntl.request_attachment();
distributed::SerializeToMultiVarMsgAndIOBuf(
message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
&request, &request_io_buffer);
int micro_id = GetMicroId(ctx, p_scope);
auto minibatch_id = micro_id / 10;
// select channel according to micro id
if (mode == "forward") {
int num = minibatch_id % xpu_channels_.size();
channel = xpu_channels_[num].get();
} else if (mode == "backward") {
int num = minibatch_id % previous_xpu_channels_.size();
channel = previous_xpu_channels_[num].get();
}
::paddle::distributed::PsService_Stub stub(channel);
stub.SendAndRecvVariable(&closure->cntl, &request, &closure->response,
closure);
} }
std::future<int32_t> HeterClient::SendCmd( std::future<int32_t> HeterClient::SendCmd(
......
...@@ -76,20 +76,23 @@ class HeterClient { ...@@ -76,20 +76,23 @@ class HeterClient {
void CreateClient2XpuConnection(); void CreateClient2XpuConnection();
void SendAndRecvAsync(const std::vector<std::string>& ep, void SendAndRecvAsync(const platform::DeviceContext& ctx,
const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& message_name, const std::string& message_name,
const std::vector<std::string>& send_var_name, const std::vector<std::string>& send_var_name,
const std::vector<std::string>& recv_var_name); const std::vector<std::string>& recv_var_name,
const std::string& mode = "forward");
// HeterClient singleton // HeterClient singleton
static std::shared_ptr<HeterClient> GetInstance( static std::shared_ptr<HeterClient> GetInstance(
const std::vector<std::string>& endpoint, const int& trainer_id) { const std::vector<std::string>& endpoint,
const std::vector<std::string>& previous_endpoint,
const int& trainer_id) {
if (NULL == s_instance_) { if (NULL == s_instance_) {
is_initialized_ = true; is_initialized_ = true;
s_instance_.reset(new paddle::distributed::HeterClient()); s_instance_.reset(new paddle::distributed::HeterClient());
s_instance_->SetXpuList(endpoint); s_instance_->SetXpuList(endpoint);
s_instance_->SetPreviousXpuList(previous_endpoint);
s_instance_->SetTrainerID(trainer_id); s_instance_->SetTrainerID(trainer_id);
s_instance_->CreateClient2XpuConnection(); s_instance_->CreateClient2XpuConnection();
} }
...@@ -118,6 +121,10 @@ class HeterClient { ...@@ -118,6 +121,10 @@ class HeterClient {
xpu_list_ = xpu_list; xpu_list_ = xpu_list;
} }
void SetPreviousXpuList(const std::vector<std::string>& xpu_list) {
previous_xpu_list_ = xpu_list;
}
void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; } void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; }
private: private:
...@@ -125,9 +132,11 @@ class HeterClient { ...@@ -125,9 +132,11 @@ class HeterClient {
static bool is_initialized_; static bool is_initialized_;
std::unique_ptr<std::thread> main_thread_{nullptr}; std::unique_ptr<std::thread> main_thread_{nullptr};
std::vector<std::shared_ptr<brpc::Channel>> xpu_channels_; std::vector<std::shared_ptr<brpc::Channel>> xpu_channels_;
std::vector<std::shared_ptr<brpc::Channel>> previous_xpu_channels_;
DISABLE_COPY_AND_ASSIGN(HeterClient); DISABLE_COPY_AND_ASSIGN(HeterClient);
std::vector<std::string> xpu_list_; std::vector<std::string> xpu_list_;
std::vector<std::string> previous_xpu_list_;
bool running_ = false; bool running_ = false;
int trainer_id_; int trainer_id_;
......
...@@ -46,20 +46,20 @@ void HeterServer::StartHeterService() { ...@@ -46,20 +46,20 @@ void HeterServer::StartHeterService() {
ready_ = 1; ready_ = 1;
} }
condition_ready_.notify_all(); condition_ready_.notify_all();
std::unique_lock<std::mutex> running_lock(mutex_); std::unique_lock<std::mutex> running_lock(mutex_);
stoped_ = false;
cv_.wait(running_lock, [&] { cv_.wait(running_lock, [&] {
VLOG(1) << "Heter Server is Stop? " << stoped_; VLOG(1) << "Heter Server is Stop? " << stoped_;
return stoped_; return stoped_;
}); });
} }
void HeterServer::SetEndPoint(std::string& endpoint) { void HeterServer::SetEndPoint(const std::string& endpoint) {
endpoint_ = endpoint; endpoint_ = endpoint;
service_.SetEndpoint(endpoint); service_.SetEndpoint(endpoint);
} }
void HeterServer::SetFanin(int& fan_in) { service_.SetFanin(fan_in); } void HeterServer::SetFanin(const int& fan_in) { service_.SetFanin(fan_in); }
void HeterServer::WaitServerReady() { void HeterServer::WaitServerReady() {
std::unique_lock<std::mutex> lock(this->mutex_ready_); std::unique_lock<std::mutex> lock(this->mutex_ready_);
......
...@@ -27,6 +27,7 @@ limitations under the License. */ ...@@ -27,6 +27,7 @@ limitations under the License. */
#include "brpc/server.h" #include "brpc/server.h"
#include "paddle/fluid/distributed/service/brpc_utils.h" #include "paddle/fluid/distributed/service/brpc_utils.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h" #include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -45,6 +46,7 @@ namespace paddle { ...@@ -45,6 +46,7 @@ namespace paddle {
namespace framework { namespace framework {
class Executor; class Executor;
class ProgramDesc; class ProgramDesc;
class Scope;
} // namespace framework } // namespace framework
namespace platform { namespace platform {
class DeviceContext; class DeviceContext;
...@@ -61,7 +63,7 @@ using VarMsg = ::paddle::distributed::VariableMessage; ...@@ -61,7 +63,7 @@ using VarMsg = ::paddle::distributed::VariableMessage;
class HeterService; class HeterService;
typedef int32_t (HeterService::*serviceHandlerFunc)( typedef int32_t (HeterService::*serviceHandlerFunc)(
const PsRequestMessage& request, PsResponseMessage& response, const PsRequestMessage& request, PsResponseMessage& response, // NOLINT
brpc::Controller* cntl); brpc::Controller* cntl);
typedef std::function<void(void*)> HeterRpcCallbackFunc; typedef std::function<void(void*)> HeterRpcCallbackFunc;
...@@ -124,19 +126,27 @@ class HeterService : public ::paddle::distributed::PsService { ...@@ -124,19 +126,27 @@ class HeterService : public ::paddle::distributed::PsService {
handler_map_[message_name] = func; handler_map_[message_name] = func;
} }
int32_t ForceExit() {
VLOG(3) << "heter service force exit";
is_exit_ = true;
return 0;
}
void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; } void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; }
void SetFanin(const int& fan_in) { fan_in_ = fan_in; } void SetFanin(const int& fan_in) { fan_in_ = fan_in; }
bool IsExit() { return is_exit_; } bool IsExit() { return is_exit_; }
private: private:
int32_t stop_profiler(const PsRequestMessage& request, int32_t stop_profiler(const PsRequestMessage& request,
PsResponseMessage& response, brpc::Controller* cntl); PsResponseMessage& response, // NOLINT
brpc::Controller* cntl);
int32_t start_profiler(const PsRequestMessage& request, int32_t start_profiler(const PsRequestMessage& request,
PsResponseMessage& response, brpc::Controller* cntl); PsResponseMessage& response, // NOLINT
brpc::Controller* cntl);
int32_t stop_heter_worker(const PsRequestMessage& request, int32_t stop_heter_worker(const PsRequestMessage& request,
PsResponseMessage& response, PsResponseMessage& response, // NOLINT
brpc::Controller* cntl); brpc::Controller* cntl);
private: private:
...@@ -148,19 +158,152 @@ class HeterService : public ::paddle::distributed::PsService { ...@@ -148,19 +158,152 @@ class HeterService : public ::paddle::distributed::PsService {
bool is_exit_ = false; bool is_exit_ = false;
}; };
using SharedMiniScope =
std::shared_ptr<std::unordered_map<int, ::paddle::framework::Scope*>>;
using SharedMicroScope = std::shared_ptr<std::unordered_map<
int, std::shared_ptr<std::vector<::paddle::framework::Scope*>>>>;
using SharedTaskQueue = std::shared_ptr<
std::unordered_map<int, std::shared_ptr<::paddle::framework::BlockingQueue<
std::pair<std::string, int>>>>>;
class HeterRequestHandler {
public:
HeterRequestHandler() : dev_ctx_(nullptr), scope_(nullptr) {}
virtual ~HeterRequestHandler() {}
void SetScope(const framework::Scope* scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
virtual int Handle(const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) = 0;
protected:
const platform::DeviceContext* dev_ctx_;
const framework::Scope* scope_;
};
class RequestSendAndRecvHandler final : public HeterRequestHandler {
public:
RequestSendAndRecvHandler() {
this->num_microbatch_ = 0;
this->num_minibatch_ = 0;
}
virtual ~RequestSendAndRecvHandler() {}
// void SetMiniScopes(SharedMiniScope mini_scopes) {
// mini_scopes_ = mini_scopes;
// num_minibatch_ = mini_scopes_->size();
//}
void SetMicroScopes(SharedMicroScope micro_scopes) {
micro_scopes_ = micro_scopes;
num_microbatch_ = micro_scopes_->size();
}
void SetTaskQueue(SharedTaskQueue task_queue) { task_queue_ = task_queue; }
int Handle(const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) override {
platform::RecordEvent record_event("RequestSendAndRecvHandler->Handle");
FLAGS_eager_delete_tensor_gb = -1;
// get microID from request
// deserialize variable to micro scope
// Push to heter worker's task_queue
std::unique_ptr<paddle::framework::Scope> local_scope_ptr(
new paddle::framework::Scope());
auto& local_scope = *(local_scope_ptr.get());
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::CPUPlace cpu_place;
auto& cpu_dev_ctx = *pool.Get(cpu_place);
auto message_name = request->message_name();
auto& request_io_buffer = cntl->request_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(
*request, &request_io_buffer, cpu_dev_ctx, &local_scope);
auto* var = local_scope.FindVar("microbatch_id");
PADDLE_ENFORCE_NE(var, nullptr,
platform::errors::InvalidArgument(
"Not find variable microbatch_id in scope."));
auto* tensor = var->GetMutable<framework::LoDTensor>();
auto data = reinterpret_cast<const float*>(tensor->data<void>());
auto micro_id = static_cast<int>(data[0]);
int minibatch_index = micro_id / 10;
int microbatch_index = micro_id % 10;
// PADDLE_ENFORCE_EQ(
// (*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end(), 1,
// platform::errors::InvalidArgument(
// "minibatch index should in current trainer"));
PADDLE_ENFORCE_EQ(
(*micro_scopes_).find(minibatch_index) != (*micro_scopes_).end(), 1,
platform::errors::InvalidArgument(
"minibatch index should in current trainer"));
auto* micro_scope =
(*((*micro_scopes_)[minibatch_index]))[microbatch_index];
distributed::DeserializeFromMultiVarMsgAndIOBuf(
*request, &request_io_buffer, *dev_ctx_, micro_scope);
// blocking queue handles multi thread
(*task_queue_)[minibatch_index]->Push(
std::make_pair(message_name, microbatch_index));
auto response_var_nums = request->recv_var_names_size();
std::vector<std::string> response_var_names(response_var_nums),
empty_var_names{};
for (int var_idx = 0; var_idx < response_var_nums; ++var_idx) {
response_var_names[var_idx] = request->recv_var_names(var_idx);
}
auto& response_io_buffer = cntl->response_attachment();
distributed::SerializeToMultiVarMsgAndIOBuf(
message_name, response_var_names, empty_var_names, *dev_ctx_,
&local_scope, response, &response_io_buffer);
return 0;
}
private:
// share with HeterPipelineTrainer
// SharedMiniScope mini_scopes_{nullptr};
SharedMicroScope micro_scopes_{nullptr};
int num_microbatch_;
int num_minibatch_;
bool is_first_stage_ = false;
bool is_last_stage_ = false;
SharedTaskQueue task_queue_;
};
class HeterServer { class HeterServer {
public: public:
virtual ~HeterServer() {} virtual ~HeterServer() {}
void Stop() { void Stop() {
VLOG(3) << "HeterServer Stop()";
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
if (stoped_ == true) return;
if (!IsExit()) service_.ForceExit();
VLOG(3) << "HeterServer Stop()";
stoped_ = true; stoped_ = true;
cv_.notify_all(); cv_.notify_all();
server_.Stop(1000); server_.Stop(1000);
server_.Join(); server_.Join();
} }
bool IsStop() {
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_ == true)
return true;
else
return false;
}
bool IsExit() { return service_.IsExit(); } bool IsExit() { return service_.IsExit(); }
HeterServer() {} HeterServer() {}
...@@ -170,8 +313,25 @@ class HeterServer { ...@@ -170,8 +313,25 @@ class HeterServer {
void StartHeterService(); void StartHeterService();
void SetEndPoint(std::string& endpoint); void SetEndPoint(const std::string& endpoint);
void SetFanin(int& fan_in); void SetFanin(const int& fan_in);
void SetRequestHandler(
std::shared_ptr<RequestSendAndRecvHandler> request_handler) {
request_handler_ = request_handler;
}
// void SetMiniBatchScopes(SharedMiniScope mini_scopes) {
// request_handler_->SetMiniScopes(mini_scopes);
//}
void SetMicroBatchScopes(SharedMicroScope micro_scopes) {
request_handler_->SetMicroScopes(micro_scopes);
}
void SetTaskQueue(SharedTaskQueue task_queue) {
request_handler_->SetTaskQueue(task_queue);
}
// HeterWrapper singleton // HeterWrapper singleton
static std::shared_ptr<HeterServer> GetInstance() { static std::shared_ptr<HeterServer> GetInstance() {
...@@ -188,84 +348,19 @@ class HeterServer { ...@@ -188,84 +348,19 @@ class HeterServer {
mutable std::mutex mutex_; mutable std::mutex mutex_;
std::condition_variable cv_; std::condition_variable cv_;
std::condition_variable condition_ready_; std::condition_variable condition_ready_;
bool stoped_ = false; bool stoped_ = true;
std::string endpoint_; std::string endpoint_;
protected: protected:
brpc::Server server_; brpc::Server server_;
HeterService service_; HeterService service_;
std::shared_ptr<RequestSendAndRecvHandler> request_handler_;
DISABLE_COPY_AND_ASSIGN(HeterServer); DISABLE_COPY_AND_ASSIGN(HeterServer);
std::mutex mutex_ready_; std::mutex mutex_ready_;
int ready_; int ready_;
}; };
class HeterRequestHandler {
public:
HeterRequestHandler()
: dev_ctx_(nullptr),
executor_(nullptr),
scope_(nullptr),
program_(nullptr) {}
virtual ~HeterRequestHandler() {}
void SetScope(framework::Scope* scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
void SetProgram(framework::ProgramDesc* program) { program_ = program; }
void SetExecutor(framework::Executor* executor) { executor_ = executor; }
void SetGradToPreparedCtx(
std::unordered_map<
std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
message_to_prepared_ctx_ = g;
}
virtual int Handle(const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) = 0;
protected:
const platform::DeviceContext* dev_ctx_;
framework::Executor* executor_;
framework::Scope* scope_;
framework::ProgramDesc* program_;
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
message_to_prepared_ctx_;
};
class RequestSendAndRecvHandler final : public HeterRequestHandler {
public:
RequestSendAndRecvHandler() {}
virtual ~RequestSendAndRecvHandler() {}
int Handle(const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) override {
platform::RecordEvent record_event("RequestSendAndRecvHandler->Handle");
FLAGS_eager_delete_tensor_gb = -1;
auto& local_scope = scope_->NewScope();
auto message_name = request->message_name();
auto& request_io_buffer = cntl->request_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(
*request, &request_io_buffer, *dev_ctx_, &local_scope);
executor_->RunPreparedContext(
(*message_to_prepared_ctx_)[message_name].get(), &local_scope, false);
auto response_var_nums = request->recv_var_names_size();
std::vector<std::string> response_var_names(response_var_nums),
empty_var_names{};
for (int var_idx = 0; var_idx < response_var_nums; ++var_idx) {
response_var_names[var_idx] = request->recv_var_names(var_idx);
}
auto& response_io_buffer = cntl->response_attachment();
distributed::SerializeToMultiVarMsgAndIOBuf(
message_name, response_var_names, empty_var_names, *dev_ctx_,
&local_scope, response, &response_io_buffer);
scope_->DeleteScope(&local_scope);
return 0;
}
};
} // end namespace distributed } // end namespace distributed
} // end namespace paddle } // end namespace paddle
...@@ -51,7 +51,7 @@ void PSCore::init_gflag(const std::string& gflags) { ...@@ -51,7 +51,7 @@ void PSCore::init_gflag(const std::string& gflags) {
std::vector<std::string> flags = paddle::string::split_string(gflags); std::vector<std::string> flags = paddle::string::split_string(gflags);
if (flags.size() < 1) { if (flags.size() < 1) {
flags.push_back("-max_body_size=314217728"); flags.push_back("-max_body_size=314217728");
flags.push_back("-bthread_concurrency=40"); flags.push_back("-bthread_concurrency=200");
flags.push_back("-socket_max_unwritten_bytes=2048000000"); flags.push_back("-socket_max_unwritten_bytes=2048000000");
flags.push_back("-max_connection_pool_size=1950"); flags.push_back("-max_connection_pool_size=1950");
} }
......
...@@ -302,17 +302,24 @@ if(WITH_DISTRIBUTE) ...@@ -302,17 +302,24 @@ if(WITH_DISTRIBUTE)
elseif(WITH_PSCORE) elseif(WITH_PSCORE)
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc heterxpu_trainer.cc heter_pipeline_trainer.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc
downpour_worker.cc downpour_worker_opt.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry pull_dense_worker.cc section_worker.cc heter_section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper lodtensor_printer feed_fetch_method lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor heter_service_proto fleet fleet_executor) graph_to_program_pass variable_helper timer monitor heter_service_proto fleet heter_server brpc fleet_executor)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(DISTRIBUTE_COMPILE_FLAGS
"${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
endif()
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(device_worker.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(multi_trainer.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(multi_trainer.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(hogwild_worker.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(hogwild_worker.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(heter_section_worker.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(heter_pipeline_trainer.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() else()
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
...@@ -367,6 +374,8 @@ if(WITH_PSCORE) ...@@ -367,6 +374,8 @@ if(WITH_PSCORE)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS
conditional_block_op executor gloo_wrapper ${RPC_DEPS}) conditional_block_op executor gloo_wrapper ${RPC_DEPS})
cc_test(heter_pipeline_trainer_test SRCS heter_pipeline_trainer_test.cc DEPS
conditional_block_op scale_op heter_listen_and_serv_op executor heter_server gloo_wrapper ${RPC_DEPS})
else() else()
cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS
conditional_block_op executor gloo_wrapper) conditional_block_op executor gloo_wrapper)
......
...@@ -606,5 +606,88 @@ class SectionWorker : public DeviceWorker { ...@@ -606,5 +606,88 @@ class SectionWorker : public DeviceWorker {
}; };
#endif #endif
#if defined(PADDLE_WITH_PSCORE)
class HeterSectionWorker : public DeviceWorker {
public:
HeterSectionWorker() {}
~HeterSectionWorker() override {}
void Initialize(const TrainerDesc& desc) override;
void CreateDeviceResource(const ProgramDesc& main_prog) override{};
void TrainFiles() override;
void TrainFilesWithProfiler() override;
void BindingDataFeedMemory() override {}
void BindingDataFeedMemory(int micro_id);
void PrintFetchVars() override;
const platform::Place& place() const { return place_; }
void SetDeviceIndex(int tid) override { thread_id_ = tid; }
void SetThreadNum(int thread_num) { thread_num_ = thread_num; }
void SetMicrobatchNum(int num) { num_microbatches_ = num; }
void SetPipelineStageNum(int num) { num_pipeline_stages_ = num; }
void SetPipelineStage(int stage) { pipeline_stage_ = stage; }
std::shared_ptr<std::vector<Scope*>> GetMicrobatchScopes() {
return microbatch_scopes_;
}
using SHARED_THREAD_QUEUE = std::shared_ptr<
::paddle::framework::BlockingQueue<std::pair<std::string, int>>>;
SHARED_THREAD_QUEUE GetThreadQueue() { return thread_queue_; }
void CopyParameters(int microbatch_id, const ProgramDesc& program,
const platform::Place& place);
void SetMinibatchScope(Scope* scope) { minibatch_scope_ = scope; }
void SetTrainerId(int trainer_id) { this->trainer_id_ = trainer_id; }
void SetTrainers(int trainers) { this->trainers_ = trainers; }
void CreateMicrobatchScopes();
void RunForward(int micro_id);
void RunBackward(int micro_id);
void RunListen();
void MiniBatchBarrier();
void Run();
void BatchPostProcess();
void SetDebug(bool debug) { debug_ = debug; }
Scope* GetThreadScope() override { return minibatch_scope_; }
// multi-stream
// #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// void SetStream(const gpuStream_t stream) override {}
// void SetEvent(const gpuEvent_t event) override {}
// #endif
protected:
int trainer_id_;
int trainers_;
int thread_num_;
int thread_id_;
int num_microbatches_;
int num_pipeline_stages_;
int pipeline_stage_;
bool epoch_finish_;
std::shared_ptr<std::vector<Scope*>> microbatch_scopes_;
Scope* minibatch_scope_;
std::vector<int> micro_ids_{};
std::unique_ptr<OperatorBase> listen_op_{nullptr};
std::vector<std::unique_ptr<OperatorBase>> forward_ops_;
std::vector<std::unique_ptr<OperatorBase>> backward_ops_;
std::shared_ptr<framework::ProgramDesc> program_;
std::shared_ptr<
::paddle::framework::BlockingQueue<std::pair<std::string, int>>>
thread_queue_;
static uint64_t batch_id_;
uint64_t total_ins_num_ = 0;
platform::DeviceContext* dev_ctx_ = nullptr;
bool debug_ = false;
std::vector<double> op_total_time_;
std::vector<std::string> op_name_;
platform::Timer timeline_;
double total_time_ = 0.0;
double read_time_ = 0.0;
};
#endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -65,6 +65,11 @@ std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker( ...@@ -65,6 +65,11 @@ std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
REGISTER_DEVICE_WORKER_CLASS(HogwildWorker); REGISTER_DEVICE_WORKER_CLASS(HogwildWorker);
REGISTER_DEVICE_WORKER_CLASS(DownpourWorker); REGISTER_DEVICE_WORKER_CLASS(DownpourWorker);
REGISTER_DEVICE_WORKER_CLASS(DownpourWorkerOpt); REGISTER_DEVICE_WORKER_CLASS(DownpourWorkerOpt);
#if defined(PADDLE_WITH_PSCORE)
REGISTER_DEVICE_WORKER_CLASS(HeterSectionWorker);
#endif
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
REGISTER_DEVICE_WORKER_CLASS(HeterCpuWorker); REGISTER_DEVICE_WORKER_CLASS(HeterCpuWorker);
#endif #endif
......
...@@ -129,7 +129,7 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, ...@@ -129,7 +129,7 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
std::shared_ptr<TrainerBase> Executor::InitForDataset( std::shared_ptr<TrainerBase> Executor::InitForDataset(
const ProgramDesc& main_program, const std::string& trainer_desc_str, const ProgramDesc& main_program, const std::string& trainer_desc_str,
Scope* scope, Dataset* dataset) { Scope* scope, Dataset* dataset) {
VLOG(3) << "Start to RunFromDataset in executor"; VLOG(3) << "Start to InitForDataset in executor";
TrainerDesc trainer_desc; TrainerDesc trainer_desc;
bool success = trainer_desc.ParseFromString(trainer_desc_str); bool success = trainer_desc.ParseFromString(trainer_desc_str);
PADDLE_ENFORCE_EQ(success, true, PADDLE_ENFORCE_EQ(success, true,
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_PSCORE)
#include "paddle/fluid/distributed/service/heter_server.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
namespace paddle {
namespace framework {
class Variable;
using MiniScope = std::unordered_map<int, Scope*>;
using MicroScope =
std::unordered_map<int, std::shared_ptr<std::vector<Scope*>>>;
using TaskQueue =
std::unordered_map<int, std::shared_ptr<::paddle::framework::BlockingQueue<
std::pair<std::string, int>>>>;
void HeterPipelineTrainer::ResetDataset(Dataset* dataset) {
if (pipeline_stage_ == 0) {
SetDataset(dataset);
const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders();
VLOG(3) << "readers num: " << readers.size();
// change thread num is not supported
PADDLE_ENFORCE_EQ(thread_num_, readers.size(),
platform::errors::InvalidArgument(
"change Dataset thread_num is not supported"));
int cnt = -1;
for (auto& worker_pair : workers_) {
cnt++;
auto device_worker = worker_pair.second;
auto this_worker =
std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
device_worker);
this_worker->SetDataFeed(readers[cnt]);
this_worker->SetReaderPlace(place_);
}
}
}
void HeterPipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
thread_num_ = trainer_desc.thread_num();
ParseDumpConfig(trainer_desc);
SetDebug(trainer_desc.debug());
// for (int i = 0; i < trainer_desc.downpour_param().stat_var_names_size();
// i++) {
// need_merge_var_names_.push_back(
// trainer_desc.downpour_param().stat_var_names(i));
//}
// get filelist from trainer_desc here
const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders();
VLOG(3) << "readers num: " << readers.size();
// change thread num to readers num
thread_num_ = readers.size();
VLOG(3) << "worker thread num: " << thread_num_;
const auto& heter_section_params = trainer_desc.heter_section_param();
num_pipeline_stages_ = heter_section_params.num_pipeline_stages();
pipeline_stage_ = heter_section_params.pipeline_stage();
num_microbatches_ = heter_section_params.num_microbatches();
VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_;
trainer_desc_ = trainer_desc;
trainer_id_ = trainer_desc.trainer_id();
for (int i = 0; i < num_pipeline_stages_; ++i) {
auto trainer_num = trainer_desc.trainers(i);
trainers_.push_back(trainer_num);
}
int cpu_trainer_num = trainers_[0];
int cur_stage_trainer_num = trainers_[pipeline_stage_];
int global_thread_num = cpu_trainer_num * thread_num_;
int previous_trainers = 0;
for (int i = 0; i < pipeline_stage_; i++) previous_trainers += trainers_[i];
int stage_trainer_id =
trainer_id_ - previous_trainers; // trainer id in current stage
int cnt = -1;
for (int i = stage_trainer_id; i < global_thread_num;
i += cur_stage_trainer_num) {
cnt++;
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
auto this_worker =
std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
workers_[i]);
this_worker->SetDebug(debug_);
this_worker->SetNeedDumpField(need_dump_field_);
this_worker->SetNeedDumpParam(need_dump_param_);
this_worker->SetDumpFieldVector(dump_fields_);
this_worker->SetDumpParamVector(dump_param_);
this_worker->InitRandomDumpConfig(trainer_desc);
this_worker->SetDeviceIndex(i);
if (pipeline_stage_ == 0) {
this_worker->SetDataFeed(readers[cnt]);
}
this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
}
}
void HeterPipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) {
if (need_dump_field_) {
InitDumpEnv();
}
}
std::string HeterPipelineTrainer::GetDumpPath(int tid) {
return string::format_string("%s/part-%05d", dump_fields_path_.c_str(), tid);
}
void HeterPipelineTrainer::InitDumpEnv() {
queue_ = paddle::framework::MakeChannel<std::string>();
for (int i = 0; i < thread_num_; ++i) {
workers_[i]->SetChannelWriter(queue_.get());
}
dump_thread_num_ = 1;
for (int i = 0; i < dump_thread_num_; i++) {
dump_thread_.push_back(
std::thread(std::bind(&TrainerBase::DumpWork, this, i)));
}
}
void HeterPipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) {
place_ = place;
PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument(
"root_scope_ can not be nullptr"));
// initialize mini_scopes & micro_scopes
mini_scopes_.reset(new MiniScope{});
micro_scopes_.reset(new MicroScope{});
task_queue_.reset(new TaskQueue{});
for (auto& worker_pair : workers_) {
auto worker_index = worker_pair.first;
auto device_worker = worker_pair.second;
auto this_worker =
std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
device_worker);
this_worker->SetPlace(place);
this_worker->Initialize(trainer_desc_);
if (pipeline_stage_ == 0) {
this_worker->SetReaderPlace(place);
}
this_worker->SetRootScope(root_scope_);
// generate mini_batch scope for every worker
auto* minibatch_scope = &root_scope_->NewScope();
(*mini_scopes_)[worker_index] = minibatch_scope;
this_worker->SetMinibatchScope(minibatch_scope);
// after set micro num & mini batch scope
this_worker->CreateMicrobatchScopes();
(*micro_scopes_)[worker_index] = this_worker->GetMicrobatchScopes();
(*task_queue_)[worker_index] = this_worker->GetThreadQueue();
}
}
void HeterPipelineTrainer::Run() {
VLOG(3) << "Going to run HeterPipelineTrainer::Run()";
if (listen_ptr_ == nullptr) {
for (auto& worker_pair : workers_) {
auto& device_worker = worker_pair.second;
auto worker_0 =
std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
device_worker);
listen_ptr_.reset(new std::thread(
std::bind(&HeterSectionWorker::RunListen, worker_0.get())));
break;
}
}
auto heter_server = paddle::distributed::HeterServer::GetInstance();
heter_server->WaitServerReady();
// heter_server->SetMiniBatchScopes(mini_scopes_);
heter_server->SetMicroBatchScopes(micro_scopes_);
heter_server->SetTaskQueue(task_queue_);
// main training logic
if (pipeline_stage_ == 0) { // for cpu trainer
for (auto& worker_pair : workers_) {
auto device_worker = worker_pair.second;
if (!debug_) {
threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, device_worker.get()));
} else {
threads_.push_back(std::thread(&DeviceWorker::TrainFilesWithProfiler,
device_worker.get()));
}
}
} else { // for heter worker
for (auto& worker_pair : workers_) {
auto device_worker = worker_pair.second;
if (!debug_) {
threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, device_worker.get()));
} else {
threads_.push_back(std::thread(&DeviceWorker::TrainFilesWithProfiler,
device_worker.get()));
}
}
}
for (auto& th : threads_) {
th.join();
}
if (threads_.size() > 0) {
threads_.clear();
}
VLOG(3) << "Epoch Trainging done";
}
void HeterPipelineTrainer::Finalize() {
VLOG(3) << "HeterPipelineTrainer Finalize";
auto heter_server = paddle::distributed::HeterServer::GetInstance();
heter_server->Stop();
if (listen_ptr_) {
(listen_ptr_.get())->join();
listen_ptr_.reset(nullptr);
}
if (need_dump_field_) {
FinalizeDumpEnv();
}
root_scope_->DropKids();
}
Scope* HeterPipelineTrainer::GetWorkerScope(int thread_id) {
return workers_[thread_id]->GetThreadScope();
}
} // end namespace framework
} // end namespace paddle
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_PSCORE)
#include "gtest/gtest.h"
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_factory.h"
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
USE_OP(scale);
USE_NO_KERNEL_OP(heter_listen_and_serv);
namespace paddle {
namespace framework {
framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0);
auto* block = program->AppendBlock(*root_block);
auto* block2 = program->AppendBlock(*root_block);
framework::OpDesc* op = block->AppendOp();
op->SetType("scale");
op->SetInput("X", {"x"});
op->SetOutput("Out", {"res"});
op->SetAttr("scale", 0.5f);
framework::OpDesc* op2 = block2->AppendOp();
op2->SetType("scale");
op2->SetInput("X", {"x"});
op2->SetOutput("Out", {"res"});
op2->SetAttr("scale", 0.5f);
auto& out = *root_block->Var("res");
out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetShape({1, 10});
auto& persistable_var = *root_block->Var("p_var");
persistable_var.SetType(framework::proto::VarType::LOD_TENSOR);
persistable_var.SetShape({1, 10});
persistable_var.SetPersistable(true);
return block;
}
void GetHeterListenAndServProgram(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0);
auto* sub_block = AppendSendAndRecvBlock(program);
std::vector<framework::BlockDesc*> optimize_blocks;
optimize_blocks.push_back(sub_block);
std::vector<std::string> message_to_block_id = {"x:1"};
std::string endpoint = "127.0.0.1:19944";
framework::OpDesc* op = root_block->AppendOp();
op->SetType("heter_listen_and_serv");
op->SetInput("X", {});
op->SetAttr("message_to_block_id", message_to_block_id);
op->SetAttr("optimize_blocks", optimize_blocks);
op->SetAttr("endpoint", endpoint);
op->SetAttr("fanin", 1);
op->SetAttr("pserver_id", 0);
}
TEST(HeterPipelineTrainerTest, GPU) {
#ifdef _LINUX
TrainerDesc t, t2, t3;
// t2
t.set_class_name("HeterPipelineTrainer");
t.set_device_worker_name("HeterSectionWorker");
t.set_thread_num(1);
t.set_trainer_id(0);
t.add_trainers(1);
t.add_trainers(1);
t.add_trainers(1);
auto* heter_section_param = t.mutable_heter_section_param();
heter_section_param->set_num_pipeline_stages(3);
heter_section_param->set_pipeline_stage(0);
heter_section_param->set_num_microbatches(1);
// t2
t2.set_class_name("HeterPipelineTrainer");
t2.set_device_worker_name("HeterSectionWorker");
t2.set_thread_num(1);
t2.set_trainer_id(1);
t2.add_trainers(1);
t2.add_trainers(1);
t2.add_trainers(1);
auto* heter_section_param2 = t2.mutable_heter_section_param();
heter_section_param2->set_num_pipeline_stages(3);
heter_section_param2->set_pipeline_stage(1);
heter_section_param2->set_num_microbatches(1);
// t3
t3.set_class_name("HeterPipelineTrainer");
t3.set_device_worker_name("HeterSectionWorker");
t3.set_thread_num(1);
t3.set_trainer_id(1);
t3.add_trainers(1);
t3.add_trainers(1);
t3.add_trainers(1);
t3.add_dump_fields("hello");
t3.add_dump_param("fc_0");
auto* heter_section_param3 = t3.mutable_heter_section_param();
heter_section_param3->set_num_pipeline_stages(3);
heter_section_param3->set_pipeline_stage(2);
heter_section_param3->set_num_microbatches(1);
std::string str;
str += "name: \"MultiSlotDataFeed\"\nbatch_size: 2\nmulti_slot_desc {\n";
str += "slots {\nname: \"words\"\ntype: \"uint64\"\nis_dense: false\n";
str += "is_used: true\n}\nslots {\nname: \"label\"\ntype: \"uint64\"\n";
str += "is_dense: false\nis_used: true\n}\n}\n";
std::shared_ptr<MultiSlotDataset> dataset =
std::make_shared<MultiSlotDataset>();
dataset->SetFileList(std::vector<std::string>{"a1.txt", "a2.txt"});
dataset->SetThreadNum(1);
dataset->SetTrainerNum(1);
dataset->SetDataFeedDesc(str);
dataset->CreateReaders();
ProgramDesc p;
// construct program
// AppendSendAndRecvBlock(&p);
GetHeterListenAndServProgram(&p);
auto* section_config = heter_section_param->mutable_section_config();
proto::ProgramDesc* pd = new proto::ProgramDesc(*(p.Proto()));
section_config->set_allocated_program_desc(pd);
ProgramDesc p2;
// construct program
// AppendSendAndRecvBlock(&p2);
GetHeterListenAndServProgram(&p2);
auto* section_config2 = heter_section_param2->mutable_section_config();
proto::ProgramDesc* pd2 = new proto::ProgramDesc(*(p2.Proto()));
section_config2->set_allocated_program_desc(pd2);
ProgramDesc p3;
// construct program
// AppendSendAndRecvBlock(&p3);
GetHeterListenAndServProgram(&p3);
auto* section_config3 = heter_section_param3->mutable_section_config();
proto::ProgramDesc* pd3 = new proto::ProgramDesc(*(p3.Proto()));
section_config3->set_allocated_program_desc(pd3);
Scope root_scope, root_scope2, root_scope3;
paddle::platform::CPUPlace place;
paddle::platform::CUDAPlace place2;
// tmp1
std::shared_ptr<TrainerBase> tmp1;
tmp1 = TrainerFactory::CreateTrainer(t.class_name());
tmp1->SetScope(&root_scope);
tmp1->Initialize(t, dataset.get());
tmp1->InitTrainerEnv(p, place);
tmp1->InitOtherEnv(p);
tmp1->GetWorkerScope(0);
tmp1->ResetDataset(dataset.get());
tmp1->Finalize();
// tmp2
std::shared_ptr<TrainerBase> tmp2;
tmp2 = TrainerFactory::CreateTrainer(t2.class_name());
tmp2->SetScope(&root_scope2);
tmp2->Initialize(t2, dataset.get());
tmp2->InitTrainerEnv(p2, place2);
tmp2->InitOtherEnv(p2);
tmp2->GetWorkerScope(0);
tmp2->ResetDataset(dataset.get());
tmp2->Finalize();
// tmp3
std::shared_ptr<TrainerBase> tmp3;
tmp3 = TrainerFactory::CreateTrainer(t3.class_name());
tmp3->SetScope(&root_scope3);
tmp3->Initialize(t3, dataset.get());
tmp3->InitTrainerEnv(p3, place);
tmp3->InitOtherEnv(p3);
// tmp3->GetDumpPath(0);
// tmp3->InitDumpEnv();
// tmp3->FinalizeDumpEnv();
tmp3->GetWorkerScope(0);
tmp3->ResetDataset(dataset.get());
tmp3->Finalize();
// tmp4 for coverage
std::shared_ptr<TrainerBase> tmp4;
tmp4 = TrainerFactory::CreateTrainer("MultiTrainer");
tmp4->ResetDataset(dataset.get());
// heter_section_worker test
std::shared_ptr<DeviceWorker> w_0;
w_0 = DeviceWorkerFactory::CreateDeviceWorker("HeterSectionWorker");
w_0->CreateDeviceResource(p3);
w_0->BindingDataFeedMemory();
#endif
}
} // namespace framework
} // namespace paddle
#endif
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_PSCORE)
#include <float.h>
#include "paddle/fluid/distributed/service/heter_server.h"
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/lodtensor_printer.h"
namespace paddle {
namespace framework {
void SetMicroId(paddle::framework::Scope* scope,
platform::DeviceContext* dev_ctx, const platform::Place& place,
int micro_id) {
// create microbatch_id variable
// and set micro id value
auto* ptr = scope->Var("microbatch_id");
InitializeVariable(ptr, proto::VarType::LOD_TENSOR);
framework::Variable* var = scope->FindVar("microbatch_id");
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), 1,
platform::errors::InvalidArgument(
"the type of microbatch_id should be LoDTensor"));
auto* tensor = var->GetMutable<framework::LoDTensor>();
std::vector<int> dims{1};
tensor->Resize(framework::make_ddim(dims));
void* tensor_data =
tensor->mutable_data(place, framework::proto::VarType::FP32);
if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
std::vector<char> temp;
temp.resize(tensor->numel() * framework::SizeOfType(tensor->type()));
char* temp_ptr = temp.data();
float* temp_ptr_float = reinterpret_cast<float*>(temp_ptr);
temp_ptr_float[0] = micro_id;
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(*dev_ctx).stream();
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place), tensor_data,
platform::CPUPlace(), reinterpret_cast<void*>(temp_ptr),
tensor->numel() * framework::SizeOfType(tensor->type()),
stream);
#endif
} else {
float* temp_ptr = reinterpret_cast<float*>(tensor_data);
temp_ptr[0] = micro_id;
}
}
class TrainerDesc;
uint64_t HeterSectionWorker::batch_id_(0);
void HeterSectionWorker::Initialize(const TrainerDesc& desc) {
trainer_desc_ = desc;
fetch_config_ = desc.fetch_config();
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
program_.reset(new ProgramDesc(
desc.heter_section_param().section_config().program_desc()));
thread_queue_.reset(
new ::paddle::framework::BlockingQueue<std::pair<std::string, int>>());
bool is_first_stage = (pipeline_stage_ == 0);
bool is_last_stage = (pipeline_stage_ + 1 == num_pipeline_stages_);
if (is_first_stage) {
for (auto& op_desc : program_->Block(0).AllOps()) {
auto op = std::move(OpRegistry::CreateOp(*op_desc));
auto op_type = op->Type();
if (listen_op_ == nullptr && op_type == "heter_listen_and_serv") {
listen_op_ = std::move(op);
} else {
forward_ops_.push_back(std::move(op));
}
}
for (auto& op_desc : program_->Block(1).AllOps()) {
backward_ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
} else if (is_last_stage) {
for (auto& op_desc : program_->Block(0).AllOps()) {
if (listen_op_ == nullptr) {
listen_op_ = std::move(OpRegistry::CreateOp(*op_desc));
}
}
for (auto& op_desc : program_->Block(1).AllOps()) {
auto op = std::move(OpRegistry::CreateOp(*op_desc));
int op_role = op->Attr<int>(std::string("op_role"));
bool is_forward_op = (op_role == static_cast<int>(OpRole::kForward)) ||
(op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss))) ||
(op_role == static_cast<int>(OpRole::kLRSched));
if (is_forward_op) {
forward_ops_.push_back(std::move(op));
} else {
backward_ops_.push_back(std::move(op));
}
}
} else {
for (auto& op_desc : program_->Block(0).AllOps()) {
if (listen_op_ == nullptr) {
listen_op_ = std::move(OpRegistry::CreateOp(*op_desc));
}
}
for (auto& op_desc : program_->Block(1).AllOps()) {
forward_ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
for (auto& op_desc : program_->Block(2).AllOps()) {
backward_ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
}
}
void HeterSectionWorker::RunBackward(int micro_id) {
for (size_t i = 0; i < backward_ops_.size(); i++) {
auto& op = backward_ops_[i];
VLOG(3) << "Backward: start to run op " << op->Type() << " for micro-batch "
<< micro_id;
if (debug_) {
timeline_.Start();
}
op->Run(*((*microbatch_scopes_)[micro_id]), place_);
dev_ctx_->Wait();
if (debug_) {
timeline_.Pause();
int offset = forward_ops_.size();
op_total_time_[i + offset] += timeline_.ElapsedSec();
total_time_ += timeline_.ElapsedSec();
}
VLOG(3) << "Backward: finish running op " << op->Type()
<< " for micro-batch " << micro_id;
}
}
void HeterSectionWorker::MiniBatchBarrier() {
// get micro id & deserialize data
std::set<int> micro_ids;
while (micro_ids.size() < micro_ids_.size()) {
auto task = (*thread_queue_).Pop();
auto message_name = task.first;
auto micro_id = task.second;
PADDLE_ENFORCE_EQ(message_name.find("backward") != std::string::npos, true,
platform::errors::InvalidArgument(
"cpu trainers only receive backward data"));
PADDLE_ENFORCE_EQ(
micro_ids.find(micro_id) == micro_ids.end(), true,
platform::errors::InvalidArgument("minibatch_scope_ can not be nullptr "
"when create MicroBatch Scope"));
micro_ids.insert(micro_id);
// backward data has been deserialized to micro scope
// now run backward computation
RunBackward(micro_id);
batch_num_++;
BatchPostProcess();
}
micro_ids_.clear();
}
void HeterSectionWorker::RunListen() { listen_op_->Run(*root_scope_, place_); }
void HeterSectionWorker::RunForward(int micro_id) {
if (pipeline_stage_ == 0) {
BindingDataFeedMemory(micro_id);
if (debug_) {
timeline_.Start();
}
int cur_micro_batch = device_reader_->Next();
if (cur_micro_batch <= 0) {
epoch_finish_ = true;
return;
}
if (debug_) {
timeline_.Pause();
read_time_ += timeline_.ElapsedSec();
total_time_ += timeline_.ElapsedSec();
total_ins_num_ += cur_micro_batch;
}
VLOG(3) << "read a batch in thread " << thread_id_ << " micro " << micro_id;
}
for (size_t i = 0; i < forward_ops_.size(); i++) {
auto& op = forward_ops_[i];
VLOG(3) << "Forward: start to run op " << op->Type() << " for micro-batch "
<< micro_id;
if (debug_) {
timeline_.Start();
}
op->Run(*((*microbatch_scopes_)[micro_id]), place_);
dev_ctx_->Wait();
if (debug_) {
timeline_.Pause();
op_total_time_[i] += timeline_.ElapsedSec();
total_time_ += timeline_.ElapsedSec();
}
VLOG(3) << "Forward: finish running op " << op->Type()
<< " for micro-batch " << micro_id;
}
}
void HeterSectionWorker::BindingDataFeedMemory(int micro_id) {
const std::vector<std::string>& input_feed =
device_reader_->GetUseSlotAlias();
for (auto name : input_feed) {
device_reader_->AddFeedVar((*microbatch_scopes_)[micro_id]->FindVar(name),
name);
}
}
void HeterSectionWorker::CreateMicrobatchScopes() {
PADDLE_ENFORCE_NOT_NULL(
minibatch_scope_,
platform::errors::InvalidArgument(
"minibatch_scope_ can not be nullptr when create MicroBatch Scopes"));
microbatch_scopes_.reset(new std::vector<paddle::framework::Scope*>{});
(*microbatch_scopes_).resize(num_microbatches_);
VLOG(3) << "Create microbatch scopes...";
std::shared_ptr<framework::ProgramDesc> program;
program.reset(new ProgramDesc(
trainer_desc_.heter_section_param().section_config().program_desc()));
for (int j = 0; j < num_microbatches_; ++j) {
(*microbatch_scopes_)[j] = &minibatch_scope_->NewScope();
CopyParameters(j, *program, place_);
}
}
void HeterSectionWorker::CopyParameters(int microbatch_id,
const ProgramDesc& program,
const platform::Place& place) {
auto& global_block = program.Block(0);
auto var_list = global_block.AllVars();
if (program.Size() > 1) {
auto& heter_block = program.Block(1);
auto heter_var_list = heter_block.AllVars();
var_list.insert(var_list.end(), heter_var_list.begin(),
heter_var_list.end());
}
if (program.Size() > 2) {
auto& heter_block = program.Block(2);
auto heter_var_list = heter_block.AllVars();
var_list.insert(var_list.end(), heter_var_list.begin(),
heter_var_list.end());
}
auto global_micro_id = thread_id_ * 10 + microbatch_id;
SetMicroId((*microbatch_scopes_)[microbatch_id], dev_ctx_, place,
global_micro_id);
for (auto& var : var_list) {
if (var->Persistable() && microbatch_id == 0) {
if (root_scope_->FindVar(var->Name()) != nullptr) continue;
auto* ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(5) << "Create persistable var: " << var->Name()
<< ", which pointer is " << ptr;
} else if (!var->Persistable()) {
auto* ptr = (*microbatch_scopes_)[microbatch_id]->Var(var->Name());
VLOG(5) << "Create variable " << var->Name() << " for microbatch "
<< microbatch_id << ", which pointer is " << ptr;
InitializeVariable(ptr, var->GetType());
}
}
}
void HeterSectionWorker::Run() {
if (debug_) {
size_t total_ops_size = forward_ops_.size() + backward_ops_.size();
op_name_.resize(total_ops_size);
op_total_time_.resize(total_ops_size);
platform::SetNumThreads(1);
// forward op + backward op
for (auto& op : forward_ops_) {
op_name_.push_back(op->Type());
}
for (auto& op : backward_ops_) {
op_name_.push_back(op->Type());
}
for (size_t i = 0; i < op_total_time_.size(); ++i) {
op_total_time_[i] = 0.0;
}
}
bool is_first_stage = (pipeline_stage_ == 0);
bool is_last_stage = (pipeline_stage_ + 1 == num_pipeline_stages_);
if (is_first_stage) { // for cpu trainer
while (!epoch_finish_) {
// forward
for (int i = 0; i < num_microbatches_; i++) {
VLOG(5) << "Run " << i << " microbatch";
RunForward(i);
if (epoch_finish_ == true) {
break;
}
micro_ids_.push_back(i);
}
// backward
if (micro_ids_.size() > 0) {
MiniBatchBarrier();
}
}
} else { // for heter worker
auto heter_server = paddle::distributed::HeterServer::GetInstance();
while (true) {
if (heter_server->IsStop()) {
epoch_finish_ = true;
break;
}
auto task = (*thread_queue_).Pop();
auto message_name = task.first;
auto micro_id = task.second;
if (is_last_stage) {
PADDLE_ENFORCE_EQ(message_name.find("forward") != std::string::npos, 1,
platform::errors::InvalidArgument(
"last stage only receive forward data"));
RunForward(micro_id);
RunBackward(micro_id);
batch_num_++;
BatchPostProcess();
} else {
if (message_name.find("forward") != std::string::npos) {
RunForward(micro_id);
} else if (message_name.find("backward") != std::string::npos) {
RunBackward(micro_id);
batch_num_++;
BatchPostProcess();
}
}
}
}
}
void HeterSectionWorker::BatchPostProcess() {
PrintFetchVars();
// dump param & field
if (need_dump_field_) {
DumpField(*((*microbatch_scopes_)[0]), dump_mode_, dump_interval_);
}
if (need_dump_param_ && thread_id_ == 0) {
DumpParam(*((*microbatch_scopes_)[0]), batch_num_);
}
// print each op time
if (thread_id_ == 0) {
size_t total_ops_size = forward_ops_.size() + backward_ops_.size();
if (batch_num_ > 0 && batch_num_ % 100 == 0) {
for (size_t i = 0; i < total_ops_size; ++i) {
fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
op_name_[i].c_str(), op_total_time_[i] / batch_num_);
}
if (pipeline_stage_ == 0) {
fprintf(stderr, "mean read time: %fs\n", read_time_ / batch_num_);
fprintf(stderr, "IO percent: %f\n", read_time_ / total_time_ * 100);
}
fprintf(stderr, "%6.2f instances/s\n", total_ins_num_ / total_time_);
}
}
}
void HeterSectionWorker::TrainFiles() {
total_ins_num_ = 0;
batch_num_ = 0;
platform::SetNumThreads(1);
timeline_.Start();
VLOG(3) << "begin section_worker TrainFiles";
epoch_finish_ = false;
if (pipeline_stage_ == 0) {
device_reader_->Start();
}
while (!epoch_finish_) {
Run();
dev_ctx_->Wait();
}
timeline_.Pause();
VLOG(3) << "worker " << thread_id_ << " train cost " << timeline_.ElapsedSec()
<< " seconds, ins_num: " << total_ins_num_;
}
void HeterSectionWorker::PrintFetchVars() {
// call count
int batch_per_print = fetch_config_.print_period();
int fetch_var_num = fetch_config_.fetch_var_names_size();
if (fetch_var_num == 0) {
return;
}
if (thread_id_ == 0 && batch_num_ % batch_per_print == 0) {
time_t curtime;
time(&curtime);
char mbstr[80];
std::strftime(mbstr, sizeof(mbstr), "%Y-%m-%d %H:%M:%S",
std::localtime(&curtime));
std::stringstream ss;
ss << "time: [" << mbstr << "], ";
ss << "batch: [" << batch_num_ << "], ";
for (int i = 0; i < fetch_var_num; ++i) {
platform::PrintVar((*microbatch_scopes_)[0],
fetch_config_.fetch_var_names(i),
fetch_config_.fetch_var_str_format(i), &ss);
if (i < fetch_var_num - 1) {
ss << ", ";
}
}
std::cout << ss.str() << std::endl;
}
}
void HeterSectionWorker::TrainFilesWithProfiler() {
VLOG(3) << "begin section_worker TrainFilesWithProfiler";
batch_num_ = 0;
epoch_finish_ = false;
total_ins_num_ = 0;
op_name_.clear();
op_total_time_.clear();
if (pipeline_stage_ == 0) {
device_reader_->Start();
}
while (!epoch_finish_) {
Run();
dev_ctx_->Wait();
if (epoch_finish_) {
// dump param for debug
if (need_dump_field_ || need_dump_param_) {
writer_.Flush();
}
}
}
}
} // namespace framework
} // namespace paddle
#endif
...@@ -27,7 +27,6 @@ limitations under the License. */ ...@@ -27,7 +27,6 @@ limitations under the License. */
#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/fleet/heter_context.h" #include "paddle/fluid/framework/fleet/heter_context.h"
//#include "paddle/fluid/framework/fleet/heter_wrapper.h"
#include "paddle/fluid/framework/heter_util.h" #include "paddle/fluid/framework/heter_util.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -72,6 +71,7 @@ class TrainerBase { ...@@ -72,6 +71,7 @@ class TrainerBase {
virtual Scope* GetWorkerScope(int thread_id) = 0; virtual Scope* GetWorkerScope(int thread_id) = 0;
virtual void InitDumpEnv() = 0; virtual void InitDumpEnv() = 0;
virtual void DumpWork(int tid); virtual void DumpWork(int tid);
virtual void ResetDataset(Dataset* dataset_ptr) {}
protected: protected:
virtual std::string GetDumpPath(int tid) = 0; virtual std::string GetDumpPath(int tid) = 0;
...@@ -263,7 +263,7 @@ class PSGPUTrainer : public TrainerBase { ...@@ -263,7 +263,7 @@ class PSGPUTrainer : public TrainerBase {
new (&program_) ProgramDesc(main_program); new (&program_) ProgramDesc(main_program);
} }
virtual std::string GetDumpPath(int tid); virtual std::string GetDumpPath(int tid);
virtual void InitDumpEnv() override; void InitDumpEnv() override;
virtual void MergeDenseParam(); virtual void MergeDenseParam();
template <typename T> template <typename T>
...@@ -325,5 +325,56 @@ class PipelineTrainer : public TrainerBase { ...@@ -325,5 +325,56 @@ class PipelineTrainer : public TrainerBase {
}; };
#endif #endif
#if defined(PADDLE_WITH_PSCORE)
class HeterPipelineTrainer : public TrainerBase {
public:
HeterPipelineTrainer() {}
~HeterPipelineTrainer() override {}
void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set) override;
void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) override;
void InitOtherEnv(const ProgramDesc& main_program) override;
void Run() override;
void Finalize() override;
Scope* GetWorkerScope(int thread_id) override;
void InitDumpEnv() override;
std::string GetDumpPath(int tid) override;
void ResetDataset(Dataset* dataset_ptr) override;
protected:
int trainer_id_; // stage_trainer_id
std::vector<int> trainers_; // std::vector<int> trainers
int thread_num_;
std::vector<std::thread> threads_;
std::vector<std::string> need_merge_var_names_;
#ifdef PADDLE_WITH_HETERPS
std::vector<platform::Place> places_;
#endif
int num_microbatches_;
platform::Place place_;
TrainerDesc trainer_desc_;
int num_pipeline_stages_;
int pipeline_stage_;
std::unordered_map<int, std::shared_ptr<paddle::framework::DeviceWorker>>
workers_;
std::shared_ptr<std::unordered_map<
int, std::shared_ptr<::paddle::framework::BlockingQueue<
std::pair<std::string, int>>>>>
task_queue_;
platform::DeviceContext* dev_ctx_ = nullptr;
std::shared_ptr<std::unordered_map<int, Scope*>> mini_scopes_;
std::shared_ptr<std::unordered_map<int, std::shared_ptr<std::vector<Scope*>>>>
micro_scopes_;
std::unique_ptr<std::thread> listen_ptr_ = nullptr;
};
#endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -63,11 +63,15 @@ message TrainerDesc { ...@@ -63,11 +63,15 @@ message TrainerDesc {
optional string user_define_dump_filename = 33; optional string user_define_dump_filename = 33;
optional bool scale_sparse_gradient_with_batch_size = 34 [ default = true ]; optional bool scale_sparse_gradient_with_batch_size = 34 [ default = true ];
repeated int32 trainers = 35;
optional int32 trainer_id = 36;
// device worker parameters // device worker parameters
optional HogwildWorkerParameter hogwild_param = 101; optional HogwildWorkerParameter hogwild_param = 101;
optional DownpourWorkerParameter downpour_param = 103; optional DownpourWorkerParameter downpour_param = 103;
optional PullDenseWorkerParameter pull_dense_param = 102; optional PullDenseWorkerParameter pull_dense_param = 102;
optional SectionWorkerParameter section_param = 104; optional SectionWorkerParameter section_param = 104;
optional HeterSectionWorkerParameter heter_section_param = 105;
// datafeed desc // datafeed desc
optional DataFeedDesc data_desc = 201; optional DataFeedDesc data_desc = 201;
} }
...@@ -99,6 +103,17 @@ message SectionWorkerParameter { ...@@ -99,6 +103,17 @@ message SectionWorkerParameter {
optional int32 schedule_mode = 9 [ default = 0 ]; optional int32 schedule_mode = 9 [ default = 0 ];
} }
message HeterSectionWorkerParameter {
optional SectionConfig section_config = 1;
optional int32 queue_size = 2 [ default = 1 ];
optional int64 sync_steps = 3 [ default = 1 ];
optional int32 start_cpu_core_id = 4 [ default = 1 ];
repeated string param_need_sync = 5;
optional int32 num_microbatches = 6;
optional int32 num_pipeline_stages = 7 [ default = 1 ];
optional int32 pipeline_stage = 8 [ default = 1 ];
}
message SectionConfig { message SectionConfig {
enum Place { enum Place {
CPUPlace = 0; CPUPlace = 0;
......
...@@ -66,6 +66,11 @@ std::shared_ptr<TrainerBase> TrainerFactory::CreateTrainer( ...@@ -66,6 +66,11 @@ std::shared_ptr<TrainerBase> TrainerFactory::CreateTrainer(
REGISTER_TRAINER_CLASS(MultiTrainer); REGISTER_TRAINER_CLASS(MultiTrainer);
REGISTER_TRAINER_CLASS(DistMultiTrainer); REGISTER_TRAINER_CLASS(DistMultiTrainer);
#if defined(PADDLE_WITH_PSCORE)
REGISTER_TRAINER_CLASS(HeterPipelineTrainer);
#endif
#if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_HIP || \ #if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_HIP || \
defined PADDLE_WITH_XPU) && \ defined PADDLE_WITH_XPU) && \
(defined PADDLE_WITH_PSLIB) (defined PADDLE_WITH_PSLIB)
......
...@@ -29,5 +29,11 @@ set(OPERATOR_DEPS ${OPERATOR_DEPS} ${DISTRIBUTE_DEPS} PARENT_SCOPE) ...@@ -29,5 +29,11 @@ set(OPERATOR_DEPS ${OPERATOR_DEPS} ${DISTRIBUTE_DEPS} PARENT_SCOPE)
set_source_files_properties(heter_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(heter_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(heter_server_test SRCS heter_server_test.cc DEPS ${RPC_DEPS} ${DISTRIBUTE_DEPS} executor scope proto_desc scale_op eigen_function) cc_test(heter_server_test SRCS heter_server_test.cc DEPS ${RPC_DEPS} ${DISTRIBUTE_DEPS} executor scope proto_desc scale_op eigen_function)
set_source_files_properties(send_and_recv_op_cpu_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(send_and_recv_cpu_test SRCS send_and_recv_op_cpu_test.cc DEPS executor scope proto_desc scale_op send_and_recv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function)
set_source_files_properties(send_and_recv_op_gpu_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(send_and_recv_gpu_test SRCS send_and_recv_op_gpu_test.cc DEPS executor scope proto_desc scale_op send_and_recv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function)
set_source_files_properties(heter_listen_and_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(heter_listen_and_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(heter_listen_and_server_test SRCS heter_listen_and_server_test.cc DEPS executor scope proto_desc scale_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function) cc_test(heter_listen_and_server_test SRCS heter_listen_and_server_test.cc DEPS executor scope proto_desc scale_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function)
...@@ -49,9 +49,7 @@ HeterListenAndServOp::~HeterListenAndServOp() { Stop(); } ...@@ -49,9 +49,7 @@ HeterListenAndServOp::~HeterListenAndServOp() { Stop(); }
void HeterListenAndServOp::Stop() {} void HeterListenAndServOp::Stop() {}
void HeterListenAndServOp::RunAsyncLoop(framework::Executor *executor, void HeterListenAndServOp::RunAsyncLoop(framework::ProgramDesc *program) const {
framework::ProgramDesc *program,
framework::Scope *recv_scope) const {
VLOG(2) << "RunAsyncLoop"; VLOG(2) << "RunAsyncLoop";
auto message_to_block_id_str = auto message_to_block_id_str =
Attr<std::vector<std::string>>("message_to_block_id"); Attr<std::vector<std::string>>("message_to_block_id");
...@@ -90,28 +88,6 @@ void HeterListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -90,28 +88,6 @@ void HeterListenAndServOp::RunAsyncLoop(framework::Executor *executor,
for (size_t blkid = 1; blkid < num_blocks; ++blkid) { for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
block_list.push_back(blkid); block_list.push_back(blkid);
} }
auto optimize_prepared = executor->Prepare(*program, block_list);
// execute global block if needed, block id 1 in the program is global
// block if it's not bind to a grad var for it's update.
if (block_list[0] == 1 &&
message_to_block_id.find_value(static_cast<int32_t>(1)) ==
message_to_block_id.end()) {
executor->RunPreparedContext(optimize_prepared[0].get(), recv_scope);
}
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
message_to_prepared_ctx;
for (size_t i = 0; i < block_list.size(); ++i) {
auto blkid = block_list[i];
auto it = message_to_block_id.find_value(blkid);
if (it != message_to_block_id.end()) {
message_to_prepared_ctx[it->first] = optimize_prepared[i];
}
}
request_send_and_recv_handler_->SetGradToPreparedCtx(
&message_to_prepared_ctx);
for (size_t i = 0; i < block_list.size(); ++i) { for (size_t i = 0; i < block_list.size(); ++i) {
auto blkid = block_list[i]; auto blkid = block_list[i];
...@@ -125,7 +101,7 @@ void HeterListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -125,7 +101,7 @@ void HeterListenAndServOp::RunAsyncLoop(framework::Executor *executor,
} }
while (true) { while (true) {
if (rpc_service_->IsExit()) { if (rpc_service_->IsExit() || rpc_service_->IsStop()) {
rpc_service_->Stop(); rpc_service_->Stop();
VLOG(0) << "get exit. rpc_processor stop!"; VLOG(0) << "get exit. rpc_processor stop!";
break; break;
...@@ -145,7 +121,6 @@ void HeterListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -145,7 +121,6 @@ void HeterListenAndServOp::RunImpl(const framework::Scope &scope,
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
VLOG(1) << "HeterListenAndServOp::RunImpl On gpu? " VLOG(1) << "HeterListenAndServOp::RunImpl On gpu? "
<< platform::is_gpu_place(dev_place); << platform::is_gpu_place(dev_place);
framework::Scope &recv_scope = scope.NewScope();
auto pserver_id = Attr<int>("pserver_id"); auto pserver_id = Attr<int>("pserver_id");
auto fan_in = Attr<int>("fanin"); auto fan_in = Attr<int>("fanin");
...@@ -154,8 +129,8 @@ void HeterListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -154,8 +129,8 @@ void HeterListenAndServOp::RunImpl(const framework::Scope &scope,
PADDLE_ENFORCE_EQ(rpc_service_, nullptr, PADDLE_ENFORCE_EQ(rpc_service_, nullptr,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"RPC service has been created unexpectedly.")); "RPC service has been created unexpectedly."));
std::string endpoint = Attr<std::string>("endpoint");
std::string endpoint = Attr<std::string>("endpoint");
VLOG(4) << "pserver_id: " << pserver_id << ", end_point:" << endpoint; VLOG(4) << "pserver_id: " << pserver_id << ", end_point:" << endpoint;
rpc_service_ = distributed::HeterServer::GetInstance(); rpc_service_ = distributed::HeterServer::GetInstance();
...@@ -168,15 +143,14 @@ void HeterListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -168,15 +143,14 @@ void HeterListenAndServOp::RunImpl(const framework::Scope &scope,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"optimize blocks is less than 1. Optimize blocks " "optimize blocks is less than 1. Optimize blocks "
"should be 1 at least on the pserver side.")); "should be 1 at least on the pserver side."));
auto *program = optimize_blocks[0]->Program(); auto *program = optimize_blocks[0]->Program();
framework::Executor executor(dev_place);
request_send_and_recv_handler_.reset( request_send_and_recv_handler_.reset(
new distributed::RequestSendAndRecvHandler()); new distributed::RequestSendAndRecvHandler());
request_send_and_recv_handler_->SetScope(&recv_scope); request_send_and_recv_handler_->SetScope(&scope);
request_send_and_recv_handler_->SetDevCtx(&dev_ctx); request_send_and_recv_handler_->SetDevCtx(&dev_ctx);
request_send_and_recv_handler_->SetProgram(program); rpc_service_->SetRequestHandler(request_send_and_recv_handler_);
request_send_and_recv_handler_->SetExecutor(&executor);
VLOG(2) << "RunAsyncLoop"; VLOG(2) << "RunAsyncLoop";
auto message_to_block_id_str = auto message_to_block_id_str =
...@@ -186,7 +160,7 @@ void HeterListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -186,7 +160,7 @@ void HeterListenAndServOp::RunImpl(const framework::Scope &scope,
server_thread_.reset(new std::thread(RunServer, rpc_service_)); server_thread_.reset(new std::thread(RunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready..."; VLOG(3) << "wait server thread to become ready...";
rpc_service_->WaitServerReady(); rpc_service_->WaitServerReady();
RunAsyncLoop(&executor, program, &recv_scope); RunAsyncLoop(program);
VLOG(3) << "Wait for Server_thread_ stop"; VLOG(3) << "Wait for Server_thread_ stop";
(server_thread_.get())->join(); (server_thread_.get())->join();
VLOG(3) << "Server_thread_ stop"; VLOG(3) << "Server_thread_ stop";
......
...@@ -77,9 +77,7 @@ class HeterListenAndServOp : public framework::OperatorBase { ...@@ -77,9 +77,7 @@ class HeterListenAndServOp : public framework::OperatorBase {
const framework::AttributeMap& attrs); const framework::AttributeMap& attrs);
virtual ~HeterListenAndServOp(); virtual ~HeterListenAndServOp();
void RunAsyncLoop(framework::Executor* executor, void RunAsyncLoop(framework::ProgramDesc* program) const;
framework::ProgramDesc* program,
framework::Scope* recv_scope) const;
void Stop() override; void Stop() override;
...@@ -89,7 +87,7 @@ class HeterListenAndServOp : public framework::OperatorBase { ...@@ -89,7 +87,7 @@ class HeterListenAndServOp : public framework::OperatorBase {
protected: protected:
mutable std::shared_ptr<paddle::distributed::HeterServer> rpc_service_; mutable std::shared_ptr<paddle::distributed::HeterServer> rpc_service_;
mutable std::shared_ptr<std::thread> server_thread_; mutable std::shared_ptr<std::thread> server_thread_;
mutable std::shared_ptr<paddle::distributed::HeterRequestHandler> mutable std::shared_ptr<paddle::distributed::RequestSendAndRecvHandler>
request_send_and_recv_handler_; request_send_and_recv_handler_;
}; };
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/service/heter_client.h" #include "paddle/fluid/distributed/service/heter_client.h"
#include "paddle/fluid/distributed/service/heter_server.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -76,6 +77,9 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { ...@@ -76,6 +77,9 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
auto x_var = scope->Var("x"); auto x_var = scope->Var("x");
x_var->GetMutable<framework::LoDTensor>(); x_var->GetMutable<framework::LoDTensor>();
auto micro_var = scope->Var("microbatch_id");
micro_var->GetMutable<framework::LoDTensor>();
auto res_var = scope->Var("res"); auto res_var = scope->Var("res");
res_var->GetMutable<framework::LoDTensor>(); res_var->GetMutable<framework::LoDTensor>();
} }
...@@ -88,6 +92,32 @@ void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, ...@@ -88,6 +92,32 @@ void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
x_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place); x_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place);
for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0; for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0;
auto micro_id_var =
scope->Var("microbatch_id")->GetMutable<framework::LoDTensor>();
float* micro_id_ptr =
micro_id_var->mutable_data<float>(framework::DDim({1}), *place);
micro_id_ptr[0] = 0;
auto res_var = scope->Var("res")->GetMutable<framework::LoDTensor>();
float* res_ptr =
res_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place);
for (int64_t i = 0; i < rows_numel; ++i) res_ptr[i] = 1.0;
}
void InitTensorsOnClient2(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope, place);
auto x_var = scope->Var("x")->GetMutable<framework::LoDTensor>();
float* x_ptr =
x_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place);
for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0;
auto micro_id_var =
scope->Var("microbatch_id")->GetMutable<framework::LoDTensor>();
float* micro_id_ptr =
micro_id_var->mutable_data<float>(framework::DDim({1}), *place);
micro_id_ptr[0] = 1;
auto res_var = scope->Var("res")->GetMutable<framework::LoDTensor>(); auto res_var = scope->Var("res")->GetMutable<framework::LoDTensor>();
float* res_ptr = float* res_ptr =
res_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place); res_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place);
...@@ -121,45 +151,78 @@ TEST(HETER_LISTEN_AND_SERV, CPU) { ...@@ -121,45 +151,78 @@ TEST(HETER_LISTEN_AND_SERV, CPU) {
setenv("http_proxy", "", 1); setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1); setenv("https_proxy", "", 1);
std::string endpoint = "127.0.0.1:19944"; std::string endpoint = "127.0.0.1:19944";
std::string previous_endpoint = "127.0.0.1:19944";
LOG(INFO) << "before StartSendAndRecvServer"; LOG(INFO) << "before StartSendAndRecvServer";
FLAGS_eager_delete_tensor_gb = -1; FLAGS_eager_delete_tensor_gb = -1;
std::thread server_thread(StartHeterServer); std::thread server_thread(StartHeterServer);
sleep(1); sleep(1);
auto b_rpc_service = distributed::HeterServer::GetInstance();
b_rpc_service->WaitServerReady();
using MicroScope =
std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>;
std::shared_ptr<MicroScope> micro_scopes(new MicroScope{});
std::shared_ptr<std::vector<framework::Scope*>> micro_scope(
new std::vector<framework::Scope*>{});
(*micro_scope).push_back(new framework::Scope());
(*micro_scope).push_back(new framework::Scope());
(*micro_scopes)[0] = micro_scope;
b_rpc_service->SetMicroBatchScopes(micro_scopes);
using TaskQueue =
std::unordered_map<int,
std::shared_ptr<::paddle::framework::BlockingQueue<
std::pair<std::string, int>>>>;
using SharedTaskQueue = std::shared_ptr<std::unordered_map<
int, std::shared_ptr<::paddle::framework::BlockingQueue<
std::pair<std::string, int>>>>>;
SharedTaskQueue task_queue_(new TaskQueue{});
(*task_queue_)[0] = std::make_shared<
::paddle::framework::BlockingQueue<std::pair<std::string, int>>>();
b_rpc_service->SetTaskQueue(task_queue_);
LOG(INFO) << "before HeterClient::GetInstance"; LOG(INFO) << "before HeterClient::GetInstance";
distributed::HeterClient* rpc_client = distributed::HeterClient* rpc_client =
distributed::HeterClient::GetInstance({endpoint}, 0).get(); distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0)
.get();
PADDLE_ENFORCE_NE(rpc_client, nullptr, PADDLE_ENFORCE_NE(rpc_client, nullptr,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env")); "Client Start Fail, Check Your Code & Env"));
framework::Scope scope; framework::Scope* scope = (*micro_scope)[0];
platform::CPUPlace place; platform::CPUPlace place;
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
// create var on local scope // create var on local scope
int64_t rows_numel = 10; int64_t rows_numel = 10;
LOG(INFO) << "before InitTensorsOnClient"; LOG(INFO) << "before InitTensorsOnClient";
InitTensorsOnClient(&scope, &place, rows_numel); InitTensorsOnClient(scope, &place, rows_numel);
std::string in_var_name("x"); std::string in_var_name("x");
std::string micro_var_name("microbatch_id");
std::string out_var_name("res"); std::string out_var_name("res");
std::vector<std::string> send_var = {in_var_name}; std::vector<std::string> send_var = {in_var_name, micro_var_name};
std::vector<std::string> recv_var = {out_var_name}; std::vector<std::string> recv_var = {};
LOG(INFO) << "before SendAndRecvAsync"; LOG(INFO) << "before SendAndRecvAsync";
rpc_client->SendAndRecvAsync({endpoint}, ctx, scope, in_var_name, send_var, rpc_client->SendAndRecvAsync(ctx, *scope, in_var_name, send_var, recv_var,
recv_var); "forward");
auto var = scope.Var(out_var_name); auto task = (*task_queue_)[0]->Pop();
auto value = var->GetMutable<framework::LoDTensor>(); PADDLE_ENFORCE_EQ(
auto ptr = value->mutable_data<float>(place); task.first, "x",
platform::errors::InvalidArgument(
LOG(INFO) << "before CHECK"; "Recv message and Send message name not match, Check your Code"));
for (int64_t i = 0; i < rows_numel; ++i) {
LOG(INFO) << "ptr " << i << " is " << ptr[i]; InitTensorsOnClient2((*micro_scope)[1], &place, rows_numel);
EXPECT_EQ(ptr[i], 0.5); LOG(INFO) << "before SendAndRecvAsync 2";
} rpc_client->SendAndRecvAsync(ctx, *((*micro_scope)[1]), in_var_name, send_var,
LOG(INFO) << "end CHECK"; recv_var, "backward");
auto task2 = (*task_queue_)[0]->Pop();
PADDLE_ENFORCE_EQ(
task2.first, "x",
platform::errors::InvalidArgument(
"Recv message and Send message name not match, Check your Code"));
rpc_client->Stop(); rpc_client->Stop();
LOG(INFO) << "end server Stop"; LOG(INFO) << "end server Stop";
server_thread.join(); server_thread.join();
......
...@@ -57,6 +57,9 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { ...@@ -57,6 +57,9 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
auto out_var = scope->Var("out"); auto out_var = scope->Var("out");
out_var->GetMutable<framework::LoDTensor>(); out_var->GetMutable<framework::LoDTensor>();
auto micro_var = scope->Var("microbatch_id");
micro_var->GetMutable<framework::LoDTensor>();
auto ids_var = scope->Var("ids"); auto ids_var = scope->Var("ids");
ids_var->GetMutable<framework::LoDTensor>(); ids_var->GetMutable<framework::LoDTensor>();
...@@ -75,6 +78,37 @@ void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, ...@@ -75,6 +78,37 @@ void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
ids_var->mutable_data<int64_t>(framework::DDim({rows_numel, 1}), *place); ids_var->mutable_data<int64_t>(framework::DDim({rows_numel, 1}), *place);
for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2; for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2;
auto micro_id_var =
scope->Var("microbatch_id")->GetMutable<framework::LoDTensor>();
float* micro_id_ptr =
micro_id_var->mutable_data<float>(framework::DDim({1}), *place);
micro_id_ptr[0] = 0;
auto x_var = scope->Var("x")->GetMutable<framework::LoDTensor>();
float* x_ptr =
x_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place);
for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0;
auto res_var = scope->Var("res")->GetMutable<framework::LoDTensor>();
float* res_ptr =
res_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place);
for (int64_t i = 0; i < rows_numel; ++i) res_ptr[i] = 1.0;
}
void InitTensorsOnClient2(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope, place);
auto ids_var = scope->Var("ids")->GetMutable<framework::LoDTensor>();
int64_t* ids_ptr =
ids_var->mutable_data<int64_t>(framework::DDim({rows_numel, 1}), *place);
for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2;
auto micro_id_var =
scope->Var("microbatch_id")->GetMutable<framework::LoDTensor>();
float* micro_id_ptr =
micro_id_var->mutable_data<float>(framework::DDim({1}), *place);
micro_id_ptr[0] = 1;
auto x_var = scope->Var("x")->GetMutable<framework::LoDTensor>(); auto x_var = scope->Var("x")->GetMutable<framework::LoDTensor>();
float* x_ptr = float* x_ptr =
x_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place); x_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place);
...@@ -114,29 +148,19 @@ void StartSendAndRecvServer(std::string endpoint) { ...@@ -114,29 +148,19 @@ void StartSendAndRecvServer(std::string endpoint) {
LOG(INFO) << "before AppendSendAndRecvBlock"; LOG(INFO) << "before AppendSendAndRecvBlock";
auto block = AppendSendAndRecvBlock(&program); auto block = AppendSendAndRecvBlock(&program);
std::string in_var_name("x"); std::string in_var_name("x");
std::string in_var_name2("y");
std::vector<int> prefetch_block_ids{block->ID()}; std::vector<int> prefetch_block_ids{block->ID()};
auto prepared = exe.Prepare(program, prefetch_block_ids);
LOG(INFO) << "before InitTensorsOnServer"; LOG(INFO) << "before InitTensorsOnServer";
InitTensorsOnServer(&scope, &place, 10); InitTensorsOnServer(&scope, &place, 10);
LOG(INFO) << "end InitTensorsOnServer"; LOG(INFO) << "end InitTensorsOnServer";
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
message_to_prepared_ctx;
message_to_prepared_ctx[in_var_name] = prepared[0];
std::shared_ptr<distributed::RequestSendAndRecvHandler> b_req_handler; std::shared_ptr<distributed::RequestSendAndRecvHandler> b_req_handler;
b_req_handler.reset(new distributed::RequestSendAndRecvHandler()); b_req_handler.reset(new distributed::RequestSendAndRecvHandler());
LOG(INFO) << "before SetProgram";
b_req_handler->SetProgram(&program);
LOG(INFO) << "before SetGradToPreparedCtx";
b_req_handler->SetGradToPreparedCtx(&message_to_prepared_ctx);
LOG(INFO) << "before SetDevCtx"; LOG(INFO) << "before SetDevCtx";
b_req_handler->SetDevCtx(&ctx); b_req_handler->SetDevCtx(&ctx);
LOG(INFO) << "before SetScope"; LOG(INFO) << "before SetScope";
b_req_handler->SetScope(&scope); b_req_handler->SetScope(&scope);
LOG(INFO) << "before SetExecutor";
b_req_handler->SetExecutor(&exe);
LOG(INFO) << "before HeterServer::GetInstance"; LOG(INFO) << "before HeterServer::GetInstance";
b_rpc_service = distributed::HeterServer::GetInstance(); b_rpc_service = distributed::HeterServer::GetInstance();
b_rpc_service->SetEndPoint(endpoint); b_rpc_service->SetEndPoint(endpoint);
...@@ -146,7 +170,13 @@ void StartSendAndRecvServer(std::string endpoint) { ...@@ -146,7 +170,13 @@ void StartSendAndRecvServer(std::string endpoint) {
brpc::Controller* cntl) -> int { brpc::Controller* cntl) -> int {
return b_req_handler->Handle(request, response, cntl); return b_req_handler->Handle(request, response, cntl);
}); });
b_rpc_service->RegisterServiceHandler(
in_var_name2, [&](const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) -> int {
return b_req_handler->Handle(request, response, cntl);
});
b_rpc_service->SetRequestHandler(b_req_handler);
LOG(INFO) << "before HeterServer::RunServer"; LOG(INFO) << "before HeterServer::RunServer";
std::thread server_thread(std::bind(RunServer, b_rpc_service)); std::thread server_thread(std::bind(RunServer, b_rpc_service));
...@@ -157,47 +187,82 @@ TEST(SENDANDRECV, CPU) { ...@@ -157,47 +187,82 @@ TEST(SENDANDRECV, CPU) {
setenv("http_proxy", "", 1); setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1); setenv("https_proxy", "", 1);
std::string endpoint = "127.0.0.1:4444"; std::string endpoint = "127.0.0.1:4444";
std::string previous_endpoint = "127.0.0.1:4444";
LOG(INFO) << "before StartSendAndRecvServer"; LOG(INFO) << "before StartSendAndRecvServer";
b_rpc_service = distributed::HeterServer::GetInstance(); b_rpc_service = distributed::HeterServer::GetInstance();
std::thread server_thread(StartSendAndRecvServer, endpoint); std::thread server_thread(StartSendAndRecvServer, endpoint);
b_rpc_service->WaitServerReady(); b_rpc_service->WaitServerReady();
using MicroScope =
std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>;
std::shared_ptr<MicroScope> micro_scopes(new MicroScope{});
std::shared_ptr<std::vector<framework::Scope*>> micro_scope(
new std::vector<framework::Scope*>{});
(*micro_scope).push_back(new framework::Scope());
(*micro_scope).push_back(new framework::Scope());
(*micro_scopes)[0] = micro_scope;
b_rpc_service->SetMicroBatchScopes(micro_scopes);
using TaskQueue =
std::unordered_map<int,
std::shared_ptr<::paddle::framework::BlockingQueue<
std::pair<std::string, int>>>>;
using SharedTaskQueue = std::shared_ptr<std::unordered_map<
int, std::shared_ptr<::paddle::framework::BlockingQueue<
std::pair<std::string, int>>>>>;
SharedTaskQueue task_queue_(new TaskQueue{});
(*task_queue_)[0] = std::make_shared<
::paddle::framework::BlockingQueue<std::pair<std::string, int>>>();
b_rpc_service->SetTaskQueue(task_queue_);
LOG(INFO) << "before HeterClient::GetInstance"; LOG(INFO) << "before HeterClient::GetInstance";
distributed::HeterClient* rpc_client = distributed::HeterClient* rpc_client =
distributed::HeterClient::GetInstance({endpoint}, 0).get(); distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0)
.get();
PADDLE_ENFORCE_NE(rpc_client, nullptr, PADDLE_ENFORCE_NE(rpc_client, nullptr,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env")); "Client Start Fail, Check Your Code & Env"));
framework::Scope scope; framework::Scope* scope = (*micro_scope)[0];
platform::CPUPlace place; platform::CPUPlace place;
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
// create var on local scope // create var on local scope
int64_t rows_numel = 10; int64_t rows_numel = 10;
LOG(INFO) << "before InitTensorsOnClient"; LOG(INFO) << "before InitTensorsOnClient";
InitTensorsOnClient(&scope, &place, rows_numel); InitTensorsOnClient(scope, &place, rows_numel);
std::string in_var_name("x"); std::string in_var_name("x");
std::string micro_var_name("microbatch_id");
std::string out_var_name("res"); std::string out_var_name("res");
std::vector<std::string> send_var = {in_var_name}; std::vector<std::string> send_var = {in_var_name, micro_var_name};
std::vector<std::string> recv_var = {out_var_name}; std::vector<std::string> recv_var = {};
LOG(INFO) << "before SendAndRecvAsync"; LOG(INFO) << "before SendAndRecvAsync";
rpc_client->SendAndRecvAsync({endpoint}, ctx, scope, in_var_name, send_var, rpc_client->SendAndRecvAsync(ctx, *scope, in_var_name, send_var, recv_var,
recv_var); "forward");
auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::LoDTensor>(); LOG(INFO) << "client wait for Pop";
auto ptr = value->mutable_data<float>(place); auto task = (*task_queue_)[0]->Pop();
LOG(INFO) << "client get from task queue";
LOG(INFO) << "before CHECK"; PADDLE_ENFORCE_EQ(
for (int64_t i = 0; i < rows_numel; ++i) { task.first, "x",
LOG(INFO) << "ptr " << i << " is " << ptr[i]; platform::errors::InvalidArgument(
EXPECT_EQ(ptr[i], 0.5); "Recv message and Send message name not match, Check your Code"));
}
LOG(INFO) << "end CHECK"; InitTensorsOnClient2((*micro_scope)[1], &place, rows_numel);
LOG(INFO) << "before SendAndRecvAsync 2";
std::string in_var_name2("y");
rpc_client->SendAndRecvAsync(ctx, *((*micro_scope)[1]), in_var_name2,
send_var, recv_var, "backward");
LOG(INFO) << "after SendAndRecvAsync 2";
auto task2 = (*task_queue_)[0]->Pop();
PADDLE_ENFORCE_EQ(
task2.first, "y",
platform::errors::InvalidArgument(
"Recv message and Send message name not match, Check your Code"));
rpc_client->FinalizeWorker(); rpc_client->FinalizeWorker();
// b_rpc_service->Stop();
b_rpc_service->Stop(); b_rpc_service->Stop();
LOG(INFO) << "end server Stop"; LOG(INFO) << "end server Stop";
server_thread.join(); server_thread.join();
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -34,17 +35,22 @@ class SendAndRecvKernel : public framework::OpKernel<T> { ...@@ -34,17 +35,22 @@ class SendAndRecvKernel : public framework::OpKernel<T> {
auto message_name = ctx.Attr<std::string>("message_name"); auto message_name = ctx.Attr<std::string>("message_name");
auto send_var_name = ctx.Attr<std::vector<std::string>>("send_var_name"); auto send_var_name = ctx.Attr<std::vector<std::string>>("send_var_name");
auto recv_var_name = ctx.Attr<std::vector<std::string>>("recv_var_name"); auto recv_var_name = ctx.Attr<std::vector<std::string>>("recv_var_name");
auto epmap = ctx.Attr<std::vector<std::string>>("endpoints"); auto next_epmap = ctx.Attr<std::vector<std::string>>("next_endpoints");
auto previous_epmap =
ctx.Attr<std::vector<std::string>>("previous_endpoints");
auto trainer_id = ctx.Attr<int>("trainer_id"); auto trainer_id = ctx.Attr<int>("trainer_id");
auto mode = ctx.Attr<std::string>("mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& context = *pool.Get(place); auto& context = *pool.Get(place);
distributed::HeterClient* rpc_client = distributed::HeterClient* rpc_client =
distributed::HeterClient::GetInstance(epmap, trainer_id).get(); distributed::HeterClient::GetInstance(next_epmap, previous_epmap,
trainer_id)
.get();
VLOG(3) << "SendAndRecvOp message_name: " << message_name; VLOG(3) << "SendAndRecvOp message_name: " << message_name;
rpc_client->SendAndRecvAsync(epmap, context, scope, message_name, rpc_client->SendAndRecvAsync(context, scope, message_name, send_var_name,
send_var_name, recv_var_name); recv_var_name, mode);
} }
}; };
...@@ -67,11 +73,17 @@ class SendAndRecvOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -67,11 +73,17 @@ class SendAndRecvOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "Tensor Input variable to be sent").AsDuplicable(); AddInput("X", "Tensor Input variable to be sent").AsDuplicable();
AddOutput("Out", "Tensor Output varibale to be recv").AsDuplicable(); AddOutput("Out", "Tensor Output varibale to be recv").AsDuplicable();
AddAttr<std::string>("message_name", ""); AddAttr<std::string>("message_name", "");
AddAttr<std::string>("mode", "forward or backward").SetDefault("forward");
AddAttr<std::vector<std::string>>("send_var_name", "Send Tensor's name"); AddAttr<std::vector<std::string>>("send_var_name", "Send Tensor's name");
AddAttr<std::vector<std::string>>("recv_var_name", "Recv Tensor's name"); AddAttr<std::vector<std::string>>("recv_var_name", "Recv Tensor's name");
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>("endpoints", "Server endpoint") AddAttr<std::vector<std::string>>("endpoints", "Server endpoint")
.SetDefault({"127.0.0.1:6164"}); .SetDefault({"127.0.0.1:6164"});
AddAttr<std::vector<std::string>>("next_endpoints", "Server endpoint")
.SetDefault({"127.0.0.1:6164"});
AddAttr<std::vector<std::string>>("previous_endpoints",
"Previous Server endpoint")
.SetDefault({"127.0.0.1:6164"});
AddComment(R"DOC( AddComment(R"DOC(
SendAndRecv operator SendAndRecv operator
This operator will send variables to listen_and_serve op at the parameter server. This operator will send variables to listen_and_serve op at the parameter server.
...@@ -86,7 +98,25 @@ class SendAndRecvOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -86,7 +98,25 @@ class SendAndRecvOpMaker : public framework::OpProtoAndCheckerMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(send_and_recv, ops::SendAndRecvOp, ops::SendAndRecvOpMaker); REGISTER_OPERATOR(send_and_recv, ops::SendAndRecvOp, ops::SendAndRecvOpMaker);
REGISTER_OP_CUDA_KERNEL(
send_and_recv,
ops::SendAndRecvKernel<paddle::platform::CUDADeviceContext, float>,
ops::SendAndRecvKernel<paddle::platform::CUDADeviceContext, double>,
ops::SendAndRecvKernel<paddle::platform::CUDADeviceContext, int>,
ops::SendAndRecvKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
send_and_recv, send_and_recv,
ops::SendAndRecvKernel<paddle::platform::CPUDeviceContext, float>) ops::SendAndRecvKernel<paddle::platform::CPUDeviceContext, float>,
ops::SendAndRecvKernel<paddle::platform::CPUDeviceContext, double>,
ops::SendAndRecvKernel<paddle::platform::CPUDeviceContext, int>,
ops::SendAndRecvKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_VERSION(send_and_recv)
.AddCheckpoint(
R"ROC(add new attributes [next_endpoints] [previous_endpoints] and [mode])ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr("next_endpoints", "Server endpoint",
std::vector<std::string>({"127.0.0.1:6164"}))
.NewAttr("previous_endpoints", "Server endpoint",
std::vector<std::string>({"127.0.0.1:6164"}))
.NewAttr("mode", "forward or backward", "forward"));
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdlib.h>
#include <memory>
#include <string>
#include <thread> // NOLINT
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/service/heter_client.h"
#include "paddle/fluid/distributed/service/heter_server.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace distributed = paddle::distributed;
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
USE_OP(scale);
USE_OP(send_and_recv);
std::shared_ptr<distributed::HeterServer> b_rpc_service;
framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0);
auto* block = program->AppendBlock(*root_block);
framework::OpDesc* op = block->AppendOp();
op->SetType("scale");
op->SetInput("X", {"x"});
op->SetOutput("Out", {"res"});
op->SetAttr("scale", 0.5f);
auto& out = *root_block->Var("res");
out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetShape({1, 10});
return block;
}
void CreateVarsOnScope(framework::Scope* scope) {
auto w_var = scope->Var("w");
w_var->GetMutable<framework::SelectedRows>();
auto out_var = scope->Var("out");
out_var->GetMutable<framework::LoDTensor>();
auto micro_var = scope->Var("microbatch_id");
micro_var->GetMutable<framework::LoDTensor>();
auto ids_var = scope->Var("ids");
ids_var->GetMutable<framework::LoDTensor>();
auto x_var = scope->Var("x");
x_var->GetMutable<framework::LoDTensor>();
auto res_var = scope->Var("res");
res_var->GetMutable<framework::LoDTensor>();
}
void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope);
auto w = scope->Var("w")->GetMutable<framework::SelectedRows>();
auto w_value = w->mutable_value();
w_value->Resize({rows_numel, 10});
for (int64_t i = 0; i < rows_numel; ++i) w->AutoGrownIndex(i, true);
auto ptr = w_value->mutable_data<float>(*place);
for (int64_t i = 0; i < w_value->numel(); ++i) {
ptr[i] = static_cast<float>(i / 10);
}
}
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope);
auto ids_var = scope->Var("ids")->GetMutable<framework::LoDTensor>();
int64_t* ids_ptr =
ids_var->mutable_data<int64_t>(framework::DDim({rows_numel, 1}), *place);
for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2;
auto micro_id_var =
scope->Var("microbatch_id")->GetMutable<framework::LoDTensor>();
float* micro_id_ptr =
micro_id_var->mutable_data<float>(framework::DDim({1}), *place);
micro_id_ptr[0] = 0;
auto x_var = scope->Var("x")->GetMutable<framework::LoDTensor>();
float* x_ptr =
x_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place);
for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0;
auto res_var = scope->Var("res")->GetMutable<framework::LoDTensor>();
float* res_ptr =
res_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place);
for (int64_t i = 0; i < rows_numel; ++i) res_ptr[i] = 1.0;
}
void RunServer(std::shared_ptr<paddle::distributed::HeterServer> service) {
service->StartHeterService();
}
void StartSendAndRecvServer(std::string endpoint) {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
LOG(INFO) << "before AppendSendAndRecvBlock";
auto block = AppendSendAndRecvBlock(&program);
std::string in_var_name("x");
// std::string in_var_name2("y");
std::vector<int> prefetch_block_ids{block->ID()};
LOG(INFO) << "before InitTensorsOnServer";
InitTensorsOnServer(&scope, &place, 10);
LOG(INFO) << "end InitTensorsOnServer";
std::shared_ptr<distributed::RequestSendAndRecvHandler> b_req_handler;
b_req_handler.reset(new distributed::RequestSendAndRecvHandler());
LOG(INFO) << "before SetDevCtx";
b_req_handler->SetDevCtx(&ctx);
LOG(INFO) << "before SetScope";
b_req_handler->SetScope(&scope);
LOG(INFO) << "before HeterServer::GetInstance";
b_rpc_service = distributed::HeterServer::GetInstance();
b_rpc_service->SetEndPoint(endpoint);
LOG(INFO) << "before HeterServer::RegisterServiceHandler";
b_rpc_service->RegisterServiceHandler(
in_var_name, [&](const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) -> int {
return b_req_handler->Handle(request, response, cntl);
});
b_rpc_service->SetRequestHandler(b_req_handler);
LOG(INFO) << "before HeterServer::RunServer";
std::thread server_thread(std::bind(RunServer, b_rpc_service));
server_thread.join();
}
TEST(SENDANDRECV, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
std::string endpoint = "127.0.0.1:4444";
std::string previous_endpoint = "127.0.0.1:4444";
LOG(INFO) << "before StartSendAndRecvServer";
b_rpc_service = distributed::HeterServer::GetInstance();
std::thread server_thread(StartSendAndRecvServer, endpoint);
b_rpc_service->WaitServerReady();
using MicroScope =
std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>;
std::shared_ptr<MicroScope> micro_scopes(new MicroScope{});
std::shared_ptr<std::vector<framework::Scope*>> micro_scope(
new std::vector<framework::Scope*>{});
(*micro_scope).push_back(new framework::Scope());
(*micro_scopes)[0] = micro_scope;
b_rpc_service->SetMicroBatchScopes(micro_scopes);
using TaskQueue =
std::unordered_map<int,
std::shared_ptr<::paddle::framework::BlockingQueue<
std::pair<std::string, int>>>>;
using SharedTaskQueue = std::shared_ptr<std::unordered_map<
int, std::shared_ptr<::paddle::framework::BlockingQueue<
std::pair<std::string, int>>>>>;
SharedTaskQueue task_queue_(new TaskQueue{});
(*task_queue_)[0] = std::make_shared<
::paddle::framework::BlockingQueue<std::pair<std::string, int>>>();
b_rpc_service->SetTaskQueue(task_queue_);
LOG(INFO) << "before HeterClient::GetInstance";
distributed::HeterClient* rpc_client =
distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0)
.get();
PADDLE_ENFORCE_NE(rpc_client, nullptr,
platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env"));
framework::Scope* scope = (*micro_scope)[0];
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
framework::Executor exe(place);
// create var on local scope
int64_t rows_numel = 10;
LOG(INFO) << "before InitTensorsOnClient";
InitTensorsOnClient(scope, &place, rows_numel);
std::string in_var_name("x");
std::string micro_var_name("microbatch_id");
// std::string out_var_name("res");
std::vector<std::string> send_var{in_var_name, micro_var_name};
std::vector<std::string> recv_var{};
std::string mode_str("forward");
LOG(INFO) << "add block & op1";
framework::ProgramDesc program;
auto root_block = program.MutableBlock(0);
// op for forward
framework::OpDesc* op = root_block->AppendOp();
op->SetType("send_and_recv");
LOG(INFO) << "op1 set input";
op->SetInput("X", std::vector<std::string>({in_var_name}));
op->SetOutput("Out", {});
op->SetAttr("next_endpoints", std::vector<std::string>({endpoint}));
op->SetAttr("previous_endpoints",
std::vector<std::string>({previous_endpoint}));
op->SetAttr("trainer_id", 0);
op->SetAttr("mode", mode_str);
op->SetAttr("message_name", in_var_name);
op->SetAttr("send_var_name", send_var);
op->SetAttr("recv_var_name", recv_var);
std::string mode_str2("backward");
// op for backward
LOG(INFO) << "add op2";
framework::OpDesc* op2 = root_block->AppendOp();
op2->SetType("send_and_recv");
LOG(INFO) << "op2 set input";
op2->SetInput("X", std::vector<std::string>({in_var_name}));
op2->SetOutput("Out", {});
op2->SetAttr("next_endpoints", std::vector<std::string>({endpoint}));
op2->SetAttr("previous_endpoints",
std::vector<std::string>({previous_endpoint}));
op2->SetAttr("trainer_id", 0);
op2->SetAttr("mode", mode_str2);
op2->SetAttr("message_name", in_var_name);
op2->SetAttr("send_var_name", send_var);
op2->SetAttr("recv_var_name", recv_var);
LOG(INFO) << "exe before prepare";
auto prepared = exe.Prepare(program, 0);
LOG(INFO) << "exe after prepare";
LOG(INFO) << "before RunPreparedContext";
exe.RunPreparedContext(prepared.get(), scope, false);
LOG(INFO) << "client wait for Pop";
auto task = (*task_queue_)[0]->Pop();
LOG(INFO) << "client get from task queue";
PADDLE_ENFORCE_EQ(
task.first, "x",
platform::errors::InvalidArgument(
"Recv message and Send message name not match, Check your Code"));
auto task2 = (*task_queue_)[0]->Pop();
PADDLE_ENFORCE_EQ(
task2.first, "x",
platform::errors::InvalidArgument(
"Recv message and Send message name not match, Check your Code"));
rpc_client->FinalizeWorker();
b_rpc_service->Stop();
LOG(INFO) << "end server Stop";
server_thread.join();
LOG(INFO) << "end server thread join";
}
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdlib.h>
#include <memory>
#include <string>
#include <thread> // NOLINT
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/service/heter_client.h"
#include "paddle/fluid/distributed/service/heter_server.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device_context.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace distributed = paddle::distributed;
namespace memory = paddle::memory;
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
USE_OP(scale);
USE_OP(send_and_recv);
std::shared_ptr<distributed::HeterServer> b_rpc_service2;
framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0);
auto* block = program->AppendBlock(*root_block);
framework::OpDesc* op = block->AppendOp();
op->SetType("scale");
op->SetInput("X", {"x"});
op->SetOutput("Out", {"res"});
op->SetAttr("scale", 0.5f);
auto& out = *root_block->Var("res");
out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetShape({1, 10});
return block;
}
void CreateVarsOnScope(framework::Scope* scope) {
auto w_var = scope->Var("w");
w_var->GetMutable<framework::SelectedRows>();
auto out_var = scope->Var("out");
out_var->GetMutable<framework::LoDTensor>();
auto micro_var = scope->Var("microbatch_id");
micro_var->GetMutable<framework::LoDTensor>();
auto ids_var = scope->Var("ids");
ids_var->GetMutable<framework::LoDTensor>();
auto x_var = scope->Var("x");
x_var->GetMutable<framework::LoDTensor>();
auto res_var = scope->Var("res");
res_var->GetMutable<framework::LoDTensor>();
}
void InitTensorsOnClient(framework::Scope* scope, int64_t rows_numel,
const platform::DeviceContext& ctx) {
CreateVarsOnScope(scope);
const auto place = ctx.GetPlace();
// auto ids_var = scope->Var("ids")->GetMutable<framework::LoDTensor>();
// int64_t* ids_ptr =
// ids_var->mutable_data<int64_t>(framework::DDim({rows_numel, 1}),
// *place);
// for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2;
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
auto micro_id_var =
scope->Var("microbatch_id")->GetMutable<framework::LoDTensor>();
float* micro_id_ptr =
micro_id_var->mutable_data<float>(framework::DDim({1}), place);
std::vector<float> temp_vec{0};
float* temp_ptr = temp_vec.data();
memory::Copy(
BOOST_GET_CONST(platform::CUDAPlace, place),
reinterpret_cast<void*>(micro_id_ptr), platform::CPUPlace(),
reinterpret_cast<void*>(temp_ptr),
micro_id_var->numel() * framework::SizeOfType(micro_id_var->type()),
stream);
auto x_var = scope->Var("x")->GetMutable<framework::LoDTensor>();
float* x_ptr =
x_var->mutable_data<float>(framework::DDim({1, rows_numel}), place);
std::vector<float> x_vec;
for (int64_t i = 0; i < rows_numel; ++i) x_vec.push_back(1.0);
float* x_vec_ptr = x_vec.data();
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place),
reinterpret_cast<void*>(x_ptr), platform::CPUPlace(),
reinterpret_cast<void*>(x_vec_ptr),
x_var->numel() * framework::SizeOfType(x_var->type()), stream);
// auto res_var = scope->Var("res")->GetMutable<framework::LoDTensor>();
// float* res_ptr =
// res_var->mutable_data<float>(framework::DDim({1, rows_numel}), place);
// for (int64_t i = 0; i < rows_numel; ++i) res_ptr[i] = 1.0;
}
void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope);
auto w = scope->Var("w")->GetMutable<framework::SelectedRows>();
auto w_value = w->mutable_value();
w_value->Resize({rows_numel, 10});
for (int64_t i = 0; i < rows_numel; ++i) w->AutoGrownIndex(i, true);
auto ptr = w_value->mutable_data<float>(*place);
for (int64_t i = 0; i < w_value->numel(); ++i) {
ptr[i] = static_cast<float>(i / 10);
}
}
void RunServer(std::shared_ptr<paddle::distributed::HeterServer> service) {
service->StartHeterService();
}
void StartSendAndRecvServer(std::string endpoint) {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
LOG(INFO) << "before AppendSendAndRecvBlock";
auto block = AppendSendAndRecvBlock(&program);
std::string in_var_name("x");
std::vector<int> prefetch_block_ids{block->ID()};
LOG(INFO) << "before InitTensorsOnServer";
InitTensorsOnServer(&scope, &place, 10);
LOG(INFO) << "end InitTensorsOnServer";
std::shared_ptr<distributed::RequestSendAndRecvHandler> b_req_handler;
b_req_handler.reset(new distributed::RequestSendAndRecvHandler());
LOG(INFO) << "before SetDevCtx";
b_req_handler->SetDevCtx(&ctx);
LOG(INFO) << "before SetScope";
b_req_handler->SetScope(&scope);
LOG(INFO) << "before HeterServer::GetInstance";
b_rpc_service2 = distributed::HeterServer::GetInstance();
b_rpc_service2->SetEndPoint(endpoint);
LOG(INFO) << "before HeterServer::RegisterServiceHandler";
b_rpc_service2->RegisterServiceHandler(
in_var_name, [&](const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) -> int {
return b_req_handler->Handle(request, response, cntl);
});
b_rpc_service2->SetRequestHandler(b_req_handler);
LOG(INFO) << "before HeterServer::RunServer";
std::thread server_thread(std::bind(RunServer, b_rpc_service2));
server_thread.join();
}
TEST(SENDANDRECV, GPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
std::string endpoint = "127.0.0.1:4445";
std::string previous_endpoint = "127.0.0.1:4445";
LOG(INFO) << "before StartSendAndRecvServer";
b_rpc_service2 = distributed::HeterServer::GetInstance();
std::thread server_thread(StartSendAndRecvServer, endpoint);
b_rpc_service2->WaitServerReady();
using MicroScope =
std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>;
std::shared_ptr<MicroScope> micro_scopes(new MicroScope{});
std::shared_ptr<std::vector<framework::Scope*>> micro_scope(
new std::vector<framework::Scope*>{});
(*micro_scope).push_back(new framework::Scope());
(*micro_scopes)[0] = micro_scope;
b_rpc_service2->SetMicroBatchScopes(micro_scopes);
using TaskQueue =
std::unordered_map<int,
std::shared_ptr<::paddle::framework::BlockingQueue<
std::pair<std::string, int>>>>;
using SharedTaskQueue = std::shared_ptr<std::unordered_map<
int, std::shared_ptr<::paddle::framework::BlockingQueue<
std::pair<std::string, int>>>>>;
SharedTaskQueue task_queue_(new TaskQueue{});
(*task_queue_)[0] = std::make_shared<
::paddle::framework::BlockingQueue<std::pair<std::string, int>>>();
b_rpc_service2->SetTaskQueue(task_queue_);
LOG(INFO) << "before HeterClient::GetInstance";
distributed::HeterClient* rpc_client =
distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0)
.get();
PADDLE_ENFORCE_NE(rpc_client, nullptr,
platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env"));
framework::Scope* scope = (*micro_scope)[0];
platform::CUDAPlace place;
platform::CUDADeviceContext ctx(place);
framework::Executor exe(place);
// create var on local scope
int64_t rows_numel = 10;
LOG(INFO) << "before InitTensorsOnClient";
InitTensorsOnClient(scope, rows_numel, ctx);
LOG(INFO) << "after InitTensorsOnClient2";
std::string in_var_name("x");
std::string micro_var_name("microbatch_id");
// std::string out_var_name("res");
std::vector<std::string> send_var{in_var_name, micro_var_name};
std::vector<std::string> recv_var{};
std::string mode_str("forward");
LOG(INFO) << "add block & op1";
framework::ProgramDesc program;
auto root_block = program.MutableBlock(0);
// op for forward
framework::OpDesc* op = root_block->AppendOp();
op->SetType("send_and_recv");
LOG(INFO) << "op1 set input";
op->SetInput("X", std::vector<std::string>({in_var_name}));
op->SetOutput("Out", {});
op->SetAttr("next_endpoints", std::vector<std::string>({endpoint}));
op->SetAttr("previous_endpoints",
std::vector<std::string>({previous_endpoint}));
op->SetAttr("trainer_id", 0);
op->SetAttr("mode", mode_str);
op->SetAttr("message_name", in_var_name);
op->SetAttr("send_var_name", send_var);
op->SetAttr("recv_var_name", recv_var);
op->SetAttr("op_device", std::string("gpu"));
std::string mode_str2("backward");
// op for backward
LOG(INFO) << "add op2";
framework::OpDesc* op2 = root_block->AppendOp();
op2->SetType("send_and_recv");
LOG(INFO) << "op2 set input";
op2->SetInput("X", std::vector<std::string>({in_var_name}));
op2->SetOutput("Out", {});
op2->SetAttr("next_endpoints", std::vector<std::string>({endpoint}));
op2->SetAttr("previous_endpoints",
std::vector<std::string>({previous_endpoint}));
op2->SetAttr("trainer_id", 0);
op2->SetAttr("mode", mode_str2);
op2->SetAttr("message_name", in_var_name);
op2->SetAttr("send_var_name", send_var);
op2->SetAttr("recv_var_name", recv_var);
op2->SetAttr("op_device", std::string("gpu"));
LOG(INFO) << "exe before prepare";
auto prepared = exe.Prepare(program, 0);
LOG(INFO) << "exe after prepare";
LOG(INFO) << "before RunPreparedContext";
exe.RunPreparedContext(prepared.get(), scope, false);
LOG(INFO) << "client wait for Pop";
auto task = (*task_queue_)[0]->Pop();
LOG(INFO) << "client get from task queue";
PADDLE_ENFORCE_EQ(
task.first, "x",
platform::errors::InvalidArgument(
"Recv message and Send message name not match, Check your Code"));
auto task2 = (*task_queue_)[0]->Pop();
PADDLE_ENFORCE_EQ(
task2.first, "x",
platform::errors::InvalidArgument(
"Recv message and Send message name not match, Check your Code"));
rpc_client->FinalizeWorker();
b_rpc_service2->Stop();
LOG(INFO) << "end server Stop";
server_thread.join();
LOG(INFO) << "end server thread join";
}
...@@ -165,9 +165,11 @@ void BindDistCommunicator(py::module* m) { ...@@ -165,9 +165,11 @@ void BindDistCommunicator(py::module* m) {
void BindHeterClient(py::module* m) { void BindHeterClient(py::module* m) {
py::class_<HeterClient, std::shared_ptr<HeterClient>>(*m, "HeterClient") py::class_<HeterClient, std::shared_ptr<HeterClient>>(*m, "HeterClient")
.def(py::init( .def(py::init([](const std::vector<std::string>& endpoints,
[](const std::vector<std::string>& endpoint, const int& trainer_id) { const std::vector<std::string>& previous_endpoints,
return HeterClient::GetInstance(endpoint, trainer_id); const int& trainer_id) {
return HeterClient::GetInstance(endpoints, previous_endpoints,
trainer_id);
})) }))
.def("stop", &HeterClient::Stop); .def("stop", &HeterClient::Stop);
} }
......
...@@ -2008,7 +2008,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2008,7 +2008,8 @@ All parameter, weight, gradient are variables in Paddle.
return self.GetWorkerScope(thread_id); return self.GetWorkerScope(thread_id);
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("finalize", &TrainerBase::Finalize); .def("finalize", &TrainerBase::Finalize)
.def("ResetDataset", &TrainerBase::ResetDataset);
m.def("_get_eager_deletion_vars", &framework::GetEagerDeletionCleanVars); m.def("_get_eager_deletion_vars", &framework::GetEagerDeletionCleanVars);
......
...@@ -68,11 +68,14 @@ server_num = fleet.server_num ...@@ -68,11 +68,14 @@ server_num = fleet.server_num
server_index = fleet.server_index server_index = fleet.server_index
server_endpoints = fleet.server_endpoints server_endpoints = fleet.server_endpoints
is_server = fleet.is_server is_server = fleet.is_server
is_heter_worker = fleet.is_heter_worker
util = UtilBase() util = UtilBase()
barrier_worker = fleet.barrier_worker barrier_worker = fleet.barrier_worker
init_worker = fleet.init_worker init_worker = fleet.init_worker
init_heter_worker = fleet.init_heter_worker
init_server = fleet.init_server init_server = fleet.init_server
run_server = fleet.run_server run_server = fleet.run_server
run_heter_worker = fleet.run_heter_worker
stop_worker = fleet.stop_worker stop_worker = fleet.stop_worker
distributed_optimizer = fleet.distributed_optimizer distributed_optimizer = fleet.distributed_optimizer
save_inference_model = fleet.save_inference_model save_inference_model = fleet.save_inference_model
......
...@@ -563,8 +563,25 @@ class Fleet(object): ...@@ -563,8 +563,25 @@ class Fleet(object):
fleet.is_server() fleet.is_server()
""" """
return self._role_maker._is_server( return self._role_maker._is_server()
) or self._role_maker._is_heter_worker()
def is_heter_worker(self):
"""
Check whether the node is an instance of heter worker.
Returns:
bool: True if this is a node of heter worker,
False if not.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.is_heter_worker()
"""
return self._role_maker._is_heter_worker()
def barrier_worker(self): def barrier_worker(self):
""" """
...@@ -600,6 +617,30 @@ class Fleet(object): ...@@ -600,6 +617,30 @@ class Fleet(object):
""" """
self._runtime_handle._init_worker() self._runtime_handle._init_worker()
@is_non_distributed_check
@inited_runtime_handler
def init_heter_worker(self):
"""
init_heter_worker executor to initialize startup program,
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.init_heter_worker()
"""
self._runtime_handle._init_heter_worker()
@is_non_distributed_check @is_non_distributed_check
@inited_runtime_handler @inited_runtime_handler
def init_server(self, *args, **kwargs): def init_server(self, *args, **kwargs):
...@@ -649,6 +690,31 @@ class Fleet(object): ...@@ -649,6 +690,31 @@ class Fleet(object):
""" """
self._runtime_handle.load_model(path, mode) self._runtime_handle.load_model(path, mode)
@is_non_distributed_check
@inited_runtime_handler
def run_heter_worker(self, dataset):
"""
run_heter_worker will run heter trainer main program with executor.
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
dataset = ""
if fleet.is_heter_worker():
fleet.run_heter_worker(dataset)
"""
self._runtime_handle._run_heter_worker(dataset)
@is_non_distributed_check @is_non_distributed_check
@inited_runtime_handler @inited_runtime_handler
def run_server(self): def run_server(self):
...@@ -1526,11 +1592,13 @@ class Fleet(object): ...@@ -1526,11 +1592,13 @@ class Fleet(object):
else: else:
apply_ir_passes(loss.block.program, startup_program, self) apply_ir_passes(loss.block.program, startup_program, self)
if not self._role_maker._is_heter_parameter_server_mode:
program = paddle.static.default_main_program() program = paddle.static.default_main_program()
opt_info = {} opt_info = {}
opt_info["mpi_size"] = self.worker_num() opt_info["mpi_size"] = self.worker_num()
opt_info["mpi_rank"] = self.worker_index() opt_info["mpi_rank"] = self.worker_index()
for k, v in self._user_defined_strategy.trainer_desc_configs.items(): for k, v in self._user_defined_strategy.trainer_desc_configs.items(
):
opt_info[k] = v opt_info[k] = v
program._fleet_opt = opt_info program._fleet_opt = opt_info
......
...@@ -371,11 +371,6 @@ class RoleMakerBase(object): ...@@ -371,11 +371,6 @@ class RoleMakerBase(object):
self._role = None self._role = None
self._current_id = -1 self._current_id = -1
# for heter parameter server mode
self._heter_trainer_endpoints = []
self._heter_trainer_device = "CPU"
self._is_heter_parameter_server_mode = False
def _is_worker(self): def _is_worker(self):
""" """
return is_worker() of current process return is_worker() of current process
...@@ -487,56 +482,55 @@ class RoleMakerBase(object): ...@@ -487,56 +482,55 @@ class RoleMakerBase(object):
""" """
print("warning: RoleMakerBase does not have barrier worker.") print("warning: RoleMakerBase does not have barrier worker.")
def _is_heter_worker(self): #def _is_heter_worker(self):
""" # """
Return is_heter_worker() of current process # Return is_heter_worker() of current process
""" # """
warnings.warn("RoleMakerBase does not have function: _is_heter_worker.") # raise NotImplementedError("Please implement this method in child class")
return False
#def _heter_worker_num(self):
def _heter_worker_num(self): # """
""" # Get current total heter-worker number.
Get current total heter-worker number. #
# Returns:
Returns: # int: heter_worker number
int: heter_worker number # """
""" # raise NotImplementedError("Please implement this method in child class")
warnings.warn(
"RoleMakerBase does not have function: _heter_worker_num.") #def _get_heter_worker_endpoints(self):
return 0 # """
# Returns:
def _get_heter_worker_endpoints(self): # string: all heter_trainers'endpoints
""" # """
Returns: # raise NotImplementedError("Please implement this method in child class")
string: all heter_trainers'endpoints
""" #def _get_heter_worker_endpoint(self):
assert self._heter_trainer_endpoints != [], "Heter Worker Endpoints Not initialized" # """
return self._heter_trainer_endpoints # Returns:
# int: corresponding heter_trainer's endpoint
def _get_heter_worker_endpoint(self): # """
""" # raise NotImplementedError("Please implement this method in child class")
Returns:
int: corresponding heter_trainer's endpoint
e.g: if we have 4 cpu-trainer(default), 2 gpu-trainer(heter)
then No.0 and No.2 cpu-trainer will work with No.0 gpu-trainer
and No.1 and No.3 cpu-trainer will work with No.1 gpu-trainer
"""
assert self._heter_trainer_endpoints != [], "Heter Worker Endpoints Not initialized"
return self._heter_trainer_endpoints[(self._current_id) %
self._heter_worker_num()]
class PaddleCloudRoleMaker(RoleMakerBase): class PaddleCloudRoleMaker(RoleMakerBase):
def __init__(self, is_collective=False, **kwargs): def __init__(self, is_collective=False, **kwargs):
super(PaddleCloudRoleMaker, self).__init__() super(PaddleCloudRoleMaker, self).__init__()
self._is_collective = is_collective self._is_collective = is_collective
self._non_distributed = False self._non_distributed = False
self._kwargs = kwargs self._kwargs = kwargs
self._role_is_generated = False self._role_is_generated = False
# for heterps
self._stage_id = 1
self._stage_num = 1
self._next_heter_trainer_endpoints = []
self._previous_heter_trainer_endpoints = []
self._heter_trainer_endpoints = []
self._heter_trainer_device = "CPU"
self._is_heter_parameter_server_mode = False
self._stage_trainers = []
self._server_endpoints = [] self._server_endpoints = []
self._worker_endpoints = [] self._worker_endpoints = []
...@@ -551,6 +545,38 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -551,6 +545,38 @@ class PaddleCloudRoleMaker(RoleMakerBase):
def _all_reduce(self, input, mode="sum", comm_world="worker"): def _all_reduce(self, input, mode="sum", comm_world="worker"):
return self._gloo.all_reduce(input, mode, comm_world) return self._gloo.all_reduce(input, mode, comm_world)
def _heter_device_type(self):
"""
return the heter device type that current heter worker is using
"""
if not self._role_is_generated:
self._generate_role()
return self._heter_trainer_device
def _get_stage_id(self):
"""
return stage id of current heter worker
"""
if not self._role_is_generated:
self._generate_role()
return self._stage_id
def _get_stage_trainers(self):
"""
return trainer num of all stages
"""
if not self._role_is_generated:
self._generate_role()
return self._stage_trainers
def _get_num_stage(self):
"""
return stage num
"""
if not self._role_is_generated:
self._generate_role()
return self._stage_num
def _is_worker(self): def _is_worker(self):
""" """
whether current process is worker whether current process is worker
...@@ -655,6 +681,32 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -655,6 +681,32 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._generate_role() self._generate_role()
return self._worker_endpoints return self._worker_endpoints
def _get_trainer_endpoint(self):
if not self._role_is_generated:
self._generate_role()
assert self._role == Role.WORKER, "get_trainer_endpoint should be called by trainer"
return self._cur_endpoint
def _get_heter_worker_endpoints(self):
"""
Returns:
string: all heter_trainers'endpoints
"""
if not self._role_is_generated:
self._generate_role()
assert self._heter_trainer_endpoints != [], "Heter Worker Endpoints Not initialized"
return self._heter_trainer_endpoints
def _get_heter_worker_endpoint(self):
"""
Returns:
int: corresponding heter_trainer's endpoint
"""
if not self._role_is_generated:
self._generate_role()
assert self._role == Role.HETER_WORKER, "_get_heter_worker_endpoint should be invoked by heter worker"
return self._cur_endpoint
def _get_pserver_endpoints(self): def _get_pserver_endpoints(self):
""" """
get endpoint of all pservers get endpoint of all pservers
...@@ -663,6 +715,28 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -663,6 +715,28 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._generate_role() self._generate_role()
return self._server_endpoints return self._server_endpoints
def _get_previous_trainers(self):
"""
invoked by heter worker
"""
if not self._role_is_generated:
self._generate_role()
assert self._role in (
Role.WORKER, Role.HETER_WORKER
), "_get_previous_trainers should be invoked by trainer or heter worker"
return self._previous_heter_trainer_endpoints
def _get_next_trainers(self):
"""
invoked by heter worker
"""
if not self._role_is_generated:
self._generate_role()
assert self._role in (
Role.WORKER, Role.HETER_WORKER
), "_get_next_trainers should be invoked by trainer or heter worker"
return self._next_heter_trainer_endpoints
def _is_non_distributed(self): def _is_non_distributed(self):
""" """
Return True if indispensable environment for fleetrun is not found Return True if indispensable environment for fleetrun is not found
...@@ -730,23 +804,67 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -730,23 +804,67 @@ class PaddleCloudRoleMaker(RoleMakerBase):
"TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER, but get {}, please check your environment.". "TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER, but get {}, please check your environment.".
format(training_role)) format(training_role))
# For heter parameter server env setting # For Heter Parameter Server env setting
heter_trainer_eplist = os.getenv("PADDLE_HETER_TRAINER_IP_PORT_LIST", next_heter_trainer_eplist = os.getenv(
"") "PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST", "")
if heter_trainer_eplist != "": previous_heter_trainer_eplist = os.getenv(
"PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST", "")
all_heter_trainer_eplist = os.getenv(
"PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST", "")
if all_heter_trainer_eplist != "":
self._heter_trainer_endpoints = all_heter_trainer_eplist.split(",")
self._is_heter_parameter_server_mode = True
self._heter_trainers_num = len(self._heter_trainer_endpoints)
if previous_heter_trainer_eplist == "":
assert training_role in (
"TRAINER", "PSERVER"
), "training_role should be trainer or pserver"
else:
try: try:
heter_trainer_eplist = os.environ[ self._previous_heter_trainer_endpoints = previous_heter_trainer_eplist.split(
"PADDLE_HETER_TRAINER_IP_PORT_LIST"].split(",") ",")
except: except:
raise ValueError( raise ValueError(
"Can not Find PADDLE_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ." "Can not Find PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ."
) )
self._is_heter_parameter_server_mode = True if next_heter_trainer_eplist == "":
heter_trainers_num = len(heter_trainer_eplist) assert training_role in (
"HETER_TRAINER", "PSERVER"
), "training_role should be heter trainer or pserver"
else:
try:
self._next_heter_trainer_endpoints = next_heter_trainer_eplist.split(
",")
except:
raise ValueError(
"Can not Find PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ."
)
#self._is_heter_parameter_server_mode = True
#heter_trainers_num = len(all_heter_trainer_eplist.split(","))
#self._heter_trainer_endpoints = all_heter_trainer_eplist.split(",")
else: else:
self._is_heter_parameter_server_mode = False self._is_heter_parameter_server_mode = False
heter_trainers_num = 0 self._heter_trainers_num = 0
#if previous_heter_trainer_eplist == "":
# self._is_heter_parameter_server_mode = False
# heter_trainers_num = 0
#else: ## for the last heter worker
# try:
# previous_heter_trainer_eplist = os.environ[
# "PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST"].split(",")
# self._previous_heter_trainer_endpoints = previous_heter_trainer_eplist
# except:
# raise ValueError(
# "Can not Find PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ."
# )
# self._is_heter_parameter_server_mode = True
# heter_trainers_num = len(all_heter_trainer_eplist.split(","))
# self._heter_trainer_endpoints = all_heter_trainer_eplist.split(",")
if training_role == "TRAINER": if training_role == "TRAINER":
role = Role.WORKER role = Role.WORKER
...@@ -756,22 +874,75 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -756,22 +874,75 @@ class PaddleCloudRoleMaker(RoleMakerBase):
"Can not find PADDLE_TRAINER_ID, please check your environment." "Can not find PADDLE_TRAINER_ID, please check your environment."
) )
current_id = int(current_id) current_id = int(current_id)
if len(self._worker_endpoints) > 0: if self._is_heter_parameter_server_mode:
self._cur_endpoint = self._worker_endpoints[current_id] self._stage_id = os.getenv("STAGE_ID", None)
if self._stage_id == None:
raise ValueError(
"Can not find STAGE_ID, please check your environment.")
self._stage_id = int(self._stage_id)
self._stage_num = os.getenv("STAGE_NUM", None)
if self._stage_num == None:
raise ValueError(
"Can not find STAGE_NUM, please check your environment.")
self._stage_num = int(self._stage_num)
self._stage_trainers = os.getenv("PADDLE_STAGE_TRAINERS_NUM",
None)
if self._stage_trainers == None:
raise ValueError(
"Can not find PADDLE_STAGE_TRAINERS_NUM, please check your environment."
)
self._stage_trainers = eval(self._stage_trainers)
cur_port = os.getenv("PADDLE_PORT", None)
if cur_port == None:
raise ValueError(
"Can not find PADDLE_PORT, please check your environment.")
cur_ip = os.getenv("POD_IP", None)
if cur_ip == None:
raise ValueError(
"Can not find POD_IP, please check your environment.")
curr_endpoint = ":".join([cur_ip, cur_port])
self._cur_endpoint = curr_endpoint
elif training_role == "PSERVER": elif training_role == "PSERVER":
role = Role.SERVER role = Role.SERVER
port = os.getenv("PADDLE_PORT", None) cur_port = os.getenv("PADDLE_PORT", None)
if port == None: if cur_port == None:
raise ValueError( raise ValueError(
"Can not find PADDLE_PORT, please check your environment.") "Can not find PADDLE_PORT, please check your environment.")
ip = os.getenv("POD_IP", None) cur_ip = os.getenv("POD_IP", None)
if ip == None: if cur_ip == None:
raise ValueError( raise ValueError(
"Can not find POD_IP, please check your environment.") "Can not find POD_IP, please check your environment.")
self._cur_endpoint = ip + ":" + port curr_endpoint = ":".join([cur_ip, cur_port])
self._cur_endpoint = curr_endpoint
current_id = self._server_endpoints.index(self._cur_endpoint) current_id = self._server_endpoints.index(self._cur_endpoint)
elif training_role == "HETER_TRAINER": elif training_role == "HETER_TRAINER":
role = Role.HETER_WORKER role = Role.HETER_WORKER
self._stage_id = os.getenv("STAGE_ID", None)
if self._stage_id == None:
raise ValueError(
"Can not find STAGE_ID, please check your environment.")
self._stage_id = int(self._stage_id)
self._stage_num = os.getenv("STAGE_NUM", None)
if self._stage_num == None:
raise ValueError(
"Can not find STAGE_NUM, please check your environment.")
self._stage_num = int(self._stage_num)
self._stage_trainers = os.getenv("PADDLE_STAGE_TRAINERS_NUM", None)
if self._stage_trainers == None:
raise ValueError(
"Can not find PADDLE_STAGE_TRAINERS_NUM, please check your environment."
)
self._stage_trainers = eval(self._stage_trainers)
self._heter_trainer_device = os.getenv("HETER_DEVICE_TYPE", None)
if self._heter_trainer_device == None:
raise ValueError(
"Can not find HETER_DEVICE_TYPE, please check your environment."
)
assert self._heter_trainer_device in (
"cpu", "gpu", "xpu"
), "HETER_DEVICE_TYPE should be cpu,gpu or xpu"
cur_port = os.getenv("PADDLE_PORT", None) cur_port = os.getenv("PADDLE_PORT", None)
if cur_port == None: if cur_port == None:
raise ValueError( raise ValueError(
...@@ -781,15 +952,15 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -781,15 +952,15 @@ class PaddleCloudRoleMaker(RoleMakerBase):
raise ValueError( raise ValueError(
"Can not find POD_IP, please check your environment.") "Can not find POD_IP, please check your environment.")
curr_endpoint = ":".join([cur_ip, cur_port]) curr_endpoint = ":".join([cur_ip, cur_port])
current_id = heter_trainer_eplist.index(curr_endpoint) self._cur_endpoint = curr_endpoint
current_id = all_heter_trainer_eplist.split(",").index(
curr_endpoint) + trainers_num
self._trainers_num = trainers_num self._trainers_num = trainers_num
self._role = role self._role = role
self._current_id = current_id self._current_id = current_id
self._nodes_num = len( self._nodes_num = len(
set([x.split(':')[0] for x in self._worker_endpoints])) set([x.split(':')[0] for x in self._worker_endpoints]))
self._heter_trainers_num = heter_trainers_num
self._heter_trainer_endpoints = heter_trainer_eplist
def _collective_env(self): def _collective_env(self):
self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
......
...@@ -200,6 +200,11 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra ...@@ -200,6 +200,11 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
type=str, type=str,
default="", default="",
help="User defined heter workers ip:port") help="User defined heter workers ip:port")
ps_group.add_argument(
"--heter_devices",
type=str,
default="",
help="User defined heter devices")
ps_group.add_argument("--worker_num", type=int, help="number of workers") ps_group.add_argument("--worker_num", type=int, help="number of workers")
ps_group.add_argument("--server_num", type=int, help="number of servers") ps_group.add_argument("--server_num", type=int, help="number of servers")
...@@ -353,11 +358,11 @@ def launch_ps(args, distribute_mode): ...@@ -353,11 +358,11 @@ def launch_ps(args, distribute_mode):
if cloud_flag and distribute_mode == DistributeMode.PS: if cloud_flag and distribute_mode == DistributeMode.PS:
direct_start(args) direct_start(args)
return return
elif cloud_flag and distribute_mode == DistributeMode.PS_HETER: #elif cloud_flag and distribute_mode == DistributeMode.PS_HETER:
cloud_ps_heter_env_set(args) # cloud_ps_heter_env_set(args)
args.workers = os.getenv("PADDLE_TRAINER_ENDPOINTS") # args.workers = os.getenv("PADDLE_TRAINER_ENDPOINTS")
args.servers = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST") # args.servers = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST")
args.heter_workers = os.getenv("PADDLE_HETER_TRAINER_IP_PORT_LIST") # args.heter_workers = os.getenv("PADDLE_HETER_TRAINER_IP_PORT_LIST")
ps_launcher = ParameterServerLauncher(args, distribute_mode) ps_launcher = ParameterServerLauncher(args, distribute_mode)
ps_launcher.start_ps() ps_launcher.start_ps()
...@@ -390,11 +395,11 @@ def which_distributed_mode(args): ...@@ -390,11 +395,11 @@ def which_distributed_mode(args):
ps_args = [ ps_args = [
'--worker_num', '--server_num', '--heter_worker_num', '--servers', '--worker_num', '--server_num', '--heter_worker_num', '--servers',
'--workers', '--heter_workers', '--http_port' '--workers', '--heter_workers', '--heter_devices', '--http_port'
] ]
collective_args = ['--ips'] collective_args = ['--ips']
ps_heter_args = ["--heter_worker_num", "--heter_workers"] ps_heter_args = ["--heter_worker_num", "--heter_workers", "--heter_devices"]
has_ps_args = [ has_ps_args = [
ps_arg for ps_arg in ps_args if ps_arg in " ".join(sys.argv[1:-1]) ps_arg for ps_arg in ps_args if ps_arg in " ".join(sys.argv[1:-1])
......
...@@ -147,6 +147,7 @@ class Trainer(object): ...@@ -147,6 +147,7 @@ class Trainer(object):
self.accelerators = [] self.accelerators = []
self.endpoint = None self.endpoint = None
self.rank = None self.rank = None
self.stage = None
def __str__(self): def __str__(self):
return "accelerator:{} endpoint:{} rank:{}".format( return "accelerator:{} endpoint:{} rank:{}".format(
...@@ -921,10 +922,15 @@ class ParameterServerLauncher(object): ...@@ -921,10 +922,15 @@ class ParameterServerLauncher(object):
self.is_local = True self.is_local = True
self.current_node_ip = "" self.current_node_ip = ""
self.stage_trainer_num = []
self.stage_heter_map = {}
self.stage_list = []
self.stage_device_map = {}
self.stage_num = 0
self.get_role_endpoints(args) self.get_role_endpoints(args)
def get_role_endpoints(self, args): def get_role_endpoints(self, args):
# get server envs
if args.server_num: if args.server_num:
self.server_num = args.server_num self.server_num = args.server_num
if args.servers: if args.servers:
...@@ -981,35 +987,140 @@ class ParameterServerLauncher(object): ...@@ -981,35 +987,140 @@ class ParameterServerLauncher(object):
else: else:
self.worker_endpoints = args.workers self.worker_endpoints = args.workers
# get http_port
if args.http_port:
self.http_port = args.http_port
else:
http_port = get_ports(1, self.server_num + self.worker_num)
http_ip = self.server_endpoints.split(",")[0].split(":")[0]
self.http_port = http_ip + ":" + str(http_port[0])
# get heter worker envs # get heter worker envs
if self.distribute_mode == DistributeMode.PS_HETER: if self.distribute_mode == DistributeMode.PS_HETER:
assert args.heter_devices != "", "The setting of Parameter-Server heter mode must has heter_devices."
self.stage_device_map[1] = "cpu" # for cpu trainer
heter_devices_list = args.heter_devices.split(";")
for i in range(len(heter_devices_list)):
self.stage_device_map[i + 2] = heter_devices_list[i]
self.stage_heter_map[1] = self.worker_endpoints
if args.heter_worker_num: if args.heter_worker_num:
self.heter_worker_num = args.heter_worker_num self.stage_heter_trainer_num = args.heter_worker_num.split(",")
self.stage_heter_trainer_num = [
int(trainer_num)
for trainer_num in self.stage_heter_trainer_num
]
if args.heter_workers: if args.heter_workers:
assert len(args.heter_workers.split(";")) == len(
self.stage_heter_trainer_num
), "The stage_num and heter_workers doesn't match. Expect heter_workers endpoints stage num epual to heter_worker_num stage, but received heter_workers enpoint stage num: {} and heter_worker_num stage {}".format(
len(args.heter_workers.split(";")),
len(self.stage_heter_trainer_num))
heter_worker_endpoints_list = args.heter_workers.split(";")
self.heter_worker_endpoints = ""
for i in range(len(self.stage_heter_trainer_num)):
if self.heter_worker_endpoints != "":
self.heter_worker_endpoints += ","
heter_worker_endpoints = heter_worker_endpoints_list[
i].split(",")
assert len( assert len(
args.heter_workers.split(",") heter_worker_endpoints
) == self.heter_worker_num, "The heter_worker_num and heter_workers doesn't match. Expect heter_workers endpoints num epual to heter_worker_num, but received heter_workers enpoint num: {} and heter_worker_num {}".format( ) == self.stage_heter_trainer_num[
len(args.heter_workers.split(",")), i], "The heter trainer num in stage {} is not equal in args.heter_worker_num and args.heter_workers".format(
self.heter_worker_num) i)
self.heter_worker_endpoints = args.heter_workers
heter_worker_endpoints_ips = [
x.strip().split(":")[0]
for x in heter_worker_endpoints
]
heter_worker_endpoints_len = [
len(x.strip().split(":"))
for x in heter_worker_endpoints
]
if 1 in heter_worker_endpoints_len:
# if no port value in heter_worker_endpoint, will set default port values.
heter_worker_endpoints_port = get_ports(
len(heter_worker_endpoints_ips), self.worker_num
+ self.server_num + self.heter_worker_num)
new_heter_worker_endpoints = []
for j in range(len(heter_worker_endpoints_ips)):
new_heter_worker_endpoints.append(":".join((
heter_worker_endpoints_ips[j], str(
heter_worker_endpoints_port[j]))))
ip_port_list = ",".join(new_heter_worker_endpoints)
else: else:
ports = get_ports(self.heter_worker_num, ip_port_list = ",".join(heter_worker_endpoints)
self.server_num + self.worker_num)
self.heter_worker_endpoints = ",".join( self.stage_heter_map[i + 2] = ip_port_list
self.stage_list.extend([i + 2] *
len(ip_port_list.split(',')))
self.heter_worker_num += self.stage_heter_trainer_num[i]
self.heter_worker_endpoints += ip_port_list
else:
for i in range(len(self.stage_heter_trainer_num)):
heter_trainer_num = self.stage_heter_trainer_num[i]
ports = get_ports(heter_trainer_num,
self.server_num + self.worker_num +
self.heter_worker_num)
ip_port_list = ",".join(
["127.0.0.1:" + str(x) for x in ports]) ["127.0.0.1:" + str(x) for x in ports])
self.stage_heter_map[i + 2] = ip_port_list
self.stage_list.extend([i + 2] *
len(ip_port_list.split(',')))
self.heter_worker_num += heter_trainer_num
if self.heter_worker_endpoints != "":
self.heter_worker_endpoints += ","
self.heter_worker_endpoints += ip_port_list
else: else:
assert args.heter_workers != "", "The setting of Parameter-Server heter mode must has heter_worker_num or heter_workers." assert args.heter_workers != "", "The setting of Parameter-Server heter mode must has heter_worker_num or heter_workers."
self.heter_worker_endpoints = args.heter_workers self.stage_heter_trainer_num = []
self.heter_worker_num = len( heter_worker_endpoints_list = args.heter_workers.split(";")
self.heter_worker_endpoints.split(",")) self.heter_worker_endpoints = ""
for i in range(len(heter_worker_endpoints_list)):
if self.heter_worker_endpoints != "":
self.heter_worker_endpoints += ","
heter_worker_endpoints = heter_worker_endpoints_list[
i].split(",")
self.stage_heter_trainer_num.append(
len(heter_worker_endpoints))
heter_worker_endpoints_ips = [
x.strip().split(":")[0] for x in heter_worker_endpoints
]
heter_worker_endpoints_len = [
len(x.strip().split(":"))
for x in heter_worker_endpoints
]
if 1 in heter_worker_endpoints_len:
# if no port value in heter_worker_endpoint, will set default port values.
heter_worker_endpoints_port = get_ports(
len(heter_worker_endpoints_ips), self.worker_num +
self.server_num + self.heter_worker_num)
new_heter_worker_endpoints = []
for j in range(len(heter_worker_endpoints_ips)):
new_heter_worker_endpoints.append(":".join((
heter_worker_endpoints_ips[j], str(
heter_worker_endpoints_port[j]))))
ip_port_list = ",".join(new_heter_worker_endpoints)
else:
ip_port_list = ",".join(heter_worker_endpoints)
self.stage_heter_map[i + 2] = ip_port_list
self.stage_list.extend([i + 2] *
len(ip_port_list.split(',')))
self.heter_worker_num += self.stage_heter_trainer_num[-1]
if self.heter_worker_endpoints != "":
self.heter_worker_endpoints += ","
self.heter_worker_endpoints += ip_port_list
self.stage_trainer_num = [self.worker_num
] + self.stage_heter_trainer_num
self.stage_num = len(self.stage_trainer_num)
# get http_port
if args.http_port:
self.http_port = args.http_port
else:
http_port = get_ports(
1, self.server_num + self.worker_num + self.heter_worker_num)
http_ip = self.server_endpoints.split(",")[0].split(":")[0]
self.http_port = http_ip + ":" + str(http_port[0])
# check local or user define # check local or user define
self.server_endpoints_ips = [ self.server_endpoints_ips = [
...@@ -1024,8 +1135,14 @@ class ParameterServerLauncher(object): ...@@ -1024,8 +1135,14 @@ class ParameterServerLauncher(object):
self.worker_endpoints_port = [ self.worker_endpoints_port = [
x.strip().split(":")[1] for x in self.worker_endpoints.split(",") x.strip().split(":")[1] for x in self.worker_endpoints.split(",")
] ]
self.node_ips = list( self.node_ips = []
set(self.server_endpoints_ips + self.worker_endpoints_ips)) for ip in self.server_endpoints_ips:
if ip not in self.node_ips:
self.node_ips.append(ip)
for ip in self.worker_endpoints_ips:
if ip not in self.node_ips:
self.node_ips.append(ip)
if self.distribute_mode == DistributeMode.PS_HETER: if self.distribute_mode == DistributeMode.PS_HETER:
self.heter_worker_endpoints_ips = [ self.heter_worker_endpoints_ips = [
x.strip().split(":")[0] x.strip().split(":")[0]
...@@ -1035,8 +1152,9 @@ class ParameterServerLauncher(object): ...@@ -1035,8 +1152,9 @@ class ParameterServerLauncher(object):
x.strip().split(":")[1] x.strip().split(":")[1]
for x in self.heter_worker_endpoints.split(",") for x in self.heter_worker_endpoints.split(",")
] ]
self.node_ips = list( for ip in self.heter_worker_endpoints_ips:
set(self.node_ips + self.heter_worker_endpoints_ips)) if ip not in self.node_ips:
self.node_ips.append(ip)
if len(set(self.node_ips)) == 1: if len(set(self.node_ips)) == 1:
self.is_local = True self.is_local = True
...@@ -1061,7 +1179,6 @@ class ParameterServerLauncher(object): ...@@ -1061,7 +1179,6 @@ class ParameterServerLauncher(object):
server_rank = 0 server_rank = 0
worker_rank = 0 worker_rank = 0
heter_worker_rank = 0 heter_worker_rank = 0
for node_rank, ip in enumerate(self.node_ips): for node_rank, ip in enumerate(self.node_ips):
pod = Pod() pod = Pod()
pod.rank = node_rank pod.rank = node_rank
...@@ -1080,6 +1197,7 @@ class ParameterServerLauncher(object): ...@@ -1080,6 +1197,7 @@ class ParameterServerLauncher(object):
worker.endpoint = "%s:%s" % (ip, worker.endpoint = "%s:%s" % (ip,
self.worker_endpoints_port[j]) self.worker_endpoints_port[j])
worker.rank = worker_rank worker.rank = worker_rank
worker.stage = 1
worker_rank += 1 worker_rank += 1
pod.workers.append(worker) pod.workers.append(worker)
for k in range(len(self.heter_worker_endpoints_ips)): for k in range(len(self.heter_worker_endpoints_ips)):
...@@ -1088,6 +1206,7 @@ class ParameterServerLauncher(object): ...@@ -1088,6 +1206,7 @@ class ParameterServerLauncher(object):
heter_worker.endpoint = "%s:%s" % ( heter_worker.endpoint = "%s:%s" % (
ip, self.heter_worker_endpoints_port[k]) ip, self.heter_worker_endpoints_port[k])
heter_worker.rank = heter_worker_rank heter_worker.rank = heter_worker_rank
heter_worker.stage = self.stage_list[k]
heter_worker_rank += 1 heter_worker_rank += 1
pod.heter_workers.append(heter_worker) pod.heter_workers.append(heter_worker)
...@@ -1153,16 +1272,32 @@ class ParameterServerLauncher(object): ...@@ -1153,16 +1272,32 @@ class ParameterServerLauncher(object):
current_env.pop("http_proxy", None) current_env.pop("http_proxy", None)
current_env.pop("https_proxy", None) current_env.pop("https_proxy", None)
for idx, cur_server in enumerate(pod.servers): for idx, cur_server in enumerate(pod.servers):
if self.distribute_mode == DistributeMode.PS_HETER:
proc_env = { proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints, "PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints, "PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
"PADDLE_HETER_TRAINER_IP_PORT_LIST": "PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST":
self.heter_worker_endpoints, self.heter_worker_endpoints,
"PADDLE_PORT": cur_server.endpoint.split(":")[1], "PADDLE_PORT": cur_server.endpoint.split(":")[1],
"TRAINING_ROLE": "PSERVER", "TRAINING_ROLE": "PSERVER",
"PADDLE_TRAINERS_NUM": str(self.worker_num), "PADDLE_TRAINERS_NUM": str(self.worker_num),
"POD_IP": cur_server.endpoint.split(":")[0], "POD_IP": cur_server.endpoint.split(":")[0],
"PADDLE_WITH_GLOO": str(os.getenv("PADDLE_WITH_GLOO", "0")), "PADDLE_WITH_GLOO":
str(os.getenv("PADDLE_WITH_GLOO", "0")),
"PADDLE_GLOO_RENDEZVOUS": "3",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"PADDLE_GLOO_HTTP_ENDPOINT": self.http_port
}
else:
proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
"PADDLE_PORT": cur_server.endpoint.split(":")[1],
"TRAINING_ROLE": "PSERVER",
"PADDLE_TRAINERS_NUM": str(self.worker_num),
"POD_IP": cur_server.endpoint.split(":")[0],
"PADDLE_WITH_GLOO":
str(os.getenv("PADDLE_WITH_GLOO", "0")),
"PADDLE_GLOO_RENDEZVOUS": "3", "PADDLE_GLOO_RENDEZVOUS": "3",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"PADDLE_GLOO_HTTP_ENDPOINT": self.http_port "PADDLE_GLOO_HTTP_ENDPOINT": self.http_port
...@@ -1216,17 +1351,47 @@ class ParameterServerLauncher(object): ...@@ -1216,17 +1351,47 @@ class ParameterServerLauncher(object):
device_list = [str(x) for x in range(0, heter_device_num)] device_list = [str(x) for x in range(0, heter_device_num)]
for idx, cur_worker in enumerate(pod.workers): for idx, cur_worker in enumerate(pod.workers):
device_id = "0" if heter_device_num == 0 else str(device_list[ device_id = "0" if heter_device_num == 0 else str(device_list[(
idx % heter_device_num]) idx) % heter_device_num])
if self.distribute_mode == DistributeMode.PS_HETER:
proc_env = { proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints, "PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints, "PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
"PADDLE_TRAINERS_NUM": str(self.worker_num), "PADDLE_TRAINERS_NUM": str(self.worker_num),
"PADDLE_HETER_TRAINER_IP_PORT_LIST": "PADDLE_STAGE_TRAINERS_NUM": str(self.stage_trainer_num),
"STAGE_ID": "1",
"STAGE_NUM": str(self.stage_num),
"PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST": "",
"PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST":
self.stage_heter_map[2],
"PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST":
self.heter_worker_endpoints, self.heter_worker_endpoints,
"HETER_DEVICE_TYPE": self.stage_device_map[1],
"TRAINING_ROLE": "TRAINER", "TRAINING_ROLE": "TRAINER",
"POD_IP": cur_worker.endpoint.split(":")[0],
"PADDLE_PORT": cur_worker.endpoint.split(":")[1],
"PADDLE_TRAINER_ID": str(cur_worker.rank), "PADDLE_TRAINER_ID": str(cur_worker.rank),
"PADDLE_WITH_GLOO": str(os.getenv("PADDLE_WITH_GLOO", "0")), "PADDLE_WITH_GLOO":
str(os.getenv("PADDLE_WITH_GLOO", "0")),
"PADDLE_GLOO_RENDEZVOUS": "3",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": "0",
"FLAGS_selected_xpus": "0",
"CUDA_VISIBLE_DEVICES": device_id,
"XPU_VISIBLE_DEVICES": device_id,
"PADDLE_GLOO_HTTP_ENDPOINT": self.http_port
}
else:
proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
"PADDLE_TRAINERS_NUM": str(self.worker_num),
"TRAINING_ROLE": "TRAINER",
"POD_IP": cur_worker.endpoint.split(":")[0],
"PADDLE_PORT": cur_worker.endpoint.split(":")[1],
"PADDLE_TRAINER_ID": str(cur_worker.rank),
"PADDLE_WITH_GLOO":
str(os.getenv("PADDLE_WITH_GLOO", "0")),
"PADDLE_GLOO_RENDEZVOUS": "3", "PADDLE_GLOO_RENDEZVOUS": "3",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": "0", "FLAGS_selected_gpus": "0",
...@@ -1235,8 +1400,8 @@ class ParameterServerLauncher(object): ...@@ -1235,8 +1400,8 @@ class ParameterServerLauncher(object):
"XPU_VISIBLE_DEVICES": device_id, "XPU_VISIBLE_DEVICES": device_id,
"PADDLE_GLOO_HTTP_ENDPOINT": self.http_port "PADDLE_GLOO_HTTP_ENDPOINT": self.http_port
} }
current_env.update(proc_env)
current_env.update(proc_env)
cmd = [sys.executable, "-u", args.training_script cmd = [sys.executable, "-u", args.training_script
] + args.training_script_args ] + args.training_script_args
self.cmds["worker"].append(cmd) self.cmds["worker"].append(cmd)
...@@ -1282,19 +1447,28 @@ class ParameterServerLauncher(object): ...@@ -1282,19 +1447,28 @@ class ParameterServerLauncher(object):
elif fluid.core.is_compiled_with_xpu(): elif fluid.core.is_compiled_with_xpu():
heter_device_num = fluid.core.get_xpu_device_count() heter_device_num = fluid.core.get_xpu_device_count()
device_list = [str(x) for x in range(0, heter_device_num)] device_list = [str(x) for x in range(0, heter_device_num)]
if heter_device_num == 0:
return
for idx, cur_heter_worker in enumerate(pod.heter_workers): for idx, cur_heter_worker in enumerate(pod.heter_workers):
device_id = str(device_list[idx % heter_device_num]) device_id = "0" if heter_device_num == 0 else str(device_list[(
idx) % heter_device_num])
stage_id = cur_heter_worker.stage
proc_env = { proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints, "PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints, "PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
"PADDLE_HETER_TRAINER_IP_PORT_LIST": "PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST":
self.stage_heter_map[stage_id + 1]
if stage_id <= self.stage_num - 1 else "",
"PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST":
self.stage_heter_map[stage_id - 1],
"PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST":
self.heter_worker_endpoints, self.heter_worker_endpoints,
"HETER_DEVICE_TYPE": self.stage_device_map[stage_id],
"STAGE_ID": str(stage_id),
"STAGE_NUM": str(self.stage_num),
"PADDLE_PORT": cur_heter_worker.endpoint.split(":")[1], "PADDLE_PORT": cur_heter_worker.endpoint.split(":")[1],
"TRAINING_ROLE": "HETER_TRAINER", "TRAINING_ROLE": "HETER_TRAINER",
"PADDLE_TRAINERS_NUM": str(self.worker_num), "PADDLE_TRAINERS_NUM": str(self.worker_num),
"PADDLE_STAGE_TRAINERS_NUM": str(self.stage_trainer_num),
"POD_IP": cur_heter_worker.endpoint.split(":")[0], "POD_IP": cur_heter_worker.endpoint.split(":")[0],
"PADDLE_WITH_GLOO": str(os.getenv("PADDLE_WITH_GLOO", "0")), "PADDLE_WITH_GLOO": str(os.getenv("PADDLE_WITH_GLOO", "0")),
"PADDLE_GLOO_RENDEZVOUS": "3", "PADDLE_GLOO_RENDEZVOUS": "3",
......
...@@ -30,6 +30,16 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -30,6 +30,16 @@ class ParameterServerOptimizer(MetaOptimizerBase):
# we do not allow meta optimizer to be inner optimizer currently # we do not allow meta optimizer to be inner optimizer currently
self.meta_optimizers_white_list = [] self.meta_optimizers_white_list = []
def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
user_defined_strategy):
super(ParameterServerOptimizer, self)._set_basic_info(
loss, role_maker, user_defined_optimizer, user_defined_strategy)
#self.micro_batch_size = user_defined_strategy.pipeline_configs[
# 'micro_batch_size']
self.num_microbatches = user_defined_strategy.pipeline_configs[
'accumulate_steps']
def _is_graph_out(self): def _is_graph_out(self):
return False return False
...@@ -97,7 +107,7 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -97,7 +107,7 @@ class ParameterServerOptimizer(MetaOptimizerBase):
if not use_ps_gpu: if not use_ps_gpu:
_main = worker.delete_optimizer_pass(_main, compiled_config) _main = worker.delete_optimizer_pass(_main, compiled_config)
_main = worker.append_send_ops_pass(_main, compiled_config) _main = worker.append_send_ops_pass(_main, compiled_config)
_startup = worker.delet_extra_optimizes_pass(_startup, _startup = worker.delete_extra_optimizes_pass(_startup,
compiled_config) compiled_config)
# for startup program # for startup program
...@@ -122,15 +132,14 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -122,15 +132,14 @@ class ParameterServerOptimizer(MetaOptimizerBase):
from paddle.fluid.incubate.fleet.parameter_server.ir import heter_trainer_pass as heter_worker from paddle.fluid.incubate.fleet.parameter_server.ir import heter_trainer_pass as heter_worker
if self.role_maker._is_heter_worker(): if self.role_maker._is_heter_worker():
# for heter worker # for heter worker
stage_id = self.role_maker._get_stage_id()
device = self.role_maker._heter_device_type().lower()
_main = heter_worker.split_heter_worker_ops_pass( _main = heter_worker.split_heter_worker_ops_pass(
_main, compiled_config) _main, compiled_config, stage_id, device)
else: else:
# for default worker # for default worker
_main = heter_worker.split_trainer_ops_pass(_main, _main = heter_worker.split_trainer_ops_pass(_main,
compiled_config) compiled_config)
# for startup change
_startup = heter_worker.delete_startup_useless_ops_var_pass(
_startup, _main, compiled_config)
else: else:
_main = worker.append_send_ops_pass(_main, compiled_config) _main = worker.append_send_ops_pass(_main, compiled_config)
_startup = _startup _startup = _startup
...@@ -319,22 +328,53 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -319,22 +328,53 @@ class ParameterServerOptimizer(MetaOptimizerBase):
if self.role_maker._is_worker() or self.role_maker._is_heter_worker(): if self.role_maker._is_worker() or self.role_maker._is_heter_worker():
main_program, startup_program = self._build_trainer_programs( main_program, startup_program = self._build_trainer_programs(
compiled_config) compiled_config)
if self.role_maker._is_heter_parameter_server_mode:
_origin_startup_program._heter_pipeline_opt = {
"startup_program": startup_program,
}
loss.block.program._heter_pipeline_opt = {
"trainer": "HeterPipelineTrainer",
"device_worker": "HeterSection",
"trainers": self.role_maker._get_stage_trainers(
), ## trainer num in each stage
"trainer_id": int(self.role_maker._role_id()),
"pipeline_stage": int(self.role_maker._get_stage_id()) - 1,
"num_pipeline_stages":
int(self.role_maker._get_num_stage()),
"section_program": main_program,
"num_microbatches": self.num_microbatches,
}
else:
loss.block.program = main_program
fluid.framework.switch_startup_program(startup_program)
elif self.role_maker._is_server(): elif self.role_maker._is_server():
main_program, startup_program = self._build_pserver_programs( main_program, startup_program = self._build_pserver_programs(
compiled_config) compiled_config)
loss.block.program = main_program loss.block.program = main_program
fluid.framework.switch_startup_program(startup_program) fluid.framework.switch_startup_program(startup_program)
return None, None return None, None
def _disable_strategy(self, dist_strategy): def _disable_strategy(self, dist_strategy):
#if self.role_maker._is_heter_parameter_server_mode:
# dist_strategy.pipeline = False
# dist_strategy.pipeline_configs = {
# "micro_batch_size": 1,
# "accumulate_steps": 1,
# }
dist_strategy.a_sync = False dist_strategy.a_sync = False
a_sync_configs = dist_strategy.a_sync_configs a_sync_configs = dist_strategy.a_sync_configs
a_sync_configs["k_steps"] = -1 a_sync_configs["k_steps"] = -1
dist_strategy.a_sync_configs = a_sync_configs dist_strategy.a_sync_configs = a_sync_configs
def _enable_strategy(self, dist_strategy, context): def _enable_strategy(self, dist_strategy, context):
#if self.role_maker._is_heter_parameter_server_mode:
# dist_strategy.pipeline = True
# dist_strategy.pipeline_configs = {
# "micro_batch_size": 1,
# "accumulate_steps": 1,
# }
a_sync_configs = dist_strategy.a_sync_configs a_sync_configs = dist_strategy.a_sync_configs
if a_sync_configs["k_steps"] >= 0: if a_sync_configs["k_steps"] >= 0:
return return
......
...@@ -528,6 +528,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -528,6 +528,7 @@ class TheOnePSRuntime(RuntimeBase):
split_dense_table=self.role_maker._is_heter_parameter_server_mode) split_dense_table=self.role_maker._is_heter_parameter_server_mode)
send_ctx = self.compiled_strategy.get_the_one_send_context( send_ctx = self.compiled_strategy.get_the_one_send_context(
split_dense_table=self.role_maker._is_heter_parameter_server_mode, split_dense_table=self.role_maker._is_heter_parameter_server_mode,
use_origin_program=True,
ep_list=endpoints) ep_list=endpoints)
trainer_config = self.async_strategy.get_trainer_runtime_config() trainer_config = self.async_strategy.get_trainer_runtime_config()
...@@ -545,8 +546,8 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -545,8 +546,8 @@ class TheOnePSRuntime(RuntimeBase):
kwargs['need_global_step'] = "0" kwargs['need_global_step'] = "0"
kwargs["trainer_id"] = self.role_maker._role_id() kwargs["trainer_id"] = self.role_maker._role_id()
kwargs["trainers"] = self.role_maker._worker_num() kwargs["trainers"] = self.role_maker._worker_num()
if self.role_maker._is_heter_worker(): #if self.role_maker._is_heter_worker():
kwargs["trainer_id"] += kwargs["trainers"] # kwargs["trainer_id"] += kwargs["trainers"]
for table in server.servers[0].tables: for table in server.servers[0].tables:
if table.table_class == "BarrierTable": if table.table_class == "BarrierTable":
...@@ -589,14 +590,18 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -589,14 +590,18 @@ class TheOnePSRuntime(RuntimeBase):
if launch_barrier and launch_barrier_flag: if launch_barrier and launch_barrier_flag:
# for trainer wait server ready # for trainer wait server ready
wait_server_ready(self.role_maker._get_pserver_endpoints()) wait_server_ready(self.role_maker._get_pserver_endpoints())
if self.role_maker._is_heter_parameter_server_mode and self.role_maker._get_next_trainers(
# for ps-heter mode, wait heter worker ready ) != []:
if self.role_maker._is_heter_parameter_server_mode and self.role_maker._is_worker( wait_server_ready(self.role_maker._get_next_trainers())
): if self.role_maker._is_heter_parameter_server_mode:
wait_server_ready(self.role_maker._get_heter_worker_endpoints()) previous_trainers = []
if self.role_maker._get_previous_trainers() != []:
self._heter_client = HeterClient( previous_trainers = self.role_maker._get_previous_trainers()
self.role_maker._get_heter_worker_endpoints(), next_trainers = []
if self.role_maker._get_next_trainers() != []:
next_trainers = self.role_maker._get_next_trainers()
self._heter_client = HeterClient(next_trainers,
previous_trainers,
self.role_maker._role_id()) self.role_maker._role_id())
def _push_sparse_param(self, def _push_sparse_param(self,
...@@ -608,18 +613,16 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -608,18 +613,16 @@ class TheOnePSRuntime(RuntimeBase):
def _get_executor(self): def _get_executor(self):
executor = fluid.Executor(fluid.CPUPlace()) executor = fluid.Executor(fluid.CPUPlace())
if self.role_maker._is_heter_parameter_server_mode: if self.role_maker._is_heter_parameter_server_mode:
heter_worker_device_guard = self.context[
"valid_strategy"].a_sync_configs[
"heter_worker_device_guard"].upper()
if heter_worker_device_guard not in ["GPU", "XPU", "CPU"]:
raise ValueError("Heter Worker Not Support Device {}".format(
heter_worker_device_guard))
if self.role_maker._is_heter_worker(): if self.role_maker._is_heter_worker():
if heter_worker_device_guard == "GPU": heter_device_type = self.role_maker._heter_device_type().upper()
if heter_device_type not in ["GPU", "XPU", "CPU"]:
raise ValueError("Heter Worker Not Support Device {}".
format(device_type))
if heter_device_type == "GPU":
executor = Executor( executor = Executor(
fluid.CUDAPlace( fluid.CUDAPlace(
int(os.getenv("FLAGS_selected_gpus", "0")))) int(os.getenv("FLAGS_selected_gpus", "0"))))
elif heter_worker_device_guard == "XPU": elif heter_device_type == "XPU":
executor = Executor( executor = Executor(
fluid.XPUPlace( fluid.XPUPlace(
int(os.getenv("FLAGS_selected_xpus", "0")))) int(os.getenv("FLAGS_selected_xpus", "0"))))
...@@ -813,14 +816,12 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -813,14 +816,12 @@ class TheOnePSRuntime(RuntimeBase):
return worker return worker
def _init_server(self, dirname=None, var_names=None, **kwargs): def _init_server(self, dirname=None, var_names=None, **kwargs):
if self.role_maker._is_heter_worker():
self._init_heter_worker()
return
role_id = self.compiled_strategy.get_role_id() role_id = self.compiled_strategy.get_role_id()
endpoints = self.compiled_strategy.get_ps_endpoints() endpoints = self.compiled_strategy.get_ps_endpoints()
is_sync = self.compiled_strategy.is_sync_mode() is_sync = self.compiled_strategy.is_sync_mode()
trainers = self.compiled_strategy.get_trainers() trainers = self.compiled_strategy.get_trainers()
if self.role_maker._is_heter_parameter_server_mode:
trainers += len(self.role_maker._get_heter_worker_endpoints())
server = self._get_fleet_proto(is_server=True, is_sync=is_sync) server = self._get_fleet_proto(is_server=True, is_sync=is_sync)
proto_txt = str(server) proto_txt = str(server)
...@@ -875,22 +876,35 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -875,22 +876,35 @@ class TheOnePSRuntime(RuntimeBase):
self._server.load_sparse(dirname, "0", table_id) self._server.load_sparse(dirname, "0", table_id)
def _run_server(self): def _run_server(self):
if self.role_maker._is_heter_worker():
self._run_heter_worker()
return
ep = self.compiled_strategy.get_ps_endpoint() ep = self.compiled_strategy.get_ps_endpoint()
host, port = ep.split(":") host, port = ep.split(":")
self._server.run_server(host, int(port)) self._server.run_server(host, int(port))
def _init_heter_worker(self): def _init_heter_worker(self):
executor = self._get_executor() executor = self._get_executor()
executor.run(fluid.default_startup_program()) startup_program = fluid.default_startup_program()
#real_startup_program = startup_program._heter_pipeline_opt[
# "startup_program"]
executor.run(startup_program)
self._init_worker() self._init_worker()
def _run_heter_worker(self): def _run_heter_worker(self,
dataset=None,
scope=None,
thread=0,
debug=False,
fetch_list=None,
fetch_info=None,
print_period=100,
fetch_handler=None):
executor = self._get_executor() executor = self._get_executor()
executor.run(fluid.default_main_program()) executor.train_from_dataset(
program=fluid.default_main_program(),
dataset=dataset,
debug=debug,
fetch_list=fetch_list,
fetch_info=fetch_info,
print_period=print_period)
def _stop_worker(self): def _stop_worker(self):
self._communicator.stop() self._communicator.stop()
......
...@@ -93,7 +93,7 @@ from .dygraph.varbase_patch_methods import monkey_patch_varbase ...@@ -93,7 +93,7 @@ from .dygraph.varbase_patch_methods import monkey_patch_varbase
from . import generator from . import generator
from .core import _cuda_synchronize from .core import _cuda_synchronize
from .generator import Generator from .generator import Generator
from .trainer_desc import TrainerDesc, DistMultiTrainer, PipelineTrainer, MultiTrainer, HeterXpuTrainer from .trainer_desc import TrainerDesc, DistMultiTrainer, PipelineTrainer, HeterPipelineTrainer, MultiTrainer, HeterXpuTrainer
from .transpiler import HashName, RoundRobin from .transpiler import HashName, RoundRobin
from .backward import append_backward from .backward import append_backward
......
...@@ -191,8 +191,9 @@ class LargeScaleKV(object): ...@@ -191,8 +191,9 @@ class LargeScaleKV(object):
class HeterClient(object): class HeterClient(object):
def __init__(self, endpoint, trainer_id): def __init__(self, endpoint, previous_endpoint, trainer_id):
self.heter_client_ = core.HeterClient(endpoint, trainer_id) self.heter_client_ = core.HeterClient(endpoint, previous_endpoint,
trainer_id)
def stop(self): def stop(self):
self.heter_client_.stop() self.heter_client_.stop()
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
from __future__ import print_function from __future__ import print_function
__all__ = [ __all__ = [
'DeviceWorker', 'Hogwild', 'DownpourSGD', 'Section', 'DownpourSGDOPT' 'DeviceWorker', 'Hogwild', 'DownpourSGD', 'Section', 'DownpourSGDOPT',
'HeterSection'
] ]
...@@ -444,6 +445,36 @@ class Section(DeviceWorker): ...@@ -444,6 +445,36 @@ class Section(DeviceWorker):
cfg.place_id = place_id cfg.place_id = place_id
class HeterSection(DeviceWorker):
"""HeterSectionWorker."""
def __init__(self):
"""Init."""
super(HeterSection, self).__init__()
def _gen_worker_desc(self, trainer_desc):
"""
Generator worker desc, which device worker is HeterSectionWorker.
Args:
trainer_desc(TrainerDesc): a TrainerDesc object
"""
from google.protobuf import text_format
from . import core
trainer_desc.device_worker_name = "HeterSectionWorker"
heter_pipeline_opt = self._program._heter_pipeline_opt
heter_section_param = trainer_desc.heter_section_param
heter_section_param.num_microbatches = heter_pipeline_opt[
"num_microbatches"]
heter_section_param.pipeline_stage = heter_pipeline_opt[
"pipeline_stage"]
heter_section_param.num_pipeline_stages = heter_pipeline_opt[
"num_pipeline_stages"]
cfg = heter_section_param.section_config
program = heter_pipeline_opt["section_program"]
cfg.program_desc.ParseFromString(program._get_desc()
.serialize_to_string())
class DeviceWorkerFactory(object): class DeviceWorkerFactory(object):
def _create_device_worker(self, worker_type): def _create_device_worker(self, worker_type):
classname = worker_type.capitalize() classname = worker_type.capitalize()
......
...@@ -690,6 +690,7 @@ class Executor(object): ...@@ -690,6 +690,7 @@ class Executor(object):
self.place = framework._get_paddle_place(place) self.place = framework._get_paddle_place(place)
self.program_caches = dict() self.program_caches = dict()
self.ctx_caches = dict() self.ctx_caches = dict()
self.trainer_caches = dict()
self.scope_caches = dict() self.scope_caches = dict()
self.var_caches = dict() self.var_caches = dict()
self.pruned_program_caches = dict() self.pruned_program_caches = dict()
...@@ -713,6 +714,9 @@ class Executor(object): ...@@ -713,6 +714,9 @@ class Executor(object):
def _get_ctx_cache(self, program_cache_key): def _get_ctx_cache(self, program_cache_key):
return self.ctx_caches.get(program_cache_key, None) return self.ctx_caches.get(program_cache_key, None)
def _get_trainer_cache(self, program_cache_key):
return self.trainer_caches.get(program_cache_key, None)
def _get_program_cache(self, program_cache_key): def _get_program_cache(self, program_cache_key):
return self.program_caches.get(program_cache_key, None) return self.program_caches.get(program_cache_key, None)
...@@ -734,6 +738,9 @@ class Executor(object): ...@@ -734,6 +738,9 @@ class Executor(object):
def _add_ctx_cache(self, ctx_cache_key, ctx): def _add_ctx_cache(self, ctx_cache_key, ctx):
self.ctx_caches[ctx_cache_key] = ctx self.ctx_caches[ctx_cache_key] = ctx
def _add_trainer_cache(self, trainer_cache_key, ctx):
self.trainer_caches[trainer_cache_key] = ctx
def _add_scope_cache(self, scope_cache_key, scope): def _add_scope_cache(self, scope_cache_key, scope):
self.scope_caches[scope_cache_key] = scope self.scope_caches[scope_cache_key] = scope
...@@ -995,8 +1002,11 @@ class Executor(object): ...@@ -995,8 +1002,11 @@ class Executor(object):
exe.close() exe.close()
""" """
if not self._closed: if not self._closed:
self._default_executor.close()
self._closed = True self._closed = True
for k, trainer_instance in self.trainer_caches.items():
self._default_executor.release_trainer(trainer_instance)
del trainer_instance
self._default_executor.close()
def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name, def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name,
return_numpy, return_merged): return_numpy, return_merged):
...@@ -1286,6 +1296,12 @@ class Executor(object): ...@@ -1286,6 +1296,12 @@ class Executor(object):
program, program,
fetch_list=fetch_list, fetch_list=fetch_list,
use_program_cache=use_program_cache) use_program_cache=use_program_cache)
if isinstance(program, Program) and program._heter_pipeline_opt:
if "startup_program" in program._heter_pipeline_opt:
program = program._heter_pipeline_opt["startup_program"]
# TODO(zhangminxu): support heterps pipeline training using exe.run
if isinstance(program, Program) and \ if isinstance(program, Program) and \
len(program.global_block().ops) == 0: len(program.global_block().ops) == 0:
if use_default_main_program: if use_default_main_program:
...@@ -1588,6 +1604,9 @@ class Executor(object): ...@@ -1588,6 +1604,9 @@ class Executor(object):
if program._pipeline_opt: if program._pipeline_opt:
trainer = TrainerFactory()._create_trainer( trainer = TrainerFactory()._create_trainer(
program._pipeline_opt) program._pipeline_opt)
elif program._heter_pipeline_opt:
trainer = TrainerFactory()._create_trainer(
program._heter_pipeline_opt)
else: else:
trainer = TrainerFactory()._create_trainer(program._fleet_opt) trainer = TrainerFactory()._create_trainer(program._fleet_opt)
trainer._set_thread_barrier(program._is_distributed) trainer._set_thread_barrier(program._is_distributed)
...@@ -1598,6 +1617,9 @@ class Executor(object): ...@@ -1598,6 +1617,9 @@ class Executor(object):
if program._pipeline_opt: if program._pipeline_opt:
trainer = TrainerFactory()._create_trainer( trainer = TrainerFactory()._create_trainer(
program.program._pipeline_opt) program.program._pipeline_opt)
elif program._heter_pipeline_opt:
trainer = TrainerFactory()._create_trainer(
program.program._heter_pipeline_opt)
else: else:
trainer = TrainerFactory()._create_trainer( trainer = TrainerFactory()._create_trainer(
program.program._fleet_opt) program.program._fleet_opt)
...@@ -1681,7 +1703,6 @@ class Executor(object): ...@@ -1681,7 +1703,6 @@ class Executor(object):
'op_role', 'op_role',
core.op_proto_and_checker_maker.OpRole.Optimize) core.op_proto_and_checker_maker.OpRole.Optimize)
fetch_list = None fetch_list = None
scope, trainer = self._prepare_trainer( scope, trainer = self._prepare_trainer(
program=program, program=program,
dataset=dataset, dataset=dataset,
...@@ -1696,14 +1717,28 @@ class Executor(object): ...@@ -1696,14 +1717,28 @@ class Executor(object):
trainer._gen_trainer_desc() trainer._gen_trainer_desc()
if program._pipeline_opt is None: if program._pipeline_opt is None:
if program._heter_pipeline_opt is None:
self._dump_debug_info(program=program, trainer=trainer) self._dump_debug_info(program=program, trainer=trainer)
# in case of calling _set_use_ps_gpu explicitly # in case of calling _set_use_ps_gpu explicitly
if dataset.use_ps_gpu is False: if dataset.use_ps_gpu is False:
dataset._set_use_ps_gpu(trainer.proto_desc.use_ps_gpu) dataset._set_use_ps_gpu(trainer.proto_desc.use_ps_gpu)
dataset._dynamic_adjust_before_train(trainer.proto_desc.thread_num) dataset._dynamic_adjust_before_train(trainer.proto_desc.thread_num)
if program._heter_pipeline_opt is None:
trainer_instance = self._default_executor.init_for_dataset(
program.desc, trainer._desc(), scope, dataset.dataset)
else:
# cache trainer instance for heterps pipeline training
if fetch_list == None:
fetch_list = []
cache_key = _get_strong_program_cache_key(program, None, fetch_list)
trainer_instance = self._get_trainer_cache(cache_key)
if trainer_instance is None:
trainer_instance = self._default_executor.init_for_dataset( trainer_instance = self._default_executor.init_for_dataset(
program.desc, trainer._desc(), scope, dataset.dataset) program.desc, trainer._desc(), scope, dataset.dataset)
self._add_trainer_cache(cache_key, trainer_instance)
else:
trainer_instance.ResetDataset(dataset.dataset)
if fetch_handler is not None: if fetch_handler is not None:
scope0 = trainer_instance.get_worker_scope(0) scope0 = trainer_instance.get_worker_scope(0)
...@@ -1711,10 +1746,11 @@ class Executor(object): ...@@ -1711,10 +1746,11 @@ class Executor(object):
fetch_monitor.start() fetch_monitor.start()
self._default_executor.run_from_dataset(trainer_instance) self._default_executor.run_from_dataset(trainer_instance)
fetch_monitor.stop() fetch_monitor.stop()
if program._heter_pipeline_opt is None:
self._default_executor.release_trainer(trainer_instance) self._default_executor.release_trainer(trainer_instance)
else: else:
self._default_executor.run_from_dataset(trainer_instance) self._default_executor.run_from_dataset(trainer_instance)
if program._heter_pipeline_opt is None:
self._default_executor.release_trainer(trainer_instance) self._default_executor.release_trainer(trainer_instance)
dataset._dynamic_adjust_after_train() dataset._dynamic_adjust_after_train()
......
...@@ -4477,6 +4477,9 @@ class Program(object): ...@@ -4477,6 +4477,9 @@ class Program(object):
# assigned if this program has been parsed by a pipeline optimizer # assigned if this program has been parsed by a pipeline optimizer
self._pipeline_opt = None self._pipeline_opt = None
# assigned if this program has been parsed by a heter pipeline parameter server optimizer
self._heter_pipeline_opt = None
# appending gradients times # appending gradients times
self._appending_grad_times = 0 self._appending_grad_times = 0
......
...@@ -20,6 +20,7 @@ import paddle.fluid.framework as framework ...@@ -20,6 +20,7 @@ import paddle.fluid.framework as framework
from paddle.fluid.transpiler.details.program_utils import delete_ops from paddle.fluid.transpiler.details.program_utils import delete_ops
from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import find_heter_ops from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import find_heter_ops
from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import union_forward_gradient_op
from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import create_heter_program from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import create_heter_program
from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import create_trainer_program from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import create_trainer_program
from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import find_block_joints from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import find_block_joints
...@@ -27,7 +28,7 @@ from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import find_op ...@@ -27,7 +28,7 @@ from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import find_op
from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import get_vars_name_in_block from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import get_vars_name_in_block
def split_heter_worker_ops_pass(program, config): def split_heter_worker_ops_pass(program, config, stage_id, device):
""" """
split heter worker program from origin-program split heter worker program from origin-program
1. find heter op (located on different device) 1. find heter op (located on different device)
...@@ -43,19 +44,15 @@ def split_heter_worker_ops_pass(program, config): ...@@ -43,19 +44,15 @@ def split_heter_worker_ops_pass(program, config):
) )
return program return program
current_device = "gpu" program_block_ops = union_forward_gradient_op(program_block_ops)
if current_device not in heter_ops:
raise ValueError("Op which run on device {} not exist.".format(
current_device))
block_vars_detail = find_block_joints(program, program_block_ops, heter_ops) block_vars_detail = find_block_joints(program, program_block_ops, heter_ops)
heter_program = framework.Program() heter_program = framework.Program()
create_heter_program(program, config, heter_program, heter_ops, create_heter_program(program, config, heter_program, program_block_ops,
block_vars_detail, current_device) heter_ops, block_vars_detail, device, stage_id)
return heter_program return heter_program
def split_trainer_ops_pass(program, config): def split_trainer_ops_pass(program, config, default_device="cpu"):
""" """
split cpu-trainer program from origin-program split cpu-trainer program from origin-program
1. find heter op (located on different device) 1. find heter op (located on different device)
...@@ -63,38 +60,13 @@ def split_trainer_ops_pass(program, config): ...@@ -63,38 +60,13 @@ def split_trainer_ops_pass(program, config):
3. create cpu-trainer program, add send&recv op 3. create cpu-trainer program, add send&recv op
""" """
# Todo: support user define default_device (MrChengmo) # Todo: support user define default_device (MrChengmo)
default_deveice = "cpu" default_device_ = default_device
program, heter_ops, _, program_block_ops = find_heter_ops(program, program, heter_ops, default_ops, program_block_ops = find_heter_ops(
default_deveice) program, default_device_)
block_vars_detail = find_block_joints(program, program_block_ops, heter_ops) program_block_ops = union_forward_gradient_op(program_block_ops)
create_trainer_program(program, config, heter_ops, block_vars_detail)
return program
def delete_startup_useless_ops_var_pass(startup_program, main_program, config): block_vars_detail = find_block_joints(program, program_block_ops, heter_ops)
""" trainer_program = program.clone()
delete variable which not used in current main_program create_trainer_program(trainer_program, program, config, program_block_ops,
""" block_vars_detail)
# find all op and its var return trainer_program
vars_in_main_program = get_vars_name_in_block(main_program.global_block())
block_nums = startup_program.num_blocks
for block_index in range(1, block_nums):
current_block = startup_program.block(block_index)
# delete useless op
need_delete_op = []
for op in current_block.ops:
inputs, outputs = find_op_input_output(startup_program,
current_block, op)
inputs += outputs
# Todo: delete some concat op
if list(set(inputs) & set(vars_in_main_program)) == None:
need_delete_op.append(op)
delete_ops(current_block, need_delete_op)
# delete useless var
for var in current_block.vars:
if var.name not in vars_in_main_program:
startup_program._remove_var(var.name)
return startup_program
...@@ -216,12 +216,36 @@ class CompileTimeStrategy(object): ...@@ -216,12 +216,36 @@ class CompileTimeStrategy(object):
except Exception: except Exception:
return self.role_maker.get_heter_worker_endpoints() return self.role_maker.get_heter_worker_endpoints()
def get_next_stage_trainers(self):
try:
return self.role_maker._get_next_trainers()
except Exception:
return self.role_maker.get_next_trainers()
def get_heter_worker_endpoint(self): def get_heter_worker_endpoint(self):
try: try:
return self.role_maker._get_heter_worker_endpoint() return self.role_maker._get_heter_worker_endpoint()
except Exception: except Exception:
return self.role_maker.get_heter_worker_endpoint() return self.role_maker.get_heter_worker_endpoint()
def get_trainer_endpoints(self):
try:
return self.role_maker._get_trainer_endpoints()
except Exception:
return self.role_maker.get_trainer_endpoints()
def get_trainer_endpoint(self):
try:
return self.role_maker._get_trainer_endpoint()
except Exception:
return self.role_maker.get_trainer_endpoint()
def get_previous_stage_trainers(self):
try:
return self.role_maker._get_previous_trainers()
except Exception:
return self.role_maker.get_previous_trainers()
def get_origin_programs(self): def get_origin_programs(self):
return self.origin_main_program, self.origin_startup_program return self.origin_main_program, self.origin_startup_program
......
...@@ -105,6 +105,8 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): ...@@ -105,6 +105,8 @@ def distributed_ops_pass(program, config, use_ps_gpu=False):
if op.type in SPARSE_OP_TYPE_DICT.keys() \ if op.type in SPARSE_OP_TYPE_DICT.keys() \
and op.attr('remote_prefetch') is True: and op.attr('remote_prefetch') is True:
param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0] param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0]
# trick for matchnet, need to modify
param_name += op.input("Ids")[0][0]
ops = pull_sparse_ops.get(param_name, []) ops = pull_sparse_ops.get(param_name, [])
ops.append(op) ops.append(op)
pull_sparse_ops[param_name] = ops pull_sparse_ops[param_name] = ops
...@@ -208,7 +210,9 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): ...@@ -208,7 +210,9 @@ def distributed_ops_pass(program, config, use_ps_gpu=False):
for param, ops in pull_sparse_ops.items(): for param, ops in pull_sparse_ops.items():
all_ops = program.global_block().ops all_ops = program.global_block().ops
op_device = ""
if config.is_heter_ps_mode:
op_device = ops[0].attr("op_device")
inputs = [ inputs = [
program.global_block().vars[op.input("Ids")[0]] for op in ops program.global_block().vars[op.input("Ids")[0]] for op in ops
] ]
...@@ -258,6 +262,10 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): ...@@ -258,6 +262,10 @@ def distributed_ops_pass(program, config, use_ps_gpu=False):
outputs_idxs[out_id]) outputs_idxs[out_id])
if min(outputs_idxs) - max(inputs_idxs) >= 1: if min(outputs_idxs) - max(inputs_idxs) >= 1:
if max(inputs_idxs) == -1:
distributed_idx = min(op_idxs)
else:
distributed_idx = max(inputs_idxs) + 1 distributed_idx = max(inputs_idxs) + 1
if use_ps_gpu: if use_ps_gpu:
...@@ -283,7 +291,8 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): ...@@ -283,7 +291,8 @@ def distributed_ops_pass(program, config, use_ps_gpu=False):
"is_distributed": is_distributed, "is_distributed": is_distributed,
"padding_idx": padding_idx, "padding_idx": padding_idx,
"table_id": table_id, "table_id": table_id,
"lookup_table_version": op_type "lookup_table_version": op_type,
"op_device": op_device
}) })
else: else:
for i in range(len(inputs_idxs)): for i in range(len(inputs_idxs)):
...@@ -299,7 +308,8 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): ...@@ -299,7 +308,8 @@ def distributed_ops_pass(program, config, use_ps_gpu=False):
"is_distributed": is_distributed, "is_distributed": is_distributed,
"padding_idx": padding_idx, "padding_idx": padding_idx,
"table_id": table_id, "table_id": table_id,
"lookup_table_version": op_type "lookup_table_version": op_type,
"op_device": op_device
}) })
pull_sparse_ops = _get_pull_sparse_ops(program) pull_sparse_ops = _get_pull_sparse_ops(program)
...@@ -504,7 +514,7 @@ def ps_gpu_pass(program): ...@@ -504,7 +514,7 @@ def ps_gpu_pass(program):
return program return program
def delet_extra_optimizes_pass(program, config): def delete_extra_optimizes_pass(program, config):
optimize_vars = [] optimize_vars = []
optimize_op_role_vars = [] optimize_op_role_vars = []
optimize_need_delete_vars = [] optimize_need_delete_vars = []
...@@ -516,7 +526,6 @@ def delet_extra_optimizes_pass(program, config): ...@@ -516,7 +526,6 @@ def delet_extra_optimizes_pass(program, config):
optimize_vars = list(set(optimize_vars)) optimize_vars = list(set(optimize_vars))
optimize_op_role_vars = list(set(optimize_op_role_vars)) optimize_op_role_vars = list(set(optimize_op_role_vars))
for var in optimize_vars: for var in optimize_vars:
if var not in optimize_op_role_vars: if var not in optimize_op_role_vars:
optimize_need_delete_vars.append(var) optimize_need_delete_vars.append(var)
...@@ -553,7 +562,7 @@ def find_heter_ops(program, default_device="cpu"): ...@@ -553,7 +562,7 @@ def find_heter_ops(program, default_device="cpu"):
elif op_type in COMMUNICATE_OPS_TYPE and current_heter_device != default_device: elif op_type in COMMUNICATE_OPS_TYPE and current_heter_device != default_device:
# for distributed communciate ops: send & recv & barrier etc. # for distributed communciate ops: send & recv & barrier etc.
# Todo: need update this method # Todo: need update this method
op._set_attr('op_device', current_heter_device) #op._set_attr('op_device', current_heter_device)
return True return True
elif op_device == None or op_device == default_device: elif op_device == None or op_device == default_device:
op._set_attr('op_device', default_device) op._set_attr('op_device', default_device)
...@@ -574,6 +583,138 @@ def find_heter_ops(program, default_device="cpu"): ...@@ -574,6 +583,138 @@ def find_heter_ops(program, default_device="cpu"):
heter_ops[op_device] = {} heter_ops[op_device] = {}
current_heter_block_ops.append(op) current_heter_block_ops.append(op)
origin_porgram = program.clone()
block = program.global_block()
'''
re-place sum op to fix bug for union forward backward op
'''
var2idx = {}
op_list = list(block.ops)
op_size = len(op_list)
for i in range(op_size - 1, -1, -1):
op_list = list(block.ops)
op = op_list[i]
if "_grad" in op.type:
forward_op_type = op.type.split("_grad")[0]
if forward_op_type in SPARSE_OP_TYPE_DICT.keys() \
and op.attr('remote_prefetch') is True:
param_name = op.input(SPARSE_OP_TYPE_DICT[forward_op_type])[0]
if param_name in var2idx:
## insert sum op & remove sum op from var2idx and origin place
op_list = list(block.ops)
sum_op = op_list[var2idx[param_name]]
sum_op_inputs = {
sum_op.input_names[0]: [
block.vars[input]
for input in sum_op.input_arg_names
]
}
sum_op_outputs = {
sum_op.output_names[0]: [
block.vars[output]
for output in sum_op.output_arg_names
]
}
block._insert_op(
index=i + 1,
type=sum_op.type,
inputs=sum_op_inputs,
outputs=sum_op_outputs,
attrs=sum_op.all_attrs())
block._remove_op(var2idx[param_name] + 1)
var2idx.pop(param_name)
for var_ in var2idx:
var2idx[var_] += 1
elif forward_op_type == "elementwise_mul":
"""
get output varname of pre op
"""
output_vars_no_grad = []
for key in pre_op.output_names:
for varname in op.output(key):
if varname == "@EMPTY@":
continue
if "lod_tensor_blocking_queue" in varname:
continue
output_vars_no_grad.append(varname.split("@GRAD")[0])
for no_grad_var in output_vars_no_grad:
if no_grad_var in var2idx:
"""
insert sum op & remove sum op from var2idx and origin place
"""
op_list = list(block.ops)
sum_op = op_list[var2idx[no_grad_var]]
sum_op_inputs = {
sum_op.input_names[0]: [
block.vars[input]
for input in sum_op.input_arg_names
]
}
sum_op_outputs = {
sum_op.output_names[0]: [
block.vars[output]
for output in sum_op.output_arg_names
]
}
block._insert_op(
index=i + 1,
type=sum_op.type,
inputs=sum_op_inputs,
outputs=sum_op_outputs,
attrs=sum_op.all_attrs())
block._remove_op(var2idx[no_grad_var] + 1)
var2idx.pop(no_grad_var)
for var_ in var2idx:
var2idx[var_] += 1
else:
if op.type == "sum":
var = op.output("Out")[0]
if "@GRAD" in var:
origin_var = var.split("@GRAD")[0]
pre_op = op_list[i - 1]
if "_grad" in pre_op.type:
forward_op_type = pre_op.type.split("_grad")[0]
if forward_op_type in SPARSE_OP_TYPE_DICT.keys() \
and pre_op.attr('remote_prefetch') is True:
param_name = pre_op.input(SPARSE_OP_TYPE_DICT[
forward_op_type])[0]
if param_name == origin_var and op.attr(
"op_device") == pre_op.attr("op_device"):
continue
else:
var2idx[origin_var] = i
elif forward_op_type == "elementwise_mul":
output_vars = []
for key in pre_op.output_names:
for varname in pre_op.output(key):
if varname == "@EMPTY@":
continue
if "lod_tensor_blocking_queue" in varname:
continue
output_vars.append(varname)
input_vars = []
for key in op.input_names:
for varname in op.input(key):
if varname == "@EMPTY@":
continue
if "lod_tensor_blocking_queue" in varname:
continue
input_vars.append(varname)
is_match = False
for varname in output_vars:
if varname in input_vars:
is_match = True
break
if is_match:
continue
else:
var2idx[origin_var] = i
else:
var2idx[origin_var] = i
origin_porgram = program.clone() origin_porgram = program.clone()
block = program.global_block() block = program.global_block()
...@@ -581,7 +722,6 @@ def find_heter_ops(program, default_device="cpu"): ...@@ -581,7 +722,6 @@ def find_heter_ops(program, default_device="cpu"):
default_ops = {default_device: {}} default_ops = {default_device: {}}
heter_ops = {} heter_ops = {}
block_index = 0 block_index = 0
# heter_ops: {"gpu": {1:[op1, op2, ...], 2:[op1, op2, ...] }; "xpu": {3:[op1, op2, ...], 4:[op1, op2, ...] }}
current_heter_block_ops = [] current_heter_block_ops = []
current_default_block_ops = [] current_default_block_ops = []
...@@ -652,11 +792,12 @@ def find_heter_ops(program, default_device="cpu"): ...@@ -652,11 +792,12 @@ def find_heter_ops(program, default_device="cpu"):
print( print(
"There are {} OPs in your main_program, and contains {} heter-OPs which is made up of {} heter-blocks.". "There are {} OPs in your main_program, and contains {} heter-OPs which is made up of {} heter-blocks.".
format(len(block.ops), total_heter_ops, heter_blocks)) format(len(block.ops), total_heter_ops, heter_blocks))
return origin_porgram, heter_ops, default_ops, program_block_ops return origin_porgram, heter_ops, default_ops, program_block_ops
def create_heter_program(program, config, heter_program, heter_ops, def create_heter_program(program, config, heter_program, program_block_ops_list,
block_var_detail, current_device): heter_ops, block_var_detail, current_device, stage_id):
# This function mainly includes the following contents: # This function mainly includes the following contents:
# 1. For every heter block: # 1. For every heter block:
# a) copy heter device op from origin program # a) copy heter device op from origin program
...@@ -670,7 +811,7 @@ def create_heter_program(program, config, heter_program, heter_ops, ...@@ -670,7 +811,7 @@ def create_heter_program(program, config, heter_program, heter_ops,
# d) copy send op from origin program for var@grad which loacted in current heter block # d) copy send op from origin program for var@grad which loacted in current heter block
# e) re-check every op in current blcok if its device is not current heter devie # e) re-check every op in current blcok if its device is not current heter devie
# 2. Create send op for step counter in last heter-block # 2. Create send op for step counter in last heter-block
# 3. Create Listen&Serv OP for distributed training # 3. Create Listen&Serv OP and Send&Recv OP for distributed training
# 4. update CompileTimeStrategy for heter_program # 4. update CompileTimeStrategy for heter_program
optimizer_block = [] optimizer_block = []
...@@ -678,33 +819,84 @@ def create_heter_program(program, config, heter_program, heter_ops, ...@@ -678,33 +819,84 @@ def create_heter_program(program, config, heter_program, heter_ops,
send_grad_var_list = [] send_grad_var_list = []
pre_block_idx = heter_program.num_blocks - 1 pre_block_idx = heter_program.num_blocks - 1
for index, heter_block_ops in heter_ops[current_device].items(): stage_id = int(stage_id)
print("stage id", stage_id)
heter_block_ops_forward = program_block_ops_list[stage_id - 1]["forward"]
heter_block_ops_backward = program_block_ops_list[stage_id - 1]["backward"]
heter_block = heter_program._create_block(pre_block_idx) heter_block = heter_program._create_block(pre_block_idx)
optimizer_block.append(heter_block) optimizer_block.append(heter_block)
for _, op in enumerate(heter_block_ops): for _, op in enumerate(heter_block_ops_forward):
block_append_op(heter_program, program, heter_block, op) block_append_op(heter_program, program, heter_block, op)
entrance_vars = block_var_detail[index]["entrance"] entrance_vars = block_var_detail[stage_id - 1]["forward"]["entrance"]
add_vars_by_var_list(entrance_vars, program, heter_program, heter_block) add_vars_by_var_list(entrance_vars, program, heter_program, heter_block)
exit_vars = block_var_detail[index]["exit"] exit_vars = block_var_detail[stage_id - 1]["forward"]["exit"]
add_vars_by_var_list(exit_vars, program, heter_program, heter_block) add_vars_by_var_list(exit_vars, program, heter_program, heter_block)
comm_info = get_communicate_var_info(program, index, entrance_vars, first_op_index_fp = len(heter_block.ops)
exit_vars)
if stage_id < len(program_block_ops_list):
heter_block_bp = heter_program._create_block(pre_block_idx)
optimizer_block.append(heter_block_bp)
for _, op in enumerate(heter_block_ops_backward):
block_append_op(heter_program, program, heter_block_bp, op)
bp_entrance_vars = block_var_detail[stage_id - 1]["backward"][
"entrance"]
add_vars_by_var_list(bp_entrance_vars, program, heter_program,
heter_block_bp)
bp_exit_vars = block_var_detail[stage_id - 1]["backward"]["exit"]
add_vars_by_var_list(bp_exit_vars, program, heter_program,
heter_block_bp)
backward_comm_info = get_communicate_var_info(
program, stage_id, bp_entrance_vars, type="backward")
grad_to_block_id.append(backward_comm_info["block_input_var_name"] + ":"
+ str(heter_block_bp.idx))
else:
for _, op in enumerate(heter_block_ops_backward):
block_append_op(heter_program, program, heter_block, op)
bp_entrance_vars = block_var_detail[stage_id - 1]["backward"][
"entrance"]
add_vars_by_var_list(bp_entrance_vars, program, heter_program,
heter_block)
bp_exit_vars = block_var_detail[stage_id - 1]["backward"]["exit"]
add_vars_by_var_list(bp_exit_vars, program, heter_program, heter_block)
heter_block_bp = heter_block
grad_to_block_id.append(comm_info["block_input_var_name"] + ":" + str( forward_comm_info = get_communicate_var_info(
heter_block.idx)) program, stage_id, entrance_vars, type="forward")
first_op_index = 0 grad_to_block_id.append(forward_comm_info["block_input_var_name"] + ":" +
str(heter_block.idx))
first_op_index_bp = len(heter_block_bp.ops)
if stage_id <= len(block_var_detail) - 1:
static_var = insert_communicate_op(program, config, heter_block,
stage_id, first_op_index_fp,
block_var_detail, current_device)
static_var_bp = insert_communicate_op(
program, config, heter_block_bp, stage_id, first_op_index_bp,
block_var_detail, current_device, False)
# add send op # add send op
send_grad_var_list = send_grad_var_list + add_heter_send_op( send_grad_var_list = add_heter_send_op(
program, heter_program, heter_block, block_var_detail[index]) program, heter_program, heter_block_bp, block_var_detail[stage_id - 1])
# ---------------
# add step conter # add step conter
send_input_vars = [] send_input_vars = []
dummy_output = [] dummy_output = []
pserver_endpoints = config.get_ps_endpoints() pserver_endpoints = config.get_ps_endpoints()
# optimizer_block[-1].append_op( # optimizer_block[-1].append_op(
# type="send", # type="send",
# inputs={"X": send_input_vars}, # inputs={"X": send_input_vars},
...@@ -718,14 +910,18 @@ def create_heter_program(program, config, heter_program, heter_ops, ...@@ -718,14 +910,18 @@ def create_heter_program(program, config, heter_program, heter_ops,
# add info in listen&serv # add info in listen&serv
attrs = { attrs = {
#"mode": "sync",
#"trainers": config.get_trainers(),
#"trainer_id": config.get_role_id() + config.get_trainers(),
"message_to_block_id": grad_to_block_id, "message_to_block_id": grad_to_block_id,
"optimize_blocks": optimizer_block, "optimize_blocks": optimizer_block,
# runtime attribute # runtime attribute
"endpoint": config.get_heter_worker_endpoint(), "endpoint": config.get_heter_worker_endpoint(),
"fanin": config.get_trainers(), "fanin": len(config.get_previous_stage_trainers()),
"pserver_id": config.get_role_id(), "pserver_id": config.get_role_id(),
"distributed_mode": config.get_distributed_mode(), "distributed_mode": config.get_distributed_mode(),
"rpc_exec_thread_num": int(os.getenv("CPU_NUM", 32)) "rpc_exec_thread_num": int(os.getenv("CPU_NUM", 32)),
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
} }
# append the listen_and_serv op # append the listen_and_serv op
heter_program.global_block().append_op( heter_program.global_block().append_op(
...@@ -747,7 +943,8 @@ def check_heter_compile_time_strategy(program, config, send_grad_var_list): ...@@ -747,7 +943,8 @@ def check_heter_compile_time_strategy(program, config, send_grad_var_list):
config.remove_var_pair_by_grad(useless_grad_var) config.remove_var_pair_by_grad(useless_grad_var)
def create_trainer_program(program, config, heter_ops, block_var_detail): def create_trainer_program(program, origin_program, config,
program_block_ops_list, block_var_detail):
# This function mainly includes the following contents: # This function mainly includes the following contents:
# 1. For every heter block in origin program # 1. For every heter block in origin program
# a) delete heter op and related variables # a) delete heter op and related variables
...@@ -759,17 +956,127 @@ def create_trainer_program(program, config, heter_ops, block_var_detail): ...@@ -759,17 +956,127 @@ def create_trainer_program(program, config, heter_ops, block_var_detail):
# d) remove send op which related var@grad is not in trainer program # d) remove send op which related var@grad is not in trainer program
# 2. check every op's device # 2. check every op's device
static_var = [] static_var = []
for device in heter_ops.keys(): for heter_block_index in range(1, len(program_block_ops_list)):
for heter_block_index in sorted(heter_ops[device]): ops_list = program_block_ops_list[heter_block_index][
"forward"] + program_block_ops_list[heter_block_index]["backward"]
static_var += replace_ops_by_communicate_op( static_var += replace_ops_by_communicate_op(
program, config, heter_block_index, program, config, heter_block_index, ops_list, block_var_detail)
heter_ops[device][heter_block_index], block_var_detail)
remove_trainer_send_op(program, config, heter_block_index, remove_trainer_send_op(program, config, heter_block_index,
block_var_detail) block_var_detail)
deleter_trainer_useless_var(config, program, static_var)
optimizer_block = []
grad_to_block_id = []
bp_ops_list = program_block_ops_list[0]["backward"]
delete_same_ops(program.global_block(), bp_ops_list)
delete_trainer_useless_var(config, program, static_var)
backward_block = create_backward_block(program, origin_program, config,
bp_ops_list, block_var_detail)
bp_entrance_vars = block_var_detail[0]["backward"]["entrance"]
backward_comm_info = get_communicate_var_info(
origin_program, 1, bp_entrance_vars, type="backward")
grad_to_block_id.append(backward_comm_info["block_input_var_name"] + ":" +
str(backward_block.idx))
optimizer_block.append(backward_block)
attrs = {
#"mode": "sync",
#"trainers": config.get_trainers(),
#"trainer_id": config.get_role_id(),
"message_to_block_id": grad_to_block_id,
"optimize_blocks": optimizer_block,
# runtime attribute
"endpoint": config.get_trainer_endpoint(), ## get trainer endpoint
"fanin": 0, ## get heter worker
"pserver_id": config.get_role_id(),
"distributed_mode": config.get_distributed_mode(),
"rpc_exec_thread_num": int(os.getenv("CPU_NUM", 32)),
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}
# append the listen_and_serv op
program.global_block()._insert_op(
index=0,
type="heter_listen_and_serv",
inputs={'X': []},
outputs={},
attrs=attrs)
## TODO add check for bp block
check_op_device(program.global_block(), DEFAULT_DEVICE) check_op_device(program.global_block(), DEFAULT_DEVICE)
def insert_communicate_op(orign_program,
config,
heter_block,
stage_id,
first_op_index,
block_var_detail,
device,
is_forward=True):
if is_forward:
next_heter_worker_endpoints = config.get_next_stage_trainers()
previous_heter_worker_endpoints = config.get_previous_stage_trainers()
entrance_var = block_var_detail[stage_id]["forward"]["entrance"]
comm_info = get_communicate_var_info(orign_program, stage_id + 1,
entrance_var)
else:
next_heter_worker_endpoints = config.get_next_stage_trainers()
#if next_heter_worker_endpoints == "":
# next_heter_worker_endpoints = []
previous_heter_worker_endpoints = config.get_previous_stage_trainers()
entrance_var = block_var_detail[stage_id - 1]["backward"]["exit"]
comm_info = get_communicate_var_info(orign_program, stage_id - 1,
entrance_var, "backward")
heter_block._insert_op(
index=first_op_index,
type="send_and_recv",
inputs={"X": heter_block.vars[entrance_var[0]]},
outputs={"Out": []},
attrs={
"mode": "forward" if is_forward else "backward",
"send_var_name": entrance_var + ["microbatch_id"],
"recv_var_name": [],
"message_name": comm_info["block_input_var_name"],
"next_endpoints": next_heter_worker_endpoints,
"previous_endpoints": previous_heter_worker_endpoints,
"trainer_id": config.get_role_id(),
"op_device": device,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
return entrance_var
def create_backward_block(program, origin_program, config, bp_ops_list,
block_var_detail):
pre_block_idx = program.num_blocks - 1
heter_block = program._create_block(pre_block_idx)
for _, op in enumerate(bp_ops_list):
if op.type == "send":
send_varnames = op.attr('send_varnames')
is_skip = False
for varname in send_varnames:
if varname not in program.global_block(
).vars and varname not in heter_block.vars:
is_skip = True
break
if is_skip == True:
continue
block_append_op(program, origin_program, heter_block, op)
entrance_vars = block_var_detail[0]["backward"]["entrance"]
add_vars_by_var_list(entrance_vars, origin_program, program, heter_block)
exit_vars = block_var_detail[0]["backward"]["exit"]
add_vars_by_var_list(exit_vars, origin_program, program, heter_block)
return heter_block
def replace_ops_by_communicate_op(program, config, heter_block_index, ops_list, def replace_ops_by_communicate_op(program, config, heter_block_index, ops_list,
block_var_detail): block_var_detail):
all_op = program.global_block().ops all_op = program.global_block().ops
...@@ -782,37 +1089,44 @@ def replace_ops_by_communicate_op(program, config, heter_block_index, ops_list, ...@@ -782,37 +1089,44 @@ def replace_ops_by_communicate_op(program, config, heter_block_index, ops_list,
assert first_op_idx != -1 assert first_op_idx != -1
delete_same_ops(program.global_block(), ops_list) delete_same_ops(program.global_block(), ops_list)
entrance_var = []
if heter_block_index == 1:
mode = config.get_distributed_mode() mode = config.get_distributed_mode()
heter_worker_endpoint = config.get_heter_worker_endpoints() next_heter_worker_endpoints = config.get_next_stage_trainers()
entrance_var = block_var_detail[heter_block_index]["entrance"]
exit_var = block_var_detail[heter_block_index]["exit"]
comm_info = get_communicate_var_info(program, heter_block_index, entrance_var = block_var_detail[heter_block_index]["forward"][
entrance_var, exit_var) "entrance"]
comm_info = get_communicate_var_info(program, heter_block_index + 1,
entrance_var)
program.global_block()._insert_op( program.global_block()._insert_op(
index=first_op_idx, index=first_op_idx,
type="send_and_recv", type="send_and_recv",
inputs={"X": program.global_block().vars[entrance_var[0]]}, inputs={"X": program.global_block().vars[entrance_var[0]]},
outputs={"Out": program.global_block().vars[exit_var[0]]}, outputs={"Out": []},
attrs={ attrs={
"send_var_name": entrance_var, "mode": "forward",
"recv_var_name": exit_var, "send_var_name": entrance_var + ["microbatch_id"],
"recv_var_name": [],
"message_name": comm_info["block_input_var_name"], "message_name": comm_info["block_input_var_name"],
"endpoints": heter_worker_endpoint, "next_endpoints": next_heter_worker_endpoints,
"previous_endpoints": [],
"trainer_id": config.get_role_id(), "trainer_id": config.get_role_id(),
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
return entrance_var + exit_var return entrance_var
def remove_trainer_send_op(program, config, heter_block_index, def remove_trainer_send_op(program, config, heter_block_index,
block_var_detaile): block_var_detail):
# if trainer do FF->BP->SEND, it has follow vars: var, var@GRAD # if trainer do FF->BP->SEND, it has follow vars: var, var@GRAD
# if trainer only do SEND, it has one var: var@GRAD # if trainer only do SEND, it has one var: var@GRAD
# Delete Send op ,if trainer doesn't has pair var (var<->var@GRAD) # Delete Send op ,if trainer doesn't has pair var (var<->var@GRAD)
persistables = block_var_detaile[heter_block_index]["persistables"] persistables = block_var_detail[heter_block_index]["forward"]["persistables"] + \
block_var_detail[heter_block_index]["backward"]["persistables"]
need_remove_send_op = [] need_remove_send_op = []
need_remove_grad_var = [] need_remove_grad_var = []
for op in find_send_op(program): for op in find_send_op(program):
...@@ -848,7 +1162,7 @@ def add_heter_send_op(program, heter_program, block, block_var_detail): ...@@ -848,7 +1162,7 @@ def add_heter_send_op(program, heter_program, block, block_var_detail):
send_grad_var_list = [] send_grad_var_list = []
send_op_dict = _get_send_op_dict() send_op_dict = _get_send_op_dict()
table_dict = {} table_dict = {}
for persistable_var in block_var_detail["persistables"]: for persistable_var in block_var_detail["backward"]["persistables"]:
# check var_name == var@GRAD # check var_name == var@GRAD
if "@GRAD" not in persistable_var: if "@GRAD" not in persistable_var:
continue continue
...@@ -897,18 +1211,21 @@ def find_send_op(program): ...@@ -897,18 +1211,21 @@ def find_send_op(program):
return send_op_list return send_op_list
def get_communicate_var_info(program, block_index, entrance_var_list, def get_communicate_var_info(program,
exit_var_list): block_index,
entrance_var_list,
type="forward"):
input_var_reshape_dim = [] input_var_reshape_dim = []
input_var_reshape_name = [] input_var_reshape_name = []
block_input_var_name = "joint_{}_{}@Heter".format(block_index - 1,
block_index) if type == "forward":
output_var_reshape_dim = [] block_input_var_name = "forward_joint_{}_{}@Heter".format(
output_var_reshape_name = [] block_index - 1, block_index)
block_output_var_name = "joint_{}_{}@Heter".format(block_index, else:
block_index + 1) block_input_var_name = "backward_joint_{}_{}@Heter".format(
block_index + 1, block_index)
entrance_var_list.sort() entrance_var_list.sort()
exit_var_list.sort()
# input # input
# Heter_SERVER_BLOCK_index@JOINT_VAR -> slice -> var@Heter_SERVER_BLOCK@INPUT_RESHAPE_VAR -> reshape -> var # Heter_SERVER_BLOCK_index@JOINT_VAR -> slice -> var@Heter_SERVER_BLOCK@INPUT_RESHAPE_VAR -> reshape -> var
for name in entrance_var_list: for name in entrance_var_list:
...@@ -924,30 +1241,95 @@ def get_communicate_var_info(program, block_index, entrance_var_list, ...@@ -924,30 +1241,95 @@ def get_communicate_var_info(program, block_index, entrance_var_list,
# output # output
# var -> reshape -> var@Heter_SERVER_BLOCK@INPUT_RESHAPE_VAR -> concat -> Heter_SERVER_BLOCK_index@JOINT_VAR # var -> reshape -> var@Heter_SERVER_BLOCK@INPUT_RESHAPE_VAR -> concat -> Heter_SERVER_BLOCK_index@JOINT_VAR
for var_name in exit_var_list: #for var_name in exit_var_list:
var = program.global_block().vars[var_name] # var = program.global_block().vars[var_name]
shape = var.shape # shape = var.shape
# if len(shape) < 2 or shape[0] != -1: # # if len(shape) < 2 or shape[0] != -1:
# raise ValueError( # # raise ValueError(
# "Variable {} not support heter training. its shape is {}". # # "Variable {} not support heter training. its shape is {}".
# format(var_name, shape)) # # format(var_name, shape))
send_reshape_dim = -1 * reduce(lambda x, y: x * y, shape) # send_reshape_dim = -1 * reduce(lambda x, y: x * y, shape)
output_var_reshape_dim.append(send_reshape_dim) # output_var_reshape_dim.append(send_reshape_dim)
output_var_reshape_name.append("{}.output_reshape@Heter".format( # output_var_reshape_name.append("{}.output_reshape@Heter".format(
var_name)) # var_name))
info = { info = {
"input_var_reshape_dim": input_var_reshape_dim, "input_var_reshape_dim": input_var_reshape_dim,
"input_var_reshape_name": input_var_reshape_name, "input_var_reshape_name": input_var_reshape_name,
"block_input_var_name": block_input_var_name, "block_input_var_name": block_input_var_name,
"output_var_reshape_dim": output_var_reshape_dim, # "output_var_reshape_dim": output_var_reshape_dim,
"output_var_reshape_name": output_var_reshape_name, # "output_var_reshape_name": output_var_reshape_name,
"block_output_var_name": block_output_var_name # "block_output_var_name": block_output_var_name
} }
return info return info
def union_forward_gradient_op(program_block_ops_list):
"""
before analyzing the input & output of each block in program_block_list, we should
union the forward op and corresponding gradient op to elimincate the uneccessary variable
transmit
"""
"""
fix for 2emb model, re-place sum op
"""
block_length = len(program_block_ops_list)
'''
## get the final part
final_part_idx = -1
for i in range(block_length):
op_list = program_block_ops_list[i]
for op in op_list:
if "_grad" in op.type:
final_part_idx = i
break
if final_part_idx != -1:
break
## eliminate wrong partition because of sum op
## lookup_table_v2_grad
## every looup_table_v2_grad op block should follow a sum op
var2idx = {}
for i in range(final_part_idx, block_length):
op_list = program_block_ops_list[i]
for j in range(len(op_list) - 1, -1, -1):
op = op_list[j]
#if op.type == "lookup_table_v2_grad":
# if j < len(op_list) - 1):
# else:
# ## get var and record place
if _grad in op.type:
forward_op_type = op.type.split("_grad")[0]
if forward_op_type in SPARSE_OP_TYPE_DICT.keys() \
and op.attr('remote_prefetch') is True:
param_name = op.input(SPARSE_OP_TYPE_DICT[forward_op_type])[0]
var2idx[] = [i,j] ##
'''
union_program_block_ops_list = []
assert block_length % 2 != 0, "the length of program_block_ops_list should be odd"
for i in range(0, block_length // 2):
block_op_list = {"forward": program_block_ops_list[i]}
block_op_list.update({
"backward": program_block_ops_list[block_length - 1 - i]
})
union_program_block_ops_list.append(block_op_list)
block_op_list = {"forward": [], "backward": []}
for op in program_block_ops_list[block_length // 2]:
if not "_grad" in op.type and not (op.type == "sum"):
block_op_list["forward"].append(op)
else:
block_op_list["backward"].append(op)
union_program_block_ops_list.append(block_op_list)
return union_program_block_ops_list
def find_block_joints(program, program_block_ops_list, heter_ops): def find_block_joints(program, program_block_ops_list, heter_ops):
block_var_detail = find_entrance_exit_private(program, block_var_detail = find_entrance_exit_private(program,
program_block_ops_list) program_block_ops_list)
...@@ -955,6 +1337,7 @@ def find_block_joints(program, program_block_ops_list, heter_ops): ...@@ -955,6 +1337,7 @@ def find_block_joints(program, program_block_ops_list, heter_ops):
block_var_detail, heter_ops) block_var_detail, heter_ops)
block_var_detail = delete_block_useless_exit( block_var_detail = delete_block_useless_exit(
program, program_block_ops_list, block_var_detail) program, program_block_ops_list, block_var_detail)
return block_var_detail return block_var_detail
...@@ -962,8 +1345,9 @@ def find_entrance_exit_private(program, program_block_ops_list): ...@@ -962,8 +1345,9 @@ def find_entrance_exit_private(program, program_block_ops_list):
block_var_detail = [] block_var_detail = []
persistables = [] persistables = []
for index, block_op_list in enumerate(program_block_ops_list): for index, block_op_list in enumerate(program_block_ops_list):
block_input, block_output = find_ops_list_input_output(program, ## forward
block_op_list) block_input, block_output = find_ops_list_input_output(
program, block_op_list["forward"])
persistables = screen_persistables( persistables = screen_persistables(
program, block_input) + screen_persistables(program, block_output) program, block_input) + screen_persistables(program, block_output)
# find entrance & exit # find entrance & exit
...@@ -971,11 +1355,33 @@ def find_entrance_exit_private(program, program_block_ops_list): ...@@ -971,11 +1355,33 @@ def find_entrance_exit_private(program, program_block_ops_list):
block_entrance = list(set(block_input) - set(block_private_vars)) block_entrance = list(set(block_input) - set(block_private_vars))
block_exit = list(set(block_output) - set(block_private_vars)) block_exit = list(set(block_output) - set(block_private_vars))
detail = { detail = {
"forward": {
"entrance": block_entrance, "entrance": block_entrance,
"exit": block_exit, "exit": block_exit,
"private": block_private_vars, "private": block_private_vars,
"persistables": persistables "persistables": persistables
} }
}
## backward
bp_block_input, bp_block_output = find_ops_list_input_output(
program, block_op_list["backward"])
bp_persistables = screen_persistables(
program, bp_block_input) + screen_persistables(program,
bp_block_output)
# find entrance & exit
bp_block_private_vars = list(set(bp_block_input) & set(bp_block_output))
bp_block_entrance = list(
set(bp_block_input) - set(bp_block_private_vars))
bp_block_exit = list(set(bp_block_output) - set(bp_block_private_vars))
detail.update({
"backward": {
"entrance": bp_block_entrance,
"exit": bp_block_exit,
"private": bp_block_private_vars,
"persistables": bp_persistables
}
})
block_var_detail.append(detail) block_var_detail.append(detail)
return block_var_detail return block_var_detail
...@@ -985,20 +1391,64 @@ def entrance_exit_check(program, program_block_ops_list, block_var_detail, ...@@ -985,20 +1391,64 @@ def entrance_exit_check(program, program_block_ops_list, block_var_detail,
for index in range(len(block_var_detail) - 1, -1, -1): for index in range(len(block_var_detail) - 1, -1, -1):
if index - 1 < 0: if index - 1 < 0:
break break
previous_block_exit = block_var_detail[index - 1]["exit"] previous_block_exit = block_var_detail[index - 1]["forward"]["exit"]
previous_block_exit.sort() previous_block_exit.sort()
current_block_entrance = block_var_detail[index]["entrance"] current_block_entrance = block_var_detail[index]["forward"]["entrance"]
backward_entrance = block_var_detail[index]["backward"]["entrance"]
forward_all = block_var_detail[index]["forward"][
"entrance"] + block_var_detail[index]["forward"][
"private"] + block_var_detail[index]["forward"]["exit"]
for var in backward_entrance:
if not ("@GRAD" in var) and not (var in forward_all):
current_block_entrance.append(var)
current_block_entrance.sort() current_block_entrance.sort()
if previous_block_exit == current_block_entrance: if previous_block_exit == current_block_entrance:
continue continue
exist_vars = list( exist_vars = list(
set(previous_block_exit) & set(current_block_entrance)) set(previous_block_exit) & set(current_block_entrance))
need_add_vars = list(set(current_block_entrance) - set(exist_vars)) need_add_vars = list(set(current_block_entrance) - set(exist_vars))
need_add_vars = find_need_var_from_previous_block( # var in different stage should not be ignored, since they are not placed in the same program & device
need_add_vars, block_var_detail, index, heter_ops) #need_add_vars = find_need_var_from_previous_block(
# need_add_vars, block_var_detail, index, heter_ops)
previous_block_private = block_var_detail[index - 1]["forward"][
"private"]
previous_block_entrance = block_var_detail[index - 1]["forward"][
"entrance"]
for var in need_add_vars:
if var not in previous_block_private and var not in previous_block_entrance:
previous_block_entrance.append(var)
previous_block_exit.append(var)
if not var in current_block_entrance:
current_block_entrance.append(var)
previous_block_private = block_var_detail[index - 1]["private"] for index in range(0, len(block_var_detail) - 1, 1):
previous_block_entrance = block_var_detail[index - 1]["entrance"] previous_block_exit = block_var_detail[index + 1]["backward"]["exit"]
previous_block_exit.sort()
current_block_entrance = block_var_detail[index]["backward"]["entrance"]
current_block_entrance.sort()
if previous_block_exit == current_block_entrance:
continue
exist_vars = list(
set(previous_block_exit) & set(current_block_entrance))
need_add_vars = list(set(current_block_entrance) - set(exist_vars))
need_ignore_vars = []
for var in need_add_vars:
if not "@GRAD" in var:
need_ignore_vars.append(var)
need_add_vars = list(
set(need_add_vars).difference(set(need_ignore_vars)))
previous_block_private = block_var_detail[index + 1]["backward"][
"private"]
previous_block_entrance = block_var_detail[index + 1]["backward"][
"entrance"]
for var in need_add_vars: for var in need_add_vars:
if var not in previous_block_private and var not in previous_block_entrance: if var not in previous_block_private and var not in previous_block_entrance:
previous_block_entrance.append(var) previous_block_entrance.append(var)
...@@ -1014,6 +1464,7 @@ def find_need_var_from_previous_block(need_add_vars, block_var_detail, ...@@ -1014,6 +1464,7 @@ def find_need_var_from_previous_block(need_add_vars, block_var_detail,
index_device_map[index] = DEFAULT_DEVICE index_device_map[index] = DEFAULT_DEVICE
for device in heter_ops: for device in heter_ops:
for index in heter_ops[device].keys(): for index in heter_ops[device].keys():
if index < len(block_var_detail):
index_device_map[index] = device index_device_map[index] = device
pre_index = current_index - 1 pre_index = current_index - 1
...@@ -1040,11 +1491,12 @@ def find_need_var_from_previous_block(need_add_vars, block_var_detail, ...@@ -1040,11 +1491,12 @@ def find_need_var_from_previous_block(need_add_vars, block_var_detail,
def delete_block_useless_exit(program, program_block_ops_list, def delete_block_useless_exit(program, program_block_ops_list,
block_var_detail): block_var_detail):
## forward
for index in range(len(block_var_detail)): for index in range(len(block_var_detail)):
if index == len(block_var_detail) - 1: if index == len(block_var_detail) - 1:
break break
current_block_exit = block_var_detail[index]["exit"] current_block_exit = block_var_detail[index]["forward"]["exit"]
next_block_entrance = block_var_detail[index + 1]["entrance"] next_block_entrance = block_var_detail[index + 1]["forward"]["entrance"]
need_delete_var = [] need_delete_var = []
for var in current_block_exit: for var in current_block_exit:
if var not in next_block_entrance: if var not in next_block_entrance:
...@@ -1052,6 +1504,19 @@ def delete_block_useless_exit(program, program_block_ops_list, ...@@ -1052,6 +1504,19 @@ def delete_block_useless_exit(program, program_block_ops_list,
for var in need_delete_var: for var in need_delete_var:
current_block_exit.remove(var) current_block_exit.remove(var)
## backward
for index in range(len(block_var_detail) - 1, -1, -1):
if index - 1 < 0:
break
current_block_exit = block_var_detail[index]["backward"]["exit"]
next_block_entrance = block_var_detail[index - 1]["backward"][
"entrance"]
need_delete_var = []
for var in current_block_exit:
if var not in next_block_entrance:
need_delete_var.append(var)
for var in need_delete_var:
current_block_exit.remove(var)
return block_var_detail return block_var_detail
...@@ -1065,6 +1530,8 @@ def screen_persistables(program, var_list): ...@@ -1065,6 +1530,8 @@ def screen_persistables(program, var_list):
need_remove = [] need_remove = []
for var_name in var_list: for var_name in var_list:
if "@GRAD" in var_name: if "@GRAD" in var_name:
if "GRAD" != var_name.split("@")[-1]:
continue
origin_var_name = var_name.split("@GRAD")[0] origin_var_name = var_name.split("@GRAD")[0]
var = program.global_block().vars[origin_var_name] var = program.global_block().vars[origin_var_name]
else: else:
...@@ -1168,27 +1635,40 @@ def insert_recv_slice_op(program, block, index, var_name, var_shape, dtype, ...@@ -1168,27 +1635,40 @@ def insert_recv_slice_op(program, block, index, var_name, var_shape, dtype,
index += 1 index += 1
def deleter_trainer_useless_var(config, program, static_var): def add_heter_trainer_useful_vars(config, program, heter_program, heter_block,
if config.role_maker._is_first_worker(): static_var):
return [] static_var = list(set(static_var))
for var_name in static_var:
if var_name not in heter_program.global_block(
).vars and var_name not in heter_block.vars:
var = program.global_block().vars[var_name]
if var.persistable:
heter_program.global_block()._clone_variable(
var, force_persistable=False)
else:
heter_block._clone_variable(var, force_persistable=False)
def delete_trainer_useless_var(config, program, static_var):
static_var = list(set(static_var)) static_var = list(set(static_var))
porgram_useful_var_list = [] program_useful_var_list = []
for op in program.global_block().ops: for op in program.global_block().ops:
input_var_list, output_var_list = find_op_input_output( input_var_list, output_var_list = find_op_input_output(
program, program.global_block(), op) program, program.global_block(), op)
op_var_list = list(set(input_var_list).union(set(output_var_list))) op_var_list = list(set(input_var_list).union(set(output_var_list)))
porgram_useful_var_list = list( program_useful_var_list = list(
set(porgram_useful_var_list).union(set(op_var_list))) set(program_useful_var_list).union(set(op_var_list)))
porgram_useful_var_list += static_var program_useful_var_list += static_var
program_useless_var_list = list( program_useless_var_list = list(
set(get_vars_name_in_block(program.global_block())).difference( set(get_vars_name_in_block(program.global_block())).difference(
set(porgram_useful_var_list))) set(program_useful_var_list)))
for var in program_useless_var_list: for var in program_useless_var_list:
program.global_block()._remove_var(var) program.global_block()._remove_var(var)
return program_useless_var_list return program_useless_var_list
def block_append_op(program, origin_program, block, op): def block_append_op(program, origin_program, block, op):
merge_ordereddict = origin_program.global_block().vars.copy() merge_ordereddict = origin_program.global_block().vars.copy()
merge_ordereddict.update(block.vars) merge_ordereddict.update(block.vars)
inputs = _get_input_map_from_op(merge_ordereddict, op) inputs = _get_input_map_from_op(merge_ordereddict, op)
...@@ -1242,7 +1722,8 @@ def block_append_op(program, origin_program, block, op): ...@@ -1242,7 +1722,8 @@ def block_append_op(program, origin_program, block, op):
def add_vars_by_var_list(var_name_list, origin_program, program, block): def add_vars_by_var_list(var_name_list, origin_program, program, block):
for var_name in var_name_list: for var_name in var_name_list:
if var_name not in program.global_block().vars: if var_name not in program.global_block(
).vars and var_name not in block.vars:
var = origin_program.global_block().vars[var_name] var = origin_program.global_block().vars[var_name]
if var.persistable: if var.persistable:
program.global_block()._clone_variable( program.global_block()._clone_variable(
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -37,7 +37,7 @@ fluid.default_startup_program().random_seed = 1 ...@@ -37,7 +37,7 @@ fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1 fluid.default_main_program().random_seed = 1
class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): class TestHeterPipelinePsCTR2x2(FleetDistHeterRunnerBase):
""" """
For test CTR model, using Fleet api For test CTR model, using Fleet api
""" """
...@@ -54,6 +54,7 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): ...@@ -54,6 +54,7 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
""" """
dnn_input_dim, lr_input_dim = int(1e5), int(1e5) dnn_input_dim, lr_input_dim = int(1e5), int(1e5)
with fluid.device_guard("cpu"):
dnn_data = fluid.layers.data( dnn_data = fluid.layers.data(
name="dnn_data", name="dnn_data",
shape=[-1, 1], shape=[-1, 1],
...@@ -75,13 +76,6 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): ...@@ -75,13 +76,6 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
datas = [dnn_data, lr_data, label] datas = [dnn_data, lr_data, label]
if args.reader == "pyreader":
self.reader = fluid.io.PyReader(
feed_list=datas,
capacity=64,
iterable=False,
use_double_buffer=False)
# build dnn model # build dnn model
dnn_layer_dims = [128, 64, 32, 1] dnn_layer_dims = [128, 64, 32, 1]
dnn_embedding = fluid.layers.embedding( dnn_embedding = fluid.layers.embedding(
...@@ -105,7 +99,8 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): ...@@ -105,7 +99,8 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
name="wide_embedding", name="wide_embedding",
initializer=fluid.initializer.Constant(value=0.01)), initializer=fluid.initializer.Constant(value=0.01)),
is_sparse=True) is_sparse=True)
lr_pool = fluid.layers.sequence_pool(input=lr_embbding, pool_type="sum") lr_pool = fluid.layers.sequence_pool(
input=lr_embbding, pool_type="sum")
with fluid.device_guard("gpu"): with fluid.device_guard("gpu"):
for i, dim in enumerate(dnn_layer_dims[1:]): for i, dim in enumerate(dnn_layer_dims[1:]):
...@@ -118,6 +113,7 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): ...@@ -118,6 +113,7 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
name='dnn-fc-%d' % i) name='dnn-fc-%d' % i)
dnn_out = fc dnn_out = fc
with fluid.device_guard("cpu"):
merge_layer = fluid.layers.concat(input=[dnn_out, lr_pool], axis=1) merge_layer = fluid.layers.concat(input=[dnn_out, lr_pool], axis=1)
label = fluid.layers.cast(label, dtype="int64") label = fluid.layers.cast(label, dtype="int64")
predict = fluid.layers.fc(input=merge_layer, size=2, act='softmax') predict = fluid.layers.fc(input=merge_layer, size=2, act='softmax')
...@@ -143,59 +139,40 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): ...@@ -143,59 +139,40 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
with open(os.path.join(dirname, "__model__.proto"), "w") as wn: with open(os.path.join(dirname, "__model__.proto"), "w") as wn:
wn.write(str(program)) wn.write(str(program))
def do_pyreader_training(self, fleet):
"""
do training using dataset, using fetch handler to catch variable
Args:
fleet(Fleet api): the fleet object of Parameter Server, define distribute training role
"""
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
fleet.init_worker()
batch_size = 4
train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size)
self.reader.decorate_sample_list_generator(train_reader)
for epoch_id in range(1):
self.reader.start()
try:
pass_start = time.time()
while True:
exe.run(program=fluid.default_main_program())
pass_time = time.time() - pass_start
except fluid.core.EOFException:
self.reader.reset()
if fleet.is_first_worker():
model_path = tempfile.mkdtemp()
fleet.save_persistables(executor=exe, dirname=model_path)
shutil.rmtree(model_path)
def do_dataset_training(self, fleet): def do_dataset_training(self, fleet):
train_file_list = ctr_dataset_reader.prepare_fake_data() train_file_list = ctr_dataset_reader.prepare_fake_data()
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
real_program = fluid.default_main_program()._heter_pipeline_opt[
"section_program"]
print(real_program)
exe.run(fluid.default_startup_program()) real_startup = fluid.default_startup_program()._heter_pipeline_opt[
"startup_program"]
exe.run(real_startup)
fleet.init_worker() fleet.init_worker()
thread_num = int(os.getenv("CPU_NUM", 2)) thread_num = int(os.getenv("CPU_NUM", 2))
batch_size = 128 batch_size = 128
filelist = fleet.util.get_file_shard(train_file_list)
block_size = len(train_file_list) // fleet.worker_num()
worker_id = fleet.worker_index()
filelist = train_file_list[worker_id * block_size:(worker_id + 1) *
block_size]
#filelist = fleet.util.get_file_shard(train_file_list)
print("filelist: {}".format(filelist)) print("filelist: {}".format(filelist))
# config dataset # config dataset
dataset = paddle.distributed.QueueDataset() dataset = fluid.DatasetFactory().create_dataset()
dataset._set_batch_size(batch_size) dataset.set_batch_size(batch_size)
dataset._set_use_var(self.feeds) dataset.set_use_var(self.feeds)
pipe_command = 'python ctr_dataset_reader.py' pipe_command = 'python3 ctr_dataset_reader.py'
dataset._set_pipe_command(pipe_command) dataset.set_pipe_command(pipe_command)
dataset.set_filelist(filelist) dataset.set_filelist(filelist)
dataset._set_thread(thread_num) dataset.set_thread(thread_num)
for epoch_id in range(1): for epoch_id in range(1):
pass_start = time.time() pass_start = time.time()
...@@ -209,7 +186,55 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): ...@@ -209,7 +186,55 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
debug=int(os.getenv("Debug", "0"))) debug=int(os.getenv("Debug", "0")))
pass_time = time.time() - pass_start pass_time = time.time() - pass_start
print("do_dataset_training done. using time {}".format(pass_time)) print("do_dataset_training done. using time {}".format(pass_time))
exe.close()
def do_dataset_heter_training(self, fleet):
fleet.init_heter_worker()
real_program = fluid.default_main_program()._heter_pipeline_opt[
"section_program"]
print(real_program)
train_file_list = ctr_dataset_reader.prepare_fake_data()
#exe = fluid.Executor(fluid.CPUPlace())
#exe.run(fluid.default_startup_program())
#fleet.init_worker()
thread_num = int(os.getenv("CPU_NUM", 2))
batch_size = 128
#filelist = fleet.util.get_file_shard(train_file_list)
block_size = len(train_file_list) // fleet.worker_num()
filelist = train_file_list[0:block_size]
print("filelist: {}".format(filelist))
# config dataset
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_batch_size(batch_size)
dataset.set_use_var(self.feeds)
pipe_command = 'python3 ctr_dataset_reader.py'
dataset.set_pipe_command(pipe_command)
dataset.set_filelist(filelist)
dataset.set_thread(thread_num)
fleet.run_heter_worker(dataset)
print("do_dataset_heter_training done. using time {}".format(pass_time))
#for epoch_id in range(1):
# pass_start = time.time()
# dataset.set_filelist(filelist)
# exe.train_from_dataset(
# program=fluid.default_main_program(),
# dataset=dataset,
# fetch_list=[self.avg_cost],
# fetch_info=["cost"],
# print_period=2,
# debug=int(os.getenv("Debug", "0")))
# pass_time = time.time() - pass_start
# print("do_dataset_heter_training done. using time {}".format(pass_time))
if __name__ == "__main__": if __name__ == "__main__":
runtime_main(TestHeterPsCTR2x2) runtime_main(TestHeterPipelinePsCTR2x2)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
fluid.disable_dygraph()
def get_dataset(inputs):
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_use_var(inputs)
dataset.set_batch_size(1)
dataset.set_filelist([])
dataset.set_thread(1)
return dataset
def net(batch_size=4, lr=0.01):
"""
network definition
Args:
batch_size(int): the size of mini-batch for training
lr(float): learning rate of training
Returns:
avg_cost: LoDTensor of cost.
"""
dnn_input_dim, lr_input_dim = int(2), int(2)
with fluid.device_guard("cpu"):
dnn_data = fluid.layers.data(
name="dnn_data",
shape=[-1, 1],
dtype="int64",
lod_level=1,
append_batch_size=False)
lr_data = fluid.layers.data(
name="lr_data",
shape=[-1, 1],
dtype="int64",
lod_level=1,
append_batch_size=False)
label = fluid.layers.data(
name="click",
shape=[-1, 1],
dtype="float32",
lod_level=0,
append_batch_size=False)
datas = [dnn_data, lr_data, label]
# build dnn model
dnn_layer_dims = [2, 1]
dnn_embedding = fluid.layers.embedding(
is_distributed=False,
input=dnn_data,
size=[dnn_input_dim, dnn_layer_dims[0]],
param_attr=fluid.ParamAttr(
name="deep_embedding",
initializer=fluid.initializer.Constant(value=0.01)),
is_sparse=True)
dnn_pool = fluid.layers.sequence_pool(
input=dnn_embedding, pool_type="sum")
dnn_out = dnn_pool
# build lr model
lr_embbding = fluid.layers.embedding(
is_distributed=False,
input=lr_data,
size=[lr_input_dim, 1],
param_attr=fluid.ParamAttr(
name="wide_embedding",
initializer=fluid.initializer.Constant(value=0.01)),
is_sparse=True)
lr_pool = fluid.layers.sequence_pool(input=lr_embbding, pool_type="sum")
with fluid.device_guard("gpu"):
for i, dim in enumerate(dnn_layer_dims[1:]):
fc = fluid.layers.fc(
input=dnn_out,
size=dim,
act="relu",
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01)),
name='dnn-fc-%d' % i)
dnn_out = fc
merge_layer = fluid.layers.concat(input=[dnn_out, lr_pool], axis=1)
label = fluid.layers.cast(label, dtype="int64")
predict = fluid.layers.fc(input=merge_layer, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
return datas, avg_cost
'''
optimizer = fluid.optimizer.Adam(learning_rate=0.01)
role = role_maker.PaddleCloudRoleMaker()
fleet.init(role)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True
strategy.a_sync_configs = {"heter_worker_device_guard": 'gpu'}
strategy.pipeline = True
strategy.pipeline_configs = {"accumulate_steps": 1, "micro_batch_size": 2048}
feeds, avg_cost = net()
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
dataset = get_dataset(feeds)
'''
if fleet.is_server():
pass
#fleet.init_server()
#fleet.run_server()
elif fleet.is_heter_worker():
pass
#fleet.init_heter_worker()
#fleet.run_heter_worker(dataset=dataset)
fleet.stop_worker()
elif fleet.is_worker():
pass
#place = fluid.CPUPlace()
#exe = fluid.Executor(place)
#exe.run(fluid.default_startup_program())
#fleet.init_worker()
#step = 1
#for i in range(step):
# exe.train_from_dataset(
# program=fluid.default_main_program(), dataset=dataset, debug=False)
#exe.close()
#fleet.stop_worker()
...@@ -52,27 +52,74 @@ class FleetDistHeterRunnerBase(object): ...@@ -52,27 +52,74 @@ class FleetDistHeterRunnerBase(object):
def build_role(self, args): def build_role(self, args):
environs = {} environs = {}
heter_trainer_endpoints = args.heter_trainer_endpoints.split(";")
all_heter_trainer_endpoints = ",".join(heter_trainer_endpoints)
if args.role.upper() == "PSERVER":
environs["PADDLE_PSERVERS_IP_PORT_LIST"] = args.endpoints environs["PADDLE_PSERVERS_IP_PORT_LIST"] = args.endpoints
environs["PADDLE_TRAINER_ENDPOINTS"] = args.trainer_endpoints environs["PADDLE_TRAINER_ENDPOINTS"] = args.trainer_endpoints
environs[ environs[
"PADDLE_HETER_TRAINER_IP_PORT_LIST"] = args.heter_trainer_endpoints "PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST"] = all_heter_trainer_endpoints
environs["PADDLE_HETER_TRAINER_DEVICE"] = args.heter_trainer_device
environs["TRAINING_ROLE"] = args.role.upper()
environs["PADDLE_TRAINERS_NUM"] = args.trainers
environs["PADDLE_TRAINER_ID"] = args.current_id
if args.role.upper() == "PSERVER":
environs["POD_IP"] = args.endpoints.split(",")[int( environs["POD_IP"] = args.endpoints.split(",")[int(
args.current_id)].split(":")[0] args.current_id)].split(":")[0]
environs["PADDLE_PORT"] = args.endpoints.split(",")[int( environs["PADDLE_PORT"] = args.endpoints.split(",")[int(
args.current_id)].split(":")[1] args.current_id)].split(":")[1]
environs["TRAINING_ROLE"] = args.role.upper()
environs["PADDLE_TRAINERS_NUM"] = args.trainers
elif args.role.upper() == "HETER_TRAINER": elif args.role.upper() == "HETER_TRAINER":
environs["POD_IP"] = args.heter_trainer_endpoints.split(",")[int( previous_endpoints = args.trainer_endpoints if args.stage_id == 2 else heter_trainer_endpoints[
0]
next_endpoints = heter_trainer_endpoints[
1] if args.stage_id == 2 else ""
heter_device = args.heter_trainer_device.split(";")[args.stage_id -
2]
environs["PADDLE_PSERVERS_IP_PORT_LIST"] = args.endpoints
environs["PADDLE_TRAINER_ENDPOINTS"] = args.trainer_endpoints
environs["PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST"] = next_endpoints
environs[
"PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST"] = previous_endpoints
environs[
"PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST"] = all_heter_trainer_endpoints
environs["HETER_DEVICE_TYPE"] = heter_device
environs["TRAINING_ROLE"] = args.role.upper()
environs["POD_IP"] = all_heter_trainer_endpoints.split(",")[int(
args.current_id)].split(":")[0] args.current_id)].split(":")[0]
environs["PADDLE_PORT"] = args.heter_trainer_endpoints.split(",")[ environs["PADDLE_PORT"] = all_heter_trainer_endpoints.split(",")[
int(args.current_id)].split(":")[1] int(args.current_id)].split(":")[1]
environs["FLAGS_selected_gpus"] = args.current_id environs["PADDLE_TRAINERS_NUM"] = args.trainers
environs["PADDLE_STAGE_TRAINERS_NUM"] = [2, 2, 2]
environs["FLAGS_selected_gpus"] = 0
environs["FLAGS_selected_xpus"] = 0
environs["CUDA_VISIBLE_DEVICES"] = 0
environs["XPU_VISIBLE_DEVICES"] = 0
environs["STAGE_ID"] = args.stage_id
environs["STAGE_NUM"] = 3
elif args.role.upper() == "TRAINER":
environs["PADDLE_PSERVERS_IP_PORT_LIST"] = args.endpoints
environs["PADDLE_TRAINER_ENDPOINTS"] = args.trainer_endpoints
environs[
"PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST"] = heter_trainer_endpoints[
0]
environs["PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST"] = ""
environs[
"PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST"] = all_heter_trainer_endpoints
environs["HETER_DEVICE_TYPE"] = "cpu"
environs["TRAINING_ROLE"] = args.role.upper()
environs["PADDLE_TRAINER_ID"] = args.current_id
environs["POD_IP"] = args.trainer_endpoints.split(",")[int(
args.current_id)].split(":")[0]
environs["PADDLE_PORT"] = args.trainer_endpoints.split(",")[int(
args.current_id)].split(":")[1]
environs["PADDLE_TRAINERS_NUM"] = args.trainers
environs["PADDLE_STAGE_TRAINERS_NUM"] = [2, 2, 2]
environs["FLAGS_selected_gpus"] = 0
environs["FLAGS_selected_xpus"] = 0
environs["CUDA_VISIBLE_DEVICES"] = 0
environs["XPU_VISIBLE_DEVICES"] = 0
environs["STAGE_ID"] = 1
environs["STAGE_NUM"] = 3
for k, v in environs.items(): for k, v in environs.items():
print(k, v)
os.environ[k] = str(v) os.environ[k] = str(v)
self.role = role_maker.PaddleCloudRoleMaker() self.role = role_maker.PaddleCloudRoleMaker()
...@@ -85,6 +132,11 @@ class FleetDistHeterRunnerBase(object): ...@@ -85,6 +132,11 @@ class FleetDistHeterRunnerBase(object):
"launch_barrier": True, "launch_barrier": True,
"heter_worker_device_guard": 'gpu' "heter_worker_device_guard": 'gpu'
} }
self.strategy.pipeline = True
self.strategy.pipeline_configs = {
"accumulate_steps": 1,
"micro_batch_size": 2048
}
return self.strategy return self.strategy
def build_optimizer(self, avg_cost, strategy): def build_optimizer(self, avg_cost, strategy):
...@@ -96,12 +148,12 @@ class FleetDistHeterRunnerBase(object): ...@@ -96,12 +148,12 @@ class FleetDistHeterRunnerBase(object):
fleet.init_server() fleet.init_server()
fleet.run_server() fleet.run_server()
def run_dataset_heter_trainer(self, args):
out = self.do_dataset_heter_training(fleet)
def run_dataset_trainer(self, args): def run_dataset_trainer(self, args):
out = self.do_dataset_training(fleet) out = self.do_dataset_training(fleet)
def run_pyreader_trainer(self, args):
out = self.do_pyreader_training(fleet)
def net(self, args, batch_size=4, lr=0.01): def net(self, args, batch_size=4, lr=0.01):
raise NotImplementedError( raise NotImplementedError(
"get_model should be implemented by child classes.") "get_model should be implemented by child classes.")
...@@ -110,9 +162,9 @@ class FleetDistHeterRunnerBase(object): ...@@ -110,9 +162,9 @@ class FleetDistHeterRunnerBase(object):
raise NotImplementedError( raise NotImplementedError(
"do_dataset_training should be implemented by child classes.") "do_dataset_training should be implemented by child classes.")
def do_pyreader_training(self, fleet): def do_dataset_heter_training(self, fleet):
raise NotImplementedError( raise NotImplementedError(
"do_pyreader_training should be implemented by child classes.") "do_dataset_heter_training should be implemented by child classes.")
class TestFleetHeterBase(unittest.TestCase): class TestFleetHeterBase(unittest.TestCase):
...@@ -132,12 +184,12 @@ class TestFleetHeterBase(unittest.TestCase): ...@@ -132,12 +184,12 @@ class TestFleetHeterBase(unittest.TestCase):
self.startTime = time.time() self.startTime = time.time()
self._mode = "async" self._mode = "async"
self._reader = "pyreader" self._reader = "dataset"
self._trainers = 2 self._trainers = 2
self._pservers = 2 self._pservers = 2
self._port_set = set() self._port_set = set()
self._heter_device = "gpu" self._heter_device = "gpu;cpu"
global DIST_UT_PORT global DIST_UT_PORT
if DIST_UT_PORT == 0 and os.getenv("PADDLE_DIST_UT_PORT"): if DIST_UT_PORT == 0 and os.getenv("PADDLE_DIST_UT_PORT"):
...@@ -151,7 +203,9 @@ class TestFleetHeterBase(unittest.TestCase): ...@@ -151,7 +203,9 @@ class TestFleetHeterBase(unittest.TestCase):
DIST_UT_PORT + 2, DIST_UT_PORT + 3) DIST_UT_PORT + 2, DIST_UT_PORT + 3)
self._heter_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( self._heter_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
DIST_UT_PORT + 4, DIST_UT_PORT + 5) DIST_UT_PORT + 4, DIST_UT_PORT + 5)
DIST_UT_PORT += 6 self._heter_endpoints_2 = "127.0.0.1:%s,127.0.0.1:%s" % (
DIST_UT_PORT + 6, DIST_UT_PORT + 7)
DIST_UT_PORT += 8
else: else:
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port()) self._find_free_port(), self._find_free_port())
...@@ -159,6 +213,8 @@ class TestFleetHeterBase(unittest.TestCase): ...@@ -159,6 +213,8 @@ class TestFleetHeterBase(unittest.TestCase):
self._find_free_port(), self._find_free_port()) self._find_free_port(), self._find_free_port())
self._heter_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( self._heter_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port()) self._find_free_port(), self._find_free_port())
self._heter_endpoints_2 = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port())
self._python_interp = sys.executable self._python_interp = sys.executable
self._geo_sgd_need_push_nums = 5 self._geo_sgd_need_push_nums = 5
...@@ -219,12 +275,17 @@ class TestFleetHeterBase(unittest.TestCase): ...@@ -219,12 +275,17 @@ class TestFleetHeterBase(unittest.TestCase):
return tr0_proc, tr1_proc, tr0_pipe, tr1_pipe return tr0_proc, tr1_proc, tr0_pipe, tr1_pipe
def _start_heter_trainer(self, cmd, required_envs): def _start_heter_trainer(self, cmd, required_envs):
heter0_cmd, heter1_cmd = cmd.format(0), cmd.format(1) heter0_cmd, heter1_cmd, heter2_cmd, heter3_cmd = cmd.format(
0, 2), cmd.format(1, 2), cmd.format(2, 3), cmd.format(3, 3)
heter0_pipe = open(tempfile.gettempdir() + "/heter0_err.log", "wb+") heter0_pipe = open(tempfile.gettempdir() + "/heter0_err.log", "wb+")
heter1_pipe = open(tempfile.gettempdir() + "/heter1_err.log", "wb+") heter1_pipe = open(tempfile.gettempdir() + "/heter1_err.log", "wb+")
heter2_pipe = open(tempfile.gettempdir() + "/heter2_err.log", "wb+")
heter3_pipe = open(tempfile.gettempdir() + "/heter3_err.log", "wb+")
heter0_out = open(tempfile.gettempdir() + "/heter0_out.log", "wb+") heter0_out = open(tempfile.gettempdir() + "/heter0_out.log", "wb+")
heter1_out = open(tempfile.gettempdir() + "/heter1_out.log", "wb+") heter1_out = open(tempfile.gettempdir() + "/heter1_out.log", "wb+")
heter2_out = open(tempfile.gettempdir() + "/heter2_out.log", "wb+")
heter3_out = open(tempfile.gettempdir() + "/heter3_out.log", "wb+")
heter0_proc = subprocess.Popen( heter0_proc = subprocess.Popen(
heter0_cmd.strip().split(" "), heter0_cmd.strip().split(" "),
...@@ -236,8 +297,18 @@ class TestFleetHeterBase(unittest.TestCase): ...@@ -236,8 +297,18 @@ class TestFleetHeterBase(unittest.TestCase):
stdout=heter1_out, stdout=heter1_out,
stderr=heter1_pipe, stderr=heter1_pipe,
env=required_envs) env=required_envs)
heter2_proc = subprocess.Popen(
heter2_cmd.strip().split(" "),
stdout=heter2_out,
stderr=heter2_pipe,
env=required_envs)
heter3_proc = subprocess.Popen(
heter3_cmd.strip().split(" "),
stdout=heter3_out,
stderr=heter3_pipe,
env=required_envs)
return heter0_proc, heter1_proc, heter0_pipe, heter1_pipe return heter0_proc, heter1_proc, heter2_proc, heter3_proc, heter0_pipe, heter1_pipe, heter2_pipe, heter3_pipe
def _run_cluster(self, model, envs): def _run_cluster(self, model, envs):
env = { env = {
...@@ -251,26 +322,31 @@ class TestFleetHeterBase(unittest.TestCase): ...@@ -251,26 +322,31 @@ class TestFleetHeterBase(unittest.TestCase):
envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '') envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
python_path += " -m coverage run --branch -p" python_path += " -m coverage run --branch -p"
env.update(envs) env.update(envs)
self._all_heter_endpoints = ";".join(
(self._heter_endpoints, self._heter_endpoints_2))
tr_cmd = "{0} {1} --role trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}".format( tr_cmd = "{0} {1} --role trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}".format(
python_path, model, self._ps_endpoints, self._tr_endpoints, python_path, model, self._ps_endpoints, self._tr_endpoints,
self._trainers, self._mode, self._geo_sgd_need_push_nums, self._trainers, self._mode, self._geo_sgd_need_push_nums,
self._reader, gloo_path, self._heter_endpoints, self._heter_device) self._reader, gloo_path, self._all_heter_endpoints,
self._heter_device)
ps_cmd = "{0} {1} --role pserver --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}".format( ps_cmd = "{0} {1} --role pserver --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}".format(
python_path, model, self._ps_endpoints, self._tr_endpoints, python_path, model, self._ps_endpoints, self._tr_endpoints,
self._trainers, self._mode, self._geo_sgd_need_push_nums, self._trainers, self._mode, self._geo_sgd_need_push_nums,
self._reader, gloo_path, self._heter_endpoints, self._heter_device) self._reader, gloo_path, self._all_heter_endpoints,
self._heter_device)
heter_cmd = "{0} {1} --role heter_trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}".format( heter_cmd = "{0} {1} --role heter_trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --stage_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}".format(
python_path, model, self._ps_endpoints, self._tr_endpoints, python_path, model, self._ps_endpoints, self._tr_endpoints,
self._trainers, self._mode, self._geo_sgd_need_push_nums, self._trainers, self._mode, self._geo_sgd_need_push_nums,
self._reader, gloo_path, self._heter_endpoints, self._heter_device) self._reader, gloo_path, self._all_heter_endpoints,
self._heter_device)
# Run dist train to compare with local results # Run dist train to compare with local results
ps0, ps1, ps0_pipe, ps1_pipe = self._start_pserver(ps_cmd, env) ps0, ps1, ps0_pipe, ps1_pipe = self._start_pserver(ps_cmd, env)
tr0, tr1, tr0_pipe, tr1_pipe = self._start_trainer(tr_cmd, env) tr0, tr1, tr0_pipe, tr1_pipe = self._start_trainer(tr_cmd, env)
heter0, heter1, heter0_pipe, heter1_pipe = self._start_heter_trainer( heter0, heter1, heter2, heter3, heter0_pipe, heter1_pipe, heter2_pipe, heter3_pipe = self._start_heter_trainer(
heter_cmd, env) heter_cmd, env)
# Wait until trainer process terminate # Wait until trainer process terminate
...@@ -300,11 +376,15 @@ class TestFleetHeterBase(unittest.TestCase): ...@@ -300,11 +376,15 @@ class TestFleetHeterBase(unittest.TestCase):
ps1_pipe.close() ps1_pipe.close()
heter0_pipe.close() heter0_pipe.close()
heter1_pipe.close() heter1_pipe.close()
heter2_pipe.close()
heter3_pipe.close()
ps0.terminate() ps0.terminate()
ps1.terminate() ps1.terminate()
heter0.terminate() heter0.terminate()
heter1.terminate() heter1.terminate()
heter2.terminate()
heter3.terminate()
self.assertEqual(tr0_ret, 0, "something wrong in tr0, please check") self.assertEqual(tr0_ret, 0, "something wrong in tr0, please check")
self.assertEqual(tr1_ret, 0, "something wrong in tr1, please check") self.assertEqual(tr1_ret, 0, "something wrong in tr1, please check")
shutil.rmtree(gloo_path) shutil.rmtree(gloo_path)
...@@ -349,6 +429,7 @@ def runtime_main(test_class): ...@@ -349,6 +429,7 @@ def runtime_main(test_class):
parser.add_argument('--gloo_path', type=str, required=False, default="") parser.add_argument('--gloo_path', type=str, required=False, default="")
parser.add_argument('--current_id', type=int, required=False, default=0) parser.add_argument('--current_id', type=int, required=False, default=0)
parser.add_argument('--trainers', type=int, required=False, default=1) parser.add_argument('--trainers', type=int, required=False, default=1)
parser.add_argument('--stage_id', type=int, required=False, default=1)
parser.add_argument('--mode', type=str, required=False, default='async') parser.add_argument('--mode', type=str, required=False, default='async')
parser.add_argument( parser.add_argument(
'--geo_sgd_need_push_nums', type=int, required=False, default=2) '--geo_sgd_need_push_nums', type=int, required=False, default=2)
...@@ -362,11 +443,11 @@ def runtime_main(test_class): ...@@ -362,11 +443,11 @@ def runtime_main(test_class):
avg_cost = model.net(args) avg_cost = model.net(args)
model.build_optimizer(avg_cost, strategy) model.build_optimizer(avg_cost, strategy)
if args.role == "pserver" or args.role == "heter_trainer": if args.role == "pserver":
model.run_pserver(args) model.run_pserver(args)
elif args.role == "heter_trainer":
model.run_dataset_heter_trainer(args)
fleet.stop_worker()
else: else:
if args.reader == "dataset":
model.run_dataset_trainer(args) model.run_dataset_trainer(args)
else:
model.run_pyreader_trainer(args)
fleet.stop_worker() fleet.stop_worker()
...@@ -23,10 +23,10 @@ import paddle ...@@ -23,10 +23,10 @@ import paddle
paddle.enable_static() paddle.enable_static()
class TestDistHeterPyreaderAsync2x2(TestFleetHeterBase): class TestDistHeterDatasetAsync2x2(TestFleetHeterBase):
def _setup_config(self): def _setup_config(self):
self._mode = "async" self._mode = "async"
self._reader = "pyreader" self._reader = "dataset"
def check_with_place(self, def check_with_place(self,
model_file, model_file,
...@@ -45,14 +45,16 @@ class TestDistHeterPyreaderAsync2x2(TestFleetHeterBase): ...@@ -45,14 +45,16 @@ class TestDistHeterPyreaderAsync2x2(TestFleetHeterBase):
required_envs.update(need_envs) required_envs.update(need_envs)
if check_error_log: if check_error_log:
required_envs["GLOG_v"] = "3" required_envs["GLOG_v"] = "4"
required_envs["GLOG_logtostderr"] = "1" required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs) tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
def test_dist_train(self): def test_dist_train(self):
self.check_with_place( self.check_with_place(
"dist_fleet_heter_ctr.py", delta=1e-5, check_error_log=True) "dist_fleet_heter_pipeline_ctr.py",
delta=1e-5,
check_error_log=True)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -32,9 +32,15 @@ class TestDistFleetHeterProgram(unittest.TestCase): ...@@ -32,9 +32,15 @@ class TestDistFleetHeterProgram(unittest.TestCase):
"PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36012,127.0.0.1:36013" "PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36012,127.0.0.1:36013"
environs["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36014,127.0.0.1:36015" environs["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36014,127.0.0.1:36015"
environs[ environs[
"PADDLE_HETER_TRAINER_IP_PORT_LIST"] = "127.0.0.1:36016,127.0.0.1:36017" "PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST"] = "127.0.0.1:36016,127.0.0.1:36017"
environs[
"PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST"] = "127.0.0.1:36014,127.0.0.1:36015"
environs["PADDLE_HETER_TRAINER_DEVICE"] = "gpu" environs["PADDLE_HETER_TRAINER_DEVICE"] = "gpu"
environs["TRAINING_ROLE"] = "HETER_TRAINER" environs["TRAINING_ROLE"] = "HETER_TRAINER"
environs["STAGE_ID"] = 2
environs["STAGE_NUM"] = 2
environs["HETER_DEVICE_TYPE"] = "gpu"
environs["PADDLE_STAGE_TRAINERS_NUM"] = [2, 2]
environs["PADDLE_TRAINERS_NUM"] = 2 environs["PADDLE_TRAINERS_NUM"] = 2
environs["PADDLE_TRAINER_ID"] = 0 environs["PADDLE_TRAINER_ID"] = 0
environs["POD_IP"] = "127.0.0.1" environs["POD_IP"] = "127.0.0.1"
......
...@@ -23,6 +23,7 @@ import paddle.fluid as fluid ...@@ -23,6 +23,7 @@ import paddle.fluid as fluid
class TestFleetBase(unittest.TestCase): class TestFleetBase(unittest.TestCase):
def setUp(self): def setUp(self):
os.environ["POD_IP"] = "127.0.0.1" os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36000"
os.environ["PADDLE_TRAINERS_NUM"] = "2" os.environ["PADDLE_TRAINERS_NUM"] = "2"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \ os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \
"127.0.0.1:36001,127.0.0.2:36001" "127.0.0.1:36001,127.0.0.2:36001"
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
set -e set -e
server_port_00=${PADDLE_DIST_UT_PORT} server_port_00=$(( PADDLE_DIST_UT_PORT ))
server_port_10=$(( PADDLE_DIST_UT_PORT + 1 )) server_port_10=$(( PADDLE_DIST_UT_PORT + 1 ))
worker_port_00=$(( PADDLE_DIST_UT_PORT + 2 )) worker_port_00=$(( PADDLE_DIST_UT_PORT + 2 ))
worker_port_10=$(( PADDLE_DIST_UT_PORT + 3 )) worker_port_10=$(( PADDLE_DIST_UT_PORT + 3 ))
...@@ -30,12 +30,11 @@ heter_worker_port_0=$(( PADDLE_DIST_UT_PORT + 8 )) ...@@ -30,12 +30,11 @@ heter_worker_port_0=$(( PADDLE_DIST_UT_PORT + 8 ))
heter_worker_port_1=$(( PADDLE_DIST_UT_PORT + 9 )) heter_worker_port_1=$(( PADDLE_DIST_UT_PORT + 9 ))
function test_launch_ps(){ function test_launch_ps(){
python -m paddle.distributed.fleet.launch \ python -m paddle.distributed.fleet.launch \
--servers="127.0.0.1:${server_port_00},127.0.0.1:${server_port_10}" \ --servers="127.0.0.1:${server_port_00},127.0.0.1:${server_port_10}" \
--workers="127.0.0.1:${worker_port_00},127.0.0.1:${worker_port_10}" \ --workers="127.0.0.1:${worker_port_00},127.0.0.1:${worker_port_10}" \
fleet_ps_training.py 2> ut.elog fleet_ps_training.py 2> ut1.elog
if grep -q "server are killed" ut.elog; then if grep -q "server are killed" ut1.elog; then
echo "test pserver launch succeed" echo "test pserver launch succeed"
else else
echo "test pserver launch failed" echo "test pserver launch failed"
...@@ -48,11 +47,12 @@ function test_launch_ps_heter(){ ...@@ -48,11 +47,12 @@ function test_launch_ps_heter(){
--servers="127.0.0.1:${server_port_01},127.0.0.1:${server_port_11}" \ --servers="127.0.0.1:${server_port_01},127.0.0.1:${server_port_11}" \
--workers="127.0.0.1:${worker_port_01},127.0.0.1:${worker_port_11}" \ --workers="127.0.0.1:${worker_port_01},127.0.0.1:${worker_port_11}" \
--heter_workers="127.0.0.1:${heter_worker_port_0},127.0.0.1:${heter_worker_port_1}" \ --heter_workers="127.0.0.1:${heter_worker_port_0},127.0.0.1:${heter_worker_port_1}" \
fleet_ps_training.py 2> ut.elog --heter_devices="gpu" \
if grep -q "server are killed" ut.elog; then fleet_heter_ps_training.py 2> ut2.elog
echo "test heter pserver launch succeed" if grep -q "server are killed" ut2.elog; then
echo "test heter trainer launch succeed"
else else
echo "test pserver launch failed" echo "test heter trainer launch failed"
exit -1 exit -1
fi fi
} }
......
...@@ -17,7 +17,7 @@ import sys ...@@ -17,7 +17,7 @@ import sys
import os import os
__all__ = [ __all__ = [
'TrainerDesc', 'MultiTrainer', 'DistMultiTrainer', 'PipelineTrainer', 'TrainerDesc', 'MultiTrainer', 'DistMultiTrainer', 'PipelineTrainer',
'HeterXpuTrainer' 'HeterXpuTrainer', 'HeterPipelineTrainer'
] ]
...@@ -118,6 +118,13 @@ class TrainerDesc(object): ...@@ -118,6 +118,13 @@ class TrainerDesc(object):
def _set_program(self, program): def _set_program(self, program):
self._program = program self._program = program
def _set_trainer_id(self, trainer_id):
self.proto_desc.trainer_id = trainer_id
def _set_trainers(self, trainers):
for trainer_num in trainers:
self.proto_desc.trainers.append(trainer_num)
def _set_use_cvm(self, use_cvm=False): def _set_use_cvm(self, use_cvm=False):
self.proto_desc.use_cvm = use_cvm self.proto_desc.use_cvm = use_cvm
...@@ -374,6 +381,30 @@ class PSGPUTrainer(TrainerDesc): ...@@ -374,6 +381,30 @@ class PSGPUTrainer(TrainerDesc):
self._device_worker._gen_worker_desc(self.proto_desc) self._device_worker._gen_worker_desc(self.proto_desc)
class HeterPipelineTrainer(TrainerDesc):
"""
Implement of HeterPipelineTrainer.
It's for HeterPS Pipeline training.
"""
def __init__(self):
super(HeterPipelineTrainer, self).__init__()
pass
def _set_program(self, program):
super(HeterPipelineTrainer, self)._set_program(program)
self._program = program
def _gen_trainer_desc(self):
super(HeterPipelineTrainer, self)._gen_trainer_desc()
self.proto_desc.class_name = "HeterPipelineTrainer"
if self._program == None:
raise RuntimeError("None Program")
self._device_worker._set_infer(self._infer)
self._device_worker._set_program(self._program)
self._device_worker._gen_worker_desc(self.proto_desc)
class PipelineTrainer(TrainerDesc): class PipelineTrainer(TrainerDesc):
""" """
Implement of PipelineTrainer. Implement of PipelineTrainer.
......
...@@ -22,8 +22,8 @@ from paddle.fluid.log_helper import get_logger ...@@ -22,8 +22,8 @@ from paddle.fluid.log_helper import get_logger
local_logger = get_logger( local_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer, HeterXpuTrainer, PSGPUTrainer from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer, HeterXpuTrainer, PSGPUTrainer, HeterPipelineTrainer
from .device_worker import Hogwild, DownpourSGD, Section, DownpourSGDOPT from .device_worker import Hogwild, DownpourSGD, Section, DownpourSGDOPT, HeterSection
from .framework import Variable from .framework import Variable
from multiprocessing import Process, Manager from multiprocessing import Process, Manager
...@@ -56,6 +56,10 @@ class TrainerFactory(object): ...@@ -56,6 +56,10 @@ class TrainerFactory(object):
# for debug tools # for debug tools
if opt_info is not None: if opt_info is not None:
if opt_info.get("trainers") is not None:
trainer._set_trainers(opt_info["trainers"])
if opt_info.get("trainer_id") is not None:
trainer._set_trainer_id(opt_info["trainer_id"])
if opt_info.get("dump_slot") is not None: if opt_info.get("dump_slot") is not None:
trainer._set_dump_slot(opt_info["dump_slot"]) trainer._set_dump_slot(opt_info["dump_slot"])
if opt_info.get("mpi_rank") is not None: if opt_info.get("mpi_rank") is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册