提交 6c16858f 编写于 作者: S sandyhouse

update, test=develop

上级 a6344af2
...@@ -441,14 +441,14 @@ class SectionWorker : public DeviceWorker { ...@@ -441,14 +441,14 @@ class SectionWorker : public DeviceWorker {
void SetSkipVars(const std::vector<std::string>& skip_vars) { void SetSkipVars(const std::vector<std::string>& skip_vars) {
skip_vars_ = skip_vars; skip_vars_ = skip_vars;
} }
void SetStartCpuCoreId(int id) { cpu_id_ = id; }
// static void ResetBatchId() { batch_id_ = 0; } // static void ResetBatchId() { batch_id_ = 0; }
static std::atomic<int> cpu_id_;
protected: protected:
void AutoSetCPUAffinity(bool reuse); void AutoSetCPUAffinity(bool reuse);
int section_id_; int section_id_;
int thread_id_; int thread_id_;
int cpu_id_;
int num_microbatches_; int num_microbatches_;
std::vector<Scope*> microbatch_scopes_; std::vector<Scope*> microbatch_scopes_;
std::vector<std::string> skip_vars_; std::vector<std::string> skip_vars_;
......
...@@ -34,8 +34,8 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -34,8 +34,8 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
ParseDumpConfig(trainer_desc); ParseDumpConfig(trainer_desc);
// get filelist from trainer_desc here // get filelist from trainer_desc here
// const std::vector<paddle::framework::DataFeed*> readers = // const std::vector<paddle::framework::DataFeed*> readers =
// VLOG(3) << "Number of program sections: " << section_num_;
// dataset->GetReaders(); // dataset->GetReaders();
// VLOG(3) << "Number of program sections: " << section_num_;
// VLOG(3) << "readers num: " << readers.size(); // VLOG(3) << "readers num: " << readers.size();
// int num_readers = readers.size(); // int num_readers = readers.size();
// PADDLE_ENFORCE_EQ(num_readers, 1, // PADDLE_ENFORCE_EQ(num_readers, 1,
...@@ -108,6 +108,7 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -108,6 +108,7 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
this_worker->SetPlace(place_); this_worker->SetPlace(place_);
this_worker->Initialize(trainer_desc); this_worker->Initialize(trainer_desc);
this_worker->SetMicrobatchNum(num_microbatches_); this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetStartCpuCoreId(start_cpu_core_id_);
// set debug here // set debug here
SetDebug(trainer_desc.debug()); SetDebug(trainer_desc.debug());
...@@ -207,7 +208,7 @@ void PipelineTrainer::CopyParameters(int microbatch_id, ...@@ -207,7 +208,7 @@ void PipelineTrainer::CopyParameters(int microbatch_id,
} else if (!var->Persistable() && !is_param_grad) { } else if (!var->Persistable() && !is_param_grad) {
auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name()); auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name());
VLOG(3) << "Create variable " << var->Name() << " microbatch " VLOG(3) << "Create variable " << var->Name() << " microbatch "
<< ", which pointer is " << ptr; << microbatch_id << ", which pointer is " << ptr;
InitializeVariable(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
} }
} }
...@@ -235,39 +236,40 @@ void PipelineTrainer::CopyParameters(int microbatch_id, ...@@ -235,39 +236,40 @@ void PipelineTrainer::CopyParameters(int microbatch_id,
// } // }
// } // }
void PipelineTrainer::GetSkipVars(const ProgramDesc& program) { // void PipelineTrainer::GetSkipVars(const ProgramDesc& program) {
auto& global_block = program.Block(0); // auto& global_block = program.Block(0);
for (auto& op : global_block.AllOps()) { // for (auto& op : global_block.AllOps()) {
if (op->Type() != "c_send") { // if (op->Type() != "c_send") {
continue; // continue;
} // }
auto input_arg_names = op->InputArgumentNames(); // auto input_arg_names = op->InputArgumentNames();
PADDLE_ENFORCE_EQ(input_arg_names.size(), 1, // PADDLE_ENFORCE_EQ(input_arg_names.size(), 1,
platform::errors::InvalidArgument( // platform::errors::InvalidArgument(
"Number of input arguments for c_send op must be 1, " // "Number of input arguments for c_send op must be 1,
"but the value given is %d.", // "
input_arg_names.size())); // "but the value given is %d.",
std::string input_arg_name = input_arg_names[0]; // input_arg_names.size()));
if (input_arg_name.rfind("@GRAD") != input_arg_name.size() - 5) { // std::string input_arg_name = input_arg_names[0];
skip_vars_.emplace_back(input_arg_name); // if (input_arg_name.rfind("@GRAD") != input_arg_name.size() - 5) {
VLOG(3) << "add skip var name: " << input_arg_name; // skip_vars_.emplace_back(input_arg_name);
} // VLOG(3) << "add skip var name: " << input_arg_name;
} // }
} // }
// }
void PipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program, void PipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) { const platform::Place& place) {
PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument( PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument(
"root_scope_ can not be nullptr")); "root_scope_ can not be nullptr"));
auto start_cpu_id = trainer_desc_.section_param().start_cpu_core_id(); // auto start_cpu_id = trainer_desc_.section_param().start_cpu_core_id();
SectionWorker::cpu_id_.store(start_cpu_id); // SectionWorker::cpu_id_.store(start_cpu_id);
// minibatch_scopes_.resize(section_num_); // minibatch_scopes_.resize(section_num_);
// microbatch_scopes_.resize(section_num_); // microbatch_scopes_.resize(section_num_);
// minibatch_scopes_.resize(1); // minibatch_scopes_.resize(1);
microbatch_scopes_.resize(num_microbatches_); microbatch_scopes_.resize(num_microbatches_);
// skip_vars_.resize(section_num_); // skip_vars_.resize(section_num_);
VLOG(3) << "Init ScopeQueues and create all scopes"; VLOG(3) << "Create minibatch and microbatch scopes...";
// for (int i = 0; i < section_num_; ++i) { // for (int i = 0; i < section_num_; ++i) {
minibatch_scope_ = &root_scope_->NewScope(); minibatch_scope_ = &root_scope_->NewScope();
std::shared_ptr<framework::ProgramDesc> program; std::shared_ptr<framework::ProgramDesc> program;
...@@ -282,7 +284,7 @@ void PipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program, ...@@ -282,7 +284,7 @@ void PipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program,
CopyParameters(j, *program, place_); CopyParameters(j, *program, place_);
} }
// GetSkipVars(i, *program); // GetSkipVars(i, *program);
GetSkipVars(*program); // GetSkipVars(*program);
// } // }
// for (int i = 0; i < section_num_; ++i) { // for (int i = 0; i < section_num_; ++i) {
......
...@@ -30,7 +30,7 @@ limitations under the License. */ ...@@ -30,7 +30,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::atomic<int> SectionWorker::cpu_id_(0); // std::atomic<int> SectionWorker::cpu_id_(0);
// std::mutex SectionWorker::thread_mutex; // std::mutex SectionWorker::thread_mutex;
// std::mutex SectionWorker::cout_mutex; // std::mutex SectionWorker::cout_mutex;
// std::condition_variable SectionWorker::thread_condition; // std::condition_variable SectionWorker::thread_condition;
...@@ -48,18 +48,20 @@ void SectionWorker::Initialize(const TrainerDesc& desc) { ...@@ -48,18 +48,20 @@ void SectionWorker::Initialize(const TrainerDesc& desc) {
} }
void SectionWorker::AutoSetCPUAffinity(bool reuse) { void SectionWorker::AutoSetCPUAffinity(bool reuse) {
int thread_cpu_id = cpu_id_.fetch_add(1); // int thread_cpu_id = cpu_id_.fetch_add(1);
unsigned concurrency_cap = std::thread::hardware_concurrency(); unsigned concurrency_cap = std::thread::hardware_concurrency();
unsigned proc = thread_cpu_id; // unsigned proc = thread_cpu_id;
unsigned proc = cpu_id_;
if (proc >= concurrency_cap) { if (proc >= concurrency_cap) {
if (reuse) { if (reuse) {
proc %= concurrency_cap; proc %= concurrency_cap;
} else { } else {
LOG(INFO) << "All " << concurrency_cap LOG(INFO) << "All " << concurrency_cap
<< " CPUs have been set affinities. Fail to set " << " CPUs have been set affinities. Fail to set " << cpu_id_
<< thread_cpu_id << "th thread"; << "th thread.";
// << thread_cpu_id << "th thread";
return; return;
} }
} }
...@@ -78,7 +80,8 @@ void SectionWorker::AutoSetCPUAffinity(bool reuse) { ...@@ -78,7 +80,8 @@ void SectionWorker::AutoSetCPUAffinity(bool reuse) {
(0 == CPU_ISSET(proc, &mask))) { (0 == CPU_ISSET(proc, &mask))) {
LOG(WARNING) << "Fail to set thread affinity to CPU " << proc; LOG(WARNING) << "Fail to set thread affinity to CPU " << proc;
} }
VLOG(3) << "Set " << thread_cpu_id << "th thread affinity to CPU " << proc; // VLOG(3) << "Set " << thread_cpu_id << "th thread affinity to CPU " << proc;
VLOG(3) << "Set " << cpu_id_ << "th thread affinity to CPU " << proc;
} }
void SectionWorker::TrainFiles() { void SectionWorker::TrainFiles() {
...@@ -141,7 +144,8 @@ void SectionWorker::TrainFiles() { ...@@ -141,7 +144,8 @@ void SectionWorker::TrainFiles() {
VLOG(3) << "thread completed."; VLOG(3) << "thread completed.";
// VLOG(3) << "called notify all"; // VLOG(3) << "called notify all";
// thread_condition.notify_all(); // thread_condition.notify_all();
VLOG(0) << "EOF encountered"; VLOG(3) << "EOF encountered";
// throw platform::EOFException();
break; break;
} }
} }
...@@ -191,8 +195,8 @@ void SectionWorker::TrainFilesWithProfiler() { ...@@ -191,8 +195,8 @@ void SectionWorker::TrainFilesWithProfiler() {
platform::Timer batch_timer; platform::Timer batch_timer;
platform::Timer timeline; platform::Timer timeline;
std::vector<double> op_total_time;
std::vector<std::string> op_name; std::vector<std::string> op_name;
std::vector<double> op_total_time;
std::vector<double> op_max_time; std::vector<double> op_max_time;
std::vector<double> op_min_time; std::vector<double> op_min_time;
std::vector<uint64_t> op_count; std::vector<uint64_t> op_count;
...@@ -204,6 +208,7 @@ void SectionWorker::TrainFilesWithProfiler() { ...@@ -204,6 +208,7 @@ void SectionWorker::TrainFilesWithProfiler() {
op_min_time.resize(ops_.size()); op_min_time.resize(ops_.size());
for (size_t i = 0; i < op_min_time.size(); ++i) { for (size_t i = 0; i < op_min_time.size(); ++i) {
op_min_time[i] = DBL_MAX; op_min_time[i] = DBL_MAX;
op_max_time[i] = 0.0;
} }
op_count.resize(ops_.size()); op_count.resize(ops_.size());
...@@ -235,7 +240,7 @@ void SectionWorker::TrainFilesWithProfiler() { ...@@ -235,7 +240,7 @@ void SectionWorker::TrainFilesWithProfiler() {
struct timeval micro_end; struct timeval micro_end;
// Start a minibatch. // Start a minibatch.
batch_timer.Start(); batch_timer.Start();
int real_microbatch_num = 0; // int real_microbatch_num = 0;
for (int i = 0; i < num_microbatches_; ++i) { for (int i = 0; i < num_microbatches_; ++i) {
try { try {
int op_idx = 0; int op_idx = 0;
...@@ -253,8 +258,9 @@ void SectionWorker::TrainFilesWithProfiler() { ...@@ -253,8 +258,9 @@ void SectionWorker::TrainFilesWithProfiler() {
op_role == (static_cast<int>(OpRole::kForward) | op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss)); static_cast<int>(OpRole::kLoss));
if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) { if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ // VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
<< " for scope " << i; // << " for scope " << i;
VLOG(3) << "running an op " << op->Type() << " for scope " << i;
timeline.Start(); timeline.Start();
op->Run(*microbatch_scopes_[i], place_); op->Run(*microbatch_scopes_[i], place_);
if (gc) { if (gc) {
...@@ -365,11 +371,11 @@ void SectionWorker::TrainFilesWithProfiler() { ...@@ -365,11 +371,11 @@ void SectionWorker::TrainFilesWithProfiler() {
} }
} }
dev_ctx_->Wait(); dev_ctx_->Wait();
if (real_microbatch_num == 0) { // if (real_microbatch_num == 0) {
batch_timer.Pause(); // batch_timer.Pause();
VLOG(0) << "batch time: " << batch_timer.ElapsedUS(); // VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
return; // return;
} // }
// update pass // update pass
int op_idx = 0; int op_idx = 0;
gettimeofday(&micro_start, NULL); gettimeofday(&micro_start, NULL);
......
...@@ -84,7 +84,7 @@ message DownpourWorkerParameter { ...@@ -84,7 +84,7 @@ message DownpourWorkerParameter {
} }
message SectionWorkerParameter { message SectionWorkerParameter {
SectionConfig section_config = 1; optional SectionConfig section_config = 1;
optional int32 queue_size = 2 [ default = 1 ]; optional int32 queue_size = 2 [ default = 1 ];
optional int64 sync_steps = 3 [ default = 1 ]; optional int64 sync_steps = 3 [ default = 1 ];
optional int32 start_cpu_core_id = 4 [ default = 1 ]; optional int32 start_cpu_core_id = 4 [ default = 1 ];
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/collective/c_recv_op.h" #include "paddle/fluid/operators/collective/c_recv_op.h"
#include <string>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -33,14 +34,36 @@ class CRecvOp : public framework::OperatorWithKernel { ...@@ -33,14 +34,36 @@ class CRecvOp : public framework::OperatorWithKernel {
ring_id, 0, ring_id, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The ring_id (%d) for c_send_op must be non-negative.", ring_id)); "The ring_id (%d) for c_send_op must be non-negative.", ring_id));
auto out_shape = ctx->Attrs().Get<std::vector<int>>("out_shape");
PADDLE_ENFORCE_GE(out_shape.size(), 1,
platform::errors::InvalidArgument(
"The size of the output shape must be greater than 0 "
"but the value given is %d.",
out_shape.size()));
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto out = ctx.Output<framework::LoDTensor>("Out"); VLOG(0) << "wow1";
auto dtype = out->type(); std::string dtype = ctx.Attr<std::string>("dtype");
return framework::OpKernelType(dtype, ctx.GetPlace()); framework::proto::VarType::Type type;
if (dtype == "fp32") {
type = framework::proto::VarType::FP32;
} else if (dtype == "fp64") {
type = framework::proto::VarType::FP64;
} else if (dtype == "fp16") {
type = framework::proto::VarType::FP16;
} else if (dtype == "int32") {
type = framework::proto::VarType::INT32;
} else if (dtype == "int64") {
type = framework::proto::VarType::INT64;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unknown data type %s for c_recv op.", dtype));
}
VLOG(0) << "wow2";
return framework::OpKernelType(type, ctx.GetPlace());
// OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace()); // OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace());
} }
}; };
...@@ -52,6 +75,11 @@ class CRecvOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -52,6 +75,11 @@ class CRecvOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.") AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("peer", "(int default 0) rank id for sender.").SetDefault(0); AddAttr<int>("peer", "(int default 0) rank id for sender.").SetDefault(0);
AddAttr<std::string>("dtype",
"(std::string default fp32) data type of tensor.")
.SetDefault("fp32");
AddAttr<std::vector<int>>("out_shape", "shape of the output tensor.")
.SetDefault(std::vector<int>());
AddAttr<bool>( AddAttr<bool>(
"use_calc_stream", "use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.") "(bool default false) eject CUDA operations to calculation stream.")
......
...@@ -27,13 +27,20 @@ class CRecvOpCUDAKernel : public framework::OpKernel<T> { ...@@ -27,13 +27,20 @@ class CRecvOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
VLOG(0) << "here1";
auto out = ctx.Output<framework::LoDTensor>("Out"); auto out = ctx.Output<framework::LoDTensor>("Out");
int numel = out->numel(); VLOG(0) << "here2";
ncclDataType_t dtype = platform::ToNCCLDataType(out->type()); auto out_shape = ctx.Attr<std::vector<int>>("out_shape");
auto out_dims = paddle::framework::make_ddim(out_shape);
int rid = ctx.Attr<int>("ring_id"); int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place); auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
out->mutable_data<T>(out_dims, place);
VLOG(0) << "out_dims:" << out_dims;
ncclDataType_t dtype = platform::ToNCCLDataType(out->type());
int numel = out->numel();
VLOG(0) << "numel:" << numel;
cudaStream_t stream = nullptr; cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
...@@ -49,9 +56,10 @@ class CRecvOpCUDAKernel : public framework::OpKernel<T> { ...@@ -49,9 +56,10 @@ class CRecvOpCUDAKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument("The value of peer (%d) you set must " platform::errors::InvalidArgument("The value of peer (%d) you set must "
"be less than comm->nranks (%d).", "be less than comm->nranks (%d).",
peer, comm->nranks())); peer, comm->nranks()));
VLOG(0) << "here3";
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv(
out->data<T>(), numel, dtype, peer, comm->comm(), stream)); out->data<T>(), numel, dtype, peer, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " recv " VLOG(0) << "rank " << comm->rank() << " recv "
<< framework::product(out->dims()) << " from " << peer; << framework::product(out->dims()) << " from " << peer;
#else #else
PADDLE_THROW( PADDLE_THROW(
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
from __future__ import print_function from __future__ import print_function
from __future__ import division
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
...@@ -21,9 +22,50 @@ from .meta_optimizer_base import MetaOptimizerBase ...@@ -21,9 +22,50 @@ from .meta_optimizer_base import MetaOptimizerBase
from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_update_op, is_loss_grad_op, is_backward_op, is_optimizer_op from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_update_op, is_loss_grad_op, is_backward_op, is_optimizer_op
class PipelineHelper(CollectiveHelper): def _get_node_num(endpoints):
def __init__(self, role_maker, nrings=1, wait_port='6174'): ss = set()
super(PipelineHelper, self).__init__(role_maker, nrings, wait_port) for ep in endpoints:
ip = ep.split(":")[0].strip()
if ip not in ss:
ss.add(ip)
return len(ss)
class PipelineHelper(object):
def __init__(self, role_maker, wait_port='6174'):
self.wait_port = wait_port
self.role_maker = role_maker
def update_startup_program(self, startup_program=None):
self.startup_program = startup_program
if startup_program is None:
self.startup_program = fluid.default_startup_program()
endpoints = self.role_maker.get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker.worker_index()]
node_num = _get_node_num(endpoints)
assert len(endpoints) % node_num == 0
gpus_per_node = len(endpoints) // node_num
# Create a global ring for all gpus
print("current_endpoint:", current_endpoint)
print("endpoints:", endpoints)
print("rank:", self.role_maker.worker_index())
self._init_communicator(
self.startup_program, current_endpoint, endpoints,
self.role_maker.worker_index(), 0, self.wait_port)
if node_num == 1: return
# Create rings for gpus with the same gpu id
eps = []
local_rank = self.role_maker.worker_index() % gpus_per_node
ring_id = local_rank + 1
for i in range(node_num):
eps.append(endpoints[i * gpus_per_node + local_rank])
temp_rank = self.role_maker.worker_index() // node_num
self._init_communicator(self.startup_program, current_endpoint, eps,
temp_rank, ring_id, self.wait_port)
self._broadcast_params(ring_id)
def _init_communicator(self, program, current_endpoint, endpoints, rank, def _init_communicator(self, program, current_endpoint, endpoints, rank,
ring_id, wait_port): ring_id, wait_port):
...@@ -46,9 +88,8 @@ class PipelineHelper(CollectiveHelper): ...@@ -46,9 +88,8 @@ class PipelineHelper(CollectiveHelper):
'rank': rank, 'rank': rank,
'endpoint': current_endpoint, 'endpoint': current_endpoint,
'other_endpoints': other_endpoints, 'other_endpoints': other_endpoints,
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward,
}) })
block.append_op( block.append_op(
type='c_comm_init', type='c_comm_init',
inputs={'X': nccl_id_var}, inputs={'X': nccl_id_var},
...@@ -58,12 +99,10 @@ class PipelineHelper(CollectiveHelper): ...@@ -58,12 +99,10 @@ class PipelineHelper(CollectiveHelper):
'rank': rank, 'rank': rank,
'ring_id': ring_id, 'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward, OP_ROLE_KEY: OpRole.Forward,
'device_id': OpRole.Forward
}) })
def _broadcast_params(self): def _broadcast_params(self, ring_id):
block = self.startup_program.global_block() block = self.startup_program.global_block()
ring_id = 0
for param in block.iter_parameters(): for param in block.iter_parameters():
if param.is_distributed: if param.is_distributed:
continue continue
...@@ -78,13 +117,12 @@ class PipelineHelper(CollectiveHelper): ...@@ -78,13 +117,12 @@ class PipelineHelper(CollectiveHelper):
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward
}) })
for ring_id in range(self.nrings): block.append_op(
block.append_op( type='c_sync_comm_stream',
type='c_sync_comm_stream', inputs={'X': param},
inputs={'X': param}, outputs={'Out': param},
outputs={'Out': param}, attrs={'ring_id': ring_id,
attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Forward})
OP_ROLE_KEY: OpRole.Forward})
class PipelineOptimizer(MetaOptimizerBase): class PipelineOptimizer(MetaOptimizerBase):
...@@ -100,7 +138,12 @@ class PipelineOptimizer(MetaOptimizerBase): ...@@ -100,7 +138,12 @@ class PipelineOptimizer(MetaOptimizerBase):
super(PipelineOptimizer, self)._set_basic_info( super(PipelineOptimizer, self)._set_basic_info(
loss, role_maker, user_defined_optimizer, user_defined_strategy) loss, role_maker, user_defined_optimizer, user_defined_strategy)
num_microbatches = user_defined_strategy.pipeline_configs['micro_batch'] num_microbatches = user_defined_strategy.pipeline_configs['micro_batch']
self.wrapped_opt = PO(self.inner_opt, num_microbatches=num_microbatches) endpoints = role_maker.get_trainer_endpoints()
current_endpoint = endpoints[role_maker.worker_index()]
self.local_rank = self._get_local_rank(current_endpoint, endpoints)
self.wrapped_opt = PO(self.inner_opt,
num_microbatches=num_microbatches,
start_cpu_core_id=self.local_rank)
def _can_apply(self): def _can_apply(self):
if self.user_defined_strategy.pipeline == True: if self.user_defined_strategy.pipeline == True:
...@@ -111,23 +154,37 @@ class PipelineOptimizer(MetaOptimizerBase): ...@@ -111,23 +154,37 @@ class PipelineOptimizer(MetaOptimizerBase):
dist_strategy.pipeline = False dist_strategy.pipeline = False
dist_strategy.pipeline_configs = {} dist_strategy.pipeline_configs = {}
def _get_local_rank(self, current_endpoint, endpoints):
cur_node_endpoints = []
cur_ip = current_endpoint.split(':')[0].strip()
for ep in endpoints:
if cur_ip == ep.split(':')[0].strip():
cur_node_endpoints.append(ep)
return cur_node_endpoints.index(current_endpoint)
def minimize_impl(self, def minimize_impl(self,
loss, loss,
startup_program=None, startup_program=None,
parameter_list=None, parameter_list=None,
no_grad_set=None): no_grad_set=None):
optimize_ops, params_grads, prog_list = \
self.wrapped_opt.minimize(loss, startup_program,
parameter_list, no_grad_set)
if self.role_maker.worker_num() == 1:
return optimize_ops, params_grads
endpoints = self.role_maker.get_trainer_endpoints() endpoints = self.role_maker.get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker.worker_index()] current_endpoint = endpoints[self.role_maker.worker_index()]
node_num = _get_node_num(endpoints)
gpus_per_node = len(endpoints) // node_num
self.startup_program = startup_program self.startup_program = startup_program
self.local_rank = self._get_local_rank(current_endpoint, endpoints)
if startup_program is None: if startup_program is None:
self.startup_program = fluid.default_startup_program() self.startup_program = fluid.default_startup_program()
if self.role_maker.worker_num() == 1:
return self.inner_opt.minimize(loss, startup_program,
parameter_list, no_grad_set)
loss.block.program._pipeline_opt = dict()
loss.block.program._pipeline_opt['local_rank'] = self.local_rank
optimize_ops, params_grads, prog_list = \
self.wrapped_opt.minimize(loss, startup_program,
parameter_list, no_grad_set)
assert prog_list assert prog_list
self.main_program_list = prog_list self.main_program_list = prog_list
self.main_program = loss.block.program self.main_program = loss.block.program
...@@ -139,24 +196,24 @@ class PipelineOptimizer(MetaOptimizerBase): ...@@ -139,24 +196,24 @@ class PipelineOptimizer(MetaOptimizerBase):
self.endpoints = endpoints self.endpoints = endpoints
self.current_endpoint = current_endpoint self.current_endpoint = current_endpoint
pipeline_helper = PipelineHelper(self.role_maker, nrings=self.nrings) pipeline_helper = PipelineHelper(self.role_maker)
pipeline_helper.update_startup_program(self.startup_program) pipeline_helper.update_startup_program(self.startup_program)
self._transpile_main_program() self._transpile_main_program(loss, node_num, gpus_per_node)
return optimize_ops, params_grads return optimize_ops, params_grads
def _transpile_main_program(self): def _transpile_main_program(self, loss, node_num, gpus_per_node):
self._insert_loss_grad_ops() self._insert_loss_grad_ops(loss, gpus_per_node, node_num)
for ring_id in range(self.nrings): for ring_id in range(1, node_num + 1):
self._insert_allreduce_ops(ring_id) self._insert_allreduce_ops(ring_id)
def _insert_loss_grad_ops(self): def _insert_loss_grad_ops(self, loss, gpus_per_node, node_num):
""" """
In order to keep the learning rate consistent in different numbers of In order to keep the learning rate consistent in different numbers of
training workers, we scale the loss grad by the number of workers training workers, we scale the loss grad by the number of workers
""" """
block = self.main_program_list[self.nrings - 1]['program'].global_block( block = self.main_program_list[gpus_per_node - 1][
) 'program'].global_block()
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if is_loss_grad_op(op): if is_loss_grad_op(op):
loss_grad_var = block.vars[op.output_arg_names[0]] loss_grad_var = block.vars[op.output_arg_names[0]]
...@@ -166,12 +223,12 @@ class PipelineOptimizer(MetaOptimizerBase): ...@@ -166,12 +223,12 @@ class PipelineOptimizer(MetaOptimizerBase):
inputs={'X': loss_grad_var}, inputs={'X': loss_grad_var},
outputs={'Out': loss_grad_var}, outputs={'Out': loss_grad_var},
attrs={ attrs={
'scale': 1.0 / self.nranks, 'scale': 1.0 / node_num,
OP_ROLE_KEY: OpRole.Backward OP_ROLE_KEY: OpRole.Backward
}) })
def _insert_allreduce_ops(self, ring_id): def _insert_allreduce_ops(self, ring_id):
block = self.main_program_list[ring_id]['program'].global_block() block = self.main_program_list[ring_id - 1]['program'].global_block()
origin_block = self.main_program.global_block() origin_block = self.main_program.global_block()
grad = None grad = None
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
......
...@@ -406,25 +406,44 @@ class Section(DeviceWorker): ...@@ -406,25 +406,44 @@ class Section(DeviceWorker):
section_param = trainer_desc.section_param section_param = trainer_desc.section_param
section_param.num_microbatches = pipeline_opt["num_microbatches"] section_param.num_microbatches = pipeline_opt["num_microbatches"]
section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"] section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"]
for i, program in enumerate(pipeline_opt["section_program_list"]): cfg = section_param.section_config
cfg = section_param.section_config.add() program = pipeline_opt["section_program"]
cfg.program_desc.ParseFromString(program["program"]._get_desc() cfg.program_desc.ParseFromString(program["program"]._get_desc()
.serialize_to_string()) .serialize_to_string())
# TODO: why does not work # TODO: why does not work
# cfg.program_desc.CopyFrom(program.program._get_desc()) # cfg.program_desc.CopyFrom(program.program._get_desc())
place = pipeline_opt["place_list"][i] place = pipeline_opt["place"]
place_id = pipeline_opt["place_id_list"][i] place_id = pipeline_opt["place_id"]
if isinstance(place, core.CPUPlace): if isinstance(place, core.CPUPlace):
cfg.place = cfg.CPUPlace cfg.place = cfg.CPUPlace
elif isinstance(place, core.CUDAPlace): elif isinstance(place, core.CUDAPlace):
cfg.place = cfg.CUDAPlace cfg.place = cfg.CUDAPlace
elif isinstance(place, core.CUDAPinnedPlace): elif isinstance(place, core.CUDAPinnedPlace):
cfg.place = cfg.CUDAPinnedPlace cfg.place = cfg.CUDAPinnedPlace
else: else:
raise NotImplementedError( raise NotImplementedError(
"SectionWorker only supports CPUPlace, CUDAPlace and CUDAPinnedPlace now." "SectionWorker only supports CPUPlace, CUDAPlace and CUDAPinnedPlace now."
) )
cfg.place_id = place_id cfg.place_id = place_id
# for i, program in enumerate(pipeline_opt["section_program_list"]):
# cfg = section_param.section_config.add()
# cfg.program_desc.ParseFromString(program["program"]._get_desc()
# .serialize_to_string())
# # TODO: why does not work
# # cfg.program_desc.CopyFrom(program.program._get_desc())
# place = pipeline_opt["place_list"][i]
# place_id = pipeline_opt["place_id_list"][i]
# if isinstance(place, core.CPUPlace):
# cfg.place = cfg.CPUPlace
# elif isinstance(place, core.CUDAPlace):
# cfg.place = cfg.CUDAPlace
# elif isinstance(place, core.CUDAPinnedPlace):
# cfg.place = cfg.CUDAPinnedPlace
# else:
# raise NotImplementedError(
# "SectionWorker only supports CPUPlace, CUDAPlace and CUDAPinnedPlace now."
# )
# cfg.place_id = place_id
class DeviceWorkerFactory(object): class DeviceWorkerFactory(object):
......
...@@ -3818,6 +3818,24 @@ class PipelineOptimizer(object): ...@@ -3818,6 +3818,24 @@ class PipelineOptimizer(object):
return programs return programs
def _split_startup_program(self, startup_program, local_rank):
block = startup_program.block(0)
new_startup_program = Program()
for op in block.ops:
device = op.attr(self._op_device_key)
if device:
device_index = int(device.split(":")[1])
else:
device_index = 0
if device_index != local_rank: continue
op_role = op.attr(self._op_role_key)
op_desc = op.desc
ap_op = new_startup_program.block(0).desc.append_op()
ap_op.copy_from(op_desc)
ap_op._set_attr(self._op_device_key, device)
self._create_vars(new_startup_program.block(0), startup_program)
return new_startup_program
def _find_post_op(self, ops, cur_op, var_name): def _find_post_op(self, ops, cur_op, var_name):
""" """
Find the real post op that has variable named var_name as input. Find the real post op that has variable named var_name as input.
...@@ -3933,6 +3951,7 @@ class PipelineOptimizer(object): ...@@ -3933,6 +3951,7 @@ class PipelineOptimizer(object):
if op.type == "read": if op.type == "read":
break break
first_dev_spec = devices[0] first_dev_spec = devices[0]
first_dev_index = int(first_dev_spec.split(':')[1])
for var_name in data_devices_map.keys(): for var_name in data_devices_map.keys():
for device in data_devices_map[var_name]: for device in data_devices_map[var_name]:
if device == first_dev_spec: continue if device == first_dev_spec: continue
...@@ -3940,13 +3959,15 @@ class PipelineOptimizer(object): ...@@ -3940,13 +3959,15 @@ class PipelineOptimizer(object):
assert main_var.is_data assert main_var.is_data
if not var_name in first_block.vars: if not var_name in first_block.vars:
self._create_var(first_block, main_var, var_name) self._create_var(first_block, main_var, var_name)
dev_index = int(device.split(':')[1])
first_block._insert_op( first_block._insert_op(
index=insert_index, index=insert_index,
type='c_send', type='c_send',
inputs={'X': first_block.var(var_name)}, inputs={'X': first_block.var(var_name)},
attrs={ attrs={
self._op_device_key: first_dev_spec, self._op_device_key: first_dev_spec,
self._op_role_key: self._op_role.Forward self._op_role_key: self._op_role.Forward,
'peer': dev_index
}) })
# Get the device that that data on # Get the device that that data on
assert device in devices assert device in devices
...@@ -3961,8 +3982,10 @@ class PipelineOptimizer(object): ...@@ -3961,8 +3982,10 @@ class PipelineOptimizer(object):
type='c_recv', type='c_recv',
outputs={'Out': [new_var]}, outputs={'Out': [new_var]},
attrs={ attrs={
'out_shape': new_var.shape,
self._op_device_key: device, self._op_device_key: device,
self._op_role_key: self._op_role.Forward, self._op_role_key: self._op_role.Forward,
'peer': first_dev_index
}) })
def _strip_grad_suffix(self, name): def _strip_grad_suffix(self, name):
...@@ -4105,13 +4128,16 @@ class PipelineOptimizer(object): ...@@ -4105,13 +4128,16 @@ class PipelineOptimizer(object):
op_role = op.all_attrs()[self._op_role_key] op_role = op.all_attrs()[self._op_role_key]
var = block.vars[var_name] var = block.vars[var_name]
prev_device_index = int(prev_device_spec.split(':')[1])
cur_device_index = int(cur_device_spec.split(':')[1])
block._insert_op( block._insert_op(
index=index + extra_index, index=index + extra_index,
type='c_send', type='c_send',
inputs={'X': var}, inputs={'X': var},
attrs={ attrs={
self._op_device_key: prev_device_spec, self._op_device_key: prev_device_spec,
self._op_role_key: op_role self._op_role_key: op_role,
'peer': prev_device_index
}) })
extra_index += 1 extra_index += 1
block._insert_op( block._insert_op(
...@@ -4119,8 +4145,10 @@ class PipelineOptimizer(object): ...@@ -4119,8 +4145,10 @@ class PipelineOptimizer(object):
type='c_recv', type='c_recv',
outputs={'Out': [var]}, outputs={'Out': [var]},
attrs={ attrs={
'out_shape': var.shape,
self._op_device_key: cur_device_spec, self._op_device_key: cur_device_spec,
self._op_role_key: op_role self._op_role_key: op_role,
'peer': cur_device_index
}) })
extra_index += 1 extra_index += 1
...@@ -4271,9 +4299,13 @@ class PipelineOptimizer(object): ...@@ -4271,9 +4299,13 @@ class PipelineOptimizer(object):
write_prog = write_info[var_name] write_prog = write_info[var_name]
write_block = write_prog.block(0) write_block = write_prog.block(0)
write_device = self._get_device_info(write_block) write_device = self._get_device_info(write_block)
write_dev_index = int(write_device.split(':')[1])
all_progs = var_info[var_name] all_progs = var_info[var_name]
for prog in all_progs: for prog in all_progs:
if prog == write_prog: continue if prog == write_prog: continue
read_block = prog.block(0)
read_device = self._get_device_info(read_block)
read_dev_index = int(read_device.split(':')[1])
write_block._insert_op( write_block._insert_op(
index=0, index=0,
...@@ -4283,19 +4315,20 @@ class PipelineOptimizer(object): ...@@ -4283,19 +4315,20 @@ class PipelineOptimizer(object):
self._op_device_key: write_device, self._op_device_key: write_device,
# A trick to make the role LRSched to avoid copy every # A trick to make the role LRSched to avoid copy every
# microbatch # microbatch
self._op_role_key: self._op_role.LRSched self._op_role_key: self._op_role.LRSched,
'peer': read_dev_index
}) })
read_block = prog.block(0)
read_device = self._get_device_info(read_block)
read_block._insert_op( read_block._insert_op(
index=0, index=0,
type='c_recv', type='c_recv',
outputs={'Out': [read_block.var(var_name)]}, outputs={'Out': [read_block.var(var_name)]},
attrs={ attrs={
'out_shape': read_block.var(var_name).shape,
self._op_device_key: read_device, self._op_device_key: read_device,
# A trick to make the role LRSched to avoid copy every # A trick to make the role LRSched to avoid copy every
# microbatch # microbatch
self._op_role_key: self._op_role.LRSched, self._op_role_key: self._op_role.LRSched,
'peer': write_dev_index
}) })
def minimize(self, def minimize(self,
...@@ -4363,12 +4396,25 @@ class PipelineOptimizer(object): ...@@ -4363,12 +4396,25 @@ class PipelineOptimizer(object):
# Step7: Add sub blocks for section programs # Step7: Add sub blocks for section programs
self._add_sub_blocks(main_block, program_list) self._add_sub_blocks(main_block, program_list)
assert (main_program._pipeline_opt and
isinstance(main_program._pipeline_opt, dict) and
'local_rank' in main_program._pipeline_opt), \
"You must use pipeline with fleet"
local_rank = main_program._pipeline_opt['local_rank']
# Step8: Split startup program
startup_program = self._split_startup_program(
startup_program, program_list[local_rank]['program'])
with open("startup_prog_%d" % local_rank, 'w') as f:
f.writelines(str(startup_program))
with open("main_prog_%d" % local_rank, 'w') as f:
f.writelines(str(program_list[local_rank]['program']))
main_program._pipeline_opt = { main_program._pipeline_opt = {
"trainer": "PipelineTrainer", "trainer": "PipelineTrainer",
"device_worker": "Section", "device_worker": "Section",
"section_program_list": program_list, "section_program": program_list[local_rank],
"place_list": place_list, "place": place_list[local_rank],
"place_id_list": place_id_list, "place_id": place_id_list[local_rank],
"sync_steps": -1, "sync_steps": -1,
"num_microbatches": self._num_microbatches, "num_microbatches": self._num_microbatches,
"start_cpu_core_id": self._start_cpu_core_id, "start_cpu_core_id": self._start_cpu_core_id,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册