未验证 提交 228bce12 编写于 作者: L lilong12 提交者: GitHub

Add 3d parallelism (#31796)

Add 3d Parallelism
Co-authored-by: NWangXi <wangxi16@baidu.com>
Co-authored-by: NJZ-LIANG <jianzhongliang10@gmail.com>
Co-authored-by: Nroot <root@yq01-sys-hic-k8s-v100-box-a225-0562.yq01.baidu.com>
上级 594bbcb1
...@@ -28,6 +28,7 @@ limitations under the License. */ ...@@ -28,6 +28,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/heter_service.h" #include "paddle/fluid/framework/heter_service.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -451,7 +452,7 @@ class HeterBoxWorker : public HogwildWorker { ...@@ -451,7 +452,7 @@ class HeterBoxWorker : public HogwildWorker {
virtual void CacheProgram(const ProgramDesc& main_program) { virtual void CacheProgram(const ProgramDesc& main_program) {
new (&program_) ProgramDesc(main_program); new (&program_) ProgramDesc(main_program);
} }
virtual void ProduceTasks() override; void ProduceTasks() override;
virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; } virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; }
virtual void SetEvent(const cudaEvent_t event) { event_ = event; } virtual void SetEvent(const cudaEvent_t event) { event_ = event; }
virtual void TrainFilesWithProfiler() {} virtual void TrainFilesWithProfiler() {}
...@@ -550,7 +551,7 @@ class PSGPUWorker : public HogwildWorker { ...@@ -550,7 +551,7 @@ class PSGPUWorker : public HogwildWorker {
virtual void CacheProgram(const ProgramDesc& main_program) { virtual void CacheProgram(const ProgramDesc& main_program) {
new (&program_) ProgramDesc(main_program); new (&program_) ProgramDesc(main_program);
} }
virtual void ProduceTasks() override; void ProduceTasks() override;
virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; } virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; }
virtual void SetEvent(const cudaEvent_t event) { event_ = event; } virtual void SetEvent(const cudaEvent_t event) { event_ = event; }
virtual void TrainFilesWithProfiler() {} virtual void TrainFilesWithProfiler() {}
...@@ -654,6 +655,9 @@ class SectionWorker : public DeviceWorker { ...@@ -654,6 +655,9 @@ class SectionWorker : public DeviceWorker {
void SetDeviceIndex(int tid) override {} void SetDeviceIndex(int tid) override {}
void SetThreadIndex(int thread_id) { thread_id_ = thread_id; } void SetThreadIndex(int thread_id) { thread_id_ = thread_id; }
void SetMicrobatchNum(int num) { num_microbatches_ = num; } void SetMicrobatchNum(int num) { num_microbatches_ = num; }
void SetPipelineStageNum(int num) { num_pipeline_stages_ = num; }
void SetPipelineStage(int stage) { pipeline_stage_ = stage; }
void SetScheduleMode(int mode) { schedule_mode_ = mode; }
void SetMicrobatchScopes(const std::vector<Scope*>& scope) { void SetMicrobatchScopes(const std::vector<Scope*>& scope) {
microbatch_scopes_ = scope; microbatch_scopes_ = scope;
} }
...@@ -661,11 +665,23 @@ class SectionWorker : public DeviceWorker { ...@@ -661,11 +665,23 @@ class SectionWorker : public DeviceWorker {
void SetSkipVars(const std::vector<std::string>& skip_vars) { void SetSkipVars(const std::vector<std::string>& skip_vars) {
skip_vars_ = skip_vars; skip_vars_ = skip_vars;
} }
void RunBackward(
int micro_id, std::unique_ptr<GarbageCollector>&,
std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
void RunForward(
int micro_id, std::unique_ptr<GarbageCollector>&,
std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
void RunUpdate(
std::unique_ptr<GarbageCollector>&,
std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
protected: protected:
int section_id_; int section_id_;
int thread_id_; int thread_id_;
int num_microbatches_; int num_microbatches_;
int num_pipeline_stages_;
int pipeline_stage_;
int schedule_mode_; // 0 for GPipe and 1 for deepspeed
std::vector<Scope*> microbatch_scopes_; std::vector<Scope*> microbatch_scopes_;
std::vector<std::string> skip_vars_; std::vector<std::string> skip_vars_;
const Scope* minibatch_scope_; const Scope* minibatch_scope_;
......
...@@ -32,6 +32,14 @@ message ShardingConfig { ...@@ -32,6 +32,14 @@ message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ]; optional float fuse_broadcast_MB = 1 [ default = 32.0 ];
optional bool hybrid_dp = 2 [ default = false ]; optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_group_size = 3 [ default = 8 ]; optional int32 sharding_group_size = 3 [ default = 8 ];
optional bool as_outer_parallelism = 4 [ default = false ];
optional int32 parallelism = 5 [ default = 1 ];
optional bool use_pipeline = 6 [ default = false ];
optional int32 acc_steps = 7 [ default = 1 ];
optional int32 schedule_mode = 8 [ default = 0 ];
optional int32 pp_bz = 9 [ default = 1 ];
optional bool pp_allreduce_in_optimize = 10 [ default = false ];
optional bool optimize_offload = 11 [ default = false ];
} }
message AMPConfig { message AMPConfig {
...@@ -44,6 +52,8 @@ message AMPConfig { ...@@ -44,6 +52,8 @@ message AMPConfig {
repeated string custom_white_list = 7; repeated string custom_white_list = 7;
repeated string custom_black_list = 8; repeated string custom_black_list = 8;
repeated string custom_black_varnames = 9; repeated string custom_black_varnames = 9;
optional bool use_pure_fp16 = 10 [ default = false ];
optional bool use_fp16_guard = 11 [ default = true ];
} }
message LocalSGDConfig { message LocalSGDConfig {
...@@ -117,6 +127,8 @@ message AsyncConfig { ...@@ -117,6 +127,8 @@ message AsyncConfig {
message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; } message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; }
message ModelParallelConfig { optional int32 parallelism = 1 [ default = 1 ]; }
message DistributedStrategy { message DistributedStrategy {
// bool options // bool options
optional Mode mode = 1 [ default = COLLECTIVE ]; optional Mode mode = 1 [ default = COLLECTIVE ];
...@@ -140,12 +152,13 @@ message DistributedStrategy { ...@@ -140,12 +152,13 @@ message DistributedStrategy {
optional int32 fuse_grad_size_in_MB = 19 [ default = 32 ]; optional int32 fuse_grad_size_in_MB = 19 [ default = 32 ];
optional float fuse_grad_size_in_TFLOPS = 20 [ default = 50 ]; optional float fuse_grad_size_in_TFLOPS = 20 [ default = 50 ];
optional bool cudnn_exhaustive_search = 21 [ default = true ]; optional bool cudnn_exhaustive_search = 21 [ default = true ];
optional int32 conv_workspace_size_limit = 22 [ default = 4000 ]; optional int32 conv_workspace_size_limit = 22 [ default = 512 ];
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ]; optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ];
optional bool adaptive_localsgd = 24 [ default = false ]; optional bool adaptive_localsgd = 24 [ default = false ];
optional bool fp16_allreduce = 25 [ default = false ]; optional bool fp16_allreduce = 25 [ default = false ];
optional bool sharding = 26 [ default = false ]; optional bool sharding = 26 [ default = false ];
optional float last_comm_group_size_MB = 27 [ default = 1 ]; optional float last_comm_group_size_MB = 27 [ default = 1 ];
optional bool model_parallel = 28 [ default = false ];
optional RecomputeConfig recompute_configs = 101; optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102; optional AMPConfig amp_configs = 102;
...@@ -158,6 +171,7 @@ message DistributedStrategy { ...@@ -158,6 +171,7 @@ message DistributedStrategy {
optional LambConfig lamb_configs = 109; optional LambConfig lamb_configs = 109;
optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110; optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110;
optional ShardingConfig sharding_configs = 111; optional ShardingConfig sharding_configs = 111;
optional ModelParallelConfig model_parallel_configs = 112;
optional BuildStrategy build_strategy = 201; optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202; optional ExecutionStrategy execution_strategy = 202;
} }
......
...@@ -25,6 +25,9 @@ namespace framework { ...@@ -25,6 +25,9 @@ namespace framework {
void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) { Dataset* dataset) {
const auto& section_params = trainer_desc.section_param(); const auto& section_params = trainer_desc.section_param();
const auto num_pipeline_stages_ = section_params.num_pipeline_stages();
const auto pipeline_stage_ = section_params.pipeline_stage();
const auto schedule_mode_ = section_params.schedule_mode();
num_microbatches_ = section_params.num_microbatches(); num_microbatches_ = section_params.num_microbatches();
VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_; VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_;
trainer_desc_ = trainer_desc; trainer_desc_ = trainer_desc;
...@@ -40,6 +43,9 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -40,6 +43,9 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
this_worker->SetPlace(place_); this_worker->SetPlace(place_);
this_worker->Initialize(trainer_desc); this_worker->Initialize(trainer_desc);
this_worker->SetMicrobatchNum(num_microbatches_); this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
this_worker->SetScheduleMode(schedule_mode_);
} }
void PipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) { void PipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) {
...@@ -76,7 +82,10 @@ void PipelineTrainer::CopyParameters(int microbatch_id, ...@@ -76,7 +82,10 @@ void PipelineTrainer::CopyParameters(int microbatch_id,
for (auto& var : global_block.AllVars()) { for (auto& var : global_block.AllVars()) {
bool is_param_grad = false; bool is_param_grad = false;
size_t pos = 0; size_t pos = 0;
if ((pos = var->Name().find(kGradVarSuffix)) != std::string::npos) { // A magic suffix to indicated the merged gradient.
std::string magicSuffix = "MERGED";
if ((pos = var->Name().find(kGradVarSuffix)) != std::string::npos &&
var->Name().find(magicSuffix) != std::string::npos) {
auto prefix_name = var->Name().substr(0, pos); auto prefix_name = var->Name().substr(0, pos);
if (param_map.find(prefix_name) != param_map.end()) { if (param_map.find(prefix_name) != param_map.end()) {
is_param_grad = true; is_param_grad = true;
......
...@@ -11,55 +11,31 @@ limitations under the License. */ ...@@ -11,55 +11,31 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
#include <float.h> #include <float.h>
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/program_desc.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/fleet/box_wrapper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/lodtensor_printer.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class TrainerDesc;
uint64_t SectionWorker::batch_id_(0); uint64_t SectionWorker::batch_id_(0);
void SectionWorker::Initialize(const TrainerDesc& desc) { void SectionWorker::Initialize(const TrainerDesc &desc) {
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
program_.reset( program_.reset(
new ProgramDesc(desc.section_param().section_config().program_desc())); new ProgramDesc(desc.section_param().section_config().program_desc()));
for (auto& op_desc : program_->Block(0).AllOps()) { for (auto &op_desc : program_->Block(0).AllOps()) {
ops_.push_back(OpRegistry::CreateOp(*op_desc)); ops_.push_back(OpRegistry::CreateOp(*op_desc));
} }
} }
void SectionWorker::TrainFiles() { void SectionWorker::RunForward(
VLOG(5) << "begin section_worker TrainFiles"; int micro_id, std::unique_ptr<GarbageCollector> &gc,
std::unordered_map<const OperatorBase *, std::vector<std::string>>
int64_t max_memory_size = GetEagerDeletionThreshold(); &unused_vars_) {
std::unique_ptr<GarbageCollector> gc; for (auto &op : ops_) {
auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
if (max_memory_size >= 0) {
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) {
if (IsFastEagerDeletionModeEnabled()) {
gc.reset(new UnsafeFastGPUGarbageCollector(
BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
}
}
#endif
}
for (int i = 0; i < num_microbatches_; ++i) {
for (auto& op : ops_) {
int op_role = op->Attr<int>(std::string("op_role")); int op_role = op->Attr<int>(std::string("op_role"));
// We run op with op_role = kLRSched only for the first microbatch // We run op with op_role = kLRSched only for the first microbatch
// to avoid increasing the @LR_DECAY_STEP@ multiple times. // to avoid increasing the @LR_DECAY_STEP@ multiple times.
...@@ -70,49 +46,115 @@ void SectionWorker::TrainFiles() { ...@@ -70,49 +46,115 @@ void SectionWorker::TrainFiles() {
bool run_others = op_role == static_cast<int>(OpRole::kForward) || bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) | op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss)); static_cast<int>(OpRole::kLoss));
if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) { if ((micro_id == 0 && run_first_mbatch) || (micro_id != 0 && run_others)) {
VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch " VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
<< i; << micro_id;
op->Run(*microbatch_scopes_[i], place_); op->Run(*microbatch_scopes_[micro_id], place_);
if (gc) { if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_, DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(),
gc.get()); unused_vars_, gc.get());
}
} }
} }
cudaDeviceSynchronize();
} }
}
// backward pass void SectionWorker::RunBackward(
for (int i = 0; i < num_microbatches_; ++i) { int micro_id, std::unique_ptr<GarbageCollector> &gc,
for (auto& op : ops_) { std::unordered_map<const OperatorBase *, std::vector<std::string>>
&unused_vars_) {
for (auto &op : ops_) {
int op_role = op->Attr<int>(std::string("op_role")); int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kBackward) || if (op_role == static_cast<int>(OpRole::kBackward) ||
op_role == (static_cast<int>(OpRole::kBackward) | op_role == (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss))) { static_cast<int>(OpRole::kLoss))) {
VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch " VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
<< i; << micro_id;
op->Run(*microbatch_scopes_[i], place_); op->Run(*microbatch_scopes_[micro_id], place_);
if (gc) { if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_, DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(),
gc.get()); unused_vars_, gc.get());
} }
} }
} }
cudaDeviceSynchronize(); }
}
// update pass void SectionWorker::RunUpdate(
for (auto& op : ops_) { std::unique_ptr<GarbageCollector> &gc,
std::unordered_map<const OperatorBase *, std::vector<std::string>>
&unused_vars_) {
for (auto &op : ops_) {
int op_role = op->Attr<int>(std::string("op_role")); int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kOptimize)) { if (op_role == static_cast<int>(OpRole::kOptimize)) {
VLOG(3) << "Update: running op " << op->Type(); VLOG(3) << "Update: running op " << op->Type();
op->Run(*microbatch_scopes_[0], place_); op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_);
if (gc) { if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[0], op.get(), unused_vars_, DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
gc.get()); op.get(), unused_vars_, gc.get());
}
}
}
}
void SectionWorker::TrainFiles() {
VLOG(5) << "begin section_worker TrainFiles";
int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector> gc;
auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
if (max_memory_size >= 0) {
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) {
if (IsFastEagerDeletionModeEnabled()) {
gc.reset(new UnsafeFastGPUGarbageCollector(
BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
}
}
#endif
}
if (schedule_mode_ == 0) {
// Gpipe scheduler which runs all forwards first, then backwards, then
// update
// step1: run forward
for (int i = 0; i < num_microbatches_; ++i) {
RunForward(i, gc, unused_vars_);
}
// step2: run backward
for (int i = 0; i < num_microbatches_; ++i) {
RunBackward(i, gc, unused_vars_);
}
// step2: run update
RunUpdate(gc, unused_vars_);
} else {
// 1F1B scheduler
auto startup_steps = num_pipeline_stages_ - pipeline_stage_ - 1;
VLOG(3) << "startup_steps:" << startup_steps
<< ", num_stages: " << num_pipeline_stages_
<< ", stage:" << pipeline_stage_;
if (startup_steps > num_microbatches_) {
startup_steps = num_microbatches_;
}
int fw_step = 0;
int bw_step = 0;
// startup phase
while (fw_step < startup_steps) {
RunForward(fw_step, gc, unused_vars_);
fw_step += 1;
}
// 1f1b phase
while (fw_step < num_microbatches_) {
RunForward(fw_step, gc, unused_vars_);
fw_step += 1;
RunBackward(bw_step, gc, unused_vars_);
bw_step += 1;
} }
// backward phase
while (bw_step < num_microbatches_) {
RunBackward(bw_step, gc, unused_vars_);
bw_step += 1;
} }
RunUpdate(gc, unused_vars_);
} }
dev_ctx_->Wait(); dev_ctx_->Wait();
++batch_id_; ++batch_id_;
......
...@@ -93,6 +93,9 @@ message SectionWorkerParameter { ...@@ -93,6 +93,9 @@ message SectionWorkerParameter {
optional int32 start_cpu_core_id = 4 [ default = 1 ]; optional int32 start_cpu_core_id = 4 [ default = 1 ];
repeated string param_need_sync = 5; repeated string param_need_sync = 5;
optional int32 num_microbatches = 6; optional int32 num_microbatches = 6;
optional int32 num_pipeline_stages = 7 [ default = 1 ];
optional int32 pipeline_stage = 8 [ default = 1 ];
optional int32 schedule_mode = 9 [ default = 0 ];
} }
message SectionConfig { message SectionConfig {
......
...@@ -55,7 +55,8 @@ class CGenNCCLIdOp : public framework::OperatorBase { ...@@ -55,7 +55,8 @@ class CGenNCCLIdOp : public framework::OperatorBase {
SendBroadCastNCCLID(endpoint_list, 1, func, local_scope); SendBroadCastNCCLID(endpoint_list, 1, func, local_scope);
} else { } else {
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
RecvBroadCastNCCLID(endpoint, 1, func, local_scope); int server_fd = platform::SocketServer::GetInstance(endpoint).socket();
platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids);
} }
scope.DeleteScope(&local_scope); scope.DeleteScope(&local_scope);
} }
...@@ -71,8 +72,7 @@ class CGenNCCLIdOp : public framework::OperatorBase { ...@@ -71,8 +72,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {}
}
}; };
#endif #endif
......
...@@ -31,7 +31,9 @@ limitations under the License. */ ...@@ -31,7 +31,9 @@ limitations under the License. */
#include "paddle/fluid/string/split.h" #include "paddle/fluid/string/split.h"
namespace paddle { namespace paddle {
namespace operators { namespace platform {
std::once_flag SocketServer::init_flag_;
constexpr char COMM_HEAD[] = "_pd_gen_comm_id_"; constexpr char COMM_HEAD[] = "_pd_gen_comm_id_";
...@@ -340,5 +342,34 @@ void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num, ...@@ -340,5 +342,34 @@ void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num,
CloseSocket(client); CloseSocket(client);
} }
} // namespace operators SocketServer& SocketServer::GetInstance(const std::string& end_point) {
static SocketServer instance;
std::call_once(init_flag_, [&]() {
instance.server_fd_ = CreateListenSocket(end_point);
instance.end_point_ = end_point;
});
PADDLE_ENFORCE_NE(instance.server_fd_, -1,
platform::errors::Unavailable(
"listen socket failed with end_point=%s", end_point));
PADDLE_ENFORCE_EQ(instance.end_point_, end_point,
platform::errors::InvalidArgument(
"old end_point=%s must equal with new end_point=%s",
instance.end_point_, end_point));
return instance;
}
/// template instantiation
#define INSTANT_TEMPLATE(Type) \
template void SendBroadCastCommID<Type>(std::vector<std::string> servers, \
std::vector<Type> * nccl_ids); \
template void RecvBroadCastCommID<Type>(std::string endpoint, \
std::vector<Type> * nccl_ids);
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
INSTANT_TEMPLATE(ncclUniqueId)
#endif
#ifdef PADDLE_WITH_XPU_BKCL
INSTANT_TEMPLATE(BKCLUniqueId)
#endif
} // namespace platform
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,8 @@ limitations under the License. */ ...@@ -15,6 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include <functional> #include <functional>
#include <memory>
#include <mutex> // NOLINT
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -25,7 +27,7 @@ class Scope; ...@@ -25,7 +27,7 @@ class Scope;
} // namespace paddle } // namespace paddle
namespace paddle { namespace paddle {
namespace operators { namespace platform {
int CreateListenSocket(const std::string& ep); int CreateListenSocket(const std::string& ep);
...@@ -41,8 +43,26 @@ void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num, ...@@ -41,8 +43,26 @@ void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num,
const framework::Scope& scope); const framework::Scope& scope);
// recv nccl id from socket // recv nccl id from socket
void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num, template <typename CommUniqueId>
std::function<std::string(size_t)> func, void RecvBroadCastCommID(int server_fd, std::string endpoint,
const framework::Scope& scope); std::vector<CommUniqueId>* nccl_ids);
} // namespace operators
class SocketServer {
public:
SocketServer() = default;
~SocketServer() { CloseSocket(server_fd_); }
int socket() const { return server_fd_; }
static SocketServer& GetInstance(const std::string& end_point);
private:
int server_fd_{-1};
std::string end_point_;
static std::once_flag init_flag_;
};
} // namespace platform
} // namespace paddle } // namespace paddle
...@@ -736,6 +736,60 @@ class DistributedStrategy(object): ...@@ -736,6 +736,60 @@ class DistributedStrategy(object):
"sharding_configs") "sharding_configs")
assign_configs_value(self.strategy.sharding_configs, configs) assign_configs_value(self.strategy.sharding_configs, configs)
@property
def model_parallel(self):
"""
Indicating whether we are using model parallel parallelism for distributed training.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.model_parallel = True
"""
return self.strategy.model_parallel
@model_parallel.setter
@is_strict_auto
def model_parallel(self, flag):
if isinstance(flag, bool):
self.strategy.model_parallel = flag
else:
print("WARNING: model_parallel should have value of bool type")
@property
def model_parallel_configs(self):
"""
Set model_parallel parallelism configurations.
**Notes**:
**Detailed arguments for model_parallel_configs**
**parallelism**: degree of model parallel
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.model_parallel = True
strategy.model_parallel_configs = {"parallelism": 12}
"""
return get_msg_dict(self.strategy.model_parallel_configs)
@model_parallel_configs.setter
@is_strict_auto
def model_parallel_configs(self, configs):
check_configs_key(self.strategy.model_parallel_configs, configs,
"model_parallel_configs")
assign_configs_value(self.strategy.model_parallel_configs, configs)
@property @property
def pipeline(self): def pipeline(self):
""" """
......
...@@ -17,6 +17,7 @@ from .gradient_merge_optimizer import GradientMergeOptimizer ...@@ -17,6 +17,7 @@ from .gradient_merge_optimizer import GradientMergeOptimizer
from .graph_execution_optimizer import GraphExecutionOptimizer from .graph_execution_optimizer import GraphExecutionOptimizer
from .parameter_server_optimizer import ParameterServerOptimizer from .parameter_server_optimizer import ParameterServerOptimizer
from .pipeline_optimizer import PipelineOptimizer from .pipeline_optimizer import PipelineOptimizer
from .model_parallel_optimizer import ModelParallelOptimizer
from .localsgd_optimizer import LocalSGDOptimizer from .localsgd_optimizer import LocalSGDOptimizer
from .localsgd_optimizer import AdaptiveLocalSGDOptimizer from .localsgd_optimizer import AdaptiveLocalSGDOptimizer
from .lars_optimizer import LarsOptimizer from .lars_optimizer import LarsOptimizer
......
...@@ -50,15 +50,17 @@ class AMPOptimizer(MetaOptimizerBase): ...@@ -50,15 +50,17 @@ class AMPOptimizer(MetaOptimizerBase):
self.inner_opt, amp_lists, config['init_loss_scaling'], self.inner_opt, amp_lists, config['init_loss_scaling'],
config['incr_every_n_steps'], config['decr_every_n_nan_or_inf'], config['incr_every_n_steps'], config['decr_every_n_nan_or_inf'],
config['incr_ratio'], config['decr_ratio'], config['incr_ratio'], config['decr_ratio'],
config['use_dynamic_loss_scaling']) config['use_dynamic_loss_scaling'], config['use_pure_fp16'],
config['use_fp16_guard'])
# if worker_num > 1, all cards will communication with each other, # if worker_num > 1, all cards will communication with each other,
# add is_distributed to optimize amp, overlap communication and # add is_distributed to optimize amp, overlap communication and
# computation by split the check_finite_and_unscale op. # computation by split the check_finite_and_unscale op.
is_distributed = self.role_maker._worker_num() > 1 is_distributed = self.role_maker._worker_num() > 1
if self.user_defined_strategy.sharding: #if self.user_defined_strategy.sharding or self.user_defined_strategy.model_parallel:
# FIXME(wangxi). sharding failed when split check_finite_and_unscale # # FIXME(wangxi). sharding failed when split check_finite_and_unscale
is_distributed = False # # FIXME(JZ-LIANG). To support Sharding-Megatron-AMP, Megatron should follow Sharding's behavior
# is_distributed = False
self.wrapped_opt._set_distributed(is_distributed) self.wrapped_opt._set_distributed(is_distributed)
def _can_apply(self): def _can_apply(self):
...@@ -112,3 +114,11 @@ class AMPOptimizer(MetaOptimizerBase): ...@@ -112,3 +114,11 @@ class AMPOptimizer(MetaOptimizerBase):
self.wrapped_opt.minimize(loss, startup_program, self.wrapped_opt.minimize(loss, startup_program,
parameter_list, no_grad_set) parameter_list, no_grad_set)
return optimize_ops, params_grads return optimize_ops, params_grads
def amp_init(self,
place,
scope=None,
test_program=None,
use_fp16_test=False):
return self.wrapped_opt.amp_init(place, scope, test_program,
use_fp16_test)
...@@ -25,6 +25,24 @@ OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() ...@@ -25,6 +25,24 @@ OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
OP_ROLE_VAR_KEY = core.op_proto_and_checker_maker.kOpRoleVarAttrName() OP_ROLE_VAR_KEY = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
class Topology:
"""A 4-D structure to describe the process group."""
def __init__(self, axes, dims):
pass
class ParallelGrid:
"""Initialize each process group."""
def __init__(self, topology):
self.build_global_group()
self.build_mp_group()
self.build_sharding_group()
self.build_pp_group()
self.build_dp_group()
def is_update_op(op): def is_update_op(op):
return 'Param' in op.input_names and 'Grad' in op.input_names and \ return 'Param' in op.input_names and 'Grad' in op.input_names and \
"LearningRate" in op.input_names "LearningRate" in op.input_names
...@@ -66,16 +84,49 @@ class CollectiveHelper(object): ...@@ -66,16 +84,49 @@ class CollectiveHelper(object):
self.role_maker._worker_index(), ring_id, self.wait_port) self.role_maker._worker_index(), ring_id, self.wait_port)
self._broadcast_params() self._broadcast_params()
def _init_communicator(self, program, current_endpoint, endpoints, rank, def _init_communicator(self,
ring_id, wait_port): program,
current_endpoint,
endpoints,
rank,
ring_id,
wait_port,
sync=True):
nranks = len(endpoints) nranks = len(endpoints)
other_endpoints = endpoints[:] other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint) other_endpoints.remove(current_endpoint)
block = program.global_block() block = program.global_block()
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
if rank == 0 and wait_port: if not wait_port and sync:
wait_server_ready(other_endpoints) temp_var = block.create_var(
nccl_id_var = block.create_var( name=unique_name.generate('temp_var'),
dtype=core.VarDesc.VarType.INT32,
persistable=False,
stop_gradient=True)
block.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [temp_var]},
attrs={
'shape': [1],
'dtype': temp_var.dtype,
'value': 1,
'force_cpu': False,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_allreduce_sum',
inputs={'X': [temp_var]},
outputs={'Out': [temp_var]},
attrs={'ring_id': 3,
OP_ROLE_KEY: OpRole.Forward})
block.append_op(
type='c_sync_comm_stream',
inputs={'X': temp_var},
outputs={'Out': temp_var},
attrs={'ring_id': 3,
OP_ROLE_KEY: OpRole.Forward})
comm_id_var = block.create_var(
name=unique_name.generate('nccl_id'), name=unique_name.generate('nccl_id'),
persistable=True, persistable=True,
type=core.VarDesc.VarType.RAW) type=core.VarDesc.VarType.RAW)
...@@ -100,9 +151,7 @@ class CollectiveHelper(object): ...@@ -100,9 +151,7 @@ class CollectiveHelper(object):
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward
}) })
elif core.is_compiled_with_npu(): elif core.is_compiled_with_npu():
endpoint_to_index_map = { endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
e: idx for idx, e in enumerate(endpoints)
}
block.append_op( block.append_op(
type='c_comm_init_hcom', type='c_comm_init_hcom',
inputs={}, inputs={},
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from __future__ import print_function
from __future__ import division
import paddle.fluid as fluid
from paddle.fluid import core, unique_name
from ..base.private_helper_function import wait_server_ready
from .meta_optimizer_base import MetaOptimizerBase
from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_update_op, is_loss_grad_op, is_backward_op, is_optimizer_op
class ModelParallelHelper(object):
def __init__(self, role_maker, wait_port=True, megatron_dp=False):
self.wait_port = wait_port
self.role_maker = role_maker
self.megatron_dp = megatron_dp
def update_startup_program(self,
startup_program=None,
inner_parallelism=None):
self.startup_program = startup_program
nranks = self.role_maker._worker_num()
rank = self.role_maker._worker_index()
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[rank]
# Create ring 0 for all model parallel parts within a single model
mp_endpoints = []
mp_rank = rank % inner_parallelism
mp_id = rank // inner_parallelism
for idx, ep in enumerate(endpoints):
if idx // inner_parallelism == mp_id:
mp_endpoints.append(ep)
print("model parallel eps:{}, rank{}".format(mp_endpoints, mp_rank))
self._init_communicator(self.startup_program, current_endpoint,
mp_endpoints, mp_rank, 0, self.wait_port)
self._broadcast_params(0, broadcast_distributed_weight=False)
print("megatron group size: {}".format(inner_parallelism))
print("megatron rank: {}".format(mp_rank))
print("megatron endpoints: {}".format(mp_endpoints))
if self.megatron_dp:
mp_num = len(endpoints) // inner_parallelism
if mp_num == 1: return
# Create rings for gpus as the same model parallel part
eps = []
dp_rank = rank // inner_parallelism
dp_id = rank % inner_parallelism
#if dp_rank == 1: dp_rank =0
#if dp_rank == 0: dp_rank =1
ring_id = 1
for idx, ep in enumerate(endpoints):
if idx % inner_parallelism == dp_id:
eps.append(ep)
#ep = eps.pop(0)
#eps.insert(1, ep)
print("data parallel eps:{}, rank{}".format(eps, dp_rank))
self._init_communicator(self.startup_program, current_endpoint, eps,
dp_rank, ring_id, self.wait_port)
self._broadcast_params(ring_id, broadcast_distributed_weight=True)
def _init_communicator(self, program, current_endpoint, endpoints, rank,
ring_id, wait_port):
nranks = len(endpoints)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
block = program.global_block()
nccl_id_var = block.create_var(
name=unique_name.generate('nccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_nccl_id',
inputs={},
outputs={'Out': nccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
OP_ROLE_KEY: OpRole.Forward,
})
block.append_op(
type='c_comm_init',
inputs={'X': nccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward,
})
def _broadcast_params(self, ring_id, broadcast_distributed_weight):
block = self.startup_program.global_block()
for param in block.iter_parameters():
if not broadcast_distributed_weight and param.is_distributed:
continue
block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': ring_id,
'root': 0,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_sync_comm_stream',
inputs={'X': param},
outputs={'Out': param},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward})
class ModelParallelOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
super(ModelParallelOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
self.meta_optimizers_white_list = [
"RecomputeOptimizer",
"AMPOptimizer",
"LarsOptimizer",
"LambOptimizer",
]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
self.megatron_dp = False
def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
user_defined_strategy):
super(ModelParallelOptimizer, self)._set_basic_info(
loss, role_maker, user_defined_optimizer, user_defined_strategy)
self.inner_parallelism = user_defined_strategy.model_parallel_configs[
'parallelism']
def _can_apply(self):
if not self.role_maker._is_collective:
return False
if self.user_defined_strategy.model_parallel == True:
return True
return False
def _disable_strategy(self, dist_strategy):
dist_strategy.model_parallel = False
dist_strategy.model_parallel_configs = {}
def _enable_strategy(self, dist_strategy, context):
dist_strategy.model_parallel = True
dist_strategy.model_parallel_configs = {"parallelism": 1, }
# the following function will be used by AMP if both Megatron and AMP are turn on together.
def apply_gradients(self, params_grads):
return self.minimize_impl(params_grads=params_grads)
def minimize_impl(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker._worker_index()]
self.startup_program = startup_program
if startup_program is None:
self.startup_program = fluid.default_startup_program()
# (TODO) check the order of metaoptimizer
# (TODO) check the params_grads
optimize_ops, params_grads = self.inner_opt.minimize(
loss, self.startup_program, parameter_list, no_grad_set)
self.main_program = loss.block.program
self.inner_parallelism = self.inner_parallelism
self.nranks = len(endpoints)
pipeline_helper = ModelParallelHelper(self.role_maker)
pipeline_helper.update_startup_program(self.startup_program,
self.inner_parallelism)
assert self.nranks % self.inner_parallelism == 0
if self.megatron_dp:
# data parallelism
dp_parallelism = self.nranks // self.inner_parallelism
self._transpile_main_program(loss, dp_parallelism)
return optimize_ops, params_grads
def _transpile_main_program(self, loss, dp_parallelism):
self._insert_loss_grad_ops(loss, dp_parallelism)
ring_id = 1
print("ring_id: ", ring_id)
# for ring_id in range(1, dp_parallelism + 1):
self._insert_allreduce_ops(loss, ring_id)
def _insert_loss_grad_ops(self, loss, dp_parallelism):
"""
In order to keep the learning rate consistent in different numbers of
training workers, we scale the loss grad by the number of workers
"""
block = loss.block
for idx, op in reversed(list(enumerate(block.ops))):
if is_loss_grad_op(op):
loss_grad_var = block.vars[op.output_arg_names[0]]
block._insert_op(
idx + 1,
type='scale',
inputs={'X': loss_grad_var},
outputs={'Out': loss_grad_var},
attrs={
'scale': 1.0 / dp_parallelism,
OP_ROLE_KEY: OpRole.Backward
})
def _insert_allreduce_ops(self, loss, ring_id):
block = loss.block
grad = None
for idx, op in reversed(list(enumerate(block.ops))):
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if len(op_role_var) == 0:
continue
assert len(op_role_var) % 2 == 0
offset = idx
for i in range(0, len(op_role_var), 2):
param = block.vars[op_role_var[i]]
grad = block.vars[op_role_var[i + 1]]
#if param.is_distributed:
# continue
if offset == idx:
offset += 1
block._insert_op(
offset,
type='c_sync_calc_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={OP_ROLE_KEY: OpRole.Backward})
offset += 1
block._insert_op(
offset,
type='c_allreduce_sum',
inputs={'X': grad},
outputs={'Out': grad},
attrs={
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward
})
if grad is None:
return
for idx, op in list(enumerate(block.ops)):
if is_optimizer_op(op):
block._insert_op(
idx,
type='c_sync_comm_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward})
break
...@@ -154,8 +154,10 @@ class PipelineOptimizer(MetaOptimizerBase): ...@@ -154,8 +154,10 @@ class PipelineOptimizer(MetaOptimizerBase):
def __init__(self, optimizer): def __init__(self, optimizer):
super(PipelineOptimizer, self).__init__(optimizer) super(PipelineOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer self.inner_opt = optimizer
# we do not allow meta optimizer to be inner optimizer currently self.meta_optimizers_white_list = [
self.meta_optimizers_white_list = [] "RecomputeOptimizer",
"AMPOptimizer",
]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
def _set_basic_info(self, loss, role_maker, user_defined_optimizer, def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
......
...@@ -73,7 +73,7 @@ class FP16Utils(object): ...@@ -73,7 +73,7 @@ class FP16Utils(object):
@staticmethod @staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, ring_id): def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
""" """
1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard 1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding 2. revise amp inifine grad checking for sharding
""" """
# remove cast # remove cast
...@@ -81,7 +81,9 @@ class FP16Utils(object): ...@@ -81,7 +81,9 @@ class FP16Utils(object):
if not FP16Utils.is_fp32_cast_op(block, op): if not FP16Utils.is_fp32_cast_op(block, op):
continue continue
output_name = op.desc.output_arg_names()[0] output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD") param_name = output_name.strip(
"@GRAD@MERGED"
) if "@MERGED" in output_name else output_name.strip("@GRAD")
if param_name not in shard.global_params: if param_name not in shard.global_params:
raise ValueError("Output 'X' of cast_op must be a grad of" raise ValueError("Output 'X' of cast_op must be a grad of"
"model param, but {} is not a grad".format( "model param, but {} is not a grad".format(
...@@ -103,20 +105,35 @@ class FP16Utils(object): ...@@ -103,20 +105,35 @@ class FP16Utils(object):
op._rename_input(inf_var_name, inf_var_name + "@sharding") op._rename_input(inf_var_name, inf_var_name + "@sharding")
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
reversed_x = [] reversed_x = []
reversed_x_paramname = []
for input_name in op.desc.input('X'): for input_name in op.desc.input('X'):
param_name = input_name.strip("@GRAD") param_name = input_name.strip("@GRAD@MERGED")
if param_name not in shard.global_params: if param_name not in shard.global_params:
raise ValueError( raise ValueError(
"Input 'X' of check_finite_and_unscale must" "Input 'X' of check_finite_and_unscale must"
"be grads, but {} is not a grad".format(input_name)) "be grads, but {} is not a grad".format(input_name))
if shard.has_param(param_name): if shard.has_param(param_name):
reversed_x.append(input_name) reversed_x.append(input_name)
reversed_x_paramname.append(param_name)
op.desc.set_input('X', reversed_x) op.desc.set_input('X', reversed_x)
op.desc.set_output('Out', reversed_x) op.desc.set_output('Out', reversed_x)
# the grad checking should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
should_check_param = set(shard.global_params).intersection(
set([
param
for param, worker_idx in shard.global_param2device.
items() if worker_idx == shard.worker_idx
]))
#assert to_check_param == should_check_param, "amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(
# should_check_param - to_check_param,
# to_check_param - should_check_param)
if update_loss_scaling_op_idx == -1: if update_loss_scaling_op_idx == -1:
return return
inf_var = block.var(inf_var_name) inf_var = block.var(inf_var_name)
inf_var_fp32 = block.create_var( inf_var_int32 = block.create_var(
name=inf_var_name + "@cast_int32", name=inf_var_name + "@cast_int32",
shape=inf_var.shape, shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32) dtype=core.VarDesc.VarType.INT32)
...@@ -128,32 +145,36 @@ class FP16Utils(object): ...@@ -128,32 +145,36 @@ class FP16Utils(object):
update_loss_scaling_op_idx, update_loss_scaling_op_idx,
type='cast', type='cast',
inputs={'X': inf_var}, inputs={'X': inf_var},
outputs={'Out': inf_var_fp32}, outputs={'Out': inf_var_int32},
attrs={ attrs={
"in_dtype": inf_var.dtype, "in_dtype": inf_var.dtype,
"out_dtype": inf_var_fp32.dtype, "out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
insert_sync_calc_op(block, update_loss_scaling_op_idx + 1, # this allreduce communication should not overlap with calc
[inf_var_fp32]) # insert_sync_calc_op(block, update_loss_scaling_op_idx + 1,
# [inf_var_int32])
block._insert_op_without_sync( block._insert_op_without_sync(
update_loss_scaling_op_idx + 2, update_loss_scaling_op_idx + 1,
type='c_allreduce_max', type='c_allreduce_max',
inputs={'X': inf_var_fp32}, inputs={'X': inf_var_int32},
outputs={'Out': inf_var_fp32}, outputs={'Out': inf_var_int32},
attrs={'ring_id': ring_id, attrs={
OP_ROLE_KEY: OpRole.Optimize}) 'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3, # comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
ring_id, [inf_var_fp32]) # ring_id, [inf_var_int32])
block._insert_op_without_sync( block._insert_op_without_sync(
update_loss_scaling_op_idx + 3 + comm_op_num, update_loss_scaling_op_idx + 2,
type='cast', type='cast',
inputs={'X': inf_var_fp32}, inputs={'X': inf_var_int32},
outputs={'Out': inf_var_sharding}, outputs={'Out': inf_var_sharding},
attrs={ attrs={
"in_dtype": inf_var_fp32.dtype, "in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var_sharding.dtype, "out_dtype": inf_var_sharding.dtype,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
......
...@@ -16,8 +16,8 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole ...@@ -16,8 +16,8 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
class GradientClipHelper(object): class GradientClipHelper(object):
def __init__(self, sharding_ring_id): def __init__(self, mp_ring_id):
self.sharding_ring_id = sharding_ring_id self.mp_ring_id = mp_ring_id
def _is_gradient_clip_op(self, op): def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \ return op.desc.has_attr("op_namescope") \
...@@ -31,6 +31,7 @@ class GradientClipHelper(object): ...@@ -31,6 +31,7 @@ class GradientClipHelper(object):
""" """
deperated_vars = set() deperated_vars = set()
deperate_op_idx = set() deperate_op_idx = set()
reversed_x_paramname = []
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if not self._is_gradient_clip_op(op): if not self._is_gradient_clip_op(op):
continue continue
...@@ -40,14 +41,17 @@ class GradientClipHelper(object): ...@@ -40,14 +41,17 @@ class GradientClipHelper(object):
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if input_name in deperated_vars: if input_name in deperated_vars:
deperate_op = True deperate_op = True
param_name = input_name.strip("@GRAD") param_name = input_name.strip("@GRAD@MERGED")
if shard.is_param(param_name) and \ if shard.is_param(param_name) and \
not shard.has_param(param_name): not shard.has_param(param_name):
deperate_op = True deperate_op = True
elif shard.is_param(param_name):
reversed_x_paramname.append(param_name)
if deperate_op: if deperate_op:
deperate_op_idx.add(idx) deperate_op_idx.add(idx)
for output_name in op.desc.output_arg_names(): for output_name in op.desc.output_arg_names():
if output_name not in op.desc.input_arg_names():
deperated_vars.add(output_name) deperated_vars.add(output_name)
if not deperated_vars: if not deperated_vars:
...@@ -65,31 +69,47 @@ class GradientClipHelper(object): ...@@ -65,31 +69,47 @@ class GradientClipHelper(object):
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars: if input_name not in deperated_vars:
reversed_inputs.append(input_name) reversed_inputs.append(input_name)
op.desc.set_input("X", reversed_inputs) op.desc.set_input("X", reversed_inputs)
assert (len(op.desc.output_arg_names()) == 1) assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0] sum_res = op.desc.output_arg_names()[0]
block._insert_op_without_sync(
idx + 1, # this allreduce should not overlap with calc and should be scheduled in calc stream
type='c_sync_comm_stream', # block._insert_op_without_sync(
inputs={'X': sum_res}, # idx + 1,
outputs={'Out': sum_res}, # type='c_sync_comm_stream',
attrs={'ring_id': 0, # inputs={'X': sum_res},
OP_ROLE_KEY: OpRole.Optimize}) # outputs={'Out': sum_res},
# attrs={'ring_id': 0,
# OP_ROLE_KEY: OpRole.Optimize})
block._insert_op_without_sync( block._insert_op_without_sync(
idx + 1, idx + 1,
type='c_allreduce_sum', type='c_allreduce_sum',
inputs={'X': sum_res}, inputs={'X': sum_res},
outputs={'Out': sum_res}, outputs={'Out': sum_res},
attrs={ attrs={
'ring_id': self.sharding_ring_id, 'ring_id': self.mp_ring_id,
OP_ROLE_KEY: OpRole.Optimize 'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
}) })
block._insert_op_without_sync( # block._insert_op_without_sync(
idx + 1, # idx + 1,
type='c_sync_calc_stream', # type='c_sync_calc_stream',
inputs={'X': sum_res}, # inputs={'X': sum_res},
outputs={'Out': sum_res}, # outputs={'Out': sum_res},
attrs={OP_ROLE_KEY: OpRole.Optimize}) # attrs={OP_ROLE_KEY: OpRole.Optimize})
# the grad sum here should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
should_check_param = set(shard.global_params).intersection(
set([
param for param, worker_idx in shard.global_param2device.items()
if worker_idx == shard.worker_idx
]))
assert to_check_param == should_check_param, "amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(
should_check_param - to_check_param,
to_check_param - should_check_param)
for var_name in deperated_vars: for var_name in deperated_vars:
block._remove_var(var_name, sync=False) block._remove_var(var_name, sync=False)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole
from paddle.fluid import core, unique_name
class OffloadHelper(object):
cpu_place_type = 0
cuda_place_type = 1
cuda_pinned_place_type = 2
def __init__(self):
pass
"0: dst is on CPUPlace. "
"1: dst is on CUDAPlace. "
"2: dst is on CUDAPinnedPlace. "
def _insert_cast_op(self, block, idx, src_name, dst_name):
src_var = block.var(src_name)
if not block.has_var(dst_name):
block.create_var(
name=dst_name,
shape=src_var.shape,
dtype=core.VarDesc.VarType.FP16,
persistable=True)
dst_var = block.var(dst_name)
assert dst_var.dtype == core.VarDesc.VarType.FP16
block._insert_op_without_sync(
idx,
type='cast',
inputs={'X': src_var},
outputs={'Out': dst_var},
attrs={
'in_dtype': src_var.dtype,
'out_dtype': dst_var.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
def _insert_memcpy_op(self, block, idx, src_name, dst_name, dst_place_type):
src_var = block.var(src_name)
dst_var = block.var(dst_name)
block._insert_op_without_sync(
idx,
type='memcpy',
inputs={'X': src_var},
outputs={'Out': dst_var},
attrs={
'dst_place_type': dst_place_type,
OP_ROLE_KEY: OpRole.Optimize,
})
def _insert_fetch_op(self, block, idx, src_name, dst_name):
self._insert_memcpy_op(block, idx, src_name, dst_name,
OffloadHelper.cuda_place_type)
def _insert_offload_op(self, block, idx, src_name, dst_name):
self._insert_memcpy_op(block, idx, src_name, dst_name,
OffloadHelper.cuda_pinned_place_type)
def _get_offload_var_name(self, name):
return unique_name.generate(name + '@offload')
def _create_offload_var(self, var_name, offload_var_name, blocks):
for block in blocks:
var = block.var(var_name)
var.persistable = False
offload_var = block.create_var(
name=offload_var_name,
shape=var.shape,
dtype=var.dtype,
persistable=True)
def offload_fp32param(self, block, startup_block):
"""
(p_fp16) = cast(p)
(p_fp16_recompute) = cast(p)
(pout,) = adam(p)
===========================>
rename(p_fp16_recompute, p_fp16)
(p,) = prefetch(p@offload)
(pout,) = adam(p)
(p_fp16) = cast(p)
(p@offload) = memcpy(p)
"""
param_to_idx = dict()
param_to_fp16 = dict()
# recompute_var which need rename to fp16_param
fp16_param_to_recompute = dict()
recompute_to_fp16 = dict()
def remove_param(input_name):
param_to_idx.pop(input_name)
if input_name in param_to_fp16:
fp16_param = param_to_fp16.pop(input_name)
if fp16_param in fp16_param_to_recompute:
recompute = fp16_param_to_recompute.pop(fp16_param)
recompute_to_fp16.pop(recompute)
# step1: record param
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'):
param = op.desc.input("Param")[0]
param_to_idx[param] = idx
# step2: remove param which can't offload
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
break
for input_name in op.desc.input_arg_names():
if input_name not in param_to_idx:
continue
# param is real used by fp32 op
if op.type != 'cast':
remove_param(input_name)
continue
# param is only used by cast op,
# which to cast fp32_param to fp16_param
output_name = op.output_arg_names[0]
if 'cast_fp16' not in output_name:
remove_param(input_name)
continue
if 'subprog' not in output_name:
assert output_name == input_name + '.cast_fp16'
assert input_name not in param_to_fp16, \
"There must be only one cast op from fp32 param to fp16 param."
param_to_fp16[input_name] = output_name
else:
# fp16-->recompute_var
assert input_name in param_to_fp16, \
"param must first be cast to fp16"
fp16_param = param_to_fp16[input_name]
fp16_param_to_recompute[fp16_param] = output_name
recompute_to_fp16[output_name] = fp16_param
param_name_to_offload_name = dict()
# step3: main_block add offload, cast op
# change recompute to fp16, remove cast(param) to fp16
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'):
param = op.desc.input("Param")[0]
if param not in param_to_idx: continue
# step3.1: create offload_var
offload_var_name = self._get_offload_var_name(param)
param_name_to_offload_name[param] = offload_var_name
self._create_offload_var(param, offload_var_name,
[block, startup_block])
# step3.2: insert cast op and offload op
self._insert_offload_op(block, idx + 1, param, offload_var_name)
assert param in param_to_fp16
fp16_param_name = param_to_fp16[param]
fp16_param_var = block.var(fp16_param_name)
fp16_param_var.persistable = True
self._insert_cast_op(block, idx + 1, param,
param_to_fp16[param])
# step3.3: insert fetch op
self._insert_fetch_op(block, idx, offload_var_name, param)
continue
# step3.4: remove cast op
if op.type == 'cast':
input_name = op.desc.input_arg_names()[0]
if input_name in param_to_idx:
block._remove_op(idx, sync=False)
continue
# step3.5: change recompute_param to fp16_param
for input_name in op.desc.input_arg_names():
if input_name in recompute_to_fp16:
op._rename_input(input_name, recompute_to_fp16[input_name])
for output_name in op.desc.output_arg_names():
if output_name in recompute_to_fp16:
op._rename_output(output_name,
recompute_to_fp16[output_name])
# step4: remove recompute_param
for name in recompute_to_fp16.keys():
block._remove_var(name, sync=False)
# step5: startup_block add offload
visited_vars = set()
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
continue
if out_name in param_name_to_offload_name:
var_name = out_name
offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1, var_name,
offload_var_name)
self._insert_cast_op(startup_block, idx + 1, var_name,
param_to_fp16[var_name])
visited_vars.add(out_name)
block._sync_with_cpp()
startup_block._sync_with_cpp()
def offload(self, block, startup_block):
"""
(m1, m2) = prefetch(m1@offload, m2@offload)
(m1out, m2out, pout) = adam(m1, m2, p)
(m1@offload, m2@offload) = memcpy(m1, m2)
"""
vars_name_to_offload_name = dict()
# main_block add offload
for idx, op in reversed(list(enumerate(block.ops))):
if not is_optimizer_op(op):
break
vars_name = []
if op.type == "adam":
# {Moment1Out = [''], Moment2Out = [''], ParamOut = ['']} =
# adam(inputs={Moment1 = [''], Moment2 = [''], Param = ['']})
vars_name.append(op.desc.input("Moment1")[0])
vars_name.append(op.desc.input("Moment2")[0])
elif op.type == 'momentum':
pass
elif op.type == 'lars':
pass
elif op.type == 'lamb':
pass
# step1: create and init offload_var
for var_name in vars_name:
assert var_name not in vars_name_to_offload_name
offload_var_name = self._get_offload_var_name(var_name)
vars_name_to_offload_name[var_name] = offload_var_name
self._create_offload_var(var_name, offload_var_name,
[block, startup_block])
# step2: insert offload op
for var_name in vars_name:
offload_var_name = vars_name_to_offload_name[var_name]
self._insert_offload_op(block, idx + 1, var_name,
offload_var_name)
# step3: insert fetch op
for var_name in vars_name:
offload_var_name = vars_name_to_offload_name[var_name]
self._insert_fetch_op(block, idx, offload_var_name, var_name)
# startup_block add offload
visited_vars = set()
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
continue
if out_name in vars_name_to_offload_name:
var_name = out_name
offload_var_name = vars_name_to_offload_name[var_name]
# insert offload op after var is generated
self._insert_offload_op(startup_block, idx + 1, var_name,
offload_var_name)
visited_vars.add(out_name)
block._sync_with_cpp()
startup_block._sync_with_cpp()
...@@ -126,6 +126,9 @@ class ProgramDeps(object): ...@@ -126,6 +126,9 @@ class ProgramDeps(object):
def should_remove_op(self, op_idx): def should_remove_op(self, op_idx):
op = self._block.ops[op_idx] op = self._block.ops[op_idx]
# remove check_finite_and_unscale op if its input 'X' is empty
if op.type == 'check_finite_and_unscale' and len(op.input('X')) == 0:
return True
for output_name in op.desc.output_arg_names(): for output_name in op.desc.output_arg_names():
if output_name not in self._should_removed_var: if output_name not in self._should_removed_var:
return False return False
......
...@@ -28,17 +28,20 @@ def check_broadcast(block): ...@@ -28,17 +28,20 @@ def check_broadcast(block):
if the broadcasted var has a fill_constant op, the fill_constant if the broadcasted var has a fill_constant op, the fill_constant
op should stay forward before the broadcast op, and before a op should stay forward before the broadcast op, and before a
sync_calc op. Otherwise, raise error. sync_calc op. Otherwise, raise error.
should ignore and skip broadcast_op of inner_parallelism (e.g. Megatron)
""" """
broadcast_vars = {} broadcast_vars = {}
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if op.type == "c_broadcast": if op.type == "c_broadcast":
if op.all_attrs()["use_calc_stream"] == False:
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name: if "@BroadCast" in var_name:
if var_name in broadcast_vars: if var_name in broadcast_vars:
raise ValueError("var_name areadly exist: {}" raise ValueError("var_name areadly exist: {}"
"the old pos is {}, the new pos is {}". "the old pos is {}, the new pos is {}".
format(var_name, broadcast_vars[var_name][ format(var_name, broadcast_vars[
"broadcast_pos"], idx)) var_name]["broadcast_pos"], idx))
broadcast_vars[var_name] = { broadcast_vars[var_name] = {
"fill_constant_pos": -1, "fill_constant_pos": -1,
"broadcast_pos": idx, "broadcast_pos": idx,
...@@ -61,6 +64,7 @@ def check_broadcast(block): ...@@ -61,6 +64,7 @@ def check_broadcast(block):
last_sync_calc_op_idx = idx last_sync_calc_op_idx = idx
continue continue
if op.type == "c_broadcast": if op.type == "c_broadcast":
if op.all_attrs()["use_calc_stream"] == False:
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name: if "@BroadCast" in var_name:
if broadcast_vars[var_name]["fill_constant_pos"] != -1: if broadcast_vars[var_name]["fill_constant_pos"] != -1:
...@@ -78,25 +82,29 @@ def check_broadcast(block): ...@@ -78,25 +82,29 @@ def check_broadcast(block):
return return
def check_allreduce_sum(block, shard, dp_ring_id=-1): def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1):
""" """
the op order should be: the op order should be:
grad: grad:
- 0: op that generate Var - 0: op that generate Var
- 1: sync_calc - 1: sync_calc
- 2: allreduce_sum_sharding - 2: reduce_sum_sharding (allreduce --> reduce)
- 3: sync_comm - 3: sync_comm
- 4: allreuce_sum_dp (dp_grads) - 4: allreuce_sum_dp (dp_grads)
- 5: sync_comm (dp_grads) - 5: sync_comm (dp_grads)
- 6: op that use Var (dp_grads & sum) - 6: op that use Var (dp_grads & sum)
should ignore and skip allreduce_op of inner_parallelism (e.g. Megatron)
""" """
vars_status = {} vars_status = {}
dp_grads_status = {} dp_grads_status = {}
idx_last_grad_allreduce = -1 idx_last_grad_allreduce = -1
idx_amp_allreduce = -1 idx_amp_allreduce = -1
idx_gradient_clip_allreduce = -1 idx_gradient_clip_allreduce = -1
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if op.type == "c_allreduce_sum": if op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
if op.all_attrs()["use_calc_stream"] == False:
ring_id = op.desc.attr("ring_id") ring_id = op.desc.attr("ring_id")
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
param = var_name.split("@")[0] param = var_name.split("@")[0]
...@@ -107,7 +115,7 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1): ...@@ -107,7 +115,7 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
else: else:
dp_grads_status[var_name] = -1 dp_grads_status[var_name] = -1
if ring_id != 0: if ring_id != sharding_ring_id:
assert shard.has_param(param) assert shard.has_param(param)
assert ring_id == dp_ring_id assert ring_id == dp_ring_id
...@@ -129,17 +137,20 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1): ...@@ -129,17 +137,20 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
var_name] == 0: var_name] == 0:
dp_grads_status[var_name] = 1 dp_grads_status[var_name] = 1
elif op.type == "c_allreduce_sum": elif op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
if op.all_attrs()["use_calc_stream"] == False:
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id") ring_id = op.desc.attr("ring_id")
if ring_id == 0: if ring_id == sharding_ring_id:
assert op.type == "c_reduce_sum", "Grad in Sharding group should be reduce rather than allreduce"
if var_name in vars_status: if var_name in vars_status:
_status = vars_status[var_name] _status = vars_status[var_name]
else: else:
_status = dp_grads_status[var_name] _status = dp_grads_status[var_name]
if _status == -1: if _status == -1:
raise ValueError("{} is not generated, but you are" raise ValueError("{} is not generated, but you are"
"trying to all-reduce it".format(var_name)) "trying to all-reduce it".format(
var_name))
if _status == 0: if _status == 0:
raise ValueError("There should be a sync_calc op " raise ValueError("There should be a sync_calc op "
"after generate Var: {} and before the" "after generate Var: {} and before the"
...@@ -159,7 +170,7 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1): ...@@ -159,7 +170,7 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
elif op.type == "c_sync_comm_stream": elif op.type == "c_sync_comm_stream":
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id") ring_id = op.desc.attr("ring_id")
if ring_id == 0: if ring_id == sharding_ring_id:
for var_name in op.desc.input_arg_names(): for var_name in op.desc.input_arg_names():
if var_name in vars_status: if var_name in vars_status:
assert vars_status[var_name] == 2 assert vars_status[var_name] == 2
...@@ -181,6 +192,9 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1): ...@@ -181,6 +192,9 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
raise ValueError("There should be a sync_comm op " raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".format( "after allreduce the Var: {}".format(
input_name)) input_name))
raise ValueError(
"The reduce output grad [{}] should NOT be be used in Non-root rank.".
format(input_name))
if input_name in dp_grads_status: if input_name in dp_grads_status:
if dp_ring_id == -1: if dp_ring_id == -1:
if dp_grads_status[input_name] != 3: if dp_grads_status[input_name] != 3:
...@@ -225,6 +239,13 @@ def get_valid_op_role(block, insert_idx): ...@@ -225,6 +239,13 @@ def get_valid_op_role(block, insert_idx):
return get_valid_op_role(block, insert_idx + 1) return get_valid_op_role(block, insert_idx + 1)
# if insert_idx >= len(block.ops): return OpRole.Optimize
# if op_role == int(OpRole.Backward): return OpRole.Backward
# if op_role == int(OpRole.Optimize): return OpRole.Optimize
# if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
# return OpRole.Forward
# return get_valid_op_role(block, insert_idx + 1)
def insert_sync_calc_op(block, insert_idx, calc_dep_vars): def insert_sync_calc_op(block, insert_idx, calc_dep_vars):
""" """
...@@ -259,6 +280,9 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars): ...@@ -259,6 +280,9 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
""" """
insert sync_comm_op for vars insert sync_comm_op for vars
""" """
if len(comm_dep_vars) == 0:
return 0
op_role = get_valid_op_role(block, insert_idx) op_role = get_valid_op_role(block, insert_idx)
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
...@@ -313,6 +337,9 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars): ...@@ -313,6 +337,9 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
""" """
_add_allreduce_ops _add_allreduce_ops
""" """
if len(allreduce_vars) == 0:
return
for var in allreduce_vars: for var in allreduce_vars:
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
...@@ -325,6 +352,62 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars): ...@@ -325,6 +352,62 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
return return
def get_grad_device(grad_name, shard):
assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
grad_name)
base_name = None
# mind the traversal order
possible_suffixes = [
'.cast_fp16@GRAD@MERGED', '.cast_fp16@GRAD', '@GRAD@MERGED', '@GRAD'
]
for suffix in possible_suffixes:
if suffix in grad_name:
base_name = re.sub(suffix, '', grad_name)
break
assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(
base_name)
return shard.global_param2device[base_name]
def insert_reduce_ops(block,
insert_idx,
ring_id,
reduce_vars,
shard,
op_role=OpRole.Backward,
use_calc_stream=False):
"""
_add_allreduce_ops
"""
for var in reduce_vars:
root_id = get_grad_device(var, shard)
assert root_id >= 0, "root id should be a positive int".format(var)
block._insert_op_without_sync(
insert_idx,
type='c_reduce_sum',
inputs={'X': var},
outputs={'Out': var},
attrs={
'ring_id': ring_id,
'root_id': root_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
return
def get_first_check_finite_and_unscale_op_idx(block):
for idx, op in enumerate(block.ops):
if op.type == "check_finite_and_unscale":
return idx
raise ValueError("check_finite_and_unscale does not exist in block")
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
""" """
_add_broadcast_ops _add_broadcast_ops
...@@ -428,7 +511,7 @@ def comm_analyse(main_program): ...@@ -428,7 +511,7 @@ def comm_analyse(main_program):
count)) count))
def add_sync_comm(program, dist_strategy): def add_sync_comm(program, nccl_ids):
""" """
When clone a test prog by clone from the sharding main prog, When clone a test prog by clone from the sharding main prog,
part of the sync_comm op maybe be pruned by mistake, this function part of the sync_comm op maybe be pruned by mistake, this function
...@@ -438,6 +521,9 @@ def add_sync_comm(program, dist_strategy): ...@@ -438,6 +521,9 @@ def add_sync_comm(program, dist_strategy):
#NOTE (liangjianzhong): only support one comm stream by now, use more than one #NOTE (liangjianzhong): only support one comm stream by now, use more than one
# comm streams will cause error. should be revise in future. # comm streams will cause error. should be revise in future.
assert isinstance(
nccl_ids, list
), "the second argument of this function should be a list of nccl_ids"
block = program.global_block() block = program.global_block()
not_sync_vars = set([]) not_sync_vars = set([])
for op in block.ops: for op in block.ops:
...@@ -448,7 +534,7 @@ def add_sync_comm(program, dist_strategy): ...@@ -448,7 +534,7 @@ def add_sync_comm(program, dist_strategy):
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
not_sync_vars.remove(input_name) not_sync_vars.remove(input_name)
if not_sync_vars: if not_sync_vars:
for nccl_id in range(dist_strategy.nccl_comm_num): for nccl_id in nccl_ids:
block.append_op( block.append_op(
type='c_sync_comm_stream', type='c_sync_comm_stream',
inputs={'X': list(not_sync_vars)}, inputs={'X': list(not_sync_vars)},
...@@ -467,6 +553,9 @@ def save_persistables(exe, dirname, main_program, filename=None): ...@@ -467,6 +553,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
This function handles the model saving for sharding training. This function handles the model saving for sharding training.
""" """
if main_program._pipeline_opt:
main_program = main_program._pipeline_opt['section_program']['program']
def is_opt_vars(var): def is_opt_vars(var):
# NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer # NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer
# now only Momentum and adam are compatible with sharding # now only Momentum and adam are compatible with sharding
......
...@@ -115,7 +115,7 @@ class ProgramStats(object): ...@@ -115,7 +115,7 @@ class ProgramStats(object):
updated_min_idx = min_idx updated_min_idx = min_idx
while idx_ > pre_segment_end_idx: while idx_ > pre_segment_end_idx:
if is_amp_cast(self.ops[idx_]): if is_amp_cast(self.ops[idx_]):
_logger.debug("found amp-cast op: {}, : {}".format(self.ops[ _logger.info("found amp-cast op: {}, : {}".format(self.ops[
idx_].desc.type(), self.ops[idx_].desc.input_arg_names()[ idx_].desc.type(), self.ops[idx_].desc.input_arg_names()[
0])) 0]))
updated_min_idx = idx_ updated_min_idx = idx_
...@@ -155,7 +155,7 @@ class ProgramStats(object): ...@@ -155,7 +155,7 @@ class ProgramStats(object):
sorted_checkpoints = [] sorted_checkpoints = []
for name in checkpoints_name: for name in checkpoints_name:
if name not in self.var_op_deps: if name not in self.var_op_deps:
_logger.debug( _logger.info(
"Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program." "Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program."
% name) % name)
elif self.var_op_deps[name]["var_as_output_ops"] == []: elif self.var_op_deps[name]["var_as_output_ops"] == []:
...@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars): ...@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
new_op_desc = block.desc.append_op() new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc) new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward) new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc) result_descs.append(new_op_desc)
return result_descs return result_descs
...@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block): ...@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block):
new_op_desc = block.desc.append_op() new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc) new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward) new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc) result_descs.append(new_op_desc)
return result_descs return result_descs
...@@ -784,7 +788,6 @@ def _append_backward_ops_with_checkpoints_( ...@@ -784,7 +788,6 @@ def _append_backward_ops_with_checkpoints_(
start_idx = 0 start_idx = 0
pre_segment_end_idx = -1 pre_segment_end_idx = -1
while True: while True:
_logger.debug("FW op range[0] - [{}]".format(len(ops)))
if start_idx >= len(checkpoints_name) - 1: if start_idx >= len(checkpoints_name) - 1:
break break
# min_idx: checkpoint_1' s input op # min_idx: checkpoint_1' s input op
...@@ -797,6 +800,9 @@ def _append_backward_ops_with_checkpoints_( ...@@ -797,6 +800,9 @@ def _append_backward_ops_with_checkpoints_(
min_idx = program_stat._update_segment_start( min_idx = program_stat._update_segment_start(
min_idx, pre_segment_end_idx) min_idx, pre_segment_end_idx)
segments.append([min_idx, max_idx + 1]) segments.append([min_idx, max_idx + 1])
else:
_logger.info("Could not recompute op range [{}] - [{}] ".format(
min_idx, max_idx + 1))
start_idx += 1 start_idx += 1
...@@ -806,15 +812,15 @@ def _append_backward_ops_with_checkpoints_( ...@@ -806,15 +812,15 @@ def _append_backward_ops_with_checkpoints_(
recompute_segments = segments recompute_segments = segments
for i, (idx1, idx2) in enumerate(recompute_segments): for i, (idx1, idx2) in enumerate(recompute_segments):
_logger.debug("recompute segment[{}]".format(i)) _logger.info("recompute segment[{}]".format(i))
_logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( _logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
), ops[idx1].desc.input_arg_names())) ), ops[idx1].desc.input_arg_names()))
_logger.debug("segment end op: [{}]: [{}]".format(ops[ _logger.info("segment end op: [{}]: [{}]".format(ops[
idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names())) idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names()))
_logger.debug("recompute segment[{}]".format(i)) _logger.info("recompute segment[{}]".format(i))
_logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( _logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
), ops[idx1].desc.input_arg_names())) ), ops[idx1].desc.input_arg_names()))
_logger.debug("segment end op: [{}]: [{}]".format(ops[ _logger.info("segment end op: [{}]: [{}]".format(ops[
idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names())) idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names()))
# 2) go through all forward ops and induct all variables that will be hold in memory # 2) go through all forward ops and induct all variables that will be hold in memory
...@@ -825,9 +831,9 @@ def _append_backward_ops_with_checkpoints_( ...@@ -825,9 +831,9 @@ def _append_backward_ops_with_checkpoints_(
program_stat.get_out_of_subgraph_vars(segment[0], segment[1])) program_stat.get_out_of_subgraph_vars(segment[0], segment[1]))
cross_vars = set(vars_should_be_hold) - set(checkpoints_name) cross_vars = set(vars_should_be_hold) - set(checkpoints_name)
_logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ _logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
len(cross_vars), cross_vars)) len(cross_vars), cross_vars))
_logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ _logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
len(cross_vars), cross_vars)) len(cross_vars), cross_vars))
# b. output of seed op should be kept in memory # b. output of seed op should be kept in memory
...@@ -843,6 +849,7 @@ def _append_backward_ops_with_checkpoints_( ...@@ -843,6 +849,7 @@ def _append_backward_ops_with_checkpoints_(
vars_in_memory = vars_should_be_hold + checkpoints_name vars_in_memory = vars_should_be_hold + checkpoints_name
max_calculated_op_position = len(ops) max_calculated_op_position = len(ops)
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
if recompute_segments == []: if recompute_segments == []:
gap_ops = ops[0:max_calculated_op_position] gap_ops = ops[0:max_calculated_op_position]
for op in reversed(gap_ops): for op in reversed(gap_ops):
...@@ -852,6 +859,11 @@ def _append_backward_ops_with_checkpoints_( ...@@ -852,6 +859,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_(op.desc, "with_sub_block")) _pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), []) op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block) added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs) grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
...@@ -866,6 +878,11 @@ def _append_backward_ops_with_checkpoints_( ...@@ -866,6 +878,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_(op.desc, "with_sub_block")) _pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), []) op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block) added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs) grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
...@@ -888,6 +905,18 @@ def _append_backward_ops_with_checkpoints_( ...@@ -888,6 +905,18 @@ def _append_backward_ops_with_checkpoints_(
continue continue
if name not in var_name_dict: if name not in var_name_dict:
var_name_dict[name] = name + var_suffix var_name_dict[name] = name + var_suffix
# we should create the rename var in subprog, otherwise its VarType will be BOOL
block.create_var(
name=var_name_dict[name],
shape=block.program.global_block().var(name).shape,
dtype=block.program.global_block().var(name).dtype,
type=block.program.global_block().var(name).type,
persistable=block.program.global_block().var(
name).persistable,
stop_gradient=block.program.global_block().var(name)
.stop_gradient)
# 3.a. add ops in current recompute_segment as forward recomputation ops # 3.a. add ops in current recompute_segment as forward recomputation ops
buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block, buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block,
vars_in_memory) vars_in_memory)
......
...@@ -489,9 +489,14 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -489,9 +489,14 @@ class ClipGradByGlobalNorm(ClipGradBase):
continue continue
with p.block.program._optimized_guard([p, g]): with p.block.program._optimized_guard([p, g]):
new_grad = layers.elementwise_mul(x=g, y=scale_var) p.block.append_op(
param_new_grad_name_dict[p.name] = new_grad.name type='elementwise_mul',
params_and_grads.append((p, new_grad)) inputs={'X': g,
'Y': scale_var},
outputs={'Out': g})
param_new_grad_name_dict[p.name] = p.name
params_and_grads.append((p, p))
_correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict) _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
return params_and_grads return params_and_grads
......
...@@ -123,7 +123,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -123,7 +123,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
outputs={"Out": out_var}, outputs={"Out": out_var},
attrs={ attrs={
"in_dtype": in_var.dtype, "in_dtype": in_var.dtype,
"out_dtype": out_var.dtype "out_dtype": out_var.dtype,
"op_device": op.attr("op_device")
}) })
num_cast_ops += 1 num_cast_ops += 1
_rename_arg(op, in_var.name, out_var.name) _rename_arg(op, in_var.name, out_var.name)
...@@ -171,8 +172,11 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name, ...@@ -171,8 +172,11 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
type="cast", type="cast",
inputs={"X": target_var}, inputs={"X": target_var},
outputs={"Out": cast_var}, outputs={"Out": cast_var},
attrs={"in_dtype": target_var.dtype, attrs={
"out_dtype": cast_var.dtype}) "in_dtype": target_var.dtype,
"out_dtype": cast_var.dtype,
"op_device": op.attr("op_device")
})
num_cast_ops += 1 num_cast_ops += 1
op_var_rename_map[block.idx][target_var.name] = cast_var.name op_var_rename_map[block.idx][target_var.name] = cast_var.name
......
...@@ -413,6 +413,9 @@ class Section(DeviceWorker): ...@@ -413,6 +413,9 @@ class Section(DeviceWorker):
section_param = trainer_desc.section_param section_param = trainer_desc.section_param
section_param.num_microbatches = pipeline_opt["num_microbatches"] section_param.num_microbatches = pipeline_opt["num_microbatches"]
section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"] section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"]
section_param.pipeline_stage = pipeline_opt["pipeline_stage"]
section_param.num_pipeline_stages = pipeline_opt["num_pipeline_stages"]
section_param.schedule_mode = pipeline_opt["schedule_mode"]
cfg = section_param.section_config cfg = section_param.section_config
program = pipeline_opt["section_program"] program = pipeline_opt["section_program"]
cfg.program_desc.ParseFromString(program["program"]._get_desc() cfg.program_desc.ParseFromString(program["program"]._get_desc()
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册