未验证 提交 228bce12 编写于 作者: L lilong12 提交者: GitHub

Add 3d parallelism (#31796)

Add 3d Parallelism
Co-authored-by: NWangXi <wangxi16@baidu.com>
Co-authored-by: NJZ-LIANG <jianzhongliang10@gmail.com>
Co-authored-by: Nroot <root@yq01-sys-hic-k8s-v100-box-a225-0562.yq01.baidu.com>
上级 594bbcb1
...@@ -28,6 +28,7 @@ limitations under the License. */ ...@@ -28,6 +28,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/heter_service.h" #include "paddle/fluid/framework/heter_service.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -451,7 +452,7 @@ class HeterBoxWorker : public HogwildWorker { ...@@ -451,7 +452,7 @@ class HeterBoxWorker : public HogwildWorker {
virtual void CacheProgram(const ProgramDesc& main_program) { virtual void CacheProgram(const ProgramDesc& main_program) {
new (&program_) ProgramDesc(main_program); new (&program_) ProgramDesc(main_program);
} }
virtual void ProduceTasks() override; void ProduceTasks() override;
virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; } virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; }
virtual void SetEvent(const cudaEvent_t event) { event_ = event; } virtual void SetEvent(const cudaEvent_t event) { event_ = event; }
virtual void TrainFilesWithProfiler() {} virtual void TrainFilesWithProfiler() {}
...@@ -550,7 +551,7 @@ class PSGPUWorker : public HogwildWorker { ...@@ -550,7 +551,7 @@ class PSGPUWorker : public HogwildWorker {
virtual void CacheProgram(const ProgramDesc& main_program) { virtual void CacheProgram(const ProgramDesc& main_program) {
new (&program_) ProgramDesc(main_program); new (&program_) ProgramDesc(main_program);
} }
virtual void ProduceTasks() override; void ProduceTasks() override;
virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; } virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; }
virtual void SetEvent(const cudaEvent_t event) { event_ = event; } virtual void SetEvent(const cudaEvent_t event) { event_ = event; }
virtual void TrainFilesWithProfiler() {} virtual void TrainFilesWithProfiler() {}
...@@ -654,6 +655,9 @@ class SectionWorker : public DeviceWorker { ...@@ -654,6 +655,9 @@ class SectionWorker : public DeviceWorker {
void SetDeviceIndex(int tid) override {} void SetDeviceIndex(int tid) override {}
void SetThreadIndex(int thread_id) { thread_id_ = thread_id; } void SetThreadIndex(int thread_id) { thread_id_ = thread_id; }
void SetMicrobatchNum(int num) { num_microbatches_ = num; } void SetMicrobatchNum(int num) { num_microbatches_ = num; }
void SetPipelineStageNum(int num) { num_pipeline_stages_ = num; }
void SetPipelineStage(int stage) { pipeline_stage_ = stage; }
void SetScheduleMode(int mode) { schedule_mode_ = mode; }
void SetMicrobatchScopes(const std::vector<Scope*>& scope) { void SetMicrobatchScopes(const std::vector<Scope*>& scope) {
microbatch_scopes_ = scope; microbatch_scopes_ = scope;
} }
...@@ -661,11 +665,23 @@ class SectionWorker : public DeviceWorker { ...@@ -661,11 +665,23 @@ class SectionWorker : public DeviceWorker {
void SetSkipVars(const std::vector<std::string>& skip_vars) { void SetSkipVars(const std::vector<std::string>& skip_vars) {
skip_vars_ = skip_vars; skip_vars_ = skip_vars;
} }
void RunBackward(
int micro_id, std::unique_ptr<GarbageCollector>&,
std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
void RunForward(
int micro_id, std::unique_ptr<GarbageCollector>&,
std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
void RunUpdate(
std::unique_ptr<GarbageCollector>&,
std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
protected: protected:
int section_id_; int section_id_;
int thread_id_; int thread_id_;
int num_microbatches_; int num_microbatches_;
int num_pipeline_stages_;
int pipeline_stage_;
int schedule_mode_; // 0 for GPipe and 1 for deepspeed
std::vector<Scope*> microbatch_scopes_; std::vector<Scope*> microbatch_scopes_;
std::vector<std::string> skip_vars_; std::vector<std::string> skip_vars_;
const Scope* minibatch_scope_; const Scope* minibatch_scope_;
......
...@@ -32,6 +32,14 @@ message ShardingConfig { ...@@ -32,6 +32,14 @@ message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ]; optional float fuse_broadcast_MB = 1 [ default = 32.0 ];
optional bool hybrid_dp = 2 [ default = false ]; optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_group_size = 3 [ default = 8 ]; optional int32 sharding_group_size = 3 [ default = 8 ];
optional bool as_outer_parallelism = 4 [ default = false ];
optional int32 parallelism = 5 [ default = 1 ];
optional bool use_pipeline = 6 [ default = false ];
optional int32 acc_steps = 7 [ default = 1 ];
optional int32 schedule_mode = 8 [ default = 0 ];
optional int32 pp_bz = 9 [ default = 1 ];
optional bool pp_allreduce_in_optimize = 10 [ default = false ];
optional bool optimize_offload = 11 [ default = false ];
} }
message AMPConfig { message AMPConfig {
...@@ -44,6 +52,8 @@ message AMPConfig { ...@@ -44,6 +52,8 @@ message AMPConfig {
repeated string custom_white_list = 7; repeated string custom_white_list = 7;
repeated string custom_black_list = 8; repeated string custom_black_list = 8;
repeated string custom_black_varnames = 9; repeated string custom_black_varnames = 9;
optional bool use_pure_fp16 = 10 [ default = false ];
optional bool use_fp16_guard = 11 [ default = true ];
} }
message LocalSGDConfig { message LocalSGDConfig {
...@@ -117,6 +127,8 @@ message AsyncConfig { ...@@ -117,6 +127,8 @@ message AsyncConfig {
message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; } message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; }
message ModelParallelConfig { optional int32 parallelism = 1 [ default = 1 ]; }
message DistributedStrategy { message DistributedStrategy {
// bool options // bool options
optional Mode mode = 1 [ default = COLLECTIVE ]; optional Mode mode = 1 [ default = COLLECTIVE ];
...@@ -140,12 +152,13 @@ message DistributedStrategy { ...@@ -140,12 +152,13 @@ message DistributedStrategy {
optional int32 fuse_grad_size_in_MB = 19 [ default = 32 ]; optional int32 fuse_grad_size_in_MB = 19 [ default = 32 ];
optional float fuse_grad_size_in_TFLOPS = 20 [ default = 50 ]; optional float fuse_grad_size_in_TFLOPS = 20 [ default = 50 ];
optional bool cudnn_exhaustive_search = 21 [ default = true ]; optional bool cudnn_exhaustive_search = 21 [ default = true ];
optional int32 conv_workspace_size_limit = 22 [ default = 4000 ]; optional int32 conv_workspace_size_limit = 22 [ default = 512 ];
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ]; optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ];
optional bool adaptive_localsgd = 24 [ default = false ]; optional bool adaptive_localsgd = 24 [ default = false ];
optional bool fp16_allreduce = 25 [ default = false ]; optional bool fp16_allreduce = 25 [ default = false ];
optional bool sharding = 26 [ default = false ]; optional bool sharding = 26 [ default = false ];
optional float last_comm_group_size_MB = 27 [ default = 1 ]; optional float last_comm_group_size_MB = 27 [ default = 1 ];
optional bool model_parallel = 28 [ default = false ];
optional RecomputeConfig recompute_configs = 101; optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102; optional AMPConfig amp_configs = 102;
...@@ -158,6 +171,7 @@ message DistributedStrategy { ...@@ -158,6 +171,7 @@ message DistributedStrategy {
optional LambConfig lamb_configs = 109; optional LambConfig lamb_configs = 109;
optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110; optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110;
optional ShardingConfig sharding_configs = 111; optional ShardingConfig sharding_configs = 111;
optional ModelParallelConfig model_parallel_configs = 112;
optional BuildStrategy build_strategy = 201; optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202; optional ExecutionStrategy execution_strategy = 202;
} }
......
...@@ -25,6 +25,9 @@ namespace framework { ...@@ -25,6 +25,9 @@ namespace framework {
void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) { Dataset* dataset) {
const auto& section_params = trainer_desc.section_param(); const auto& section_params = trainer_desc.section_param();
const auto num_pipeline_stages_ = section_params.num_pipeline_stages();
const auto pipeline_stage_ = section_params.pipeline_stage();
const auto schedule_mode_ = section_params.schedule_mode();
num_microbatches_ = section_params.num_microbatches(); num_microbatches_ = section_params.num_microbatches();
VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_; VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_;
trainer_desc_ = trainer_desc; trainer_desc_ = trainer_desc;
...@@ -40,6 +43,9 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -40,6 +43,9 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
this_worker->SetPlace(place_); this_worker->SetPlace(place_);
this_worker->Initialize(trainer_desc); this_worker->Initialize(trainer_desc);
this_worker->SetMicrobatchNum(num_microbatches_); this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
this_worker->SetScheduleMode(schedule_mode_);
} }
void PipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) { void PipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) {
...@@ -76,7 +82,10 @@ void PipelineTrainer::CopyParameters(int microbatch_id, ...@@ -76,7 +82,10 @@ void PipelineTrainer::CopyParameters(int microbatch_id,
for (auto& var : global_block.AllVars()) { for (auto& var : global_block.AllVars()) {
bool is_param_grad = false; bool is_param_grad = false;
size_t pos = 0; size_t pos = 0;
if ((pos = var->Name().find(kGradVarSuffix)) != std::string::npos) { // A magic suffix to indicated the merged gradient.
std::string magicSuffix = "MERGED";
if ((pos = var->Name().find(kGradVarSuffix)) != std::string::npos &&
var->Name().find(magicSuffix) != std::string::npos) {
auto prefix_name = var->Name().substr(0, pos); auto prefix_name = var->Name().substr(0, pos);
if (param_map.find(prefix_name) != param_map.end()) { if (param_map.find(prefix_name) != param_map.end()) {
is_param_grad = true; is_param_grad = true;
......
...@@ -11,36 +11,90 @@ limitations under the License. */ ...@@ -11,36 +11,90 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
#include <float.h> #include <float.h>
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/program_desc.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/fleet/box_wrapper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/lodtensor_printer.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class TrainerDesc;
uint64_t SectionWorker::batch_id_(0); uint64_t SectionWorker::batch_id_(0);
void SectionWorker::Initialize(const TrainerDesc& desc) { void SectionWorker::Initialize(const TrainerDesc &desc) {
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
program_.reset( program_.reset(
new ProgramDesc(desc.section_param().section_config().program_desc())); new ProgramDesc(desc.section_param().section_config().program_desc()));
for (auto& op_desc : program_->Block(0).AllOps()) { for (auto &op_desc : program_->Block(0).AllOps()) {
ops_.push_back(OpRegistry::CreateOp(*op_desc)); ops_.push_back(OpRegistry::CreateOp(*op_desc));
} }
} }
void SectionWorker::RunForward(
int micro_id, std::unique_ptr<GarbageCollector> &gc,
std::unordered_map<const OperatorBase *, std::vector<std::string>>
&unused_vars_) {
for (auto &op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
// We run op with op_role = kLRSched only for the first microbatch
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss)) ||
op_role == static_cast<int>(OpRole::kLRSched);
bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss));
if ((micro_id == 0 && run_first_mbatch) || (micro_id != 0 && run_others)) {
VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
<< micro_id;
op->Run(*microbatch_scopes_[micro_id], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(),
unused_vars_, gc.get());
}
}
}
}
void SectionWorker::RunBackward(
int micro_id, std::unique_ptr<GarbageCollector> &gc,
std::unordered_map<const OperatorBase *, std::vector<std::string>>
&unused_vars_) {
for (auto &op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kBackward) ||
op_role == (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss))) {
VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
<< micro_id;
op->Run(*microbatch_scopes_[micro_id], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(),
unused_vars_, gc.get());
}
}
}
}
void SectionWorker::RunUpdate(
std::unique_ptr<GarbageCollector> &gc,
std::unordered_map<const OperatorBase *, std::vector<std::string>>
&unused_vars_) {
for (auto &op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kOptimize)) {
VLOG(3) << "Update: running op " << op->Type();
op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
op.get(), unused_vars_, gc.get());
}
}
}
}
void SectionWorker::TrainFiles() { void SectionWorker::TrainFiles() {
VLOG(5) << "begin section_worker TrainFiles"; VLOG(5) << "begin section_worker TrainFiles";
...@@ -58,61 +112,49 @@ void SectionWorker::TrainFiles() { ...@@ -58,61 +112,49 @@ void SectionWorker::TrainFiles() {
#endif #endif
} }
for (int i = 0; i < num_microbatches_; ++i) { if (schedule_mode_ == 0) {
for (auto& op : ops_) { // Gpipe scheduler which runs all forwards first, then backwards, then
int op_role = op->Attr<int>(std::string("op_role")); // update
// We run op with op_role = kLRSched only for the first microbatch // step1: run forward
// to avoid increasing the @LR_DECAY_STEP@ multiple times. for (int i = 0; i < num_microbatches_; ++i) {
bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) || RunForward(i, gc, unused_vars_);
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss)) ||
op_role == static_cast<int>(OpRole::kLRSched);
bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss));
if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
<< i;
op->Run(*microbatch_scopes_[i], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
gc.get());
}
}
} }
cudaDeviceSynchronize(); // step2: run backward
} for (int i = 0; i < num_microbatches_; ++i) {
RunBackward(i, gc, unused_vars_);
// backward pass }
for (int i = 0; i < num_microbatches_; ++i) { // step2: run update
for (auto& op : ops_) { RunUpdate(gc, unused_vars_);
int op_role = op->Attr<int>(std::string("op_role")); } else {
if (op_role == static_cast<int>(OpRole::kBackward) || // 1F1B scheduler
op_role == (static_cast<int>(OpRole::kBackward) | auto startup_steps = num_pipeline_stages_ - pipeline_stage_ - 1;
static_cast<int>(OpRole::kLoss))) { VLOG(3) << "startup_steps:" << startup_steps
VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch " << ", num_stages: " << num_pipeline_stages_
<< i; << ", stage:" << pipeline_stage_;
op->Run(*microbatch_scopes_[i], place_); if (startup_steps > num_microbatches_) {
if (gc) { startup_steps = num_microbatches_;
DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_, }
gc.get()); int fw_step = 0;
} int bw_step = 0;
} // startup phase
while (fw_step < startup_steps) {
RunForward(fw_step, gc, unused_vars_);
fw_step += 1;
} }
cudaDeviceSynchronize();
}
// update pass // 1f1b phase
for (auto& op : ops_) { while (fw_step < num_microbatches_) {
int op_role = op->Attr<int>(std::string("op_role")); RunForward(fw_step, gc, unused_vars_);
if (op_role == static_cast<int>(OpRole::kOptimize)) { fw_step += 1;
VLOG(3) << "Update: running op " << op->Type(); RunBackward(bw_step, gc, unused_vars_);
op->Run(*microbatch_scopes_[0], place_); bw_step += 1;
if (gc) { }
DeleteUnusedTensors(*microbatch_scopes_[0], op.get(), unused_vars_, // backward phase
gc.get()); while (bw_step < num_microbatches_) {
} RunBackward(bw_step, gc, unused_vars_);
bw_step += 1;
} }
RunUpdate(gc, unused_vars_);
} }
dev_ctx_->Wait(); dev_ctx_->Wait();
++batch_id_; ++batch_id_;
......
...@@ -93,6 +93,9 @@ message SectionWorkerParameter { ...@@ -93,6 +93,9 @@ message SectionWorkerParameter {
optional int32 start_cpu_core_id = 4 [ default = 1 ]; optional int32 start_cpu_core_id = 4 [ default = 1 ];
repeated string param_need_sync = 5; repeated string param_need_sync = 5;
optional int32 num_microbatches = 6; optional int32 num_microbatches = 6;
optional int32 num_pipeline_stages = 7 [ default = 1 ];
optional int32 pipeline_stage = 8 [ default = 1 ];
optional int32 schedule_mode = 9 [ default = 0 ];
} }
message SectionConfig { message SectionConfig {
......
...@@ -55,7 +55,8 @@ class CGenNCCLIdOp : public framework::OperatorBase { ...@@ -55,7 +55,8 @@ class CGenNCCLIdOp : public framework::OperatorBase {
SendBroadCastNCCLID(endpoint_list, 1, func, local_scope); SendBroadCastNCCLID(endpoint_list, 1, func, local_scope);
} else { } else {
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
RecvBroadCastNCCLID(endpoint, 1, func, local_scope); int server_fd = platform::SocketServer::GetInstance(endpoint).socket();
platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids);
} }
scope.DeleteScope(&local_scope); scope.DeleteScope(&local_scope);
} }
...@@ -71,8 +72,7 @@ class CGenNCCLIdOp : public framework::OperatorBase { ...@@ -71,8 +72,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {}
}
}; };
#endif #endif
......
...@@ -31,7 +31,9 @@ limitations under the License. */ ...@@ -31,7 +31,9 @@ limitations under the License. */
#include "paddle/fluid/string/split.h" #include "paddle/fluid/string/split.h"
namespace paddle { namespace paddle {
namespace operators { namespace platform {
std::once_flag SocketServer::init_flag_;
constexpr char COMM_HEAD[] = "_pd_gen_comm_id_"; constexpr char COMM_HEAD[] = "_pd_gen_comm_id_";
...@@ -340,5 +342,34 @@ void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num, ...@@ -340,5 +342,34 @@ void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num,
CloseSocket(client); CloseSocket(client);
} }
} // namespace operators SocketServer& SocketServer::GetInstance(const std::string& end_point) {
static SocketServer instance;
std::call_once(init_flag_, [&]() {
instance.server_fd_ = CreateListenSocket(end_point);
instance.end_point_ = end_point;
});
PADDLE_ENFORCE_NE(instance.server_fd_, -1,
platform::errors::Unavailable(
"listen socket failed with end_point=%s", end_point));
PADDLE_ENFORCE_EQ(instance.end_point_, end_point,
platform::errors::InvalidArgument(
"old end_point=%s must equal with new end_point=%s",
instance.end_point_, end_point));
return instance;
}
/// template instantiation
#define INSTANT_TEMPLATE(Type) \
template void SendBroadCastCommID<Type>(std::vector<std::string> servers, \
std::vector<Type> * nccl_ids); \
template void RecvBroadCastCommID<Type>(std::string endpoint, \
std::vector<Type> * nccl_ids);
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
INSTANT_TEMPLATE(ncclUniqueId)
#endif
#ifdef PADDLE_WITH_XPU_BKCL
INSTANT_TEMPLATE(BKCLUniqueId)
#endif
} // namespace platform
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,8 @@ limitations under the License. */ ...@@ -15,6 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include <functional> #include <functional>
#include <memory>
#include <mutex> // NOLINT
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -25,7 +27,7 @@ class Scope; ...@@ -25,7 +27,7 @@ class Scope;
} // namespace paddle } // namespace paddle
namespace paddle { namespace paddle {
namespace operators { namespace platform {
int CreateListenSocket(const std::string& ep); int CreateListenSocket(const std::string& ep);
...@@ -41,8 +43,26 @@ void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num, ...@@ -41,8 +43,26 @@ void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num,
const framework::Scope& scope); const framework::Scope& scope);
// recv nccl id from socket // recv nccl id from socket
void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num, template <typename CommUniqueId>
std::function<std::string(size_t)> func, void RecvBroadCastCommID(int server_fd, std::string endpoint,
const framework::Scope& scope); std::vector<CommUniqueId>* nccl_ids);
} // namespace operators
class SocketServer {
public:
SocketServer() = default;
~SocketServer() { CloseSocket(server_fd_); }
int socket() const { return server_fd_; }
static SocketServer& GetInstance(const std::string& end_point);
private:
int server_fd_{-1};
std::string end_point_;
static std::once_flag init_flag_;
};
} // namespace platform
} // namespace paddle } // namespace paddle
...@@ -736,6 +736,60 @@ class DistributedStrategy(object): ...@@ -736,6 +736,60 @@ class DistributedStrategy(object):
"sharding_configs") "sharding_configs")
assign_configs_value(self.strategy.sharding_configs, configs) assign_configs_value(self.strategy.sharding_configs, configs)
@property
def model_parallel(self):
"""
Indicating whether we are using model parallel parallelism for distributed training.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.model_parallel = True
"""
return self.strategy.model_parallel
@model_parallel.setter
@is_strict_auto
def model_parallel(self, flag):
if isinstance(flag, bool):
self.strategy.model_parallel = flag
else:
print("WARNING: model_parallel should have value of bool type")
@property
def model_parallel_configs(self):
"""
Set model_parallel parallelism configurations.
**Notes**:
**Detailed arguments for model_parallel_configs**
**parallelism**: degree of model parallel
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.model_parallel = True
strategy.model_parallel_configs = {"parallelism": 12}
"""
return get_msg_dict(self.strategy.model_parallel_configs)
@model_parallel_configs.setter
@is_strict_auto
def model_parallel_configs(self, configs):
check_configs_key(self.strategy.model_parallel_configs, configs,
"model_parallel_configs")
assign_configs_value(self.strategy.model_parallel_configs, configs)
@property @property
def pipeline(self): def pipeline(self):
""" """
......
...@@ -17,6 +17,7 @@ from .gradient_merge_optimizer import GradientMergeOptimizer ...@@ -17,6 +17,7 @@ from .gradient_merge_optimizer import GradientMergeOptimizer
from .graph_execution_optimizer import GraphExecutionOptimizer from .graph_execution_optimizer import GraphExecutionOptimizer
from .parameter_server_optimizer import ParameterServerOptimizer from .parameter_server_optimizer import ParameterServerOptimizer
from .pipeline_optimizer import PipelineOptimizer from .pipeline_optimizer import PipelineOptimizer
from .model_parallel_optimizer import ModelParallelOptimizer
from .localsgd_optimizer import LocalSGDOptimizer from .localsgd_optimizer import LocalSGDOptimizer
from .localsgd_optimizer import AdaptiveLocalSGDOptimizer from .localsgd_optimizer import AdaptiveLocalSGDOptimizer
from .lars_optimizer import LarsOptimizer from .lars_optimizer import LarsOptimizer
......
...@@ -50,15 +50,17 @@ class AMPOptimizer(MetaOptimizerBase): ...@@ -50,15 +50,17 @@ class AMPOptimizer(MetaOptimizerBase):
self.inner_opt, amp_lists, config['init_loss_scaling'], self.inner_opt, amp_lists, config['init_loss_scaling'],
config['incr_every_n_steps'], config['decr_every_n_nan_or_inf'], config['incr_every_n_steps'], config['decr_every_n_nan_or_inf'],
config['incr_ratio'], config['decr_ratio'], config['incr_ratio'], config['decr_ratio'],
config['use_dynamic_loss_scaling']) config['use_dynamic_loss_scaling'], config['use_pure_fp16'],
config['use_fp16_guard'])
# if worker_num > 1, all cards will communication with each other, # if worker_num > 1, all cards will communication with each other,
# add is_distributed to optimize amp, overlap communication and # add is_distributed to optimize amp, overlap communication and
# computation by split the check_finite_and_unscale op. # computation by split the check_finite_and_unscale op.
is_distributed = self.role_maker._worker_num() > 1 is_distributed = self.role_maker._worker_num() > 1
if self.user_defined_strategy.sharding: #if self.user_defined_strategy.sharding or self.user_defined_strategy.model_parallel:
# FIXME(wangxi). sharding failed when split check_finite_and_unscale # # FIXME(wangxi). sharding failed when split check_finite_and_unscale
is_distributed = False # # FIXME(JZ-LIANG). To support Sharding-Megatron-AMP, Megatron should follow Sharding's behavior
# is_distributed = False
self.wrapped_opt._set_distributed(is_distributed) self.wrapped_opt._set_distributed(is_distributed)
def _can_apply(self): def _can_apply(self):
...@@ -112,3 +114,11 @@ class AMPOptimizer(MetaOptimizerBase): ...@@ -112,3 +114,11 @@ class AMPOptimizer(MetaOptimizerBase):
self.wrapped_opt.minimize(loss, startup_program, self.wrapped_opt.minimize(loss, startup_program,
parameter_list, no_grad_set) parameter_list, no_grad_set)
return optimize_ops, params_grads return optimize_ops, params_grads
def amp_init(self,
place,
scope=None,
test_program=None,
use_fp16_test=False):
return self.wrapped_opt.amp_init(place, scope, test_program,
use_fp16_test)
...@@ -25,6 +25,24 @@ OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() ...@@ -25,6 +25,24 @@ OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
OP_ROLE_VAR_KEY = core.op_proto_and_checker_maker.kOpRoleVarAttrName() OP_ROLE_VAR_KEY = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
class Topology:
"""A 4-D structure to describe the process group."""
def __init__(self, axes, dims):
pass
class ParallelGrid:
"""Initialize each process group."""
def __init__(self, topology):
self.build_global_group()
self.build_mp_group()
self.build_sharding_group()
self.build_pp_group()
self.build_dp_group()
def is_update_op(op): def is_update_op(op):
return 'Param' in op.input_names and 'Grad' in op.input_names and \ return 'Param' in op.input_names and 'Grad' in op.input_names and \
"LearningRate" in op.input_names "LearningRate" in op.input_names
...@@ -66,16 +84,49 @@ class CollectiveHelper(object): ...@@ -66,16 +84,49 @@ class CollectiveHelper(object):
self.role_maker._worker_index(), ring_id, self.wait_port) self.role_maker._worker_index(), ring_id, self.wait_port)
self._broadcast_params() self._broadcast_params()
def _init_communicator(self, program, current_endpoint, endpoints, rank, def _init_communicator(self,
ring_id, wait_port): program,
current_endpoint,
endpoints,
rank,
ring_id,
wait_port,
sync=True):
nranks = len(endpoints) nranks = len(endpoints)
other_endpoints = endpoints[:] other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint) other_endpoints.remove(current_endpoint)
block = program.global_block() block = program.global_block()
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
if rank == 0 and wait_port: if not wait_port and sync:
wait_server_ready(other_endpoints) temp_var = block.create_var(
nccl_id_var = block.create_var( name=unique_name.generate('temp_var'),
dtype=core.VarDesc.VarType.INT32,
persistable=False,
stop_gradient=True)
block.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [temp_var]},
attrs={
'shape': [1],
'dtype': temp_var.dtype,
'value': 1,
'force_cpu': False,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_allreduce_sum',
inputs={'X': [temp_var]},
outputs={'Out': [temp_var]},
attrs={'ring_id': 3,
OP_ROLE_KEY: OpRole.Forward})
block.append_op(
type='c_sync_comm_stream',
inputs={'X': temp_var},
outputs={'Out': temp_var},
attrs={'ring_id': 3,
OP_ROLE_KEY: OpRole.Forward})
comm_id_var = block.create_var(
name=unique_name.generate('nccl_id'), name=unique_name.generate('nccl_id'),
persistable=True, persistable=True,
type=core.VarDesc.VarType.RAW) type=core.VarDesc.VarType.RAW)
...@@ -100,9 +151,7 @@ class CollectiveHelper(object): ...@@ -100,9 +151,7 @@ class CollectiveHelper(object):
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward
}) })
elif core.is_compiled_with_npu(): elif core.is_compiled_with_npu():
endpoint_to_index_map = { endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
e: idx for idx, e in enumerate(endpoints)
}
block.append_op( block.append_op(
type='c_comm_init_hcom', type='c_comm_init_hcom',
inputs={}, inputs={},
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from __future__ import print_function
from __future__ import division
import paddle.fluid as fluid
from paddle.fluid import core, unique_name
from ..base.private_helper_function import wait_server_ready
from .meta_optimizer_base import MetaOptimizerBase
from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_update_op, is_loss_grad_op, is_backward_op, is_optimizer_op
class ModelParallelHelper(object):
def __init__(self, role_maker, wait_port=True, megatron_dp=False):
self.wait_port = wait_port
self.role_maker = role_maker
self.megatron_dp = megatron_dp
def update_startup_program(self,
startup_program=None,
inner_parallelism=None):
self.startup_program = startup_program
nranks = self.role_maker._worker_num()
rank = self.role_maker._worker_index()
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[rank]
# Create ring 0 for all model parallel parts within a single model
mp_endpoints = []
mp_rank = rank % inner_parallelism
mp_id = rank // inner_parallelism
for idx, ep in enumerate(endpoints):
if idx // inner_parallelism == mp_id:
mp_endpoints.append(ep)
print("model parallel eps:{}, rank{}".format(mp_endpoints, mp_rank))
self._init_communicator(self.startup_program, current_endpoint,
mp_endpoints, mp_rank, 0, self.wait_port)
self._broadcast_params(0, broadcast_distributed_weight=False)
print("megatron group size: {}".format(inner_parallelism))
print("megatron rank: {}".format(mp_rank))
print("megatron endpoints: {}".format(mp_endpoints))
if self.megatron_dp:
mp_num = len(endpoints) // inner_parallelism
if mp_num == 1: return
# Create rings for gpus as the same model parallel part
eps = []
dp_rank = rank // inner_parallelism
dp_id = rank % inner_parallelism
#if dp_rank == 1: dp_rank =0
#if dp_rank == 0: dp_rank =1
ring_id = 1
for idx, ep in enumerate(endpoints):
if idx % inner_parallelism == dp_id:
eps.append(ep)
#ep = eps.pop(0)
#eps.insert(1, ep)
print("data parallel eps:{}, rank{}".format(eps, dp_rank))
self._init_communicator(self.startup_program, current_endpoint, eps,
dp_rank, ring_id, self.wait_port)
self._broadcast_params(ring_id, broadcast_distributed_weight=True)
def _init_communicator(self, program, current_endpoint, endpoints, rank,
ring_id, wait_port):
nranks = len(endpoints)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
block = program.global_block()
nccl_id_var = block.create_var(
name=unique_name.generate('nccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_nccl_id',
inputs={},
outputs={'Out': nccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
OP_ROLE_KEY: OpRole.Forward,
})
block.append_op(
type='c_comm_init',
inputs={'X': nccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward,
})
def _broadcast_params(self, ring_id, broadcast_distributed_weight):
block = self.startup_program.global_block()
for param in block.iter_parameters():
if not broadcast_distributed_weight and param.is_distributed:
continue
block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': ring_id,
'root': 0,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_sync_comm_stream',
inputs={'X': param},
outputs={'Out': param},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward})
class ModelParallelOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
super(ModelParallelOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
self.meta_optimizers_white_list = [
"RecomputeOptimizer",
"AMPOptimizer",
"LarsOptimizer",
"LambOptimizer",
]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
self.megatron_dp = False
def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
user_defined_strategy):
super(ModelParallelOptimizer, self)._set_basic_info(
loss, role_maker, user_defined_optimizer, user_defined_strategy)
self.inner_parallelism = user_defined_strategy.model_parallel_configs[
'parallelism']
def _can_apply(self):
if not self.role_maker._is_collective:
return False
if self.user_defined_strategy.model_parallel == True:
return True
return False
def _disable_strategy(self, dist_strategy):
dist_strategy.model_parallel = False
dist_strategy.model_parallel_configs = {}
def _enable_strategy(self, dist_strategy, context):
dist_strategy.model_parallel = True
dist_strategy.model_parallel_configs = {"parallelism": 1, }
# the following function will be used by AMP if both Megatron and AMP are turn on together.
def apply_gradients(self, params_grads):
return self.minimize_impl(params_grads=params_grads)
def minimize_impl(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker._worker_index()]
self.startup_program = startup_program
if startup_program is None:
self.startup_program = fluid.default_startup_program()
# (TODO) check the order of metaoptimizer
# (TODO) check the params_grads
optimize_ops, params_grads = self.inner_opt.minimize(
loss, self.startup_program, parameter_list, no_grad_set)
self.main_program = loss.block.program
self.inner_parallelism = self.inner_parallelism
self.nranks = len(endpoints)
pipeline_helper = ModelParallelHelper(self.role_maker)
pipeline_helper.update_startup_program(self.startup_program,
self.inner_parallelism)
assert self.nranks % self.inner_parallelism == 0
if self.megatron_dp:
# data parallelism
dp_parallelism = self.nranks // self.inner_parallelism
self._transpile_main_program(loss, dp_parallelism)
return optimize_ops, params_grads
def _transpile_main_program(self, loss, dp_parallelism):
self._insert_loss_grad_ops(loss, dp_parallelism)
ring_id = 1
print("ring_id: ", ring_id)
# for ring_id in range(1, dp_parallelism + 1):
self._insert_allreduce_ops(loss, ring_id)
def _insert_loss_grad_ops(self, loss, dp_parallelism):
"""
In order to keep the learning rate consistent in different numbers of
training workers, we scale the loss grad by the number of workers
"""
block = loss.block
for idx, op in reversed(list(enumerate(block.ops))):
if is_loss_grad_op(op):
loss_grad_var = block.vars[op.output_arg_names[0]]
block._insert_op(
idx + 1,
type='scale',
inputs={'X': loss_grad_var},
outputs={'Out': loss_grad_var},
attrs={
'scale': 1.0 / dp_parallelism,
OP_ROLE_KEY: OpRole.Backward
})
def _insert_allreduce_ops(self, loss, ring_id):
block = loss.block
grad = None
for idx, op in reversed(list(enumerate(block.ops))):
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if len(op_role_var) == 0:
continue
assert len(op_role_var) % 2 == 0
offset = idx
for i in range(0, len(op_role_var), 2):
param = block.vars[op_role_var[i]]
grad = block.vars[op_role_var[i + 1]]
#if param.is_distributed:
# continue
if offset == idx:
offset += 1
block._insert_op(
offset,
type='c_sync_calc_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={OP_ROLE_KEY: OpRole.Backward})
offset += 1
block._insert_op(
offset,
type='c_allreduce_sum',
inputs={'X': grad},
outputs={'Out': grad},
attrs={
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward
})
if grad is None:
return
for idx, op in list(enumerate(block.ops)):
if is_optimizer_op(op):
block._insert_op(
idx,
type='c_sync_comm_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward})
break
...@@ -154,8 +154,10 @@ class PipelineOptimizer(MetaOptimizerBase): ...@@ -154,8 +154,10 @@ class PipelineOptimizer(MetaOptimizerBase):
def __init__(self, optimizer): def __init__(self, optimizer):
super(PipelineOptimizer, self).__init__(optimizer) super(PipelineOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer self.inner_opt = optimizer
# we do not allow meta optimizer to be inner optimizer currently self.meta_optimizers_white_list = [
self.meta_optimizers_white_list = [] "RecomputeOptimizer",
"AMPOptimizer",
]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
def _set_basic_info(self, loss, role_maker, user_defined_optimizer, def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
......
...@@ -73,7 +73,7 @@ class FP16Utils(object): ...@@ -73,7 +73,7 @@ class FP16Utils(object):
@staticmethod @staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, ring_id): def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
""" """
1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard 1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding 2. revise amp inifine grad checking for sharding
""" """
# remove cast # remove cast
...@@ -81,7 +81,9 @@ class FP16Utils(object): ...@@ -81,7 +81,9 @@ class FP16Utils(object):
if not FP16Utils.is_fp32_cast_op(block, op): if not FP16Utils.is_fp32_cast_op(block, op):
continue continue
output_name = op.desc.output_arg_names()[0] output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD") param_name = output_name.strip(
"@GRAD@MERGED"
) if "@MERGED" in output_name else output_name.strip("@GRAD")
if param_name not in shard.global_params: if param_name not in shard.global_params:
raise ValueError("Output 'X' of cast_op must be a grad of" raise ValueError("Output 'X' of cast_op must be a grad of"
"model param, but {} is not a grad".format( "model param, but {} is not a grad".format(
...@@ -103,20 +105,35 @@ class FP16Utils(object): ...@@ -103,20 +105,35 @@ class FP16Utils(object):
op._rename_input(inf_var_name, inf_var_name + "@sharding") op._rename_input(inf_var_name, inf_var_name + "@sharding")
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
reversed_x = [] reversed_x = []
reversed_x_paramname = []
for input_name in op.desc.input('X'): for input_name in op.desc.input('X'):
param_name = input_name.strip("@GRAD") param_name = input_name.strip("@GRAD@MERGED")
if param_name not in shard.global_params: if param_name not in shard.global_params:
raise ValueError( raise ValueError(
"Input 'X' of check_finite_and_unscale must" "Input 'X' of check_finite_and_unscale must"
"be grads, but {} is not a grad".format(input_name)) "be grads, but {} is not a grad".format(input_name))
if shard.has_param(param_name): if shard.has_param(param_name):
reversed_x.append(input_name) reversed_x.append(input_name)
reversed_x_paramname.append(param_name)
op.desc.set_input('X', reversed_x) op.desc.set_input('X', reversed_x)
op.desc.set_output('Out', reversed_x) op.desc.set_output('Out', reversed_x)
# the grad checking should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
should_check_param = set(shard.global_params).intersection(
set([
param
for param, worker_idx in shard.global_param2device.
items() if worker_idx == shard.worker_idx
]))
#assert to_check_param == should_check_param, "amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(
# should_check_param - to_check_param,
# to_check_param - should_check_param)
if update_loss_scaling_op_idx == -1: if update_loss_scaling_op_idx == -1:
return return
inf_var = block.var(inf_var_name) inf_var = block.var(inf_var_name)
inf_var_fp32 = block.create_var( inf_var_int32 = block.create_var(
name=inf_var_name + "@cast_int32", name=inf_var_name + "@cast_int32",
shape=inf_var.shape, shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32) dtype=core.VarDesc.VarType.INT32)
...@@ -128,32 +145,36 @@ class FP16Utils(object): ...@@ -128,32 +145,36 @@ class FP16Utils(object):
update_loss_scaling_op_idx, update_loss_scaling_op_idx,
type='cast', type='cast',
inputs={'X': inf_var}, inputs={'X': inf_var},
outputs={'Out': inf_var_fp32}, outputs={'Out': inf_var_int32},
attrs={ attrs={
"in_dtype": inf_var.dtype, "in_dtype": inf_var.dtype,
"out_dtype": inf_var_fp32.dtype, "out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
insert_sync_calc_op(block, update_loss_scaling_op_idx + 1, # this allreduce communication should not overlap with calc
[inf_var_fp32]) # insert_sync_calc_op(block, update_loss_scaling_op_idx + 1,
# [inf_var_int32])
block._insert_op_without_sync( block._insert_op_without_sync(
update_loss_scaling_op_idx + 2, update_loss_scaling_op_idx + 1,
type='c_allreduce_max', type='c_allreduce_max',
inputs={'X': inf_var_fp32}, inputs={'X': inf_var_int32},
outputs={'Out': inf_var_fp32}, outputs={'Out': inf_var_int32},
attrs={'ring_id': ring_id, attrs={
OP_ROLE_KEY: OpRole.Optimize}) 'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3, # comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
ring_id, [inf_var_fp32]) # ring_id, [inf_var_int32])
block._insert_op_without_sync( block._insert_op_without_sync(
update_loss_scaling_op_idx + 3 + comm_op_num, update_loss_scaling_op_idx + 2,
type='cast', type='cast',
inputs={'X': inf_var_fp32}, inputs={'X': inf_var_int32},
outputs={'Out': inf_var_sharding}, outputs={'Out': inf_var_sharding},
attrs={ attrs={
"in_dtype": inf_var_fp32.dtype, "in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var_sharding.dtype, "out_dtype": inf_var_sharding.dtype,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
......
...@@ -16,8 +16,8 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole ...@@ -16,8 +16,8 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
class GradientClipHelper(object): class GradientClipHelper(object):
def __init__(self, sharding_ring_id): def __init__(self, mp_ring_id):
self.sharding_ring_id = sharding_ring_id self.mp_ring_id = mp_ring_id
def _is_gradient_clip_op(self, op): def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \ return op.desc.has_attr("op_namescope") \
...@@ -31,6 +31,7 @@ class GradientClipHelper(object): ...@@ -31,6 +31,7 @@ class GradientClipHelper(object):
""" """
deperated_vars = set() deperated_vars = set()
deperate_op_idx = set() deperate_op_idx = set()
reversed_x_paramname = []
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if not self._is_gradient_clip_op(op): if not self._is_gradient_clip_op(op):
continue continue
...@@ -40,15 +41,18 @@ class GradientClipHelper(object): ...@@ -40,15 +41,18 @@ class GradientClipHelper(object):
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if input_name in deperated_vars: if input_name in deperated_vars:
deperate_op = True deperate_op = True
param_name = input_name.strip("@GRAD") param_name = input_name.strip("@GRAD@MERGED")
if shard.is_param(param_name) and \ if shard.is_param(param_name) and \
not shard.has_param(param_name): not shard.has_param(param_name):
deperate_op = True deperate_op = True
elif shard.is_param(param_name):
reversed_x_paramname.append(param_name)
if deperate_op: if deperate_op:
deperate_op_idx.add(idx) deperate_op_idx.add(idx)
for output_name in op.desc.output_arg_names(): for output_name in op.desc.output_arg_names():
deperated_vars.add(output_name) if output_name not in op.desc.input_arg_names():
deperated_vars.add(output_name)
if not deperated_vars: if not deperated_vars:
# got no gradient_clip op # got no gradient_clip op
...@@ -65,31 +69,47 @@ class GradientClipHelper(object): ...@@ -65,31 +69,47 @@ class GradientClipHelper(object):
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars: if input_name not in deperated_vars:
reversed_inputs.append(input_name) reversed_inputs.append(input_name)
op.desc.set_input("X", reversed_inputs) op.desc.set_input("X", reversed_inputs)
assert (len(op.desc.output_arg_names()) == 1) assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0] sum_res = op.desc.output_arg_names()[0]
block._insert_op_without_sync(
idx + 1, # this allreduce should not overlap with calc and should be scheduled in calc stream
type='c_sync_comm_stream', # block._insert_op_without_sync(
inputs={'X': sum_res}, # idx + 1,
outputs={'Out': sum_res}, # type='c_sync_comm_stream',
attrs={'ring_id': 0, # inputs={'X': sum_res},
OP_ROLE_KEY: OpRole.Optimize}) # outputs={'Out': sum_res},
# attrs={'ring_id': 0,
# OP_ROLE_KEY: OpRole.Optimize})
block._insert_op_without_sync( block._insert_op_without_sync(
idx + 1, idx + 1,
type='c_allreduce_sum', type='c_allreduce_sum',
inputs={'X': sum_res}, inputs={'X': sum_res},
outputs={'Out': sum_res}, outputs={'Out': sum_res},
attrs={ attrs={
'ring_id': self.sharding_ring_id, 'ring_id': self.mp_ring_id,
OP_ROLE_KEY: OpRole.Optimize 'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
}) })
block._insert_op_without_sync( # block._insert_op_without_sync(
idx + 1, # idx + 1,
type='c_sync_calc_stream', # type='c_sync_calc_stream',
inputs={'X': sum_res}, # inputs={'X': sum_res},
outputs={'Out': sum_res}, # outputs={'Out': sum_res},
attrs={OP_ROLE_KEY: OpRole.Optimize}) # attrs={OP_ROLE_KEY: OpRole.Optimize})
# the grad sum here should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
should_check_param = set(shard.global_params).intersection(
set([
param for param, worker_idx in shard.global_param2device.items()
if worker_idx == shard.worker_idx
]))
assert to_check_param == should_check_param, "amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(
should_check_param - to_check_param,
to_check_param - should_check_param)
for var_name in deperated_vars: for var_name in deperated_vars:
block._remove_var(var_name, sync=False) block._remove_var(var_name, sync=False)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole
from paddle.fluid import core, unique_name
class OffloadHelper(object):
cpu_place_type = 0
cuda_place_type = 1
cuda_pinned_place_type = 2
def __init__(self):
pass
"0: dst is on CPUPlace. "
"1: dst is on CUDAPlace. "
"2: dst is on CUDAPinnedPlace. "
def _insert_cast_op(self, block, idx, src_name, dst_name):
src_var = block.var(src_name)
if not block.has_var(dst_name):
block.create_var(
name=dst_name,
shape=src_var.shape,
dtype=core.VarDesc.VarType.FP16,
persistable=True)
dst_var = block.var(dst_name)
assert dst_var.dtype == core.VarDesc.VarType.FP16
block._insert_op_without_sync(
idx,
type='cast',
inputs={'X': src_var},
outputs={'Out': dst_var},
attrs={
'in_dtype': src_var.dtype,
'out_dtype': dst_var.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
def _insert_memcpy_op(self, block, idx, src_name, dst_name, dst_place_type):
src_var = block.var(src_name)
dst_var = block.var(dst_name)
block._insert_op_without_sync(
idx,
type='memcpy',
inputs={'X': src_var},
outputs={'Out': dst_var},
attrs={
'dst_place_type': dst_place_type,
OP_ROLE_KEY: OpRole.Optimize,
})
def _insert_fetch_op(self, block, idx, src_name, dst_name):
self._insert_memcpy_op(block, idx, src_name, dst_name,
OffloadHelper.cuda_place_type)
def _insert_offload_op(self, block, idx, src_name, dst_name):
self._insert_memcpy_op(block, idx, src_name, dst_name,
OffloadHelper.cuda_pinned_place_type)
def _get_offload_var_name(self, name):
return unique_name.generate(name + '@offload')
def _create_offload_var(self, var_name, offload_var_name, blocks):
for block in blocks:
var = block.var(var_name)
var.persistable = False
offload_var = block.create_var(
name=offload_var_name,
shape=var.shape,
dtype=var.dtype,
persistable=True)
def offload_fp32param(self, block, startup_block):
"""
(p_fp16) = cast(p)
(p_fp16_recompute) = cast(p)
(pout,) = adam(p)
===========================>
rename(p_fp16_recompute, p_fp16)
(p,) = prefetch(p@offload)
(pout,) = adam(p)
(p_fp16) = cast(p)
(p@offload) = memcpy(p)
"""
param_to_idx = dict()
param_to_fp16 = dict()
# recompute_var which need rename to fp16_param
fp16_param_to_recompute = dict()
recompute_to_fp16 = dict()
def remove_param(input_name):
param_to_idx.pop(input_name)
if input_name in param_to_fp16:
fp16_param = param_to_fp16.pop(input_name)
if fp16_param in fp16_param_to_recompute:
recompute = fp16_param_to_recompute.pop(fp16_param)
recompute_to_fp16.pop(recompute)
# step1: record param
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'):
param = op.desc.input("Param")[0]
param_to_idx[param] = idx
# step2: remove param which can't offload
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
break
for input_name in op.desc.input_arg_names():
if input_name not in param_to_idx:
continue
# param is real used by fp32 op
if op.type != 'cast':
remove_param(input_name)
continue
# param is only used by cast op,
# which to cast fp32_param to fp16_param
output_name = op.output_arg_names[0]
if 'cast_fp16' not in output_name:
remove_param(input_name)
continue
if 'subprog' not in output_name:
assert output_name == input_name + '.cast_fp16'
assert input_name not in param_to_fp16, \
"There must be only one cast op from fp32 param to fp16 param."
param_to_fp16[input_name] = output_name
else:
# fp16-->recompute_var
assert input_name in param_to_fp16, \
"param must first be cast to fp16"
fp16_param = param_to_fp16[input_name]
fp16_param_to_recompute[fp16_param] = output_name
recompute_to_fp16[output_name] = fp16_param
param_name_to_offload_name = dict()
# step3: main_block add offload, cast op
# change recompute to fp16, remove cast(param) to fp16
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'):
param = op.desc.input("Param")[0]
if param not in param_to_idx: continue
# step3.1: create offload_var
offload_var_name = self._get_offload_var_name(param)
param_name_to_offload_name[param] = offload_var_name
self._create_offload_var(param, offload_var_name,
[block, startup_block])
# step3.2: insert cast op and offload op
self._insert_offload_op(block, idx + 1, param, offload_var_name)
assert param in param_to_fp16
fp16_param_name = param_to_fp16[param]
fp16_param_var = block.var(fp16_param_name)
fp16_param_var.persistable = True
self._insert_cast_op(block, idx + 1, param,
param_to_fp16[param])
# step3.3: insert fetch op
self._insert_fetch_op(block, idx, offload_var_name, param)
continue
# step3.4: remove cast op
if op.type == 'cast':
input_name = op.desc.input_arg_names()[0]
if input_name in param_to_idx:
block._remove_op(idx, sync=False)
continue
# step3.5: change recompute_param to fp16_param
for input_name in op.desc.input_arg_names():
if input_name in recompute_to_fp16:
op._rename_input(input_name, recompute_to_fp16[input_name])
for output_name in op.desc.output_arg_names():
if output_name in recompute_to_fp16:
op._rename_output(output_name,
recompute_to_fp16[output_name])
# step4: remove recompute_param
for name in recompute_to_fp16.keys():
block._remove_var(name, sync=False)
# step5: startup_block add offload
visited_vars = set()
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
continue
if out_name in param_name_to_offload_name:
var_name = out_name
offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1, var_name,
offload_var_name)
self._insert_cast_op(startup_block, idx + 1, var_name,
param_to_fp16[var_name])
visited_vars.add(out_name)
block._sync_with_cpp()
startup_block._sync_with_cpp()
def offload(self, block, startup_block):
"""
(m1, m2) = prefetch(m1@offload, m2@offload)
(m1out, m2out, pout) = adam(m1, m2, p)
(m1@offload, m2@offload) = memcpy(m1, m2)
"""
vars_name_to_offload_name = dict()
# main_block add offload
for idx, op in reversed(list(enumerate(block.ops))):
if not is_optimizer_op(op):
break
vars_name = []
if op.type == "adam":
# {Moment1Out = [''], Moment2Out = [''], ParamOut = ['']} =
# adam(inputs={Moment1 = [''], Moment2 = [''], Param = ['']})
vars_name.append(op.desc.input("Moment1")[0])
vars_name.append(op.desc.input("Moment2")[0])
elif op.type == 'momentum':
pass
elif op.type == 'lars':
pass
elif op.type == 'lamb':
pass
# step1: create and init offload_var
for var_name in vars_name:
assert var_name not in vars_name_to_offload_name
offload_var_name = self._get_offload_var_name(var_name)
vars_name_to_offload_name[var_name] = offload_var_name
self._create_offload_var(var_name, offload_var_name,
[block, startup_block])
# step2: insert offload op
for var_name in vars_name:
offload_var_name = vars_name_to_offload_name[var_name]
self._insert_offload_op(block, idx + 1, var_name,
offload_var_name)
# step3: insert fetch op
for var_name in vars_name:
offload_var_name = vars_name_to_offload_name[var_name]
self._insert_fetch_op(block, idx, offload_var_name, var_name)
# startup_block add offload
visited_vars = set()
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
continue
if out_name in vars_name_to_offload_name:
var_name = out_name
offload_var_name = vars_name_to_offload_name[var_name]
# insert offload op after var is generated
self._insert_offload_op(startup_block, idx + 1, var_name,
offload_var_name)
visited_vars.add(out_name)
block._sync_with_cpp()
startup_block._sync_with_cpp()
...@@ -126,6 +126,9 @@ class ProgramDeps(object): ...@@ -126,6 +126,9 @@ class ProgramDeps(object):
def should_remove_op(self, op_idx): def should_remove_op(self, op_idx):
op = self._block.ops[op_idx] op = self._block.ops[op_idx]
# remove check_finite_and_unscale op if its input 'X' is empty
if op.type == 'check_finite_and_unscale' and len(op.input('X')) == 0:
return True
for output_name in op.desc.output_arg_names(): for output_name in op.desc.output_arg_names():
if output_name not in self._should_removed_var: if output_name not in self._should_removed_var:
return False return False
......
...@@ -28,21 +28,24 @@ def check_broadcast(block): ...@@ -28,21 +28,24 @@ def check_broadcast(block):
if the broadcasted var has a fill_constant op, the fill_constant if the broadcasted var has a fill_constant op, the fill_constant
op should stay forward before the broadcast op, and before a op should stay forward before the broadcast op, and before a
sync_calc op. Otherwise, raise error. sync_calc op. Otherwise, raise error.
should ignore and skip broadcast_op of inner_parallelism (e.g. Megatron)
""" """
broadcast_vars = {} broadcast_vars = {}
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if op.type == "c_broadcast": if op.type == "c_broadcast":
var_name = op.desc.input_arg_names()[0] if op.all_attrs()["use_calc_stream"] == False:
if "@BroadCast" in var_name: var_name = op.desc.input_arg_names()[0]
if var_name in broadcast_vars: if "@BroadCast" in var_name:
raise ValueError("var_name areadly exist: {}" if var_name in broadcast_vars:
"the old pos is {}, the new pos is {}". raise ValueError("var_name areadly exist: {}"
format(var_name, broadcast_vars[var_name][ "the old pos is {}, the new pos is {}".
"broadcast_pos"], idx)) format(var_name, broadcast_vars[
broadcast_vars[var_name] = { var_name]["broadcast_pos"], idx))
"fill_constant_pos": -1, broadcast_vars[var_name] = {
"broadcast_pos": idx, "fill_constant_pos": -1,
} "broadcast_pos": idx,
}
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if op.type == "fill_constant": if op.type == "fill_constant":
...@@ -61,14 +64,15 @@ def check_broadcast(block): ...@@ -61,14 +64,15 @@ def check_broadcast(block):
last_sync_calc_op_idx = idx last_sync_calc_op_idx = idx
continue continue
if op.type == "c_broadcast": if op.type == "c_broadcast":
var_name = op.desc.input_arg_names()[0] if op.all_attrs()["use_calc_stream"] == False:
if "@BroadCast" in var_name: var_name = op.desc.input_arg_names()[0]
if broadcast_vars[var_name]["fill_constant_pos"] != -1: if "@BroadCast" in var_name:
assert (last_sync_calc_op_idx != -1) if broadcast_vars[var_name]["fill_constant_pos"] != -1:
assert (broadcast_vars[var_name]["fill_constant_pos"] < assert (last_sync_calc_op_idx != -1)
last_sync_calc_op_idx) assert (broadcast_vars[var_name]["fill_constant_pos"] <
assert (last_sync_calc_op_idx < idx) last_sync_calc_op_idx)
continue assert (last_sync_calc_op_idx < idx)
continue
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if input_name in broadcast_vars: if input_name in broadcast_vars:
assert (broadcast_vars[input_name]["broadcast_pos"] != -1) assert (broadcast_vars[input_name]["broadcast_pos"] != -1)
...@@ -78,43 +82,47 @@ def check_broadcast(block): ...@@ -78,43 +82,47 @@ def check_broadcast(block):
return return
def check_allreduce_sum(block, shard, dp_ring_id=-1): def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1):
""" """
the op order should be: the op order should be:
grad: grad:
- 0: op that generate Var - 0: op that generate Var
- 1: sync_calc - 1: sync_calc
- 2: allreduce_sum_sharding - 2: reduce_sum_sharding (allreduce --> reduce)
- 3: sync_comm - 3: sync_comm
- 4: allreuce_sum_dp (dp_grads) - 4: allreuce_sum_dp (dp_grads)
- 5: sync_comm (dp_grads) - 5: sync_comm (dp_grads)
- 6: op that use Var (dp_grads & sum) - 6: op that use Var (dp_grads & sum)
should ignore and skip allreduce_op of inner_parallelism (e.g. Megatron)
""" """
vars_status = {} vars_status = {}
dp_grads_status = {} dp_grads_status = {}
idx_last_grad_allreduce = -1 idx_last_grad_allreduce = -1
idx_amp_allreduce = -1 idx_amp_allreduce = -1
idx_gradient_clip_allreduce = -1 idx_gradient_clip_allreduce = -1
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if op.type == "c_allreduce_sum": if op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
ring_id = op.desc.attr("ring_id") if op.all_attrs()["use_calc_stream"] == False:
var_name = op.desc.input_arg_names()[0] ring_id = op.desc.attr("ring_id")
param = var_name.split("@")[0] var_name = op.desc.input_arg_names()[0]
param = var_name.split("@")[0]
assert 'sum' in var_name or ("@GRAD" in var_name) assert 'sum' in var_name or ("@GRAD" in var_name)
if 'sum' in var_name or (not shard.has_param(param)): if 'sum' in var_name or (not shard.has_param(param)):
vars_status[var_name] = -1 vars_status[var_name] = -1
else: else:
dp_grads_status[var_name] = -1 dp_grads_status[var_name] = -1
if ring_id != 0: if ring_id != sharding_ring_id:
assert shard.has_param(param) assert shard.has_param(param)
assert ring_id == dp_ring_id assert ring_id == dp_ring_id
if "sum" in var_name: if "sum" in var_name:
idx_amp_allreduce = idx idx_amp_allreduce = idx
elif "@GRAD": elif "@GRAD":
idx_last_grad_allreduce = idx idx_last_grad_allreduce = idx
if op.type == "c_allreduce_max": if op.type == "c_allreduce_max":
idx_gradient_clip_allreduce = idx idx_gradient_clip_allreduce = idx
...@@ -129,37 +137,40 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1): ...@@ -129,37 +137,40 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
var_name] == 0: var_name] == 0:
dp_grads_status[var_name] = 1 dp_grads_status[var_name] = 1
elif op.type == "c_allreduce_sum": elif op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
var_name = op.desc.input_arg_names()[0] if op.all_attrs()["use_calc_stream"] == False:
ring_id = op.desc.attr("ring_id") var_name = op.desc.input_arg_names()[0]
if ring_id == 0: ring_id = op.desc.attr("ring_id")
if var_name in vars_status: if ring_id == sharding_ring_id:
_status = vars_status[var_name] assert op.type == "c_reduce_sum", "Grad in Sharding group should be reduce rather than allreduce"
else: if var_name in vars_status:
_status = dp_grads_status[var_name] _status = vars_status[var_name]
if _status == -1: else:
raise ValueError("{} is not generated, but you are" _status = dp_grads_status[var_name]
"trying to all-reduce it".format(var_name)) if _status == -1:
if _status == 0: raise ValueError("{} is not generated, but you are"
raise ValueError("There should be a sync_calc op " "trying to all-reduce it".format(
"after generate Var: {} and before the" var_name))
"c_allreduce_sum op".format(var_name)) if _status == 0:
assert (_status == 1) raise ValueError("There should be a sync_calc op "
if var_name in vars_status: "after generate Var: {} and before the"
vars_status[var_name] = 2 "c_allreduce_sum op".format(var_name))
assert (_status == 1)
if var_name in vars_status:
vars_status[var_name] = 2
else:
dp_grads_status[var_name] = 2
else: else:
dp_grads_status[var_name] = 2 assert ring_id == dp_ring_id
else: param = var_name.split("@")[0]
assert ring_id == dp_ring_id assert shard.has_param(param)
param = var_name.split("@")[0] assert dp_grads_status[var_name] == 3
assert shard.has_param(param) dp_grads_status[var_name] = 4
assert dp_grads_status[var_name] == 3
dp_grads_status[var_name] = 4
elif op.type == "c_sync_comm_stream": elif op.type == "c_sync_comm_stream":
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id") ring_id = op.desc.attr("ring_id")
if ring_id == 0: if ring_id == sharding_ring_id:
for var_name in op.desc.input_arg_names(): for var_name in op.desc.input_arg_names():
if var_name in vars_status: if var_name in vars_status:
assert vars_status[var_name] == 2 assert vars_status[var_name] == 2
...@@ -181,6 +192,9 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1): ...@@ -181,6 +192,9 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
raise ValueError("There should be a sync_comm op " raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".format( "after allreduce the Var: {}".format(
input_name)) input_name))
raise ValueError(
"The reduce output grad [{}] should NOT be be used in Non-root rank.".
format(input_name))
if input_name in dp_grads_status: if input_name in dp_grads_status:
if dp_ring_id == -1: if dp_ring_id == -1:
if dp_grads_status[input_name] != 3: if dp_grads_status[input_name] != 3:
...@@ -225,6 +239,13 @@ def get_valid_op_role(block, insert_idx): ...@@ -225,6 +239,13 @@ def get_valid_op_role(block, insert_idx):
return get_valid_op_role(block, insert_idx + 1) return get_valid_op_role(block, insert_idx + 1)
# if insert_idx >= len(block.ops): return OpRole.Optimize
# if op_role == int(OpRole.Backward): return OpRole.Backward
# if op_role == int(OpRole.Optimize): return OpRole.Optimize
# if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
# return OpRole.Forward
# return get_valid_op_role(block, insert_idx + 1)
def insert_sync_calc_op(block, insert_idx, calc_dep_vars): def insert_sync_calc_op(block, insert_idx, calc_dep_vars):
""" """
...@@ -259,6 +280,9 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars): ...@@ -259,6 +280,9 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
""" """
insert sync_comm_op for vars insert sync_comm_op for vars
""" """
if len(comm_dep_vars) == 0:
return 0
op_role = get_valid_op_role(block, insert_idx) op_role = get_valid_op_role(block, insert_idx)
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
...@@ -313,6 +337,9 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars): ...@@ -313,6 +337,9 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
""" """
_add_allreduce_ops _add_allreduce_ops
""" """
if len(allreduce_vars) == 0:
return
for var in allreduce_vars: for var in allreduce_vars:
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
...@@ -325,6 +352,62 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars): ...@@ -325,6 +352,62 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
return return
def get_grad_device(grad_name, shard):
assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
grad_name)
base_name = None
# mind the traversal order
possible_suffixes = [
'.cast_fp16@GRAD@MERGED', '.cast_fp16@GRAD', '@GRAD@MERGED', '@GRAD'
]
for suffix in possible_suffixes:
if suffix in grad_name:
base_name = re.sub(suffix, '', grad_name)
break
assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(
base_name)
return shard.global_param2device[base_name]
def insert_reduce_ops(block,
insert_idx,
ring_id,
reduce_vars,
shard,
op_role=OpRole.Backward,
use_calc_stream=False):
"""
_add_allreduce_ops
"""
for var in reduce_vars:
root_id = get_grad_device(var, shard)
assert root_id >= 0, "root id should be a positive int".format(var)
block._insert_op_without_sync(
insert_idx,
type='c_reduce_sum',
inputs={'X': var},
outputs={'Out': var},
attrs={
'ring_id': ring_id,
'root_id': root_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
return
def get_first_check_finite_and_unscale_op_idx(block):
for idx, op in enumerate(block.ops):
if op.type == "check_finite_and_unscale":
return idx
raise ValueError("check_finite_and_unscale does not exist in block")
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
""" """
_add_broadcast_ops _add_broadcast_ops
...@@ -428,7 +511,7 @@ def comm_analyse(main_program): ...@@ -428,7 +511,7 @@ def comm_analyse(main_program):
count)) count))
def add_sync_comm(program, dist_strategy): def add_sync_comm(program, nccl_ids):
""" """
When clone a test prog by clone from the sharding main prog, When clone a test prog by clone from the sharding main prog,
part of the sync_comm op maybe be pruned by mistake, this function part of the sync_comm op maybe be pruned by mistake, this function
...@@ -438,6 +521,9 @@ def add_sync_comm(program, dist_strategy): ...@@ -438,6 +521,9 @@ def add_sync_comm(program, dist_strategy):
#NOTE (liangjianzhong): only support one comm stream by now, use more than one #NOTE (liangjianzhong): only support one comm stream by now, use more than one
# comm streams will cause error. should be revise in future. # comm streams will cause error. should be revise in future.
assert isinstance(
nccl_ids, list
), "the second argument of this function should be a list of nccl_ids"
block = program.global_block() block = program.global_block()
not_sync_vars = set([]) not_sync_vars = set([])
for op in block.ops: for op in block.ops:
...@@ -448,7 +534,7 @@ def add_sync_comm(program, dist_strategy): ...@@ -448,7 +534,7 @@ def add_sync_comm(program, dist_strategy):
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
not_sync_vars.remove(input_name) not_sync_vars.remove(input_name)
if not_sync_vars: if not_sync_vars:
for nccl_id in range(dist_strategy.nccl_comm_num): for nccl_id in nccl_ids:
block.append_op( block.append_op(
type='c_sync_comm_stream', type='c_sync_comm_stream',
inputs={'X': list(not_sync_vars)}, inputs={'X': list(not_sync_vars)},
...@@ -467,6 +553,9 @@ def save_persistables(exe, dirname, main_program, filename=None): ...@@ -467,6 +553,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
This function handles the model saving for sharding training. This function handles the model saving for sharding training.
""" """
if main_program._pipeline_opt:
main_program = main_program._pipeline_opt['section_program']['program']
def is_opt_vars(var): def is_opt_vars(var):
# NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer # NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer
# now only Momentum and adam are compatible with sharding # now only Momentum and adam are compatible with sharding
......
...@@ -16,14 +16,16 @@ from paddle.fluid import unique_name, core ...@@ -16,14 +16,16 @@ from paddle.fluid import unique_name, core
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op
from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase
from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper import WeightDecayHelper from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper import WeightDecayHelper
from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper
from .sharding.offload_helper import OffloadHelper
from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps
from paddle.distributed.fleet.meta_optimizers.sharding.utils import * from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
import logging import logging
from functools import reduce from functools import reduce
...@@ -31,6 +33,8 @@ __all__ = ["ShardingOptimizer"] ...@@ -31,6 +33,8 @@ __all__ = ["ShardingOptimizer"]
class ShardingOptimizer(MetaOptimizerBase): class ShardingOptimizer(MetaOptimizerBase):
"""Sharding Optimizer."""
def __init__(self, optimizer): def __init__(self, optimizer):
super(ShardingOptimizer, self).__init__(optimizer) super(ShardingOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer self.inner_opt = optimizer
...@@ -39,6 +43,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -39,6 +43,8 @@ class ShardingOptimizer(MetaOptimizerBase):
"AMPOptimizer", "AMPOptimizer",
"LarsOptimizer", "LarsOptimizer",
"LambOptimizer", "LambOptimizer",
# "ModelParallelOptimizer",
"PipelineOptimizer",
] ]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
self._main_program = None self._main_program = None
...@@ -51,6 +57,10 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -51,6 +57,10 @@ class ShardingOptimizer(MetaOptimizerBase):
self._reduced_grads_to_param = {} self._reduced_grads_to_param = {}
self._shard = Shard() self._shard = Shard()
# use sharding as outer parallelism (e.g. inner:Megatron & outer sharding)
self._as_outer_parallelism = False
self._inner_parallelism_size = None
def _can_apply(self): def _can_apply(self):
if not self.role_maker._is_collective: if not self.role_maker._is_collective:
return False return False
...@@ -71,6 +81,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -71,6 +81,7 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_program=None, startup_program=None,
parameter_list=None, parameter_list=None,
no_grad_set=None): no_grad_set=None):
"""Implementation of minimize."""
# TODO: (JZ-LIANG) support multiple comm in future # TODO: (JZ-LIANG) support multiple comm in future
# self._nrings = self.user_defined_strategy.nccl_comm_num # self._nrings = self.user_defined_strategy.nccl_comm_num
self._nrings_sharding = 1 self._nrings_sharding = 1
...@@ -79,20 +90,72 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -79,20 +90,72 @@ class ShardingOptimizer(MetaOptimizerBase):
"fuse_broadcast_MB"] "fuse_broadcast_MB"]
self.hybrid_dp = self.user_defined_strategy.sharding_configs[ self.hybrid_dp = self.user_defined_strategy.sharding_configs[
"hybrid_dp"] "hybrid_dp"]
self._as_outer_parallelism = self.user_defined_strategy.sharding_configs[
"as_outer_parallelism"]
self._inner_parallelism_size = int(
self.user_defined_strategy.sharding_configs["parallelism"])
self.use_pipeline = self.user_defined_strategy.sharding_configs[
"use_pipeline"]
self.acc_steps = self.user_defined_strategy.sharding_configs[
"acc_steps"]
self.schedule_mode = self.user_defined_strategy.sharding_configs[
"schedule_mode"]
self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"]
self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[
"pp_allreduce_in_optimize"]
if self.inner_opt is None: if self.inner_opt is None:
raise ValueError( raise ValueError(
"self.inner_opt of ShardingOptimizer should not be None.") "self.inner_opt of ShardingOptimizer should not be None.")
optimize_ops, params_grads = self.inner_opt.minimize( if self.use_pipeline:
loss, startup_program, parameter_list, no_grad_set) pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt,
self.acc_steps)
main_program = loss.block.program
main_program._pipeline_opt = dict()
main_program._pipeline_opt['schedule_mode'] = self.schedule_mode
main_program._pipeline_opt['pp_bz'] = self.pp_bz
pp_rank = self.role_maker._worker_index() // (
self.user_defined_strategy.sharding_configs[
'sharding_group_size'] * self._inner_parallelism_size)
main_program._pipeline_opt['local_rank'] = pp_rank
main_program._pipeline_opt[
'global_rank'] = self.role_maker._worker_index()
main_program._pipeline_opt['use_sharding'] = True
main_program._pipeline_opt['ring_id'] = 20
optimize_ops, params_grads, program_list, self.pipeline_pair, self.pp_ring_map = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set)
self.pipeline_nodes = len(program_list)
else:
optimize_ops, params_grads = self.inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set)
if startup_program is None: if startup_program is None:
startup_program = default_startup_program() startup_program = default_startup_program()
main_block = loss.block if self.use_pipeline:
startup_program = startup_program._pipeline_opt['startup_program']
#main_program = main_program._pipeline_opt['section_program']['program']
print("pp_rank:", pp_rank)
main_program = program_list[pp_rank]['program']
with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
f.writelines(str(main_program))
main_block = main_program.global_block()
new_params_grads = []
for param, grad in params_grads:
if main_block.has_var(param.name):
new_params_grads.append((param, grad))
params_grads = new_params_grads
else:
main_block = loss.block
startup_block = startup_program.global_block() startup_block = startup_program.global_block()
self._main_program = main_block.program self._main_program = main_block.program
self._startup_program = startup_program self._startup_program = startup_program
if self.use_pipeline:
pp_optimizer._rename_gradient_var_name(main_block)
with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
f.writelines(str(main_program))
# step1: set_up # step1: set_up
self._set_up(params_grads) self._set_up(params_grads)
...@@ -105,17 +168,76 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -105,17 +168,76 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
# step4: insert reduce_sum for grad # step4: insert reduce_sum for grad
insert_scale_loss_grad_ops( # grad_scale_coeff = self.role_maker._worker_num()
main_block, scale=1.0 / self.role_maker._worker_num()) # if self._as_outer_parallelism:
# grad_scale_coeff = grad_scale_coeff / self._inner_parallelism_size
# insert_scale_loss_grad_ops(main_block, scale=1.0 / grad_scale_coeff)
sharding_group_size = self.user_defined_strategy.sharding_configs[
'sharding_group_size']
insert_scale_loss_grad_ops(main_block, scale=1.0 / sharding_group_size)
main_block._sync_with_cpp() main_block._sync_with_cpp()
# step5: remove unneeded ops and vars from block # step5: remove unneeded ops and vars from block
self._prune_main_program(main_block) self._prune_main_program(main_block)
self._prune_startup_program(startup_block) self._prune_startup_program(startup_block)
if self.hybrid_dp:
self._initialization_broadcast(startup_program)
if self.use_pipeline:
# pp_optimizer._rename_gradient_var_name(main_block)
# crop ops
for idx, op in reversed(list(enumerate(main_block.ops))):
if is_update_op(op):
op_role_var = op.attr('op_role_var')
param_name = op_role_var[0]
if not self._shard.has_param(param_name):
main_block._remove_op(idx)
for idx, op in reversed(list(enumerate(main_block.ops))):
if op.type != 'cast': continue
in_name = op.input_arg_names[0]
if in_name not in self._params: continue
#if self._shard.has_param(param_name): continue
if in_name not in main_block.vars:
main_block._remove_op(idx)
accumulated_grad_names = pp_optimizer._accumulate_gradients(
main_block)
# accumulated_grad_names = sorted(accumulated_grad_names)
if self.pp_allreduce_in_optimize:
print("persistable FP32 grad: ")
print(accumulated_grad_names)
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block)
insert_reduce_ops(
main_block,
first_optimize_op_index,
self.sharding_ring_id,
accumulated_grad_names,
self._shard,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True)
main_block._sync_with_cpp()
# TODO(wangxi): add optimize offload
if self.optimize_offload:
logging.info("Sharding with optimize offload !")
offload_helper = OffloadHelper()
offload_helper.offload(main_block, startup_block)
offload_helper.offload_fp32param(main_block, startup_block)
with open("start_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(startup_block.program))
with open("main_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(main_block.program))
# check op dependecy # check op dependecy
check_broadcast(main_block) check_broadcast(main_block)
check_allreduce_sum(main_block, self._shard, self.dp_ring_id) #check_allreduce_sum(main_block, self._shard, self.sharding_ring_id,
# self.dp_ring_id)
#check_allreduce_sum(main_block, self._shard, self.dp_ring_id)
self._wait() self._wait()
return optimize_ops, params_grads return optimize_ops, params_grads
...@@ -129,16 +251,72 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -129,16 +251,72 @@ class ShardingOptimizer(MetaOptimizerBase):
self._nrings_sharding) self._nrings_sharding)
# config sharding & dp groups # config sharding & dp groups
self._init_comm() self._init_comm()
# global
if self._as_outer_parallelism:
print("global_group_endpoints:", self.global_group_endpoints)
print("global_rank:", self.global_rank)
print("global_ring_id:", self.global_group_id)
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.global_group_endpoints, self.global_rank,
self.global_group_id, False)
if self._as_outer_parallelism:
print("mp_group_endpoints:", self.mp_group_endpoints)
print("mp_rank:", self.mp_rank)
print("mp_ring_id:", self.mp_group_id)
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.mp_group_endpoints, self.mp_rank, self.mp_group_id, False)
# sharding # sharding
print("sharding_group_endpoints:", self.sharding_group_endpoints)
print("sharding_rank:", self.sharding_rank)
print("sharding_ring_id:", self.sharding_ring_id)
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint, self._startup_program, self.current_endpoint,
self.sharding_group_endpoints, self.sharding_rank, self.sharding_group_endpoints, self.sharding_rank,
self.sharding_ring_id, True) self.sharding_ring_id, False)
# dp # dp
if self.hybrid_dp: if self.hybrid_dp:
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint, self._startup_program, self.current_endpoint,
self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, True) self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, False)
# pp
if self.use_pipeline:
print("pp_group_endpoints:", self.pp_group_endpoints)
print("pp_rank:", self.pp_rank)
print("pp_ring_id:", self.pp_ring_id)
if self.schedule_mode == 0: # GPipe
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.pp_group_endpoints, self.pp_rank, self.pp_ring_id,
False)
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.pp_group_endpoints, self.pp_rank, self.pp_ring_id + 2,
False)
else:
for pair in self.pipeline_pair:
pair_key = pair[0] * 1000 + pair[1]
ring_id = self.pp_ring_map[pair_key]
print("pp pair:{}, ring_id: {}".format(pair, ring_id))
if self.pp_rank not in pair: continue
pp_group_endpoints = [
self.pp_group_endpoints[pair[0]],
self.pp_group_endpoints[pair[1]],
]
if pair[0] < pair[1]:
start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1
else:
start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[
1] - 1
pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
pp_group_endpoints, pp_rank, ring_id, False, False)
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
...@@ -153,10 +331,35 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -153,10 +331,35 @@ class ShardingOptimizer(MetaOptimizerBase):
self._main_program.global_block()) self._main_program.global_block())
def _wait(self, ): def _wait(self, ):
endpoints = self.role_maker._get_trainer_endpoints() # only the first parallelsm group that init nccl need to be wait.
current_endpoint = endpoints[self.role_maker._worker_index()] if self._as_outer_parallelism:
if self.role_maker._worker_index() == 0: endpoints = self.role_maker._get_trainer_endpoints()
self._collective_helper._wait(current_endpoint, endpoints) current_endpoint = endpoints[self.role_maker._worker_index()]
else:
endpoints = self.sharding_group_endpoints[:]
current_endpoint = self.sharding_group_endpoints[self.sharding_rank]
if self._as_outer_parallelism:
if self.role_maker._worker_index() == 0:
self._collective_helper._wait(current_endpoint, endpoints)
else:
if self.sharding_rank == 0:
self._collective_helper._wait(current_endpoint, endpoints)
# def _wait(self, ):
# # only the first parallelsm group that init nccl need to be wait.
# if self._as_outer_parallelism:
# endpoints = self.role_maker._get_trainer_endpoints()
# else:
# endpoints = self.sharding_group_endpoints[:]
# current_endpoint = endpoints[self.role_maker._worker_index()]
# if self._as_outer_parallelism:
# if self.role_maker._worker_index() == 0:
# self._collective_helper._wait(current_endpoint, endpoints)
# else:
# if self.sharding_rank == 0:
# self._collective_helper._wait(current_endpoint, endpoints)
def _split_program(self, block): def _split_program(self, block):
for op_idx, op in reversed(list(enumerate(block.ops))): for op_idx, op in reversed(list(enumerate(block.ops))):
...@@ -197,17 +400,22 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -197,17 +400,22 @@ class ShardingOptimizer(MetaOptimizerBase):
self._main_program.global_block().var(input_name)) self._main_program.global_block().var(input_name))
# find reduce vars # find reduce vars
if is_backward_op(op) and \ if self.use_pipeline and self.pp_allreduce_in_optimize:
OP_ROLE_VAR_KEY in op.attr_names: # place pipeline gradient allreduce in optimize
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] pass
if len(op_role_var) != 0: else:
assert len(op_role_var) % 2 == 0 if is_backward_op(op) and \
for i in range(0, len(op_role_var), 2): OP_ROLE_VAR_KEY in op.attr_names:
param, reduced_grad = op_role_var[i], op_role_var[i + 1] op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
segment._allreduce_vars.append(reduced_grad) if len(op_role_var) != 0:
assert ( assert len(op_role_var) % 2 == 0
reduced_grad not in self._reduced_grads_to_param) for i in range(0, len(op_role_var), 2):
self._reduced_grads_to_param[reduced_grad] = param param, reduced_grad = op_role_var[i], op_role_var[
i + 1]
segment._allreduce_vars.append(reduced_grad)
#assert (
# reduced_grad not in self._reduced_grads_to_param)
self._reduced_grads_to_param[reduced_grad] = param
# find cast op # find cast op
if FP16Utils.is_fp16_cast_op(block, op, self._params): if FP16Utils.is_fp16_cast_op(block, op, self._params):
...@@ -234,9 +442,14 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -234,9 +442,14 @@ class ShardingOptimizer(MetaOptimizerBase):
""" """
weightdecay_helper = WeightDecayHelper() weightdecay_helper = WeightDecayHelper()
weightdecay_helper.prune_weight_decay(block, self._shard) weightdecay_helper.prune_weight_decay(block, self._shard)
# NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism
# group. and each Data Parallelism group should have its own sync of FoundInfinite
Model_Paramllelism_ring_id = self.sharding_ring_id
if self._as_outer_parallelism:
Model_Paramllelism_ring_id = self.global_group_id
FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param, FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param,
self.sharding_ring_id) Model_Paramllelism_ring_id)
gradientclip_helper = GradientClipHelper(self.sharding_ring_id) gradientclip_helper = GradientClipHelper(Model_Paramllelism_ring_id)
gradientclip_helper.prune_gradient_clip(block, self._shard) gradientclip_helper.prune_gradient_clip(block, self._shard)
# build prog deps # build prog deps
...@@ -264,8 +477,13 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -264,8 +477,13 @@ class ShardingOptimizer(MetaOptimizerBase):
# Prune # Prune
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if op.type in [ if op.type in [
"c_allreduce_sum", "c_sync_comm_stream", "c_allreduce_sum",
"c_calc_comm_stream", "c_gen_nccl_id", "c_comm_init, c_comm_init_hcom" "c_sync_comm_stream",
"c_calc_comm_stream",
"c_gen_nccl_id",
"c_comm_init",
'send_v2',
'recv_v2',
]: ]:
pass pass
elif op.type == "conditional_block": elif op.type == "conditional_block":
...@@ -303,15 +521,41 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -303,15 +521,41 @@ class ShardingOptimizer(MetaOptimizerBase):
program_deps.remove_op(idx) program_deps.remove_op(idx)
block._sync_with_cpp() block._sync_with_cpp()
for idx, op in reversed(list(enumerate(block.ops))):
if op.type == 'concat' and is_optimizer_op(op):
# remove inputs that not on this card
reserved_x = []
for var_name in op.desc.input("X"):
if block.has_var(var_name): reserved_x.append(var_name)
op.desc.set_input('X', reserved_x)
block._sync_with_cpp()
return return
def _add_broadcast_allreduce(self, block): def _add_broadcast_allreduce(self, block):
""" """
_add_broadcast_allreduce _add_broadcast_allreduce
if combined with pipeline(grad accumulate),
the grad allreduce should be done in optimize role
""" """
if len(self._segments) < 1: if len(self._segments) < 1:
return return
# sharding # sharding
if self.use_pipeline and self.pp_allreduce_in_optimize:
for idx in range(len(self._segments)):
assert len(self._segments[idx]._allreduce_vars) == 0
# fix the _end_idx for segments[-1] if pp is used.
new_end_idx = self._segments[-1]._end_idx
for idx in range(self._segments[-1]._end_idx - 1,
self._segments[-1]._start_idx - 1, -1):
op = block.ops[idx]
if op.type == "fill_constant" or op.type == "sum":
if "MERGED" in op.output_arg_names[0]: new_end_idx = idx + 1
elif op.type == "cast":
if "@TMP" in op.output_arg_names[0]: new_end_idx = idx + 1
self._segments[-1]._end_idx = new_end_idx
if self._segments[-1]._allreduce_vars: if self._segments[-1]._allreduce_vars:
shard_allredue_vars = self._shard.filter_grads(self._segments[-1] shard_allredue_vars = self._shard.filter_grads(self._segments[-1]
._allreduce_vars) ._allreduce_vars)
...@@ -323,9 +567,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -323,9 +567,15 @@ class ShardingOptimizer(MetaOptimizerBase):
insert_sync_comm_ops(block, self._segments[-1]._end_idx, insert_sync_comm_ops(block, self._segments[-1]._end_idx,
self.sharding_ring_id, self.sharding_ring_id,
self._segments[-1]._allreduce_vars) self._segments[-1]._allreduce_vars)
insert_allreduce_ops(block, self._segments[-1]._end_idx, # allreduce --> reduce
self.sharding_ring_id, insert_reduce_ops(
self._segments[-1]._allreduce_vars) block,
self._segments[-1]._end_idx,
self.sharding_ring_id,
self._segments[-1]._allreduce_vars,
self._shard,
op_role=OpRole.Backward,
use_calc_stream=False)
for idx, segment in reversed(list(enumerate(self._segments))): for idx, segment in reversed(list(enumerate(self._segments))):
allreduce_vars = self._segments[ allreduce_vars = self._segments[
...@@ -391,6 +641,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -391,6 +641,7 @@ class ShardingOptimizer(MetaOptimizerBase):
fill_constant_vars) fill_constant_vars)
# step4: add `cast` ops # step4: add `cast` ops
print("cast_ops:", cast_ops)
insert_cast_ops(block, segment._end_idx, cast_ops) insert_cast_ops(block, segment._end_idx, cast_ops)
# step5: add broadcast ops # step5: add broadcast ops
...@@ -404,8 +655,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -404,8 +655,15 @@ class ShardingOptimizer(MetaOptimizerBase):
insert_sync_comm_ops(block, segment._start_idx, insert_sync_comm_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars) self.sharding_ring_id, allreduce_vars)
# sharding # sharding
insert_allreduce_ops(block, segment._start_idx, # allreduce --> reduce
self.sharding_ring_id, allreduce_vars) insert_reduce_ops(
block,
segment._start_idx,
self.sharding_ring_id,
allreduce_vars,
self._shard,
op_role=OpRole.Backward,
use_calc_stream=False)
block._sync_with_cpp() block._sync_with_cpp()
...@@ -459,6 +717,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -459,6 +717,7 @@ class ShardingOptimizer(MetaOptimizerBase):
def _init_comm(self): def _init_comm(self):
if self.hybrid_dp: if self.hybrid_dp:
assert self._as_outer_parallelism == False, "hybrid dp is conflict when using sharding as outer parallelism"
self.sharding_group_size = self.user_defined_strategy.sharding_configs[ self.sharding_group_size = self.user_defined_strategy.sharding_configs[
"sharding_group_size"] "sharding_group_size"]
self.sharding_ring_id = 0 self.sharding_ring_id = 0
...@@ -476,6 +735,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -476,6 +735,7 @@ class ShardingOptimizer(MetaOptimizerBase):
ep for idx, ep in enumerate(self.endpoints) ep for idx, ep in enumerate(self.endpoints)
if (idx % self.sharding_group_size) == self.sharding_rank if (idx % self.sharding_group_size) == self.sharding_rank
] ]
assert self.global_word_size > self.sharding_group_size, \ assert self.global_word_size > self.sharding_group_size, \
"global_word_size: {} should be larger than sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size) "global_word_size: {} should be larger than sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size)
assert self.global_word_size % self.sharding_group_size == 0, \ assert self.global_word_size % self.sharding_group_size == 0, \
...@@ -485,30 +745,215 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -485,30 +745,215 @@ class ShardingOptimizer(MetaOptimizerBase):
self.global_word_size, self.global_word_size,
self.sharding_group_size, self.sharding_group_size,
self.dp_group_size) self.dp_group_size)
self.pp_ring_id = -1
self.pp_rank = -1
self.pp_group_size = None
self.pp_group_endpoints = None
# sharding parallelism is the only model parallelism in the current setting
self.mp_group_id = self.sharding_ring_id
self.mp_rank = self.sharding_rank
self.mp_group_size = self.sharding_group_size
self.mp_group_endpoints = self.sharding_group_endpoints[:]
logging.info("Using Sharing&DP mode !") logging.info("Using Sharing&DP mode !")
else: else:
self.sharding_ring_id = 0 if self._as_outer_parallelism and not self.use_pipeline:
self.sharding_rank = self.global_rank self.sharding_ring_id = 1
self.sharding_group_size = self.role_maker._worker_num() assert self.global_word_size > self._inner_parallelism_size, \
self.sharding_group_endpoints = self.endpoints "global_word_size: {} should be larger than inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size)
assert self.global_word_size % self._inner_parallelism_size == 0, \
"global_word_size: {} should be divisible to the inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size)
self.sharding_rank = self.global_rank // self._inner_parallelism_size
self.sharding_group_size = self.role_maker._worker_num(
) // self._inner_parallelism_size
_offset = self.global_rank % self._inner_parallelism_size
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if idx % self._inner_parallelism_size == _offset
]
# the current entire model parallelism group is the combination of innert & sharding parallelism
self.mp_group_id = 2
self.mp_rank = self.global_rank
self.mp_group_size = self.role_maker._worker_num()
self.mp_group_endpoints = self.endpoints[:]
logging.info("Using Sharing as Outer parallelism mode !")
# print(
# "init the nccl comm for megatron paramllelism, this should be done in Megatron Metaoptimizer"
# )
# partition_idx = self.global_rank // self._inner_parallelism_size
# magetron_endpoints = self.endpoints[
# partition_idx * self._inner_parallelism_size:partition_idx *
# self._inner_parallelism_size + self._inner_parallelism_size]
# magetron_rank = self.global_rank % self._inner_parallelism_size
# self._collective_helper._init_communicator(
# program=self._startup_program,
# current_endpoint=self.current_endpoint,
# endpoints=magetron_endpoints,
# rank=magetron_rank,
# ring_id=0,
# wait_port=True)
# logging.info("megatron group size: {}".format(
# self._inner_parallelism_size))
# logging.info("megatron rank: {}".format(magetron_rank))
# logging.info("megatron endpoints: {}".format(
# magetron_endpoints))
if self.use_pipeline:
if self._inner_parallelism_size == 1:
self.sharding_ring_id = 0
self.sharding_group_size = self.user_defined_strategy.sharding_configs[
'sharding_group_size']
self.sharding_rank = self.global_rank % self.sharding_group_size
assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num(
)
self.pp_ring_id = 20
self.pp_rank = self.global_rank // (
self.sharding_group_size * self._inner_parallelism_size)
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx // self.sharding_group_size) == self.pp_rank
]
self.pp_group_size = self.pipeline_nodes
self.pp_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx % self.sharding_group_size
) == self.sharding_rank
]
else:
self.mp_group_id = 0
self.sharding_ring_id = 1
self.pp_ring_id = 20
self.mp_rank = self.global_rank % self._inner_parallelism_size
self.mp_group = self.global_rank // self._inner_parallelism_size
self.mp_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if idx // self._inner_parallelism_size == self.mp_group
]
print("megatron_group_endpoints:", self.mp_group_endpoints)
print("megatron_rank:", self.mp_rank)
# self.cards_per_node = 8
self.sharding_group_size = self.user_defined_strategy.sharding_configs[
'sharding_group_size']
self.sharding_rank = (
self.global_rank //
self._inner_parallelism_size) % self.sharding_group_size
self.sharding_group_id = self.global_rank // (
self._inner_parallelism_size * self.sharding_group_size)
self.megatron_rank = self.global_rank % self._inner_parallelism_size
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx // (self._inner_parallelism_size *
self.sharding_group_size)
) == self.sharding_group_id and idx %
self._inner_parallelism_size == self.megatron_rank
]
print("sharding_endpoint:", self.sharding_group_endpoints)
print("sharding_rank:", self.sharding_rank)
assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num(
)
self.pp_rank = self.global_rank // (
self.sharding_group_size *
self._inner_parallelism_size) % self.pipeline_nodes
offset = self.sharding_group_size * self._inner_parallelism_size
# TODO: Adjust for dp
idx_with_pp_0 = self.global_rank % (
self.sharding_group_size * self._inner_parallelism_size)
self.pp_group_endpoints = []
for i in range(self.pipeline_nodes):
self.pp_group_endpoints.append(self.endpoints[
idx_with_pp_0])
idx_with_pp_0 += offset
print("pp_group_endpoints:", self.pp_group_endpoints)
print("pp_rank:", self.pp_rank)
#self.pp_group_endpoints = [
# ep for idx, ep in enumerate(self.endpoints)
# if (idx % self.sharding_group_size) == self.sharding_rank
#]
self.global_group_id = 3
self.global_rank = self.global_rank
self.global_group_size = self.role_maker._worker_num()
self.global_group_endpoints = self.endpoints[:]
logging.info("Using Sharing as Outer parallelism mode !")
self.dp_ring_id = -1
self.dp_rank = -1
self.dp_group_size = None
self.dp_group_endpoints = None
logging.info("Using Sharing with pipeline !")
#else:
# self.sharding_ring_id = 0
# self.sharding_rank = self.global_rank
# self.sharding_group_size = self.role_maker._worker_num()
# self.sharding_group_endpoints = self.endpoints
# # sharding parallelism is the only model parallelism in the current setting
# self.mp_group_id = self.sharding_ring_id
# self.mp_rank = self.sharding_rank
# self.mp_group_size = self.sharding_group_size
# self.mp_group_endpoints = self.sharding_group_endpoints[:]
# logging.info("Using Sharing alone mode !")
self.dp_ring_id = -1 self.dp_ring_id = -1
self.dp_rank = -1 self.dp_rank = -1
self.dp_group_size = None self.dp_group_size = None
self.dp_group_endpoints = None self.dp_group_endpoints = None
#self.pp_ring_id = -1
#self.pp_rank = -1
#self.pp_group_size = None
#self.pp_group_endpoints = None
#self.dp_ring_id = -1
#self.dp_rank = -1
#self.dp_group_size = None
#self.dp_group_endpoints = None
logging.info("Using Sharing alone mode !") logging.info("Using Sharing alone mode !")
logging.info("global word size: {}".format(self.global_word_size)) #logging.info("global word size: {}".format(self.global_word_size))
logging.info("global rank: {}".format(self.global_rank)) #logging.info("global rank: {}".format(self.global_rank))
logging.info("sharding group_size: {}".format(self.sharding_group_size)) #logging.info("sharding group_size: {}".format(self.sharding_group_size))
logging.info("sharding rank: {}".format(self.sharding_rank)) #logging.info("sharding rank: {}".format(self.sharding_rank))
logging.info("dp group size: {}".format(self.dp_group_size)) #logging.info("current model parallelism group_size: {}".format(
logging.info("dp rank: {}".format(self.dp_rank)) # self.mp_group_size))
logging.info("current endpoint: {}".format(self.current_endpoint)) #logging.info("current model parallelism rank: {}".format(self.mp_rank))
logging.info("sharding group endpoints: {}".format( #logging.info("dp group size: {}".format(self.dp_group_size))
self.sharding_group_endpoints)) #logging.info("dp rank: {}".format(self.dp_rank))
logging.info("dp group endpoints: {}".format(self.dp_group_endpoints)) #logging.info("current endpoint: {}".format(self.current_endpoint))
logging.info("global word endpoints: {}".format(self.endpoints)) #logging.info("global word endpoints: {}".format(self.endpoints))
#logging.info("sharding group endpoints: {}".format(
# self.sharding_group_endpoints))
#logging.info("current model parallelism group endpoints: {}".format(
# self.mp_group_endpoints))
#logging.info("dp group endpoints: {}".format(self.dp_group_endpoints))
return return
def _initialization_broadcast(self, startup_prog):
"""
this funtion is to ensure the initialization between dp group to be
identical when hybrid-dp is used.
"""
block = startup_prog.global_block()
params = []
for param in block.iter_parameters():
params.append(param)
block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': self.dp_ring_id,
'root': 0,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_sync_comm_stream',
inputs={'X': params},
outputs={'Out': params},
attrs={'ring_id': self.dp_ring_id,
OP_ROLE_KEY: OpRole.Forward})
...@@ -115,7 +115,7 @@ class ProgramStats(object): ...@@ -115,7 +115,7 @@ class ProgramStats(object):
updated_min_idx = min_idx updated_min_idx = min_idx
while idx_ > pre_segment_end_idx: while idx_ > pre_segment_end_idx:
if is_amp_cast(self.ops[idx_]): if is_amp_cast(self.ops[idx_]):
_logger.debug("found amp-cast op: {}, : {}".format(self.ops[ _logger.info("found amp-cast op: {}, : {}".format(self.ops[
idx_].desc.type(), self.ops[idx_].desc.input_arg_names()[ idx_].desc.type(), self.ops[idx_].desc.input_arg_names()[
0])) 0]))
updated_min_idx = idx_ updated_min_idx = idx_
...@@ -155,7 +155,7 @@ class ProgramStats(object): ...@@ -155,7 +155,7 @@ class ProgramStats(object):
sorted_checkpoints = [] sorted_checkpoints = []
for name in checkpoints_name: for name in checkpoints_name:
if name not in self.var_op_deps: if name not in self.var_op_deps:
_logger.debug( _logger.info(
"Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program." "Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program."
% name) % name)
elif self.var_op_deps[name]["var_as_output_ops"] == []: elif self.var_op_deps[name]["var_as_output_ops"] == []:
...@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars): ...@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
new_op_desc = block.desc.append_op() new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc) new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward) new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc) result_descs.append(new_op_desc)
return result_descs return result_descs
...@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block): ...@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block):
new_op_desc = block.desc.append_op() new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc) new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward) new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc) result_descs.append(new_op_desc)
return result_descs return result_descs
...@@ -784,7 +788,6 @@ def _append_backward_ops_with_checkpoints_( ...@@ -784,7 +788,6 @@ def _append_backward_ops_with_checkpoints_(
start_idx = 0 start_idx = 0
pre_segment_end_idx = -1 pre_segment_end_idx = -1
while True: while True:
_logger.debug("FW op range[0] - [{}]".format(len(ops)))
if start_idx >= len(checkpoints_name) - 1: if start_idx >= len(checkpoints_name) - 1:
break break
# min_idx: checkpoint_1' s input op # min_idx: checkpoint_1' s input op
...@@ -797,6 +800,9 @@ def _append_backward_ops_with_checkpoints_( ...@@ -797,6 +800,9 @@ def _append_backward_ops_with_checkpoints_(
min_idx = program_stat._update_segment_start( min_idx = program_stat._update_segment_start(
min_idx, pre_segment_end_idx) min_idx, pre_segment_end_idx)
segments.append([min_idx, max_idx + 1]) segments.append([min_idx, max_idx + 1])
else:
_logger.info("Could not recompute op range [{}] - [{}] ".format(
min_idx, max_idx + 1))
start_idx += 1 start_idx += 1
...@@ -806,15 +812,15 @@ def _append_backward_ops_with_checkpoints_( ...@@ -806,15 +812,15 @@ def _append_backward_ops_with_checkpoints_(
recompute_segments = segments recompute_segments = segments
for i, (idx1, idx2) in enumerate(recompute_segments): for i, (idx1, idx2) in enumerate(recompute_segments):
_logger.debug("recompute segment[{}]".format(i)) _logger.info("recompute segment[{}]".format(i))
_logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( _logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
), ops[idx1].desc.input_arg_names())) ), ops[idx1].desc.input_arg_names()))
_logger.debug("segment end op: [{}]: [{}]".format(ops[ _logger.info("segment end op: [{}]: [{}]".format(ops[
idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names())) idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names()))
_logger.debug("recompute segment[{}]".format(i)) _logger.info("recompute segment[{}]".format(i))
_logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( _logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
), ops[idx1].desc.input_arg_names())) ), ops[idx1].desc.input_arg_names()))
_logger.debug("segment end op: [{}]: [{}]".format(ops[ _logger.info("segment end op: [{}]: [{}]".format(ops[
idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names())) idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names()))
# 2) go through all forward ops and induct all variables that will be hold in memory # 2) go through all forward ops and induct all variables that will be hold in memory
...@@ -825,9 +831,9 @@ def _append_backward_ops_with_checkpoints_( ...@@ -825,9 +831,9 @@ def _append_backward_ops_with_checkpoints_(
program_stat.get_out_of_subgraph_vars(segment[0], segment[1])) program_stat.get_out_of_subgraph_vars(segment[0], segment[1]))
cross_vars = set(vars_should_be_hold) - set(checkpoints_name) cross_vars = set(vars_should_be_hold) - set(checkpoints_name)
_logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ _logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
len(cross_vars), cross_vars)) len(cross_vars), cross_vars))
_logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ _logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
len(cross_vars), cross_vars)) len(cross_vars), cross_vars))
# b. output of seed op should be kept in memory # b. output of seed op should be kept in memory
...@@ -843,6 +849,7 @@ def _append_backward_ops_with_checkpoints_( ...@@ -843,6 +849,7 @@ def _append_backward_ops_with_checkpoints_(
vars_in_memory = vars_should_be_hold + checkpoints_name vars_in_memory = vars_should_be_hold + checkpoints_name
max_calculated_op_position = len(ops) max_calculated_op_position = len(ops)
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
if recompute_segments == []: if recompute_segments == []:
gap_ops = ops[0:max_calculated_op_position] gap_ops = ops[0:max_calculated_op_position]
for op in reversed(gap_ops): for op in reversed(gap_ops):
...@@ -852,6 +859,11 @@ def _append_backward_ops_with_checkpoints_( ...@@ -852,6 +859,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_(op.desc, "with_sub_block")) _pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), []) op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block) added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs) grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
...@@ -866,6 +878,11 @@ def _append_backward_ops_with_checkpoints_( ...@@ -866,6 +878,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_(op.desc, "with_sub_block")) _pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), []) op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block) added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs) grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
...@@ -888,6 +905,18 @@ def _append_backward_ops_with_checkpoints_( ...@@ -888,6 +905,18 @@ def _append_backward_ops_with_checkpoints_(
continue continue
if name not in var_name_dict: if name not in var_name_dict:
var_name_dict[name] = name + var_suffix var_name_dict[name] = name + var_suffix
# we should create the rename var in subprog, otherwise its VarType will be BOOL
block.create_var(
name=var_name_dict[name],
shape=block.program.global_block().var(name).shape,
dtype=block.program.global_block().var(name).dtype,
type=block.program.global_block().var(name).type,
persistable=block.program.global_block().var(
name).persistable,
stop_gradient=block.program.global_block().var(name)
.stop_gradient)
# 3.a. add ops in current recompute_segment as forward recomputation ops # 3.a. add ops in current recompute_segment as forward recomputation ops
buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block, buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block,
vars_in_memory) vars_in_memory)
......
...@@ -489,9 +489,14 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -489,9 +489,14 @@ class ClipGradByGlobalNorm(ClipGradBase):
continue continue
with p.block.program._optimized_guard([p, g]): with p.block.program._optimized_guard([p, g]):
new_grad = layers.elementwise_mul(x=g, y=scale_var) p.block.append_op(
param_new_grad_name_dict[p.name] = new_grad.name type='elementwise_mul',
params_and_grads.append((p, new_grad)) inputs={'X': g,
'Y': scale_var},
outputs={'Out': g})
param_new_grad_name_dict[p.name] = p.name
params_and_grads.append((p, p))
_correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict) _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
return params_and_grads return params_and_grads
......
...@@ -123,7 +123,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -123,7 +123,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
outputs={"Out": out_var}, outputs={"Out": out_var},
attrs={ attrs={
"in_dtype": in_var.dtype, "in_dtype": in_var.dtype,
"out_dtype": out_var.dtype "out_dtype": out_var.dtype,
"op_device": op.attr("op_device")
}) })
num_cast_ops += 1 num_cast_ops += 1
_rename_arg(op, in_var.name, out_var.name) _rename_arg(op, in_var.name, out_var.name)
...@@ -171,8 +172,11 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name, ...@@ -171,8 +172,11 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
type="cast", type="cast",
inputs={"X": target_var}, inputs={"X": target_var},
outputs={"Out": cast_var}, outputs={"Out": cast_var},
attrs={"in_dtype": target_var.dtype, attrs={
"out_dtype": cast_var.dtype}) "in_dtype": target_var.dtype,
"out_dtype": cast_var.dtype,
"op_device": op.attr("op_device")
})
num_cast_ops += 1 num_cast_ops += 1
op_var_rename_map[block.idx][target_var.name] = cast_var.name op_var_rename_map[block.idx][target_var.name] = cast_var.name
......
...@@ -413,6 +413,9 @@ class Section(DeviceWorker): ...@@ -413,6 +413,9 @@ class Section(DeviceWorker):
section_param = trainer_desc.section_param section_param = trainer_desc.section_param
section_param.num_microbatches = pipeline_opt["num_microbatches"] section_param.num_microbatches = pipeline_opt["num_microbatches"]
section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"] section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"]
section_param.pipeline_stage = pipeline_opt["pipeline_stage"]
section_param.num_pipeline_stages = pipeline_opt["num_pipeline_stages"]
section_param.schedule_mode = pipeline_opt["schedule_mode"]
cfg = section_param.section_config cfg = section_param.section_config
program = pipeline_opt["section_program"] program = pipeline_opt["section_program"]
cfg.program_desc.ParseFromString(program["program"]._get_desc() cfg.program_desc.ParseFromString(program["program"]._get_desc()
......
...@@ -19,6 +19,7 @@ import six ...@@ -19,6 +19,7 @@ import six
import os import os
import logging import logging
from collections import defaultdict from collections import defaultdict
import time
import paddle import paddle
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
...@@ -3759,15 +3760,21 @@ class PipelineOptimizer(object): ...@@ -3759,15 +3760,21 @@ class PipelineOptimizer(object):
def __init__(self, optimizer, num_microbatches=1, start_cpu_core_id=0): def __init__(self, optimizer, num_microbatches=1, start_cpu_core_id=0):
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
raise Exception("In dygraph, don't support PipelineOptimizer.") raise Exception("In dygraph, don't support PipelineOptimizer.")
if not isinstance(optimizer, Optimizer) and not isinstance( supported_opt_types = (Optimizer, paddle.fluid.contrib.mixed_precision.
optimizer, paddle.optimizer.Optimizer) and not isinstance( decorator.OptimizerWithMixedPrecision)
optimizer, paddle.fluid.contrib.mixed_precision.decorator. if not isinstance(optimizer, supported_opt_types):
OptimizerWithMixedPrecision):
raise ValueError("The 'optimizer' parameter for " raise ValueError("The 'optimizer' parameter for "
"PipelineOptimizer must be an instance of " "PipelineOptimizer must be an instance of one of "
"Optimizer, but the given type is {}.".format( "{}, but the type is {}.".format(
type(optimizer))) supported_opt_types, type(optimizer)))
self._optimizer = optimizer self._optimizer = optimizer
# Get the original optimizer defined by users, such as SGD
self._origin_optimizer = self._optimizer
while hasattr(self._origin_optimizer, "inner_opt"):
self._origin_optimizer = self._origin_optimizer.inner_opt
assert num_microbatches >= 1, ( assert num_microbatches >= 1, (
"num_microbatches must be a positive value.") "num_microbatches must be a positive value.")
self._num_microbatches = num_microbatches self._num_microbatches = num_microbatches
...@@ -3781,52 +3788,147 @@ class PipelineOptimizer(object): ...@@ -3781,52 +3788,147 @@ class PipelineOptimizer(object):
self._op_role_var_key = op_maker.kOpRoleVarAttrName() self._op_role_var_key = op_maker.kOpRoleVarAttrName()
self._op_device_key = op_maker.kOpDeviceAttrName() self._op_device_key = op_maker.kOpDeviceAttrName()
self._param_device_map = None self._param_device_map = None
self._pipeline_pair = []
self._pp_ring_map = dict()
def _create_vars(self, block, ori_block): def _create_vars(self, block, ori_block):
# Create vars for block, copied from main_program's global block # Create vars for block, copied from ori_block
used_var_set = set() used_var_set = set()
for op_idx in range(block.desc.op_size()): for op_idx in range(block.desc.op_size()):
op_desc = block.desc.op(op_idx) # Whether to insert allreduce_sum or allreduce_max op?
vars = op_desc.input_arg_names() + op_desc.output_arg_names() # For amp and global gradient clip strategies, we should
# get the global infomation, so allreduce op is needed.
should_insert = False
op = block.ops[op_idx]
# For op process vars on all devices, remove its input
# vars not in this block
reserved_x = []
if op.type == 'reduce_any' and self._is_optimize_op(op):
should_insert = True
if op.type == 'concat' and self._is_optimize_op(op):
for input_name in op.desc.input("X"):
if block._find_var_recursive(input_name):
reserved_x.append(input_name)
op.desc.set_input('X', reserved_x)
print('reserved_x:', reserved_x)
if op.type == 'update_loss_scaling':
for input_name in op.desc.input("X"):
if block._find_var_recursive(input_name):
reserved_x.append(input_name)
op.desc.set_input('X', reserved_x)
op.desc.set_output('Out', reserved_x)
if op.type == 'sum' and self._is_gradient_clip_op(op):
for input_name in op.desc.input("X"):
if block._find_var_recursive(input_name):
reserved_x.append(input_name)
op.desc.set_input('X', reserved_x)
should_insert = True
vars = op.desc.input_arg_names() + op.desc.output_arg_names()
for var in vars: for var in vars:
# a var whose name contains "blocking_queue" # a var whose name contains "blocking_queue"
# only exists in startup program # only exists in startup program
if var in used_var_set or "_blocking_queue" in var: if var in used_var_set or "_blocking_queue" in var: continue
continue
used_var_set.add(var) used_var_set.add(var)
if block._find_var_recursive(str(var)): continue if block._find_var_recursive(str(var)): continue
source_var = ori_block._var_recursive(str(var)) source_var = ori_block._var_recursive(str(var))
if source_var.type == core.VarDesc.VarType.READER: if source_var.type == core.VarDesc.VarType.READER:
block.create_var( dest_var = block.create_var(
name=var, name=var,
type=core.VarDesc.VarType.READER, type=core.VarDesc.VarType.READER,
persistable=source_var.persistable) persistable=source_var.persistable)
else: else:
block._clone_variable(source_var, False) dest_var = block._clone_variable(source_var, False)
dest_var.stop_gradient = source_var.stop_gradient
continue
# TODO add allreduce_max when without sharding
if not should_insert: continue
out_name = op.desc.output_arg_names()[0]
out_var = block.var(out_name)
offset = 0
if op.type == "reduce_any":
# cast the bool var to int32 to use allreduce op
temp_var_name = unique_name.generate(out_name + "_cast_int32")
temp_var = block.create_var(
name=temp_var_name, shape=[1], dtype="int32")
block._insert_op(
op_idx + 1 + offset,
type='cast',
inputs={'X': out_var},
outputs={'Out': temp_var},
attrs={
'in_dtype': out_var.dtype,
'out_dtype': temp_var.dtype,
self._op_role_key:
core.op_proto_and_checker_maker.OpRole.Optimize
})
offset += 1
# block._insert_op(
# op_idx + 1 + offset,
# type='c_sync_calc_stream',
# inputs={'X': temp_var if op.type == "reduce_any" else out_var},
# outputs={
# 'Out': temp_var if op.type == "reduce_any" else out_var
# },
# attrs={
# OP_ROLE_KEY:
# core.op_proto_and_checker_maker.OpRole.Optimize,
# })
# offset += 1
block._insert_op(
op_idx + 1 + offset,
type='c_allreduce_max'
if op.type == "reduce_any" else 'c_allreduce_sum',
inputs={'X': temp_var if op.type == "reduce_any" else out_var},
outputs={
'Out': temp_var if op.type == "reduce_any" else out_var
},
attrs={
'ring_id': self.ring_id,
self._op_role_key:
core.op_proto_and_checker_maker.OpRole.Optimize,
'use_calc_stream': True
})
offset += 1
# block._insert_op(
# # op_idx + 1 + extra_index,
# op_idx + 1 + offset,
# type='c_sync_comm_stream',
# inputs={'X': temp_var if op.type == "reduce_any" else out_var},
# outputs={
# 'Out': temp_var if op.type == "reduce_any" else out_var
# },
# attrs={
# 'ring_id': self.ring_id,
# OP_ROLE_KEY:
# core.op_proto_and_checker_maker.OpRole.Optimize,
# })
# offset += 1
if op.type == "reduce_any":
block._insert_op(
op_idx + 1 + offset,
type='cast',
inputs={'X': temp_var},
outputs={'Out': out_var},
attrs={
'in_dtype': temp_var.dtype,
'out_dtype': out_var.dtype,
self._op_role_key:
core.op_proto_and_checker_maker.OpRole.Optimize
})
def _is_loss_grad_op(self, op): def _is_loss_grad_op(self, op):
if self._op_role_key not in op.attr_names: assert self._op_role_key in op.attr_names
return False op_role = int(op.attr(self._op_role_key))
op_role = int(op.all_attrs()[self._op_role_key])
return op_role & int(self._op_role.Backward) and op_role & int( return op_role & int(self._op_role.Backward) and op_role & int(
self._op_role.Loss) self._op_role.Loss)
def _is_backward_op(self, op):
return self._op_role_key in op.attr_names and int(op.all_attrs()[
self._op_role_key]) & int(self._op_role.Backward)
def _is_optimize_op(self, op):
return self._op_role_key in op.attr_names and int(op.all_attrs()[
self._op_role_key]) & int(self._op_role.Optimize)
def _is_update_op(self, op):
return 'Param' in op.input_names and 'Grad' in op.input_names and (
"LearningRate" in op.input_names)
def _split_program(self, main_program, devices): def _split_program(self, main_program, devices):
""" """
Split a program into sections according to devices that ops run on. Split a program into sections according to devices that ops run on.
The ops of the role LRSched are copied to all sections. The op whose op_device attr is "gpu:all" is copied to all sections.
Args: Args:
main_program (Program): the main program main_program (Program): the main program
...@@ -3842,27 +3944,20 @@ class PipelineOptimizer(object): ...@@ -3842,27 +3944,20 @@ class PipelineOptimizer(object):
block = main_program.block(0) block = main_program.block(0)
for op in block.ops: for op in block.ops:
device = op.attr(self._op_device_key) device = op.attr(self._op_device_key)
op_role = op.attr(self._op_role_key) # Copy ops whose op_device set to "gpu:all" to all sections.
if int(op_role) & int(self._op_role.LRSched): if device == "gpu:all":
# Copy ops of the role LRSched to all sections.
for device in device_program_map.keys():
program = device_program_map[device]
op_desc = op.desc
ap_op = program["program"].block(0).desc.append_op()
ap_op.copy_from(op_desc)
# ap_op._set_attr(self._op_device_key, "")
elif op.type == "create_py_reader" or op.type == "read" or op.type == "create_double_buffer_reader":
# Copy read related ops to all section to make them exit after each epoch.
for device in device_program_map.keys(): for device in device_program_map.keys():
program = device_program_map[device] program = device_program_map[device]
op_desc = op.desc op_desc = op.desc
ap_op = program["program"].block(0).desc.append_op() ap_op = program["program"].block(0).desc.append_op()
ap_op.copy_from(op_desc) ap_op.copy_from(op_desc)
ap_op._set_attr(self._op_device_key, "")
else: else:
program = device_program_map[device] program = device_program_map[device]
op_desc = op.desc op_desc = op.desc
ap_op = program["program"].block(0).desc.append_op() ap_op = program["program"].block(0).desc.append_op()
ap_op.copy_from(op_desc) ap_op.copy_from(op_desc)
ap_op._set_attr(self._op_device_key, "")
for key in devices: for key in devices:
program = device_program_map[key] program = device_program_map[key]
...@@ -3921,6 +4016,11 @@ class PipelineOptimizer(object): ...@@ -3921,6 +4016,11 @@ class PipelineOptimizer(object):
var_name as output. var_name as output.
var_name (string): Variable name. var_name (string): Variable name.
""" """
# To skip the cast op added by amp which has no op_device set
if '.cast_fp32' in var_name:
var_name = var_name.replace('.cast_fp32', '')
if '.cast_fp16' in var_name:
var_name = var_name.replace('.cast_fp16', '')
post_op = [] post_op = []
before = True before = True
for op in ops: for op in ops:
...@@ -3949,7 +4049,7 @@ class PipelineOptimizer(object): ...@@ -3949,7 +4049,7 @@ class PipelineOptimizer(object):
""" """
prev_op = [] prev_op = []
for op in ops: for op in ops:
if op.type == 'send_v2' or op.type == 'recv_v2': if op.type == 'send_v2' or op.type == 'recv_v2' or op.type == 'c_broadcast':
continue continue
if op == cur_op: if op == cur_op:
break break
...@@ -3964,11 +4064,8 @@ class PipelineOptimizer(object): ...@@ -3964,11 +4064,8 @@ class PipelineOptimizer(object):
return None return None
def _rename_arg(self, op, old_name, new_name): def _rename_arg(self, op, old_name, new_name):
op_desc = op.desc op._rename_input(old_name, new_name)
if isinstance(op_desc, tuple): op._rename_output(old_name, new_name)
op_desc = op_desc[0]
op_desc._rename_input(old_name, new_name)
op_desc._rename_output(old_name, new_name)
def _create_var(self, block, ref_var, name): def _create_var(self, block, ref_var, name):
""" """
...@@ -3982,9 +4079,10 @@ class PipelineOptimizer(object): ...@@ -3982,9 +4079,10 @@ class PipelineOptimizer(object):
dtype=ref_var.dtype, dtype=ref_var.dtype,
type=ref_var.type, type=ref_var.type,
lod_level=ref_var.lod_level, lod_level=ref_var.lod_level,
persistable=False, persistable=ref_var.persistable,
is_data=False, is_data=ref_var.is_data,
need_check_feed=ref_var.desc.need_check_feed()) need_check_feed=ref_var.desc.need_check_feed())
new_var.stop_gradient = ref_var.stop_gradient
return new_var return new_var
def _get_data_var_info(self, block): def _get_data_var_info(self, block):
...@@ -4037,6 +4135,7 @@ class PipelineOptimizer(object): ...@@ -4037,6 +4135,7 @@ class PipelineOptimizer(object):
if not var_name in first_block.vars: if not var_name in first_block.vars:
self._create_var(first_block, main_var, var_name) self._create_var(first_block, main_var, var_name)
dev_index = int(device.split(':')[1]) dev_index = int(device.split(':')[1])
print("dev_index:", dev_index)
first_block._insert_op( first_block._insert_op(
index=insert_index, index=insert_index,
type='send_v2', type='send_v2',
...@@ -4044,8 +4143,11 @@ class PipelineOptimizer(object): ...@@ -4044,8 +4143,11 @@ class PipelineOptimizer(object):
attrs={ attrs={
self._op_device_key: first_dev_spec, self._op_device_key: first_dev_spec,
self._op_role_key: self._op_role.Forward, self._op_role_key: self._op_role.Forward,
'use_calc_stream': True, 'use_calc_stream': False,
'peer': dev_index, 'peer': dev_index,
#'ring_id': self.ring_id,
'ring_id': self.ring_id
if dev_index > first_dev_index else self.ring_id + 2,
}) })
# Get the device that that data on # Get the device that that data on
assert device in devices assert device in devices
...@@ -4070,6 +4172,21 @@ class PipelineOptimizer(object): ...@@ -4070,6 +4172,21 @@ class PipelineOptimizer(object):
self._op_role_key: self._op_role.Forward, self._op_role_key: self._op_role.Forward,
'peer': first_dev_index, 'peer': first_dev_index,
'use_calc_stream': True, 'use_calc_stream': True,
#'ring_id': self.ring_id,
'ring_id': self.ring_id
if first_dev_index < dev_index else self.ring_id + 2,
})
block._insert_op(
index=index + 1,
type='c_sync_comm_stream',
inputs={'X': [new_var]},
outputs={'Out': [new_var]},
attrs={
self._op_device_key: device,
self._op_role_key: self._op_role.Forward,
#'ring_id': self.ring_id,
'ring_id': self.ring_id
if first_dev_index > dev_index else self.ring_id + 2,
}) })
def _strip_grad_suffix(self, name): def _strip_grad_suffix(self, name):
...@@ -4085,79 +4202,190 @@ class PipelineOptimizer(object): ...@@ -4085,79 +4202,190 @@ class PipelineOptimizer(object):
""" """
return name + core.grad_var_suffix() return name + core.grad_var_suffix()
def _add_opdevice_attr_for_regularization_clip(self, block): def _is_forward_op(self, op):
""" """
Add op_device attribute for regulization and clip ops. Is the op_role attribute of a op is Forward.
""" """
for op in block.ops: assert self._op_role_key in op.attr_names
# role for regularization and clip ops is optimize return int(op.attr(self._op_role_key)) == int(self._op_role.Forward)
if int(op.attr(self._op_role_key)) != int(self._op_role.Optimize):
continue def _is_backward_op(self, op):
if op.has_attr(self._op_device_key) and ( """
op.attr(self._op_device_key) != ""): Is the op_role attribute of a op is Backward.
continue """
assert self._op_role_var_key in op.attr_names assert self._op_role_key in op.attr_names
op_role_var = op.all_attrs()[self._op_role_var_key] return int(op.attr(self._op_role_key)) == int(self._op_role.Backward)
assert len(op_role_var) == 2
param_name = op_role_var[0]
device = self._param_device_map[param_name]
op._set_attr(self._op_device_key, device)
def _add_default_opdevice_attr(self, block): def _is_loss_op(self, op):
""" """
1. Add default op_device attribute for lr-related ops. Is the op_role attribute of a op is Loss.
The default value is the one that of the first place.
2. Add default op_device attribute for sum ops added during
backward. For these ops, we set the op_device attribute
as the one of its post op, i.e, which op has the output of the
sum op as an input.
""" """
first_devcie = "" assert self._op_role_key in op.attr_names
return int(op.attr(self._op_role_key)) == int(self._op_role.Loss)
# Get the device spec of the first place. def _is_optimize_op(self, op):
# device_spec: 'cpu' for cpu device and 'gpu:id' for gpu device, """
# e.g. 'gpu:0', 'gpu:1', etc. Is the op_role attribute of a op is Optimize.
for op in block.ops: """
if op.has_attr(self._op_device_key) and ( assert self._op_role_key in op.attr_names
op.attr(self._op_device_key) != ""): return int(op.attr(self._op_role_key)) == int(self._op_role.Optimize)
first_device = op.attr(self._op_device_key)
break
assert first_device
first_device_type = first_device.split(":")[0]
assert first_device_type == "gpu"
# set op_device attr for lr-related ops def _is_lrsched_op(self, op):
lrsched_role = int(self._op_role.LRSched) """
for op in block.ops: Is the op_role attribute of a op is LRSched.
if not op.has_attr(self._op_device_key) or ( """
op.attr(self._op_device_key) == ""): assert self._op_role_key in op.attr_names
if op.type == "sum": return int(op.attr(self._op_role_key)) == int(self._op_role.LRSched)
# For sum ops that compute the sum of @RENAMED@ vars
for name in op.desc.input_arg_names(): def _is_update_op(self, op):
assert '@RENAME@' in name """
assert len(op.desc.output_arg_names()) == 1 Is the op updates the parameter using gradient.
out_name = op.desc.output_arg_names()[0] """
post_op = self._find_post_op(block.ops, op, out_name) return 'Param' in op.input_names and 'Grad' in op.input_names and (
device = post_op.attr(self._op_device_key) "LearningRate" in op.input_names)
assert device
op._set_attr(self._op_device_key, device) def _get_op_device_attr(self, op):
continue """
Get the op_device attribute of a op.
"""
device = op.attr(self._op_device_key) \
if op.has_attr(self._op_device_key) else None
if device:
assert device[0:3] == 'gpu', "Now, only gpu devices are " \
"supported in pipeline parallemism."
return device
def _add_op_device_attr_for_op(self, op, idx, block):
"""
Add op_device attrribute for ops that have not that attribute set.
assert op.attr(self._op_role_key) == lrsched_role, ( We use "gpu:all" to represent the op should be put on all
"Op whose op_device attr has not been set for pipeline" sub-programs, such as lr-related ops. Note that: "gpu:all"
" must be of the role LRSched.") is only used by pipeline as an indicator.
op._set_attr(self._op_device_key, first_device) """
lrsched_role = int(self._op_role.LRSched)
if op.attr(self._op_role_key) == lrsched_role:
# For LRSched ops, we should put them on all sub-programs to
# make sure each sub-program update the lr correctly
op._set_attr(self._op_device_key, "gpu:all")
elif op.type == "sum" and self._is_backward_op(op):
# For sum ops that compute the sum of @RENAMED@ vars
for name in op.desc.input_arg_names():
assert '@RENAME@' in name, \
"The op must be sum used to accumulate renamed vars."
assert len(op.desc.output_arg_names()) == 1
out_name = op.desc.output_arg_names()[0]
post_op = self._find_post_op(block.ops, op, out_name)
assert post_op.has_attr(
'op_device'), "{} has no op_device attr for var {}".format(
post_op.type, out_name)
device = post_op.attr(self._op_device_key)
assert device, "The post op must have op_device set."
op._set_attr(self._op_device_key, device)
elif (op.type == "cast" or
op.type == "scale") and self._is_backward_op(op):
prev_op = self._find_real_prev_op(block.ops, op,
op.desc.input("X")[0])
op._set_attr('op_device', prev_op.attr('op_device'))
elif op.type == "memcpy" and not self._is_optimize_op(op):
assert len(op.input_arg_names) == 1 and len(
op.output_arg_names) == 1
input_name = op.input_arg_names[0]
output_name = op.output_arg_names[0]
if '@Fetch' in output_name:
post_op = self._find_post_op(block.ops, op, output_name)
op._set_attr('op_device', post_op.attr('op_device'))
else:
prev_op = self._find_real_prev_op(block.ops, op,
op.desc.input("X")[0])
op._set_attr('op_device', prev_op.attr('op_device'))
elif self._is_loss_op(op):
# For loss * loss_scaling op added by AMP
offset = 1
while (not block.ops[idx + offset].has_attr(self._op_device_key) or
not block.ops[idx + offset].attr(self._op_device_key)):
offset += 1
# assert block.ops[idx + 1].type == "fill_constant"
# assert block.ops[idx + 2].type == "elementwise_mul_grad"
# assert block.ops[idx + 3].type == "elementwise_add_grad"
# assert block.ops[idx + 4].type == "mean_grad"
# device = block.ops[idx + 4].attr(self._op_device_key)
device = block.ops[idx + offset].attr(self._op_device_key)
assert device, "Please put you program within device_guard scope."
# op._set_attr(self._op_device_key, device)
# block.ops[idx + 1]._set_attr(self._op_device_key, device)
# block.ops[idx + 2]._set_attr(self._op_device_key, device)
# block.ops[idx + 2]._set_attr(self._op_device_key, device)
for i in range(offset):
block.ops[idx + i]._set_attr(self._op_device_key, device)
elif self._is_optimize_op(op) and op.type == "check_finite_and_unscale":
#op._set_attr(self._op_device_key, "gpu:all")
op_role_var = op.attr(self._op_role_var_key)
param_name = op_role_var[0]
device = self._param_device_map[param_name]
op._set_attr(self._op_device_key, device)
elif self._is_optimize_op(op) and op.type == "cast":
# For fp16-->fp32 cast added by AMP
grad_name = op.output('Out')
assert len(grad_name) == 1
param_name = grad_name[0].strip(core.grad_var_suffix())
device = self._param_device_map[param_name]
op._set_attr(self._op_device_key, device)
elif self._is_gradient_clip_op(op) or self._is_regularization_op(op):
# For gradient clip and regularization ops, we set their op_device
# attribute to the device where their corresponding parameters on.
assert self._op_role_var_key in op.attr_names, "gradient_clip " \
"and regularization ops must have op_role_var attribute."
op_role_var = op.attr(self._op_role_var_key)
assert len(op_role_var) == 2, "op_role_var for gradient_clip " \
"regularization ops must have two elements."
param_name = op_role_var[0]
device = self._param_device_map[param_name]
# For sum op added by global gradient clip, it must be
# put on all devices
if (op.type == 'sum' or op.type == 'sqrt' or
op.type == 'fill_constant' or
op.type == 'elementwise_max' or
op.type == 'elementwise_div'):
device = "gpu:all"
op._set_attr(self._op_device_key, device)
else:
other_known_ops = [
'update_loss_scaling', 'reduce_any', 'concat', 'sum'
]
assert op.type in other_known_ops, "For other ops without " \
"op_device set, they must be one of {}, but it " \
"is {}".format(other_known_ops, op.type)
assert self._is_optimize_op(op)
op._set_attr(self._op_device_key, "gpu:all")
def _add_op_device_attr(self, block):
"""
Add op_device attrribute for ops in block that have
not that attribute set.
"""
for idx, op in enumerate(list(block.ops)):
if (op.type == "create_py_reader" or op.type == "read" or
op.type == "create_double_buffer_reader"):
# Copy read related ops to all section to make them exit
# after each epoch.
# We use "gpu:all" to represent the op should be put on all
# sub-programs, such as lr-related ops. Note that: "gpu:all"
# is only used by pipeline as an indicator.
op._set_attr(self._op_device_key, "gpu:all")
continue
# op_device attribute has been set
if self._get_op_device_attr(op): continue
self._add_op_device_attr_for_op(op, idx, block)
def _check_validation(self, block): def _check_validation(self, block):
""" """
Check whether ops in a block are all validate (i.e., the Check whether ops in a block have the op_device attribute set.
op_device attribute has been set). Then, return all devices in order.
Then, return all device specifications in order.
""" """
device_specs = [] device_list = []
for op in block.ops: for op in block.ops:
type = op.type if not op._has_kernel(op.type):
if not op._has_kernel(type):
assert op.type == "conditional_block" and ( assert op.type == "conditional_block" and (
op.attr(self._op_role_key) == int(self._op_role.LRSched)), ( op.attr(self._op_role_key) == int(self._op_role.LRSched)), (
"Now, the only supported op without kernel is " "Now, the only supported op without kernel is "
...@@ -4165,15 +4393,16 @@ class PipelineOptimizer(object): ...@@ -4165,15 +4393,16 @@ class PipelineOptimizer(object):
assert op.has_attr(self._op_device_key), ( assert op.has_attr(self._op_device_key), (
"op ({}) has no {} attribute.".format(op.type, "op ({}) has no {} attribute.".format(op.type,
self._op_device_key)) self._op_device_key))
dev_spec = op.attr(self._op_device_key) device = op.attr(self._op_device_key)
assert dev_spec, ("op_device attribute for op " assert device, ("op_device attribute for op "
"{} has not been set.".format(op.type)) "{} has not been set.".format(op.type))
dev_type = dev_spec.split(':')[0] if device == "gpu:all": continue
dev_type = device.split(':')[0]
assert dev_type == "gpu", ("Now only gpu devices are supported " assert dev_type == "gpu", ("Now only gpu devices are supported "
"for pipeline parallelism.") "for pipeline parallelism.")
if not dev_spec in device_specs: if not device in device_list:
device_specs.append(dev_spec) device_list.append(device)
return device_specs return device_list
def _insert_sendrecv_ops_for_boundaries(self, block): def _insert_sendrecv_ops_for_boundaries(self, block):
""" """
...@@ -4182,75 +4411,387 @@ class PipelineOptimizer(object): ...@@ -4182,75 +4411,387 @@ class PipelineOptimizer(object):
""" """
extra_index = 0 extra_index = 0
# A map from var to device spec where op takes it as input, # A map from var to device where op takes it as input,
# avoiding multiple send and recv ops. # avoiding multiple send and recv ops.
var_devspec = dict() var_dev_map = dict()
for index, op in enumerate(list(block.ops)): for index, op in enumerate(list(block.ops)):
# skips lr-related ops and vars, as we will process them later. cur_device = op.attr(self._op_device_key)
if int(op.attr(self._op_role_key)) & int(self._op_role.LRSched): if cur_device == "gpu:all": continue
continue
# skips update ops and vars, as we will process them later.
if self._is_update_op(op): continue
cur_device_spec = op.attr(self._op_device_key)
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
# i.e., lod_tensor_blocking_queue created by DataLoader, # i.e., lod_tensor_blocking_queue created by DataLoader,
# which only exists in startup program. # which only exists in startup program.
if not var_name in block.vars: continue # if not var_name in block.vars: continue
var = block.var(var_name) var = block.var(var_name)
# skip data, because we will process it later # skip data, because we will process it later
if var.is_data: continue if var.is_data: continue
prev_device = None
if var_name in self._param_device_map:
prev_device = self._param_device_map[var_name]
prev_op = self._find_real_prev_op(block.ops, op, var_name) prev_op = self._find_real_prev_op(block.ops, op, var_name)
if prev_op is None: if not pre_device:
continue prev_device = prev_op.attr(self._op_device_key) \
prev_device_spec = prev_op.attr(self._op_device_key) if prev_op else None
if not prev_device or prev_device == 'gpu:all': continue
if prev_device_spec != cur_device_spec: if prev_device != cur_device:
if var_name not in var_devspec: if var_name not in var_dev_map: var_dev_map[var_name] = []
var_devspec[var_name] = [] if cur_device in var_dev_map[var_name]: continue
if cur_device_spec in var_devspec[var_name]: continue var_dev_map[var_name].append(cur_device)
var_devspec[var_name].append(cur_device_spec)
op_role = op.all_attrs()[self._op_role_key] op_role = op.all_attrs()[self._op_role_key]
var = block.vars[var_name] var = block.vars[var_name]
prev_device_index = int(prev_device_spec.split(':')[1]) prev_device_index = int(prev_device.split(':')[1])
cur_device_index = int(cur_device_spec.split(':')[1]) cur_device_index = int(cur_device.split(':')[1])
pair = (prev_device_index, cur_device_index)
pair_key = prev_device_index * 1000 + cur_device_index
if cur_device_index > prev_device_index:
ring_id = self.ring_id + cur_device_index - prev_device_index - 1
else:
ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1
if self.schedule_mode == 0: # GPipe
block._insert_op(
index=index + extra_index,
type='send_v2',
inputs={'X': var},
attrs={
self._op_device_key: prev_device,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': cur_device_index,
'ring_id': self.ring_id
if cur_device_index > prev_device_index else
self.ring_id + 2,
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='recv_v2',
outputs={'Out': [var]},
attrs={
'out_shape': var.shape,
'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': prev_device_index,
'ring_id': self.ring_id
if cur_device_index > prev_device_index else
self.ring_id + 2,
})
extra_index += 1
continue
assert self.schedule_mode == 1
if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair)
self._pp_ring_map[pair_key] = self.ring_id
ring_id = self.ring_id
self.ring_id += 1
else:
ring_id = self._pp_ring_map[pair_key]
block._insert_op( block._insert_op(
index=index + extra_index, index=index + extra_index,
type='send_v2', #type='send_v2',
type='c_broadcast',
inputs={'X': var}, inputs={'X': var},
outputs={'Out': var},
attrs={ attrs={
self._op_device_key: prev_device_spec, self._op_device_key: prev_device,
self._op_role_key: op_role, self._op_role_key: op_role,
'use_calc_stream': True, 'use_calc_stream': False,
'peer': cur_device_index, #'peer': cur_device_index,
#'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2,
'ring_id': ring_id,
#'ring_id': self.ring_id,
#'root': prev_device_index,
'root': 0,
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: cur_device,
self._op_role_key:
core.op_proto_and_checker_maker.OpRole.Backward,
#self._op_role_key: op_role,
'ring_id': ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
})
extra_index += 1
#block._insert_op(
# index=index + extra_index,
# type='c_sync_comm_stream',
# inputs={'X': [var]},
# outputs={'Out': [var]},
# attrs={
# self._op_device_key: cur_device,
# self._op_role_key:
# core.op_proto_and_checker_maker.OpRole.Backward,
# 'ring_id': self.ring_id,
# #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
# })
#extra_index += 1
fill_shape = list(var.shape)
fill_shape[0] = self.pp_bz
block._insert_op(
index=index + extra_index,
#type='recv_v2',
type='fill_constant',
inputs={},
outputs={'Out': [var]},
attrs={
'shape': fill_shape,
'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'value': float(0.0),
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: cur_device,
#self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward,
self._op_role_key: op_role,
'ring_id': ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
})
extra_index += 1
block._insert_op(
index=index + extra_index,
#type='recv_v2',
type='c_broadcast',
inputs={'X': var},
outputs={'Out': var},
attrs={
#'out_shape': var.shape,
#'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'use_calc_stream': False,
#'peer': prev_device_index,
#'root': prev_device_index,
'root': 0,
#'ring_id': self.ring_id,
'ring_id': ring_id,
#'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2,
#'ring_id': self.ring_id if prev_device_index < cur_device_index else self.ring_id + 2,
}) })
extra_index += 1 extra_index += 1
block._insert_op( block._insert_op(
index=index + extra_index, index=index + extra_index,
type='recv_v2', type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: cur_device,
#self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward,
self._op_role_key: op_role,
'ring_id': ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
})
extra_index += 1
def _xx_insert_sendrecv_ops_for_boundaries(self, block):
"""
Insert a pair of send and recv ops for every two
consecutive ops on different devices.
"""
extra_index = 0
# A map from var to device where op takes it as input,
# avoiding multiple send and recv ops.
input_var_to_device = dict()
# A map from output var to op which generate it.
output_var_to_op = dict()
for index, op in enumerate(list(block.ops)):
for var_name in op.output_arg_names:
ops = output_var_to_op.setdefault(var_name, [])
ops.append([op, index])
for index, op in enumerate(list(block.ops)):
cur_device = op.attr(self._op_device_key)
if cur_device == "gpu:all": continue
for var_name in op.input_arg_names:
var = block.var(var_name)
if var.is_data: continue
#if var_name not in input_var_to_device:
# input_var_to_device[var_name] = []
#if cur_device in input_var_to_device[var_name]:
# continue
#input_var_to_device[var_name].append(cur_device)
prev_device = None
generate_ops = output_var_to_op.get(var_name)
if generate_ops is None:
if var_name not in self._param_device_map:
continue
prev_device = self._param_device_map[var_name]
prev_op = None
for gen_op, gen_idx in reversed(generate_ops):
if gen_idx < index:
prev_op = gen_op
break
if not prev_device:
prev_device = prev_op.attr(self._op_device_key) \
if prev_op else None
if prev_device is None or prev_device == 'gpu:all': continue
if prev_device == cur_device: continue
if var_name not in input_var_to_device:
input_var_to_device[var_name] = []
if (cur_device, prev_device) in input_var_to_device[var_name]:
continue
device_type = cur_device.split(':')[0] + ':'
def _insert_send_recv(cur_id, prev_id):
nonlocal extra_index
cur_dev = device_type + str(cur_id)
prev_dev = device_type + str(prev_id)
if (cur_dev, prev_dev) in input_var_to_device[var_name]:
return
if cur_id - prev_id > 1:
_insert_send_recv(cur_id - 1, prev_id)
_insert_send_recv(cur_id, cur_id - 1)
input_var_to_device[var_name].append(
(cur_dev, prev_dev))
return
elif cur_id - prev_id < -1:
_insert_send_recv(cur_id + 1, prev_id)
_insert_send_recv(cur_id, cur_id + 1)
input_var_to_device[var_name].append(
(cur_dev, prev_dev))
return
assert abs(cur_id - prev_id) == 1
input_var_to_device[var_name].append((cur_dev, prev_dev))
op_role = op.all_attrs()[self._op_role_key]
var = block.vars[var_name]
pair = (prev_id, cur_id)
pair_key = prev_id * 1000 + cur_id
if cur_id > prev_id:
ring_id = self.ring_id + cur_id - prev_id - 1
else:
ring_id = self.ring_id + 2 + prev_id - cur_id - 1
print("call xx_insert, schedule_mode:", self.schedule_mode)
assert self.schedule_mode == 1
if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair)
self._pp_ring_map[pair_key] = self.ring_id
ring_id = self.ring_id
self.ring_id += 1
else:
ring_id = self._pp_ring_map[pair_key]
print("opt: pp_pair: {}, ring_id: {}".format(pair, ring_id))
block._insert_op_without_sync(
index=index + extra_index,
type='c_sync_calc_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: prev_dev,
self._op_role_key: op_role,
})
extra_index += 1
block._insert_op_without_sync(
index=index + extra_index,
type="c_broadcast",
inputs={'X': var},
outputs={'Out': var},
attrs={
self._op_device_key: prev_dev,
self._op_role_key: op_role,
'use_calc_stream': False,
'ring_id': ring_id,
'root': 0,
})
extra_index += 1
fill_shape = list(var.shape)
fill_shape[0] = self.pp_bz
block._insert_op_without_sync(
index=index + extra_index,
type='fill_constant',
inputs={},
outputs={'Out': [var]}, outputs={'Out': [var]},
attrs={ attrs={
'out_shape': var.shape, 'shape': fill_shape,
'dtype': var.dtype, 'dtype': var.dtype,
self._op_device_key: cur_device_spec, self._op_device_key: cur_dev,
self._op_role_key: op_role,
'value': float(0.0),
})
extra_index += 1
block._insert_op_without_sync(
index=index + extra_index,
type='c_broadcast',
inputs={'X': var},
outputs={'Out': var},
attrs={
self._op_device_key: cur_dev,
self._op_role_key: op_role, self._op_role_key: op_role,
'use_calc_stream': True, 'use_calc_stream': True,
'peer': prev_device_index, 'root': 0,
'ring_id': ring_id,
}) })
extra_index += 1 extra_index += 1
def _clear_gradients(self, main_block, dev_spec): # block._insert_op_without_sync(
# index=index + extra_index,
# type='c_sync_comm_stream',
# inputs={'X': [var]},
# outputs={'Out': [var]},
# attrs={
# self._op_device_key: cur_dev,
# self._op_role_key: op_role,
# 'ring_id': ring_id,
# })
# extra_index += 1
_insert_send_recv(
int(cur_device.split(':')[1]),
int(prev_device.split(':')[1]))
block._sync_with_cpp()
def _clear_gradients(self, main_block, param_names):
""" """
Clear gradients at the begining of each run of a minibatch. Clear gradients at the begining of each run of a minibatch.
""" """
for param_name in self._param_device_map: # for param_name in self._param_device_map:
device = self._param_device_map[param_name] print("param_names:", param_names)
if device != dev_spec: continue for param_name in param_names:
# device = self._param_device_map[param_name]
# if device != dev_spec: continue
grad_name = self._append_grad_suffix(param_name) grad_name = self._append_grad_suffix(param_name)
if not main_block.has_var(grad_name): continue # if not main_block.has_var(grad_name): continue
grad_var = main_block.vars[grad_name] assert main_block.has_var(grad_name)
grad_var = main_block.var(grad_name)
grad_var.persistable = True
main_block._insert_op( main_block._insert_op(
index=0, index=0,
type='fill_constant', type='fill_constant',
...@@ -4260,21 +4801,20 @@ class PipelineOptimizer(object): ...@@ -4260,21 +4801,20 @@ class PipelineOptimizer(object):
'shape': grad_var.shape, 'shape': grad_var.shape,
'dtype': grad_var.dtype, 'dtype': grad_var.dtype,
'value': float(0), 'value': float(0),
self._op_device_key: device, # self._op_device_key: device,
# a trick to run this op once per mini-batch # a trick to run this op once per mini-batch
self._op_role_key: self._op_role.Optimize.LRSched, self._op_role_key: self._op_role.Optimize.LRSched,
}) })
def _accumulate_gradients(self, block): def _insert_loss_scale(self, block):
""" """
Accumulate the gradients generated in microbatch to the one in mini-batch.
We also scale the loss corresponding to number of micro-batches as well. We also scale the loss corresponding to number of micro-batches as well.
""" """
if self._num_microbatches == 1: return
for index, op in reversed(tuple(enumerate(list(block.ops)))): for index, op in reversed(tuple(enumerate(list(block.ops)))):
offset = index offset = index
device = op.attr(self._op_device_key) #device = op.attr(self._op_device_key)
# Backward pass
if self._is_loss_grad_op(op): if self._is_loss_grad_op(op):
loss_grad_var = block.vars[op.output_arg_names[0]] loss_grad_var = block.vars[op.output_arg_names[0]]
scale_factor = self._num_microbatches scale_factor = self._num_microbatches
...@@ -4285,36 +4825,130 @@ class PipelineOptimizer(object): ...@@ -4285,36 +4825,130 @@ class PipelineOptimizer(object):
outputs={'Out': loss_grad_var}, outputs={'Out': loss_grad_var},
attrs={ attrs={
'scale': 1.0 / scale_factor, 'scale': 1.0 / scale_factor,
self._op_device_key: device, #self._op_device_key: device,
self._op_role_key: self._op_role.Backward self._op_role_key: self._op_role.Backward
}) })
break break
def _rename_gradient_var_name(self, block):
for index, op in enumerate(block.ops):
if not self._is_optimize_op(op): continue
input_names = op.input_arg_names
output_names = op.output_arg_names
in_out_names = input_names + output_names
if op.type == 'cast': continue
# append "MERGED" to the names of parameter gradients,
# and mofify the op_role_var attribute (by rename_arg func).
for name in in_out_names:
if not core.grad_var_suffix() in name: continue
param_name = name.strip(core.grad_var_suffix())
new_grad_name = name + "@MERGED"
self._rename_arg(op, name, new_grad_name)
def _accumulate_gradients(self, block, pp_allreduce_in_optimize=False):
"""
Create a new merged gradient for each parameter and accumulate the
corresponding gradient to it.
"""
merged_gradient_names = []
first_opt_op_idx = None
for index, op in reversed(tuple(enumerate(list(block.ops)))):
# remove the cast op of fp16 grad to fp32 grad
if self._is_optimize_op(op) and op.type == 'cast':
in_name = op.input_arg_names[0]
out_name = op.output_arg_names[0]
if out_name.strip('@GRAD') in self._param_device_map:
assert in_name.replace('.cast_fp16', '') == out_name
block._remove_op(index)
continue
if self._is_backward_op(op) and not first_opt_op_idx:
first_opt_op_idx = index + 1
if block.ops[first_opt_op_idx].type == "c_sync_comm_stream":
#block.ops[first_opt_op_idx]._set_attr(self._op_role_key, self._op_role.Backward)
first_opt_op_idx += 1
if self._is_backward_op(op) and ( if self._is_backward_op(op) and (
self._op_role_var_key in op.attr_names): self._op_role_var_key in op.attr_names):
op_role_var = op.all_attrs()[self._op_role_var_key] op_role_var = op.attr(self._op_role_var_key)
if len(op_role_var) == 0: if len(op_role_var) == 0:
continue continue
assert len(op_role_var) % 2 == 0 assert len(op_role_var) % 2 == 0
offset = index #op._remove_attr(self._op_role_var_key)
for i in range(0, len(op_role_var), 2): for i in range(0, len(op_role_var), 2):
grad_name = op_role_var[i + 1] offset = 0
grad_var = block.vars[grad_name] param_name = op_role_var[i]
new_grad_var_name = unique_name.generate(grad_name) if not block.has_var(param_name): continue
new_var = self._create_var(block, grad_var, if '@BroadCast' in param_name: continue
new_grad_var_name) param_grad_name = param_name + core.grad_var_suffix()
self._rename_arg(op, grad_name, new_grad_var_name) merged_param_grad_name = param_grad_name + '@MERGED'
if not block.has_var(merged_param_grad_name):
self._create_var(block, block.vars[param_name],
merged_param_grad_name)
assert block.has_var(merged_param_grad_name)
param_grad_var = block.var(param_grad_name)
merged_param_grad_var = block.var(merged_param_grad_name)
merged_param_grad_var.persistable = True
block._insert_op( block._insert_op(
index=offset + 1, index=first_opt_op_idx + offset,
type='sum', type='fill_constant',
inputs={'X': [grad_var, new_var]}, inputs={},
outputs={'Out': grad_var}, outputs={'Out': [merged_param_grad_var]},
attrs={ attrs={
self._op_device_key: device, 'shape': merged_param_grad_var.shape,
self._op_role_key: self._op_role.Backward, 'dtype': merged_param_grad_var.dtype,
self._op_role_var_key: op_role_var 'value': float(0),
# a trick to run this op once per mini-batch
self._op_role_key: self._op_role.Optimize.LRSched,
}) })
offset += 1 offset += 1
grad_name = op_role_var[i + 1]
grad_var = block.vars[grad_name]
if not 'cast_fp16' in grad_name:
block._insert_op(
index=first_opt_op_idx + offset,
type='sum',
inputs={'X': [grad_var, merged_param_grad_var]},
outputs={'Out': merged_param_grad_var},
attrs={
self._op_role_key: self._op_role.Backward,
})
offset += 1
merged_gradient_names.append(merged_param_grad_name)
else:
# cast gradient to fp32 to accumulate to merged gradient
cast_grad_var_name = param_grad_name + '@TMP'
cast_grad_var = self._create_var(block, param_grad_var,
cast_grad_var_name)
cast_grad_var.persistable = False
block._insert_op(
index=first_opt_op_idx + offset,
type='cast',
inputs={'X': grad_var},
outputs={'Out': cast_grad_var},
attrs={
'in_dtype': grad_var.dtype,
'out_dtype': cast_grad_var.dtype,
self._op_role_key: self._op_role.Backward,
})
offset += 1
block._insert_op(
index=first_opt_op_idx + offset,
type='sum',
inputs={
'X': [merged_param_grad_var, cast_grad_var]
},
outputs={'Out': merged_param_grad_var},
attrs={
# self._op_device_key: device,
self._op_role_key: self._op_role.Backward,
#self._op_role_var_key: op_role_var
})
offset += 1
merged_gradient_names.append(merged_param_grad_name)
return merged_gradient_names
def _add_sub_blocks(self, main_block, program_list): def _add_sub_blocks(self, main_block, program_list):
main_program = main_block.program main_program = main_block.program
...@@ -4372,7 +5006,7 @@ class PipelineOptimizer(object): ...@@ -4372,7 +5006,7 @@ class PipelineOptimizer(object):
block = prog.block(0) block = prog.block(0)
for op in block.ops: for op in block.ops:
if op.type == "recv_v2" or op.type == "create_py_reader" or \ if op.type == "recv_v2" or op.type == "create_py_reader" or \
op.type == "read": op.type == "read" or op.type == "update_loss_scaling":
continue continue
# We have processed lr related vars # We have processed lr related vars
if op.attr(self._op_role_key) == int( if op.attr(self._op_role_key) == int(
...@@ -4407,11 +5041,14 @@ class PipelineOptimizer(object): ...@@ -4407,11 +5041,14 @@ class PipelineOptimizer(object):
inputs={'X': write_block.var(var_name), }, inputs={'X': write_block.var(var_name), },
attrs={ attrs={
self._op_device_key: write_device, self._op_device_key: write_device,
'use_calc_stream': True, 'use_calc_stream': False,
# A trick to make the role LRSched to avoid copy every # A trick to make the role LRSched to avoid copy every
# microbatch # microbatch
self._op_role_key: self._op_role.LRSched, self._op_role_key: self._op_role.LRSched,
'peer': read_dev_index, 'peer': read_dev_index,
#'ring_id': self.ring_id,
'ring_id': self.ring_id if
read_dev_index > write_dev_index else self.ring_id + 2,
}) })
read_block._insert_op( read_block._insert_op(
index=0, index=0,
...@@ -4421,34 +5058,77 @@ class PipelineOptimizer(object): ...@@ -4421,34 +5058,77 @@ class PipelineOptimizer(object):
'out_shape': read_block.var(var_name).shape, 'out_shape': read_block.var(var_name).shape,
'dtype': read_block.var(var_name).dtype, 'dtype': read_block.var(var_name).dtype,
self._op_device_key: read_device, self._op_device_key: read_device,
'use_calc_stream': True, 'use_calc_stream': False,
# A trick to make the role LRSched to avoid copy every
# microbatch
self._op_role_key: self._op_role.LRSched,
'peer': write_dev_index,
#'ring_id': self.ring_id,
'ring_id': self.ring_id if
write_dev_index < read_dev_index else self.ring_id + 2,
})
read_block._insert_op(
index=1,
type='c_sync_comm_stream',
inputs={'X': [read_block.var(var_name)]},
outputs={'Out': [read_block.var(var_name)]},
attrs={
self._op_device_key: read_device,
# A trick to make the role LRSched to avoid copy every # A trick to make the role LRSched to avoid copy every
# microbatch # microbatch
self._op_role_key: self._op_role.LRSched, self._op_role_key: self._op_role.LRSched,
'peer': write_dev_index #'ring_id': self.ring_id,
'ring_id': self.ring_id if
write_dev_index > read_dev_index else self.ring_id + 2,
}) })
def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip")
def _is_regularization_op(self, op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/regularization")
def minimize(self, def minimize(self,
loss, loss,
startup_program=None, startup_program=None,
parameter_list=None, parameter_list=None,
no_grad_set=None): no_grad_set=None):
main_block = loss.block main_block = loss.block
self.origin_main_block = main_block
if startup_program is None: if startup_program is None:
startup_program = default_startup_program() startup_program = default_startup_program()
optimize_ops, params_grads = self._optimizer.minimize( optimize_ops, params_grads = self._optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set) loss, startup_program, parameter_list, no_grad_set)
self._param_device_map = self._optimizer._param_device_map self._param_device_map = self._origin_optimizer._param_device_map
assert main_block.program._pipeline_opt \
# Step1: add default op_device attribute for regulization and clip ops and 'local_rank' in main_block.program._pipeline_opt, \
self._add_opdevice_attr_for_regularization_clip(main_block) 'Please use pipeline with fleet.'
local_rank = main_block.program._pipeline_opt['local_rank']
# Step2: add default op_device attribute for ops whose op_device schedule_mode = 0
# attribute have not been set yet. Then check all ops have the if 'schedule_mode' in main_block.program._pipeline_opt:
# op_device attribute. schedule_mode = main_block.program._pipeline_opt['schedule_mode']
self._add_default_opdevice_attr(main_block) self.schedule_mode = schedule_mode
self.pp_bz = main_block.program._pipeline_opt['pp_bz']
device_specs = self._check_validation(main_block)
self.use_sharding = False
if 'use_sharding' in main_block.program._pipeline_opt:
self.use_sharding = main_block.program._pipeline_opt['use_sharding']
self.ring_id = 0
if 'ring_id' in main_block.program._pipeline_opt:
self.ring_id = main_block.program._pipeline_opt['ring_id']
if main_block.program._pipeline_opt['global_rank'] == 0:
with open("startup_raw", 'w') as f:
f.writelines(str(startup_program))
with open("main_raw", 'w') as f:
f.writelines(str(main_block.program))
# Step1: add default op_device attribute for ops.
self._add_op_device_attr(main_block)
device_list = self._check_validation(main_block)
def device_cmp(device1, device2): def device_cmp(device1, device2):
dev1_id = int(device1.split(':')[1]) dev1_id = int(device1.split(':')[1])
...@@ -4460,66 +5140,169 @@ class PipelineOptimizer(object): ...@@ -4460,66 +5140,169 @@ class PipelineOptimizer(object):
else: else:
return 0 return 0
sorted_device_spec = sorted(device_specs, key=cmp_to_key(device_cmp)) sorted_device_list = sorted(device_list, key=cmp_to_key(device_cmp))
assert sorted_device_spec == device_specs, ( assert sorted_device_list == device_list, (
"With pipeline " "With pipeline parallelism, you must use gpu devices one after "
"parallelism, you must use gpu devices one after another " "another in the order of their ids.")
"in the order of their ids.")
# Step3: add send and recv ops between section boundaries # Step2: add send and recv ops between section boundaries
self._insert_sendrecv_ops_for_boundaries(main_block) self._xx_insert_sendrecv_ops_for_boundaries(main_block)
# Step4: split program into sections and add pairs of # Step3: split program into sections and add pairs of
# send and recv ops for data var. # send and recv ops for data var.
main_program = main_block.program main_program = main_block.program
program_list = self._split_program(main_program, device_specs) program_list = self._split_program(main_program, device_list)
#cur_device_index = 0
#device_num = len(program_list)
for p in program_list: for p in program_list:
self._create_vars(p["program"].block(0), self._create_vars(p["program"].block(0), main_block)
main_program.global_block()) # # Add send/recv pair to sync the execution.
self._insert_sendrecv_for_data_var(main_block, program_list, # block = p['program'].block(0)
startup_program, device_specs) # prev_device_index = cur_device_index - 1
# next_device_index = cur_device_index + 1
# Step5: Special Case: process persistable vars that exist in # add_send_for_forward = False
# add_send_for_backward = False
# add_recv_for_backward = False
# extra_index = 0
# new_var = block.create_var(
# name=unique_name.generate('sync'),
# shape=[1],
# dtype='float32',
# persistable=False,
# stop_gradient=True)
# block._insert_op(
# index=0,
# type='fill_constant',
# inputs={},
# outputs={'Out': [new_var]},
# attrs={
# 'shape': [1],
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Forward,
# 'value': float(0.0),
# })
# extra_index += 1
# for op_idx, op in enumerate(list(block.ops)):
# if op_idx == extra_index:
# if cur_device_index > 0:
# pair_key = prev_device_index * 1000 + cur_device_index
# ring_id = self._pp_ring_map[pair_key]
# block._insert_op(
# index=op_idx,
# type='recv_v2',
# outputs={'Out': [new_var]},
# attrs={
# 'out_shape': new_var.shape,
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Forward,
# 'peer': 0,
# 'use_calc_stream': True,
# 'ring_id': ring_id,
# })
# extra_index += 1
# continue
# if op.type == "send_v2" and self._is_forward_op(op) \
# and not add_send_for_forward \
# and cur_device_index < device_num - 1:
# add_send_for_forward = True
# pair_key = cur_device_index * 1000 + next_device_index
# ring_id = self._pp_ring_map[pair_key]
# block._insert_op(
# index=op_idx + extra_index,
# type='send_v2',
# inputs={'Out': new_var},
# attrs={
# 'out_shape': new_var.shape,
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Forward,
# 'peer': 1,
# 'use_calc_stream': True,
# 'ring_id': ring_id,
# })
# extra_index += 1
# if self._is_backward_op(op) and not add_recv_for_backward \
# and cur_device_index < device_num - 1:
# pair_key = next_device_index * 1000 + cur_device_index
# add_recv_for_backward = True
# ring_id = self._pp_ring_map[pair_key]
# block._insert_op(
# index=op_idx + extra_index,
# type='recv_v2',
# outputs={'Out': [new_var]},
# attrs={
# 'out_shape': new_var.shape,
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Backward,
# 'peer': 0,
# 'use_calc_stream': True,
# 'ring_id': ring_id,
# })
# if op.type == "send_v2" and self._is_backward_op(op) \
# and not add_send_for_backward \
# and cur_device_index > 0:
# pair_key = cur_device_index * 1000 + prev_device_index
# add_send_for_backward = True
# ring_id = self._pp_ring_map[pair_key]
# block._insert_op(
# index=op_idx + extra_index,
# type='send_v2',
# outputs={'Out': [new_var]},
# attrs={
# 'out_shape': new_var.shape,
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Backward,
# 'peer': 1,
# 'use_calc_stream': True,
# 'ring_id': ring_id,
# })
# cur_device_index += 1
#self._insert_sendrecv_for_data_var(main_block, program_list,
# startup_program, device_list)
# Step4: Special Case: process persistable vars that exist in
# multiple sections # multiple sections
self._process_persistable_vars_in_multi_sections( #self._process_persistable_vars_in_multi_sections(
main_program, startup_program, program_list) # main_program, startup_program, program_list)
# Step6: Add sub blocks for section programs # Step5: Add sub blocks for section programs
self._add_sub_blocks(main_block, program_list) self._add_sub_blocks(main_block, program_list)
assert (main_program._pipeline_opt and local_rank = main_program._pipeline_opt['local_rank'] % len(device_list)
isinstance(main_program._pipeline_opt, dict) and
'local_rank' in main_program._pipeline_opt), \
"You must use pipeline with fleet"
local_rank = main_program._pipeline_opt['local_rank'] % len(
device_specs)
place_list = [] place_list = []
for dev_spec in device_specs: for dev in device_list:
dev_index = dev_spec.split(":")[1] dev_index = int(dev.split(":")[1])
place_list.append(core.CUDAPlace(local_rank)) place_list.append(core.CUDAPlace(dev_index % 8))
# Step7: Split startup program # Step6: Split startup program
new_startup_program = self._split_startup_program(startup_program, new_startup_program = self._split_startup_program(startup_program,
local_rank) local_rank)
# Step8: clear gradients before each mini-batch and
# accumulate gradients during backward
self._clear_gradients(
program_list[local_rank]['program'].global_block(),
dev_spec=device_specs[local_rank])
self._accumulate_gradients(program_list[local_rank]['program']
.global_block())
startup_program._pipeline_opt = { startup_program._pipeline_opt = {
"startup_program": new_startup_program, "startup_program": new_startup_program,
} }
real_block = program_list[local_rank]['program'].global_block()
self._insert_loss_scale(real_block)
if not self.use_sharding:
# Step7: clear gradients before each mini-batch and
# accumulate gradients during backward
param_list = []
for param, grad in params_grads:
if real_block.has_var(param): param_list.append(param)
#self._clear_gradients(real_block, param_list)
self._rename_gradient_var_name(real_block)
real_block._sync_with_cpp()
self._accumulate_gradients(real_block)
real_block._sync_with_cpp()
place_id = int(os.getenv("FLAGS_selected_gpus", "0")) place_id = int(os.getenv("FLAGS_selected_gpus", "0"))
main_program._pipeline_opt = { main_program._pipeline_opt = {
"trainer": "PipelineTrainer", "trainer": "PipelineTrainer",
"device_worker": "Section", "device_worker": "Section",
"inner_parallelism": len(device_specs), "inner_parallelism": len(device_list),
"num_pipeline_stages": len(device_list),
"pipeline_stage": local_rank,
"schedule_mode": schedule_mode,
"section_program": program_list[local_rank], "section_program": program_list[local_rank],
"place": place_list[local_rank], "place": place_list[local_rank],
"place_id": place_id, "place_id": place_id,
...@@ -4527,7 +5310,7 @@ class PipelineOptimizer(object): ...@@ -4527,7 +5310,7 @@ class PipelineOptimizer(object):
"num_microbatches": self._num_microbatches, "num_microbatches": self._num_microbatches,
"start_cpu_core_id": self._start_cpu_core_id, "start_cpu_core_id": self._start_cpu_core_id,
} }
return optimize_ops, params_grads, program_list return optimize_ops, params_grads, program_list, self._pipeline_pair, self._pp_ring_map
class RecomputeOptimizer(Optimizer): class RecomputeOptimizer(Optimizer):
...@@ -4928,10 +5711,10 @@ class RecomputeOptimizer(Optimizer): ...@@ -4928,10 +5711,10 @@ class RecomputeOptimizer(Optimizer):
for output_var in output_vars: for output_var in output_vars:
if output_var in need_offload_checkpoint_names: if output_var in need_offload_checkpoint_names:
assert len( #assert len(
output_vars # output_vars
) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format( #) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format(
output_var, op) # output_var, op)
if output_var in self.un_offload_checkpoint_names: if output_var in self.un_offload_checkpoint_names:
# insert sync op if last checkpoint has not been sync # insert sync op if last checkpoint has not been sync
...@@ -4956,14 +5739,14 @@ class RecomputeOptimizer(Optimizer): ...@@ -4956,14 +5739,14 @@ class RecomputeOptimizer(Optimizer):
format(output_var)) format(output_var))
# need to sync the last need to offload checkpoint before the last checkpoint as output op # need to sync the last need to offload checkpoint before the last checkpoint as output op
if output_var == last_checkpoint: if output_var == last_checkpoint:
assert len( #assert len(
output_vars # output_vars
) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format( #) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format(
output_var, op) # output_var, op)
assert last_offload_checkpoint == self.sorted_checkpoint_names[ #assert last_offload_checkpoint == self.sorted_checkpoint_names[
-2], "the last offload chekpoint before [{}] is suppose to be [{}], but got [{}]".format( # -2], "the last offload chekpoint before [{}] is suppose to be [{}], but got [{}]".format(
last_checkpoint, self.sorted_checkpoint_names[-2], # last_checkpoint, self.sorted_checkpoint_names[-2],
last_offload_checkpoint) # last_offload_checkpoint)
# sync if last checkpoint has not been sync # sync if last checkpoint has not been sync
if self.checkpoint_usage_count_and_idx[ if self.checkpoint_usage_count_and_idx[
last_offload_checkpoint]['idx'] == 0: last_offload_checkpoint]['idx'] == 0:
...@@ -5487,7 +6270,7 @@ class GradientMergeOptimizer(object): ...@@ -5487,7 +6270,7 @@ class GradientMergeOptimizer(object):
def _is_the_backward_op(self, op): def _is_the_backward_op(self, op):
op_maker = core.op_proto_and_checker_maker op_maker = core.op_proto_and_checker_maker
backward = core.op_proto_and_checker_maker.OpRole.Backward backward = core.op_proto_and_checker_maker.OpRole.Bcackward
if op_maker.kOpRoleVarAttrName() in op.attr_names and \ if op_maker.kOpRoleVarAttrName() in op.attr_names and \
int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(backward): int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(backward):
return True return True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册