未验证 提交 0073f9bd 编写于 作者: T Thunderbrook 提交者: GitHub

support ps-gpu (#28752)

* ps gpu transpile

* ps gpu

* remove op

* gps trainer

* local ps

* add macro

* HeterBox

* def cuda

* tab

* code style

* style

Co-authored-by: Thunderbrook <a754913769#163.com>
上级 768dab44
......@@ -200,23 +200,41 @@ cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context
cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector)
if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell
fleet_wrapper heter_wrapper box_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto timer monitor
heter_service_proto)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
if(WITH_PSLIB)
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc
heterbox_worker.cc heterbox_trainer.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell
fleet_wrapper heter_wrapper box_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto timer monitor
heter_service_proto pslib_brpc)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc
heterbox_worker.cc heterbox_trainer.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell
fleet_wrapper heter_wrapper box_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto timer monitor
heter_service_proto)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
endif()
elseif(WITH_PSLIB)
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc downpour_worker.cc downpour_worker_opt.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc
heterbox_worker.cc heterbox_trainer.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper lodtensor_printer feed_fetch_method
......@@ -229,7 +247,8 @@ else()
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc downpour_worker.cc downpour_worker_opt.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc
heterbox_worker.cc heterbox_trainer.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper lodtensor_printer feed_fetch_method
......
......@@ -92,6 +92,7 @@ class PullDenseWorker {
void Wait(std::vector<::std::future<int32_t>>* status_vec);
void PullDense(bool force_update = false);
void CreatePinVar();
void MergeDenseParam();
int GetThreadIdByScope(const Scope* scope);
void SetThreadIdByScope(const Scope* scope, int tid);
static std::shared_ptr<PullDenseWorker> GetInstance() {
......@@ -164,7 +165,12 @@ class DeviceWorker {
virtual void SetDataFeed(DataFeed* data_feed);
virtual void SetWorkerNum(int num) {}
virtual void CacheProgram(const ProgramDesc& main_program) {}
virtual void ProduceTasks() {}
virtual void GetXpuOpIndex() {}
#ifdef PADDLE_WITH_CUDA
virtual void SetStream(const cudaStream_t stream) {}
virtual void SetEvent(const cudaEvent_t event) {}
#endif
virtual void SetNeedDumpField(bool need_dump_field) {
need_dump_field_ = need_dump_field;
}
......@@ -187,6 +193,7 @@ class DeviceWorker {
device_reader_->SetPlace(place);
}
virtual Scope* GetThreadScope() { return thread_scope_; }
DataFeed* device_reader_ = nullptr;
protected:
virtual void DumpParam(const Scope& scope, const int batch_id);
......@@ -195,7 +202,6 @@ class DeviceWorker {
Scope* root_scope_ = nullptr;
Scope* thread_scope_;
paddle::platform::Place place_;
DataFeed* device_reader_ = nullptr;
int64_t batch_num_;
FetchConfig fetch_config_;
bool use_cvm_;
......@@ -431,6 +437,106 @@ class HeterCpuWorker : public HogwildWorker {
};
#endif
#if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU) && \
(defined PADDLE_WITH_PSLIB)
class HeterBoxWorker : public HogwildWorker {
public:
HeterBoxWorker() {}
virtual ~HeterBoxWorker() {}
virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles();
virtual void SetNeedDump(bool need_dump_field);
virtual void SetChannelWriter(ChannelObject<std::string>* queue);
virtual void SetWorkerNum(int num) { worker_num_ = num; }
virtual void CacheProgram(const ProgramDesc& main_program) {
new (&program_) ProgramDesc(main_program);
}
virtual void ProduceTasks() override;
virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; }
virtual void SetEvent(const cudaEvent_t event) { event_ = event; }
virtual void TrainFilesWithProfiler() {}
void ResetStat();
protected:
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
void FillSparseValue(std::shared_ptr<HeterTask> task, size_t table_id);
void PushGradients();
void CollectLabelInfo(std::shared_ptr<HeterTask> task, size_t table_id);
void AdjustInsWeight(std::shared_ptr<HeterTask> task);
void DumpParam();
void CopySparseTable();
void CopyDenseTable();
void CopyDenseVars();
private:
int mpi_rank_;
std::mutex mutex_;
std::vector<std::string> send_var_list_;
int worker_num_;
ProgramDesc program_;
HeterObjectPool<HeterTask> object_pool_;
bool need_dump_param_;
std::vector<std::string> dump_param_;
bool need_to_push_dense_;
bool need_dump_field_;
bool dump_slot_;
bool need_to_push_sparse_;
std::vector<std::string> dump_fields_;
ChannelWriter<std::string> writer_;
DownpourWorkerParameter param_;
float scale_datanorm_;
// just save the value in param_ for easy access
std::map<uint64_t, std::string> label_var_name_;
std::map<uint64_t, std::vector<std::string>> sparse_key_names_;
std::map<uint64_t, std::vector<std::string>> sparse_value_names_;
std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
std::map<uint64_t, std::vector<std::string>> dense_value_names_;
std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
platform::Place root_place_;
// actually pushed feasign of each table
std::map<uint64_t, std::vector<uint64_t>> sparse_push_keys_;
// skipped ops
std::vector<std::string> skip_ops_;
std::vector<::std::future<int32_t>> push_sparse_status_;
std::vector<::std::future<int32_t>> push_dense_status_;
// adjust ins weight
AdjustInsWeightConfig adjust_ins_weight_config_;
std::vector<float> nid_show_;
// check nan and inf during training
std::vector<std::string> check_nan_var_names_;
// copy table
CopyTableConfig copy_table_config_;
std::map<uint64_t, uint64_t> table_dependency_;
std::vector<std::pair<uint64_t, uint64_t>> copy_sparse_tables_;
std::vector<std::pair<uint64_t, uint64_t>> copy_dense_tables_;
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> feasign_set_;
paddle::framework::Channel<std::shared_ptr<HeterTask>> pull_queue_;
paddle::framework::Channel<std::shared_ptr<HeterTask>> push_queue_;
cudaEvent_t event_;
cudaStream_t copy_stream_;
int batch_cnt_{0};
std::atomic<int> done_cnt_{0};
double total_time_;
double read_time_;
double pack_time_;
double pull_sparse_local_time_;
double op_all_time_;
double xpu_op_time_;
double xpu_wait_time_;
double cpu_op_time_;
double collect_label_time_;
double fill_sparse_time_;
double push_sparse_time_;
double gpu_2_cpu_time_;
double cpu_2_gpu_time_;
uint64_t total_inst_;
};
#endif
#if defined(PADDLE_WITH_NCCL)
class SectionWorker : public DeviceWorker {
public:
......
......@@ -66,6 +66,7 @@ REGISTER_DEVICE_WORKER_CLASS(DownpourWorker);
REGISTER_DEVICE_WORKER_CLASS(DownpourWorkerOpt);
#ifdef PADDLE_WITH_PSLIB
REGISTER_DEVICE_WORKER_CLASS(HeterCpuWorker);
REGISTER_DEVICE_WORKER_CLASS(HeterBoxWorker);
#endif
#if defined(PADDLE_WITH_NCCL)
REGISTER_DEVICE_WORKER_CLASS(SectionWorker);
......
......@@ -214,12 +214,11 @@ void FleetWrapper::HeterPullSparseVars(
}
void FleetWrapper::HeterPushSparseVars(
std::shared_ptr<HeterTask> task, const uint64_t table_id,
const std::vector<std::string>& sparse_key_names,
std::shared_ptr<HeterTask> task, const Scope& scope,
const uint64_t table_id, const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<::std::future<int32_t>>* push_sparse_status, const bool use_cvm,
const bool dump_slot, const bool no_cvm) {
auto& scope = *(task->scope_);
int batch_size = task->cur_batch_;
int offset = 2;
int slot_offset = 0;
......
......@@ -95,8 +95,8 @@ class FleetWrapper {
const std::vector<std::string>& var_emb_names);
void HeterPushSparseVars(
std::shared_ptr<HeterTask> task, const uint64_t table_id,
const std::vector<std::string>& sparse_key_names,
std::shared_ptr<HeterTask> task, const Scope& scope,
const uint64_t table_id, const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<::std::future<int32_t>>* push_sparse_status,
const bool use_cvm, const bool dump_slot, const bool no_cvm);
......
......@@ -88,12 +88,10 @@ class HeterWrapper {
#ifdef PADDLE_WITH_CUDA
void DeSerializeToTensor(Scope* scope, const VariableMessage& req_var,
platform::Place place,
cudaStream_t stream = nullptr);
#else
platform::Place place, cudaStream_t stream);
#endif
void DeSerializeToTensor(Scope* scope, const VariableMessage& req_var,
platform::Place place);
#endif
// HeterWrapper singleton
static std::shared_ptr<HeterWrapper> GetInstance() {
if (NULL == s_instance_) {
......
......@@ -29,6 +29,7 @@ limitations under the License. */
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace framework {
......@@ -100,6 +101,9 @@ class HeterTask {
collect_label_time = 0;
fill_sparse_time = 0;
push_sparse_time = 0;
gpu_2_cpu_time = 0;
cpu_2_gpu_time = 0;
timeline.Reset();
}
void Show() {
std::cout << "features size " << features_.size() << std::endl;
......@@ -110,6 +114,8 @@ class HeterTask {
}
void PackTask(Scope* scope, int taskid, DataFeed* reader, int cur_batch,
const ProgramDesc& program);
void PackGpuTask(Scope* thread_scope, DataFeed* reader,
const ProgramDesc& program);
Scope* scope_{nullptr};
int taskid_;
......@@ -132,6 +138,9 @@ class HeterTask {
double collect_label_time{0};
double fill_sparse_time{0};
double push_sparse_time{0};
double gpu_2_cpu_time{0};
double cpu_2_gpu_time{0};
platform::Timer timeline;
};
template <class T>
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cstdlib>
#include <string>
#include <vector>
#include "io/fs.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/trainer.h"
#if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU) && \
(defined PADDLE_WITH_PSLIB)
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
namespace paddle {
namespace framework {
void HeterBoxTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
thread_num_ = trainer_desc.thread_num();
param_ = trainer_desc.downpour_param();
for (int i = 0; i < param_.dense_table_size(); ++i) {
uint64_t table_id = static_cast<uint64_t>(param_.dense_table(i).table_id());
auto table = param_.dense_table(i);
dense_grad_names_[table_id].resize(table.dense_grad_name_size());
for (int j = 0; j < table.dense_grad_name_size(); ++j) {
dense_grad_names_[table_id][j] = table.dense_grad_name(j);
}
}
RegisterHeterCallback();
scale_datanorm_ = trainer_desc.scale_datanorm();
int place_num = trainer_desc.worker_places_size();
const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders();
for (int i = 0; i < place_num; ++i) {
int num = trainer_desc.worker_places(i);
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace place = platform::CUDAPlace(num);
platform::CUDADeviceGuard guard(place.device);
cudaStream_t stream;
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream));
copy_streams_.push_back(stream);
places_.push_back(place);
cudaEvent_t event;
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
events_.push_back(event);
#endif
#ifdef PADDLE_WITH_XPU
platform::XPUPlace place = platform::XPUPlace(num);
places_.push_back(place);
#endif
}
for (int i = 0; i < trainer_desc.downpour_param().stat_var_names_size();
i++) {
need_merge_var_names_.push_back(
trainer_desc.downpour_param().stat_var_names(i));
}
VLOG(3) << "going to initialize pull dense worker";
pull_dense_worker_ = PullDenseWorker::GetInstance();
pull_dense_worker_->Initialize(trainer_desc);
VLOG(3) << "initialize pull dense worker";
SetDebug(trainer_desc.debug());
fleet_ptr_ = FleetWrapper::GetInstance();
trainer_desc_ = trainer_desc;
workers_.resize(place_num);
for (int i = 0; i < place_num; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
workers_[i]->SetDeviceIndex(i);
workers_[i]->SetDataFeed(readers[i]);
workers_[i]->Initialize(trainer_desc);
workers_[i]->SetWorkerNum(place_num);
}
}
void HeterBoxTrainer::DumpWork(int tid) {}
void HeterBoxTrainer::RegisterHeterCallback() {
auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->RegisterHeterCallback([this](int worker, int taskid) {
// workers_[worker]->Schedule(taskid);
});
}
void HeterBoxTrainer::InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) {
for (size_t i = 0; i < places_.size(); ++i) {
workers_[i]->SetPlace(places_[i]);
workers_[i]->SetStream(copy_streams_[i]);
workers_[i]->SetEvent(events_[i]);
workers_[i]->SetReaderPlace(platform::CPUPlace());
workers_[i]->SetRootScope(root_scope_);
workers_[i]->CreateDeviceResource(main_program); // Program
workers_[i]->BindingDataFeedMemory();
#ifdef PADDLE_WITH_PSLIB
workers_[i]->CacheProgram(main_program);
#endif
}
for (size_t num = 0; num < places_.size(); ++num) {
auto place = places_[num];
Scope* scope = workers_[num]->GetThreadScope();
auto stream = copy_streams_[num];
auto event = events_[num];
auto dev_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
platform::CUDADeviceGuard guard(dev_id);
auto& block = main_program.Block(0);
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto name = var->Name();
Variable* root_var = root_scope_->FindVar(name);
if (!root_var) {
continue;
}
LoDTensor* root_tensor = root_var->GetMutable<LoDTensor>();
auto* ptr = scope->Var(name);
InitializeVariable(ptr, proto::VarType::LOD_TENSOR);
LoDTensor* thread_tensor = ptr->GetMutable<LoDTensor>();
#define HeterMemcpyFunc(cpp_type, proto_type) \
do { \
if (root_tensor->type() == proto_type) { \
HeterMemCpy<cpp_type>(thread_tensor, root_tensor, place, stream); \
} \
} while (0)
_ForEachDataType_(HeterMemcpyFunc);
}
}
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, stream));
cudaEventSynchronize(event);
}
place_ = place;
}
template <typename T>
void HeterBoxTrainer::HeterMemCpy(LoDTensor* thread_tensor,
LoDTensor* root_tensor,
const paddle::platform::Place& thread_place,
cudaStream_t stream) {
T* thread_ptr =
thread_tensor->mutable_data<T>(root_tensor->dims(), thread_place);
T* root_ptr = root_tensor->data<T>();
if (platform::is_cpu_place(root_tensor->place())) {
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, thread_place), thread_ptr,
platform::CPUPlace(), root_ptr,
sizeof(T) * root_tensor->numel(), stream);
} else {
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, thread_place), thread_ptr,
BOOST_GET_CONST(platform::CUDAPlace, root_tensor->place()),
root_ptr, sizeof(T) * root_tensor->numel(), stream);
}
}
void HeterBoxTrainer::InitOtherEnv(const ProgramDesc& main_program) {
pull_dense_worker_->SetRootScope(root_scope_);
pull_dense_worker_->CreatePinVar();
for (size_t i = 0; i < places_.size(); ++i) {
pull_dense_worker_->AddThreadScope(workers_[i]->GetThreadScope());
pull_dense_worker_->AddPlace(places_[i]);
#ifdef PADDLE_WITH_CUDA
pull_dense_worker_->AddStream(copy_streams_[i]);
#endif
}
VLOG(3) << "init other env done.";
}
void HeterBoxTrainer::Run() {
int pull_thread_num = 3 * places_.size();
for (size_t thidx = 0; thidx < places_.size(); ++thidx) {
workers_[thidx]->device_reader_->Start();
std::dynamic_pointer_cast<paddle::framework::HeterBoxWorker>(
workers_[thidx])
->ResetStat();
}
for (int i = 0; i < pull_thread_num; ++i) {
int worker_id = i % places_.size();
pull_threads_.push_back(
std::thread(&DeviceWorker::ProduceTasks, workers_[worker_id].get()));
}
for (size_t thidx = 0; thidx < places_.size(); ++thidx) {
threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get()));
}
}
template <typename T>
void HeterBoxTrainer::MergeToRootScope(LoDTensor* root_tensor,
LoDTensor* tensor) {
LoDTensor tmp_root;
TensorCopy(*root_tensor, platform::CPUPlace(), &tmp_root);
T* tmp_root_data = tmp_root.data<T>();
LoDTensor tmp_tensor;
TensorCopy(*tensor, platform::CPUPlace(), &tmp_tensor);
T* data = tmp_tensor.data<T>();
for (int i = 0; i < tmp_tensor.numel(); i++) {
tmp_root_data[i] += data[i];
}
TensorCopy(tmp_root, platform::CPUPlace(), root_tensor);
}
Scope* HeterBoxTrainer::GetWorkerScope(int thread_id) { return nullptr; }
void HeterBoxTrainer::Finalize() {
for (auto& th : pull_threads_) {
th.join();
}
for (auto& th : threads_) {
th.join();
}
for (size_t i = 0; i < need_merge_var_names_.size(); i++) {
Variable* root_var = root_scope_->FindVar(need_merge_var_names_[i]);
if (root_var == nullptr) {
continue;
}
LoDTensor* root_tensor = root_var->GetMutable<LoDTensor>();
for (size_t j = 0; j < places_.size(); j++) {
Scope* cur_thread_scope = workers_[j]->GetThreadScope();
Variable* thread_var =
cur_thread_scope->FindVar(need_merge_var_names_[i]);
if (thread_var == nullptr) {
continue;
}
LoDTensor* thread_tensor = thread_var->GetMutable<LoDTensor>();
#define MergeCallback(cpp_type, proto_type) \
do { \
if (root_tensor->type() == proto_type) { \
if (thread_tensor->type() != proto_type) { \
VLOG(0) << "Error: thread id=" << j << ", need_merge_var_names_[" << i \
<< "] " << need_merge_var_names_[i] \
<< ", root tensor type=" << root_tensor->type() \
<< ", thread tensor type=" << thread_tensor->type(); \
exit(-1); \
} \
MergeToRootScope<cpp_type>(root_tensor, thread_tensor); \
} \
} while (0)
_ForEachDataType_(MergeCallback);
}
}
pull_dense_worker_->MergeDenseParam();
root_scope_->DropKids();
}
} // namespace framework
} // namespace paddle
#endif
此差异已折叠。
......@@ -811,9 +811,9 @@ void HeterCpuWorker::TrainFilesWithProfiler() {
}
timeline.Start();
fleet_ptr_->HeterPushSparseVars(
task, tid, sparse_key_names_[tid], sparse_grad_names_[tid],
table.emb_dim(), &push_sparse_status_, use_cvm_, dump_slot_,
no_cvm_);
task, *(task->scope_), tid, sparse_key_names_[tid],
sparse_grad_names_[tid], table.emb_dim(), &push_sparse_status_,
use_cvm_, dump_slot_, no_cvm_);
timeline.Pause();
task->push_sparse_time += timeline.ElapsedSec();
task->total_time += timeline.ElapsedSec();
......@@ -1074,9 +1074,9 @@ void HeterCpuWorker::TrainFiles() {
}
}
fleet_ptr_->HeterPushSparseVars(
task, tid, sparse_key_names_[tid], sparse_grad_names_[tid],
table.emb_dim(), &push_sparse_status_, use_cvm_, dump_slot_,
no_cvm_);
task, *(task->scope_), tid, sparse_key_names_[tid],
sparse_grad_names_[tid], table.emb_dim(), &push_sparse_status_,
use_cvm_, dump_slot_, no_cvm_);
}
}
......
......@@ -415,7 +415,7 @@ int HeterXpuTrainer::RunTask(const HeterRequest* request,
std::shared_ptr<HeterServiceContext> context = object_pool_.Get();
if (!context->scope_) {
int num = rand_r() % places_.size();
int num = rand() % places_.size();
context->place_num_ = num;
auto place = places_[num];
context->scope_ = &(place_scopes_[num]->NewScope());
......
......@@ -225,5 +225,22 @@ void PullDenseWorker::SetThreadIdByScope(const Scope* scope, int tid) {
scope_to_thread_id_[scope] = tid;
}
void PullDenseWorker::MergeDenseParam() {
for (int x = 0; x < dwp_param_.program_config(0).pull_dense_table_id_size();
++x) {
uint64_t tid = static_cast<uint64_t>(
dwp_param_.program_config(0).pull_dense_table_id(x));
for (size_t j = 0; j < dense_value_names_[tid].size(); j++) {
auto& name = dense_value_names_[tid][j];
Variable* root_var = root_scope_->FindVar(name);
LoDTensor* root_tensor = root_var->GetMutable<LoDTensor>();
Variable* var = thread_scopes_[0]->FindVar(name);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
TensorCopy((*tensor), root_tensor->place(), root_tensor);
}
}
}
} // namespace framework
} // namespace paddle
......@@ -224,6 +224,56 @@ class HeterXpuTrainer : public TrainerBase {
std::vector<cudaEvent_t> events_;
#endif
};
class HeterBoxTrainer : public TrainerBase {
public:
HeterBoxTrainer() {}
virtual ~HeterBoxTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set);
virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place);
virtual void InitOtherEnv(const ProgramDesc& main_program);
virtual void Run();
virtual void Finalize();
virtual void RegisterHeterCallback();
virtual void DumpWork(int tid);
virtual Scope* GetWorkerScope(int thread_id);
virtual void CacheProgram(const ProgramDesc& main_program) {
new (&program_) ProgramDesc(main_program);
}
virtual std::string GetDumpPath(int tid) { return ""; }
virtual void InitDumpEnv() {}
template <typename T>
#ifdef PADDLE_WITH_CUDA
void HeterMemCpy(LoDTensor* tensor, LoDTensor* root_tensor,
const paddle::platform::Place& thread_place,
cudaStream_t stream);
#endif
void CreateThreadParam(const ProgramDesc& program, int num);
template <typename T>
void MergeToRootScope(LoDTensor* root_tensor, LoDTensor* thread_tensor);
protected:
DownpourWorkerParameter param_;
std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
std::vector<std::string> need_merge_var_names_;
float scale_datanorm_;
paddle::platform::Place place_;
ProgramDesc program_;
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
std::vector<std::shared_ptr<DeviceWorker>> workers_;
std::vector<platform::Place> places_;
// ps-gpu
std::vector<std::thread> pull_threads_;
std::vector<std::thread> threads_;
int use_ps_gpu_;
int thread_num_;
#ifdef PADDLE_WITH_CUDA
std::vector<cudaStream_t> copy_streams_;
std::vector<cudaEvent_t> events_;
#endif
};
#endif
#if defined(PADDLE_WITH_NCCL)
......
......@@ -59,6 +59,8 @@ message TrainerDesc {
optional int32 xpu_start_idx = 30;
optional int32 xpu_end_idx = 31;
optional bool use_ps_gpu = 32 [ default = false ];
// device worker parameters
optional HogwildWorkerParameter hogwild_param = 101;
optional DownpourWorkerParameter downpour_param = 103;
......
......@@ -66,6 +66,7 @@ REGISTER_TRAINER_CLASS(DistMultiTrainer);
#if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU) && \
(defined PADDLE_WITH_PSLIB)
REGISTER_TRAINER_CLASS(HeterXpuTrainer);
REGISTER_TRAINER_CLASS(HeterBoxTrainer);
#endif
#if defined(PADDLE_WITH_NCCL)
REGISTER_TRAINER_CLASS(PipelineTrainer);
......
......@@ -1368,6 +1368,8 @@ class Executor(object):
is_heter = 1
if program._fleet_opt.get("trainer", "") == "HeterXpuTrainer":
is_heter = 1
if program._fleet_opt.get("use_ps_gpu", ""):
is_heter = 1
if scope is None:
scope = global_scope()
if fetch_list is None:
......
......@@ -987,12 +987,64 @@ class DownpourOptimizer(DistributedOptimizer):
"""
raise NotImplementedError()
def _remove_collective_ops(self, program, name):
"""
colective init op should call once, so remove other call.
"""
block = program.global_block()
for ids, op in list(enumerate(block.ops)):
if op.type == name:
block._remove_op(ids)
return
def apply_gradients(self, params_grads):
"""
Currently, apply_gradients function can not be called through DistributedOptimizer
"""
raise NotImplementedError()
def get_dist_env(self):
trainer_id = int(os.getenv('PADDLE_TRAINER_ID', '0'))
trainer_endpoints = ''
current_endpoint = ''
num_trainers = 0
if os.getenv('PADDLE_TRAINER_ENDPOINTS') and os.getenv(
'PADDLE_CURRENT_ENDPOINT'):
trainer_endpoints = os.getenv('PADDLE_TRAINER_ENDPOINTS')
current_endpoint = os.getenv('PADDLE_CURRENT_ENDPOINT')
num_trainers = len(trainer_endpoints.split(','))
return {
'trainer_id': trainer_id,
'num_trainers': num_trainers,
'current_endpoint': current_endpoint,
'trainer_endpoints': trainer_endpoints
}
def _remove_collective_op_for_embedding(self, loss, table_name):
"""
find multi-sparse-table
"""
table_name = [name + "@GRAD" for name in table_name]
need_remove_op_index = []
block = loss.block.program.global_block()
collective_ops = ["c_sync_calc_stream", "c_allreduce_sum"]
for ids, op in list(enumerate(block.ops)):
if op.type in collective_ops:
if op.input("X")[0] in table_name:
need_remove_op_index.append(ids)
if op.type == "lookup_table_grad":
need_remove_op_index.append(ids)
try:
if op.output("Out")[0] in table_name:
need_remove_op_index.append(ids)
except:
pass
need_remove_op_index.sort(reverse=True)
for index in need_remove_op_index:
block._remove_op(index)
def minimize(self,
losses,
scopes=None,
......@@ -1043,5 +1095,31 @@ class DownpourOptimizer(DistributedOptimizer):
fleet._main_programs = programs
fleet._scopes = scopes
if opt_info["use_ps_gpu"]:
from paddle.fluid.transpiler.collective import SingleProcessMultiThread
# check start program
env = self.get_dist_env()
if not isinstance(losses, list):
startup_programs = [startup_programs]
for i in range(0, len(startup_programs)):
t = SingleProcessMultiThread()
start_program = startup_programs[i]
main_program = programs[i]
t.transpile(
startup_program=start_program,
main_program=main_program,
rank=env["trainer_id"],
endpoints=env["trainer_endpoints"],
current_endpoint=env['current_endpoint'],
wait_port=False)
if i > 0:
self._remove_collective_ops(start_program,
"c_comm_init_all")
for i in range(0, len(losses)):
loss = losses[i]
embedding_table = self._distributed_optimizer._find_multi_distributed_lookup_table(
[loss])
self._remove_collective_op_for_embedding(loss, embedding_table)
return [optimize_ops, param_grads]
......@@ -15,14 +15,17 @@
__all__ = ["DistributedAdam", "FLEET_GLOBAL_DICT"]
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_inputs
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_outputs
from google.protobuf import text_format
from collections import OrderedDict
import copy
from .node import DownpourWorker, DownpourServer
from . import ps_pb2 as pslib
OpRole = core.op_proto_and_checker_maker.OpRole
# this dict is for store info about pull/push sparse ops.
FLEET_GLOBAL_DICT = {
# global settings
......@@ -87,6 +90,8 @@ class DistributedAdam(DistributedOptimizerImplBase):
self.supported_embedding_grad_types = [
"lookup_table_grad", "push_sparse", "push_sparse_v2"
]
op_maker = core.op_proto_and_checker_maker
self.op_role_key = op_maker.kOpRoleAttrName()
def _find_distributed_lookup_table_inputs(self, program, table_names):
"""
......@@ -145,6 +150,26 @@ class DistributedAdam(DistributedOptimizerImplBase):
[local_vars[name] for name in op.input("Out@GRAD")])
return grads_dict
def _is_optimizer_op(self, op):
return self.op_role_key in op.attr_names and \
int(op.all_attrs()[self.op_role_key]) & int(OpRole.Optimize)
def _remove_optimize_op_for_embedding(self, loss, table_name):
"""
find multi-sparse-table
"""
table_name = [name + "@GRAD" for name in table_name]
need_remove_op_index = []
block = loss.block.program.global_block()
for ids, op in list(enumerate(block.ops)):
if self._is_optimizer_op(op):
if op.input("Grad")[0] in table_name:
need_remove_op_index.append(ids)
need_remove_op_index.sort(reverse=True)
for index in need_remove_op_index:
block._remove_op(index)
def _find_multi_distributed_lookup_table(self, losses):
"""
find multi-sparse-table
......@@ -314,7 +339,8 @@ class DistributedAdam(DistributedOptimizerImplBase):
sparse_table_to_index = OrderedDict()
sparse_table_index = 0
for loss in losses:
for num in range(len(losses)):
loss = losses[num]
prog_id = str(id(loss.block.program))
# param_grads of program
params_grads = sorted(
......@@ -322,6 +348,18 @@ class DistributedAdam(DistributedOptimizerImplBase):
no_grad_set),
key=lambda x: x[0].name)
flag_use_ps_gpu = strategy.get("use_ps_gpu", False)
if flag_use_ps_gpu:
if not isinstance(startup_program, list):
startup_program = [startup_program]
optimizer = copy.deepcopy(self._optimizer)
optimize_ops = optimizer.apply_optimize(
loss,
startup_program=startup_program[num],
params_grads=params_grads)
embedding_table = self._find_multi_distributed_lookup_table(
[loss])
self._remove_optimize_op_for_embedding(loss, embedding_table)
# has condition_block op means multi-task
flag_multi_task = self._has_conditional_block(loss)
if flag_multi_task:
......@@ -725,6 +763,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "")
opt_info["dump_param"] = strategy.get("dump_param", [])
opt_info["worker_places"] = strategy.get("worker_places", [])
opt_info["use_ps_gpu"] = strategy.get("use_ps_gpu", False)
if server._server.downpour_server_param.downpour_table_param[
0].accessor.accessor_class in [
"DownpourCtrAccessor", "DownpourCtrDoubleAccessor",
......
......@@ -23,18 +23,18 @@ import os
import sys
import time
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.log_helper import get_logger
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet as fleet_pslib
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet as fleet_transpiler
from paddle.distributed.fleet.utils.fs import LocalFS, HDFSClient
from . import utils
OpRole = core.op_proto_and_checker_maker.OpRole
__all__ = ["FleetUtil"]
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
fleet = fleet_pslib
fleet = None
class FleetUtil(object):
......@@ -52,9 +52,13 @@ class FleetUtil(object):
def __init__(self, mode="pslib"):
global fleet
op_maker = core.op_proto_and_checker_maker
self.op_role_key = op_maker.kOpRoleAttrName()
if mode == "pslib":
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet as fleet_pslib
fleet = fleet_pslib
elif mode == "transpiler":
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet as fleet_transpiler
fleet = fleet_transpiler
else:
raise ValueError(
......@@ -1616,20 +1620,26 @@ class FleetUtil(object):
program = utils.load_program(prog_path, is_text)
utils.parse_program(program, output_dir)
def _is_optimizer_op(self, op):
return self.op_role_key in op.attr_names and \
int(op.all_attrs()[self.op_role_key]) & int(OpRole.Optimize)
def split_program_by_device(self, program):
ops_list = []
type_list = []
pre = None
type_cpu = "cpu"
for op in program.global_block().ops:
if self._is_optimizer_op(op):
break
if op.has_attr("op_device"):
if pre is None or pre != op.attr("op_device"):
cur_attr = op.attr("op_device") if op.attr(
"op_device") != "" else type_cpu
if pre is None or pre != cur_attr:
ops_list.append([])
type_list.append(
op.attr("op_device")
if op.attr("op_device") != "" else type_cpu)
type_list.append(cur_attr)
ops_list[-1].append(op)
pre = op.attr("op_device")
pre = cur_attr
l = len(type_list)
i = 0
type_heter = None
......
......@@ -79,7 +79,6 @@ class HDFSClient(FS):
time_out=5 * 60 * 1000, #ms
sleep_inter=1000): #ms
# Raise exception if JAVA_HOME not exists.
java_home = os.environ["JAVA_HOME"]
self.pre_commands = []
hadoop_bin = '%s/bin/hadoop' % hadoop_home
......
......@@ -17,7 +17,7 @@ import sys
import os
__all__ = [
'TrainerDesc', 'MultiTrainer', 'DistMultiTrainer', 'PipelineTrainer',
'HeterXpuTrainer'
'HeterXpuTrainer', 'HeterBoxWorker'
]
......@@ -166,6 +166,9 @@ class TrainerDesc(object):
for place in worker_places:
self.proto_desc.worker_places.append(place)
def _set_use_ps_gpu(self, use_ps_gpu=False):
self.proto_desc.use_ps_gpu = use_ps_gpu
def _set_thread_barrier(self, thread_barrier):
self.proto_desc.thread_barrier = thread_barrier
......@@ -340,6 +343,30 @@ class HeterXpuTrainer(TrainerDesc):
self._device_worker._gen_worker_desc(self.proto_desc)
class HeterBoxTrainer(TrainerDesc):
"""
Implement of HeterBoxTrainer.
It's for Distributed training.
"""
def __init__(self):
super(HeterBoxTrainer, self).__init__()
pass
def _set_program(self, program):
super(HeterBoxTrainer, self)._set_program(program)
self._program = program
def _gen_trainer_desc(self):
super(HeterBoxTrainer, self)._gen_trainer_desc()
self.proto_desc.class_name = "HeterBoxTrainer"
if self._program == None:
raise RuntimeError("None Program")
self._device_worker._set_infer(self._infer)
self._device_worker._set_program(self._program)
self._device_worker._gen_worker_desc(self.proto_desc)
class PipelineTrainer(TrainerDesc):
"""
Implement of PipelineTrainer.
......
......@@ -22,7 +22,7 @@ from paddle.fluid.log_helper import get_logger
local_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer, HeterXpuTrainer
from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer, HeterXpuTrainer, HeterBoxTrainer
from .device_worker import Hogwild, DownpourSGD, Section, DownpourSGDOPT
from .framework import Variable
from multiprocessing import Process, Manager
......@@ -77,6 +77,8 @@ class TrainerFactory(object):
trainer._set_dump_param(opt_info["dump_param"])
if opt_info.get("worker_places") is not None:
trainer._set_worker_places(opt_info["worker_places"])
if opt_info.get("use_ps_gpu") is not None:
trainer._set_use_ps_gpu(opt_info["use_ps_gpu"])
if opt_info.get("enable_random_dump") is not None:
trainer._set_enable_random_dump(opt_info[
"enable_random_dump"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册