未验证 提交 228bce12 编写于 作者: L lilong12 提交者: GitHub

Add 3d parallelism (#31796)

Add 3d Parallelism
Co-authored-by: NWangXi <wangxi16@baidu.com>
Co-authored-by: NJZ-LIANG <jianzhongliang10@gmail.com>
Co-authored-by: Nroot <root@yq01-sys-hic-k8s-v100-box-a225-0562.yq01.baidu.com>
上级 594bbcb1
......@@ -28,6 +28,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/heter_service.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -451,7 +452,7 @@ class HeterBoxWorker : public HogwildWorker {
virtual void CacheProgram(const ProgramDesc& main_program) {
new (&program_) ProgramDesc(main_program);
}
virtual void ProduceTasks() override;
void ProduceTasks() override;
virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; }
virtual void SetEvent(const cudaEvent_t event) { event_ = event; }
virtual void TrainFilesWithProfiler() {}
......@@ -550,7 +551,7 @@ class PSGPUWorker : public HogwildWorker {
virtual void CacheProgram(const ProgramDesc& main_program) {
new (&program_) ProgramDesc(main_program);
}
virtual void ProduceTasks() override;
void ProduceTasks() override;
virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; }
virtual void SetEvent(const cudaEvent_t event) { event_ = event; }
virtual void TrainFilesWithProfiler() {}
......@@ -654,6 +655,9 @@ class SectionWorker : public DeviceWorker {
void SetDeviceIndex(int tid) override {}
void SetThreadIndex(int thread_id) { thread_id_ = thread_id; }
void SetMicrobatchNum(int num) { num_microbatches_ = num; }
void SetPipelineStageNum(int num) { num_pipeline_stages_ = num; }
void SetPipelineStage(int stage) { pipeline_stage_ = stage; }
void SetScheduleMode(int mode) { schedule_mode_ = mode; }
void SetMicrobatchScopes(const std::vector<Scope*>& scope) {
microbatch_scopes_ = scope;
}
......@@ -661,11 +665,23 @@ class SectionWorker : public DeviceWorker {
void SetSkipVars(const std::vector<std::string>& skip_vars) {
skip_vars_ = skip_vars;
}
void RunBackward(
int micro_id, std::unique_ptr<GarbageCollector>&,
std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
void RunForward(
int micro_id, std::unique_ptr<GarbageCollector>&,
std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
void RunUpdate(
std::unique_ptr<GarbageCollector>&,
std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
protected:
int section_id_;
int thread_id_;
int num_microbatches_;
int num_pipeline_stages_;
int pipeline_stage_;
int schedule_mode_; // 0 for GPipe and 1 for deepspeed
std::vector<Scope*> microbatch_scopes_;
std::vector<std::string> skip_vars_;
const Scope* minibatch_scope_;
......
......@@ -32,6 +32,14 @@ message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ];
optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_group_size = 3 [ default = 8 ];
optional bool as_outer_parallelism = 4 [ default = false ];
optional int32 parallelism = 5 [ default = 1 ];
optional bool use_pipeline = 6 [ default = false ];
optional int32 acc_steps = 7 [ default = 1 ];
optional int32 schedule_mode = 8 [ default = 0 ];
optional int32 pp_bz = 9 [ default = 1 ];
optional bool pp_allreduce_in_optimize = 10 [ default = false ];
optional bool optimize_offload = 11 [ default = false ];
}
message AMPConfig {
......@@ -44,6 +52,8 @@ message AMPConfig {
repeated string custom_white_list = 7;
repeated string custom_black_list = 8;
repeated string custom_black_varnames = 9;
optional bool use_pure_fp16 = 10 [ default = false ];
optional bool use_fp16_guard = 11 [ default = true ];
}
message LocalSGDConfig {
......@@ -117,6 +127,8 @@ message AsyncConfig {
message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; }
message ModelParallelConfig { optional int32 parallelism = 1 [ default = 1 ]; }
message DistributedStrategy {
// bool options
optional Mode mode = 1 [ default = COLLECTIVE ];
......@@ -140,12 +152,13 @@ message DistributedStrategy {
optional int32 fuse_grad_size_in_MB = 19 [ default = 32 ];
optional float fuse_grad_size_in_TFLOPS = 20 [ default = 50 ];
optional bool cudnn_exhaustive_search = 21 [ default = true ];
optional int32 conv_workspace_size_limit = 22 [ default = 4000 ];
optional int32 conv_workspace_size_limit = 22 [ default = 512 ];
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ];
optional bool adaptive_localsgd = 24 [ default = false ];
optional bool fp16_allreduce = 25 [ default = false ];
optional bool sharding = 26 [ default = false ];
optional float last_comm_group_size_MB = 27 [ default = 1 ];
optional bool model_parallel = 28 [ default = false ];
optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
......@@ -158,6 +171,7 @@ message DistributedStrategy {
optional LambConfig lamb_configs = 109;
optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110;
optional ShardingConfig sharding_configs = 111;
optional ModelParallelConfig model_parallel_configs = 112;
optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202;
}
......
......@@ -25,6 +25,9 @@ namespace framework {
void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
const auto& section_params = trainer_desc.section_param();
const auto num_pipeline_stages_ = section_params.num_pipeline_stages();
const auto pipeline_stage_ = section_params.pipeline_stage();
const auto schedule_mode_ = section_params.schedule_mode();
num_microbatches_ = section_params.num_microbatches();
VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_;
trainer_desc_ = trainer_desc;
......@@ -40,6 +43,9 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
this_worker->SetPlace(place_);
this_worker->Initialize(trainer_desc);
this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
this_worker->SetScheduleMode(schedule_mode_);
}
void PipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) {
......@@ -76,7 +82,10 @@ void PipelineTrainer::CopyParameters(int microbatch_id,
for (auto& var : global_block.AllVars()) {
bool is_param_grad = false;
size_t pos = 0;
if ((pos = var->Name().find(kGradVarSuffix)) != std::string::npos) {
// A magic suffix to indicated the merged gradient.
std::string magicSuffix = "MERGED";
if ((pos = var->Name().find(kGradVarSuffix)) != std::string::npos &&
var->Name().find(magicSuffix) != std::string::npos) {
auto prefix_name = var->Name().substr(0, pos);
if (param_map.find(prefix_name) != param_map.end()) {
is_param_grad = true;
......
......@@ -11,36 +11,90 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL)
#include <float.h>
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/program_desc.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/lodtensor_printer.h"
namespace paddle {
namespace framework {
class TrainerDesc;
uint64_t SectionWorker::batch_id_(0);
void SectionWorker::Initialize(const TrainerDesc& desc) {
void SectionWorker::Initialize(const TrainerDesc &desc) {
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
program_.reset(
new ProgramDesc(desc.section_param().section_config().program_desc()));
for (auto& op_desc : program_->Block(0).AllOps()) {
for (auto &op_desc : program_->Block(0).AllOps()) {
ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
}
void SectionWorker::RunForward(
int micro_id, std::unique_ptr<GarbageCollector> &gc,
std::unordered_map<const OperatorBase *, std::vector<std::string>>
&unused_vars_) {
for (auto &op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
// We run op with op_role = kLRSched only for the first microbatch
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss)) ||
op_role == static_cast<int>(OpRole::kLRSched);
bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss));
if ((micro_id == 0 && run_first_mbatch) || (micro_id != 0 && run_others)) {
VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
<< micro_id;
op->Run(*microbatch_scopes_[micro_id], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(),
unused_vars_, gc.get());
}
}
}
}
void SectionWorker::RunBackward(
int micro_id, std::unique_ptr<GarbageCollector> &gc,
std::unordered_map<const OperatorBase *, std::vector<std::string>>
&unused_vars_) {
for (auto &op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kBackward) ||
op_role == (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss))) {
VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
<< micro_id;
op->Run(*microbatch_scopes_[micro_id], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(),
unused_vars_, gc.get());
}
}
}
}
void SectionWorker::RunUpdate(
std::unique_ptr<GarbageCollector> &gc,
std::unordered_map<const OperatorBase *, std::vector<std::string>>
&unused_vars_) {
for (auto &op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kOptimize)) {
VLOG(3) << "Update: running op " << op->Type();
op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
op.get(), unused_vars_, gc.get());
}
}
}
}
void SectionWorker::TrainFiles() {
VLOG(5) << "begin section_worker TrainFiles";
......@@ -58,61 +112,49 @@ void SectionWorker::TrainFiles() {
#endif
}
for (int i = 0; i < num_microbatches_; ++i) {
for (auto& op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
// We run op with op_role = kLRSched only for the first microbatch
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss)) ||
op_role == static_cast<int>(OpRole::kLRSched);
bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss));
if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
<< i;
op->Run(*microbatch_scopes_[i], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
gc.get());
}
}
if (schedule_mode_ == 0) {
// Gpipe scheduler which runs all forwards first, then backwards, then
// update
// step1: run forward
for (int i = 0; i < num_microbatches_; ++i) {
RunForward(i, gc, unused_vars_);
}
cudaDeviceSynchronize();
}
// backward pass
for (int i = 0; i < num_microbatches_; ++i) {
for (auto& op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kBackward) ||
op_role == (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss))) {
VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
<< i;
op->Run(*microbatch_scopes_[i], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
gc.get());
}
}
// step2: run backward
for (int i = 0; i < num_microbatches_; ++i) {
RunBackward(i, gc, unused_vars_);
}
// step2: run update
RunUpdate(gc, unused_vars_);
} else {
// 1F1B scheduler
auto startup_steps = num_pipeline_stages_ - pipeline_stage_ - 1;
VLOG(3) << "startup_steps:" << startup_steps
<< ", num_stages: " << num_pipeline_stages_
<< ", stage:" << pipeline_stage_;
if (startup_steps > num_microbatches_) {
startup_steps = num_microbatches_;
}
int fw_step = 0;
int bw_step = 0;
// startup phase
while (fw_step < startup_steps) {
RunForward(fw_step, gc, unused_vars_);
fw_step += 1;
}
cudaDeviceSynchronize();
}
// update pass
for (auto& op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kOptimize)) {
VLOG(3) << "Update: running op " << op->Type();
op->Run(*microbatch_scopes_[0], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[0], op.get(), unused_vars_,
gc.get());
}
// 1f1b phase
while (fw_step < num_microbatches_) {
RunForward(fw_step, gc, unused_vars_);
fw_step += 1;
RunBackward(bw_step, gc, unused_vars_);
bw_step += 1;
}
// backward phase
while (bw_step < num_microbatches_) {
RunBackward(bw_step, gc, unused_vars_);
bw_step += 1;
}
RunUpdate(gc, unused_vars_);
}
dev_ctx_->Wait();
++batch_id_;
......
......@@ -93,6 +93,9 @@ message SectionWorkerParameter {
optional int32 start_cpu_core_id = 4 [ default = 1 ];
repeated string param_need_sync = 5;
optional int32 num_microbatches = 6;
optional int32 num_pipeline_stages = 7 [ default = 1 ];
optional int32 pipeline_stage = 8 [ default = 1 ];
optional int32 schedule_mode = 9 [ default = 0 ];
}
message SectionConfig {
......
......@@ -55,7 +55,8 @@ class CGenNCCLIdOp : public framework::OperatorBase {
SendBroadCastNCCLID(endpoint_list, 1, func, local_scope);
} else {
std::string endpoint = Attr<std::string>("endpoint");
RecvBroadCastNCCLID(endpoint, 1, func, local_scope);
int server_fd = platform::SocketServer::GetInstance(endpoint).socket();
platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids);
}
scope.DeleteScope(&local_scope);
}
......@@ -71,8 +72,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
}
const platform::Place& dev_place) const override {}
};
#endif
......
......@@ -31,7 +31,9 @@ limitations under the License. */
#include "paddle/fluid/string/split.h"
namespace paddle {
namespace operators {
namespace platform {
std::once_flag SocketServer::init_flag_;
constexpr char COMM_HEAD[] = "_pd_gen_comm_id_";
......@@ -340,5 +342,34 @@ void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num,
CloseSocket(client);
}
} // namespace operators
SocketServer& SocketServer::GetInstance(const std::string& end_point) {
static SocketServer instance;
std::call_once(init_flag_, [&]() {
instance.server_fd_ = CreateListenSocket(end_point);
instance.end_point_ = end_point;
});
PADDLE_ENFORCE_NE(instance.server_fd_, -1,
platform::errors::Unavailable(
"listen socket failed with end_point=%s", end_point));
PADDLE_ENFORCE_EQ(instance.end_point_, end_point,
platform::errors::InvalidArgument(
"old end_point=%s must equal with new end_point=%s",
instance.end_point_, end_point));
return instance;
}
/// template instantiation
#define INSTANT_TEMPLATE(Type) \
template void SendBroadCastCommID<Type>(std::vector<std::string> servers, \
std::vector<Type> * nccl_ids); \
template void RecvBroadCastCommID<Type>(std::string endpoint, \
std::vector<Type> * nccl_ids);
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
INSTANT_TEMPLATE(ncclUniqueId)
#endif
#ifdef PADDLE_WITH_XPU_BKCL
INSTANT_TEMPLATE(BKCLUniqueId)
#endif
} // namespace platform
} // namespace paddle
......@@ -15,6 +15,8 @@ limitations under the License. */
#pragma once
#include <functional>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <vector>
......@@ -25,7 +27,7 @@ class Scope;
} // namespace paddle
namespace paddle {
namespace operators {
namespace platform {
int CreateListenSocket(const std::string& ep);
......@@ -41,8 +43,26 @@ void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num,
const framework::Scope& scope);
// recv nccl id from socket
void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope);
} // namespace operators
template <typename CommUniqueId>
void RecvBroadCastCommID(int server_fd, std::string endpoint,
std::vector<CommUniqueId>* nccl_ids);
class SocketServer {
public:
SocketServer() = default;
~SocketServer() { CloseSocket(server_fd_); }
int socket() const { return server_fd_; }
static SocketServer& GetInstance(const std::string& end_point);
private:
int server_fd_{-1};
std::string end_point_;
static std::once_flag init_flag_;
};
} // namespace platform
} // namespace paddle
......@@ -736,6 +736,60 @@ class DistributedStrategy(object):
"sharding_configs")
assign_configs_value(self.strategy.sharding_configs, configs)
@property
def model_parallel(self):
"""
Indicating whether we are using model parallel parallelism for distributed training.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.model_parallel = True
"""
return self.strategy.model_parallel
@model_parallel.setter
@is_strict_auto
def model_parallel(self, flag):
if isinstance(flag, bool):
self.strategy.model_parallel = flag
else:
print("WARNING: model_parallel should have value of bool type")
@property
def model_parallel_configs(self):
"""
Set model_parallel parallelism configurations.
**Notes**:
**Detailed arguments for model_parallel_configs**
**parallelism**: degree of model parallel
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.model_parallel = True
strategy.model_parallel_configs = {"parallelism": 12}
"""
return get_msg_dict(self.strategy.model_parallel_configs)
@model_parallel_configs.setter
@is_strict_auto
def model_parallel_configs(self, configs):
check_configs_key(self.strategy.model_parallel_configs, configs,
"model_parallel_configs")
assign_configs_value(self.strategy.model_parallel_configs, configs)
@property
def pipeline(self):
"""
......
......@@ -17,6 +17,7 @@ from .gradient_merge_optimizer import GradientMergeOptimizer
from .graph_execution_optimizer import GraphExecutionOptimizer
from .parameter_server_optimizer import ParameterServerOptimizer
from .pipeline_optimizer import PipelineOptimizer
from .model_parallel_optimizer import ModelParallelOptimizer
from .localsgd_optimizer import LocalSGDOptimizer
from .localsgd_optimizer import AdaptiveLocalSGDOptimizer
from .lars_optimizer import LarsOptimizer
......
......@@ -50,15 +50,17 @@ class AMPOptimizer(MetaOptimizerBase):
self.inner_opt, amp_lists, config['init_loss_scaling'],
config['incr_every_n_steps'], config['decr_every_n_nan_or_inf'],
config['incr_ratio'], config['decr_ratio'],
config['use_dynamic_loss_scaling'])
config['use_dynamic_loss_scaling'], config['use_pure_fp16'],
config['use_fp16_guard'])
# if worker_num > 1, all cards will communication with each other,
# add is_distributed to optimize amp, overlap communication and
# computation by split the check_finite_and_unscale op.
is_distributed = self.role_maker._worker_num() > 1
if self.user_defined_strategy.sharding:
# FIXME(wangxi). sharding failed when split check_finite_and_unscale
is_distributed = False
#if self.user_defined_strategy.sharding or self.user_defined_strategy.model_parallel:
# # FIXME(wangxi). sharding failed when split check_finite_and_unscale
# # FIXME(JZ-LIANG). To support Sharding-Megatron-AMP, Megatron should follow Sharding's behavior
# is_distributed = False
self.wrapped_opt._set_distributed(is_distributed)
def _can_apply(self):
......@@ -112,3 +114,11 @@ class AMPOptimizer(MetaOptimizerBase):
self.wrapped_opt.minimize(loss, startup_program,
parameter_list, no_grad_set)
return optimize_ops, params_grads
def amp_init(self,
place,
scope=None,
test_program=None,
use_fp16_test=False):
return self.wrapped_opt.amp_init(place, scope, test_program,
use_fp16_test)
......@@ -25,6 +25,24 @@ OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
OP_ROLE_VAR_KEY = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
class Topology:
"""A 4-D structure to describe the process group."""
def __init__(self, axes, dims):
pass
class ParallelGrid:
"""Initialize each process group."""
def __init__(self, topology):
self.build_global_group()
self.build_mp_group()
self.build_sharding_group()
self.build_pp_group()
self.build_dp_group()
def is_update_op(op):
return 'Param' in op.input_names and 'Grad' in op.input_names and \
"LearningRate" in op.input_names
......@@ -66,16 +84,49 @@ class CollectiveHelper(object):
self.role_maker._worker_index(), ring_id, self.wait_port)
self._broadcast_params()
def _init_communicator(self, program, current_endpoint, endpoints, rank,
ring_id, wait_port):
def _init_communicator(self,
program,
current_endpoint,
endpoints,
rank,
ring_id,
wait_port,
sync=True):
nranks = len(endpoints)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
block = program.global_block()
if core.is_compiled_with_cuda():
if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
nccl_id_var = block.create_var(
if not wait_port and sync:
temp_var = block.create_var(
name=unique_name.generate('temp_var'),
dtype=core.VarDesc.VarType.INT32,
persistable=False,
stop_gradient=True)
block.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [temp_var]},
attrs={
'shape': [1],
'dtype': temp_var.dtype,
'value': 1,
'force_cpu': False,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_allreduce_sum',
inputs={'X': [temp_var]},
outputs={'Out': [temp_var]},
attrs={'ring_id': 3,
OP_ROLE_KEY: OpRole.Forward})
block.append_op(
type='c_sync_comm_stream',
inputs={'X': temp_var},
outputs={'Out': temp_var},
attrs={'ring_id': 3,
OP_ROLE_KEY: OpRole.Forward})
comm_id_var = block.create_var(
name=unique_name.generate('nccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
......@@ -100,9 +151,7 @@ class CollectiveHelper(object):
OP_ROLE_KEY: OpRole.Forward
})
elif core.is_compiled_with_npu():
endpoint_to_index_map = {
e: idx for idx, e in enumerate(endpoints)
}
endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
block.append_op(
type='c_comm_init_hcom',
inputs={},
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from __future__ import print_function
from __future__ import division
import paddle.fluid as fluid
from paddle.fluid import core, unique_name
from ..base.private_helper_function import wait_server_ready
from .meta_optimizer_base import MetaOptimizerBase
from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_update_op, is_loss_grad_op, is_backward_op, is_optimizer_op
class ModelParallelHelper(object):
def __init__(self, role_maker, wait_port=True, megatron_dp=False):
self.wait_port = wait_port
self.role_maker = role_maker
self.megatron_dp = megatron_dp
def update_startup_program(self,
startup_program=None,
inner_parallelism=None):
self.startup_program = startup_program
nranks = self.role_maker._worker_num()
rank = self.role_maker._worker_index()
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[rank]
# Create ring 0 for all model parallel parts within a single model
mp_endpoints = []
mp_rank = rank % inner_parallelism
mp_id = rank // inner_parallelism
for idx, ep in enumerate(endpoints):
if idx // inner_parallelism == mp_id:
mp_endpoints.append(ep)
print("model parallel eps:{}, rank{}".format(mp_endpoints, mp_rank))
self._init_communicator(self.startup_program, current_endpoint,
mp_endpoints, mp_rank, 0, self.wait_port)
self._broadcast_params(0, broadcast_distributed_weight=False)
print("megatron group size: {}".format(inner_parallelism))
print("megatron rank: {}".format(mp_rank))
print("megatron endpoints: {}".format(mp_endpoints))
if self.megatron_dp:
mp_num = len(endpoints) // inner_parallelism
if mp_num == 1: return
# Create rings for gpus as the same model parallel part
eps = []
dp_rank = rank // inner_parallelism
dp_id = rank % inner_parallelism
#if dp_rank == 1: dp_rank =0
#if dp_rank == 0: dp_rank =1
ring_id = 1
for idx, ep in enumerate(endpoints):
if idx % inner_parallelism == dp_id:
eps.append(ep)
#ep = eps.pop(0)
#eps.insert(1, ep)
print("data parallel eps:{}, rank{}".format(eps, dp_rank))
self._init_communicator(self.startup_program, current_endpoint, eps,
dp_rank, ring_id, self.wait_port)
self._broadcast_params(ring_id, broadcast_distributed_weight=True)
def _init_communicator(self, program, current_endpoint, endpoints, rank,
ring_id, wait_port):
nranks = len(endpoints)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
block = program.global_block()
nccl_id_var = block.create_var(
name=unique_name.generate('nccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_nccl_id',
inputs={},
outputs={'Out': nccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
OP_ROLE_KEY: OpRole.Forward,
})
block.append_op(
type='c_comm_init',
inputs={'X': nccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward,
})
def _broadcast_params(self, ring_id, broadcast_distributed_weight):
block = self.startup_program.global_block()
for param in block.iter_parameters():
if not broadcast_distributed_weight and param.is_distributed:
continue
block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': ring_id,
'root': 0,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_sync_comm_stream',
inputs={'X': param},
outputs={'Out': param},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward})
class ModelParallelOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
super(ModelParallelOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
self.meta_optimizers_white_list = [
"RecomputeOptimizer",
"AMPOptimizer",
"LarsOptimizer",
"LambOptimizer",
]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
self.megatron_dp = False
def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
user_defined_strategy):
super(ModelParallelOptimizer, self)._set_basic_info(
loss, role_maker, user_defined_optimizer, user_defined_strategy)
self.inner_parallelism = user_defined_strategy.model_parallel_configs[
'parallelism']
def _can_apply(self):
if not self.role_maker._is_collective:
return False
if self.user_defined_strategy.model_parallel == True:
return True
return False
def _disable_strategy(self, dist_strategy):
dist_strategy.model_parallel = False
dist_strategy.model_parallel_configs = {}
def _enable_strategy(self, dist_strategy, context):
dist_strategy.model_parallel = True
dist_strategy.model_parallel_configs = {"parallelism": 1, }
# the following function will be used by AMP if both Megatron and AMP are turn on together.
def apply_gradients(self, params_grads):
return self.minimize_impl(params_grads=params_grads)
def minimize_impl(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker._worker_index()]
self.startup_program = startup_program
if startup_program is None:
self.startup_program = fluid.default_startup_program()
# (TODO) check the order of metaoptimizer
# (TODO) check the params_grads
optimize_ops, params_grads = self.inner_opt.minimize(
loss, self.startup_program, parameter_list, no_grad_set)
self.main_program = loss.block.program
self.inner_parallelism = self.inner_parallelism
self.nranks = len(endpoints)
pipeline_helper = ModelParallelHelper(self.role_maker)
pipeline_helper.update_startup_program(self.startup_program,
self.inner_parallelism)
assert self.nranks % self.inner_parallelism == 0
if self.megatron_dp:
# data parallelism
dp_parallelism = self.nranks // self.inner_parallelism
self._transpile_main_program(loss, dp_parallelism)
return optimize_ops, params_grads
def _transpile_main_program(self, loss, dp_parallelism):
self._insert_loss_grad_ops(loss, dp_parallelism)
ring_id = 1
print("ring_id: ", ring_id)
# for ring_id in range(1, dp_parallelism + 1):
self._insert_allreduce_ops(loss, ring_id)
def _insert_loss_grad_ops(self, loss, dp_parallelism):
"""
In order to keep the learning rate consistent in different numbers of
training workers, we scale the loss grad by the number of workers
"""
block = loss.block
for idx, op in reversed(list(enumerate(block.ops))):
if is_loss_grad_op(op):
loss_grad_var = block.vars[op.output_arg_names[0]]
block._insert_op(
idx + 1,
type='scale',
inputs={'X': loss_grad_var},
outputs={'Out': loss_grad_var},
attrs={
'scale': 1.0 / dp_parallelism,
OP_ROLE_KEY: OpRole.Backward
})
def _insert_allreduce_ops(self, loss, ring_id):
block = loss.block
grad = None
for idx, op in reversed(list(enumerate(block.ops))):
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if len(op_role_var) == 0:
continue
assert len(op_role_var) % 2 == 0
offset = idx
for i in range(0, len(op_role_var), 2):
param = block.vars[op_role_var[i]]
grad = block.vars[op_role_var[i + 1]]
#if param.is_distributed:
# continue
if offset == idx:
offset += 1
block._insert_op(
offset,
type='c_sync_calc_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={OP_ROLE_KEY: OpRole.Backward})
offset += 1
block._insert_op(
offset,
type='c_allreduce_sum',
inputs={'X': grad},
outputs={'Out': grad},
attrs={
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward
})
if grad is None:
return
for idx, op in list(enumerate(block.ops)):
if is_optimizer_op(op):
block._insert_op(
idx,
type='c_sync_comm_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward})
break
......@@ -154,8 +154,10 @@ class PipelineOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
super(PipelineOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
# we do not allow meta optimizer to be inner optimizer currently
self.meta_optimizers_white_list = []
self.meta_optimizers_white_list = [
"RecomputeOptimizer",
"AMPOptimizer",
]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
......
......@@ -73,7 +73,7 @@ class FP16Utils(object):
@staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
"""
1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard
1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding
"""
# remove cast
......@@ -81,7 +81,9 @@ class FP16Utils(object):
if not FP16Utils.is_fp32_cast_op(block, op):
continue
output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD")
param_name = output_name.strip(
"@GRAD@MERGED"
) if "@MERGED" in output_name else output_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError("Output 'X' of cast_op must be a grad of"
"model param, but {} is not a grad".format(
......@@ -103,20 +105,35 @@ class FP16Utils(object):
op._rename_input(inf_var_name, inf_var_name + "@sharding")
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
reversed_x = []
reversed_x_paramname = []
for input_name in op.desc.input('X'):
param_name = input_name.strip("@GRAD")
param_name = input_name.strip("@GRAD@MERGED")
if param_name not in shard.global_params:
raise ValueError(
"Input 'X' of check_finite_and_unscale must"
"be grads, but {} is not a grad".format(input_name))
if shard.has_param(param_name):
reversed_x.append(input_name)
reversed_x_paramname.append(param_name)
op.desc.set_input('X', reversed_x)
op.desc.set_output('Out', reversed_x)
# the grad checking should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
should_check_param = set(shard.global_params).intersection(
set([
param
for param, worker_idx in shard.global_param2device.
items() if worker_idx == shard.worker_idx
]))
#assert to_check_param == should_check_param, "amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(
# should_check_param - to_check_param,
# to_check_param - should_check_param)
if update_loss_scaling_op_idx == -1:
return
inf_var = block.var(inf_var_name)
inf_var_fp32 = block.create_var(
inf_var_int32 = block.create_var(
name=inf_var_name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
......@@ -128,32 +145,36 @@ class FP16Utils(object):
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var},
outputs={'Out': inf_var_fp32},
outputs={'Out': inf_var_int32},
attrs={
"in_dtype": inf_var.dtype,
"out_dtype": inf_var_fp32.dtype,
"out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
insert_sync_calc_op(block, update_loss_scaling_op_idx + 1,
[inf_var_fp32])
# this allreduce communication should not overlap with calc
# insert_sync_calc_op(block, update_loss_scaling_op_idx + 1,
# [inf_var_int32])
block._insert_op_without_sync(
update_loss_scaling_op_idx + 2,
update_loss_scaling_op_idx + 1,
type='c_allreduce_max',
inputs={'X': inf_var_fp32},
outputs={'Out': inf_var_fp32},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Optimize})
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_int32},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
ring_id, [inf_var_fp32])
# comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
# ring_id, [inf_var_int32])
block._insert_op_without_sync(
update_loss_scaling_op_idx + 3 + comm_op_num,
update_loss_scaling_op_idx + 2,
type='cast',
inputs={'X': inf_var_fp32},
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_sharding},
attrs={
"in_dtype": inf_var_fp32.dtype,
"in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var_sharding.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
......
......@@ -16,8 +16,8 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
class GradientClipHelper(object):
def __init__(self, sharding_ring_id):
self.sharding_ring_id = sharding_ring_id
def __init__(self, mp_ring_id):
self.mp_ring_id = mp_ring_id
def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \
......@@ -31,6 +31,7 @@ class GradientClipHelper(object):
"""
deperated_vars = set()
deperate_op_idx = set()
reversed_x_paramname = []
for idx, op in enumerate(block.ops):
if not self._is_gradient_clip_op(op):
continue
......@@ -40,15 +41,18 @@ class GradientClipHelper(object):
for input_name in op.desc.input_arg_names():
if input_name in deperated_vars:
deperate_op = True
param_name = input_name.strip("@GRAD")
param_name = input_name.strip("@GRAD@MERGED")
if shard.is_param(param_name) and \
not shard.has_param(param_name):
deperate_op = True
elif shard.is_param(param_name):
reversed_x_paramname.append(param_name)
if deperate_op:
deperate_op_idx.add(idx)
for output_name in op.desc.output_arg_names():
deperated_vars.add(output_name)
if output_name not in op.desc.input_arg_names():
deperated_vars.add(output_name)
if not deperated_vars:
# got no gradient_clip op
......@@ -65,31 +69,47 @@ class GradientClipHelper(object):
for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars:
reversed_inputs.append(input_name)
op.desc.set_input("X", reversed_inputs)
assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0]
block._insert_op_without_sync(
idx + 1,
type='c_sync_comm_stream',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
# this allreduce should not overlap with calc and should be scheduled in calc stream
# block._insert_op_without_sync(
# idx + 1,
# type='c_sync_comm_stream',
# inputs={'X': sum_res},
# outputs={'Out': sum_res},
# attrs={'ring_id': 0,
# OP_ROLE_KEY: OpRole.Optimize})
block._insert_op_without_sync(
idx + 1,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'ring_id': self.sharding_ring_id,
OP_ROLE_KEY: OpRole.Optimize
'ring_id': self.mp_ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})
block._insert_op_without_sync(
idx + 1,
type='c_sync_calc_stream',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={OP_ROLE_KEY: OpRole.Optimize})
# block._insert_op_without_sync(
# idx + 1,
# type='c_sync_calc_stream',
# inputs={'X': sum_res},
# outputs={'Out': sum_res},
# attrs={OP_ROLE_KEY: OpRole.Optimize})
# the grad sum here should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
should_check_param = set(shard.global_params).intersection(
set([
param for param, worker_idx in shard.global_param2device.items()
if worker_idx == shard.worker_idx
]))
assert to_check_param == should_check_param, "amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(
should_check_param - to_check_param,
to_check_param - should_check_param)
for var_name in deperated_vars:
block._remove_var(var_name, sync=False)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole
from paddle.fluid import core, unique_name
class OffloadHelper(object):
cpu_place_type = 0
cuda_place_type = 1
cuda_pinned_place_type = 2
def __init__(self):
pass
"0: dst is on CPUPlace. "
"1: dst is on CUDAPlace. "
"2: dst is on CUDAPinnedPlace. "
def _insert_cast_op(self, block, idx, src_name, dst_name):
src_var = block.var(src_name)
if not block.has_var(dst_name):
block.create_var(
name=dst_name,
shape=src_var.shape,
dtype=core.VarDesc.VarType.FP16,
persistable=True)
dst_var = block.var(dst_name)
assert dst_var.dtype == core.VarDesc.VarType.FP16
block._insert_op_without_sync(
idx,
type='cast',
inputs={'X': src_var},
outputs={'Out': dst_var},
attrs={
'in_dtype': src_var.dtype,
'out_dtype': dst_var.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
def _insert_memcpy_op(self, block, idx, src_name, dst_name, dst_place_type):
src_var = block.var(src_name)
dst_var = block.var(dst_name)
block._insert_op_without_sync(
idx,
type='memcpy',
inputs={'X': src_var},
outputs={'Out': dst_var},
attrs={
'dst_place_type': dst_place_type,
OP_ROLE_KEY: OpRole.Optimize,
})
def _insert_fetch_op(self, block, idx, src_name, dst_name):
self._insert_memcpy_op(block, idx, src_name, dst_name,
OffloadHelper.cuda_place_type)
def _insert_offload_op(self, block, idx, src_name, dst_name):
self._insert_memcpy_op(block, idx, src_name, dst_name,
OffloadHelper.cuda_pinned_place_type)
def _get_offload_var_name(self, name):
return unique_name.generate(name + '@offload')
def _create_offload_var(self, var_name, offload_var_name, blocks):
for block in blocks:
var = block.var(var_name)
var.persistable = False
offload_var = block.create_var(
name=offload_var_name,
shape=var.shape,
dtype=var.dtype,
persistable=True)
def offload_fp32param(self, block, startup_block):
"""
(p_fp16) = cast(p)
(p_fp16_recompute) = cast(p)
(pout,) = adam(p)
===========================>
rename(p_fp16_recompute, p_fp16)
(p,) = prefetch(p@offload)
(pout,) = adam(p)
(p_fp16) = cast(p)
(p@offload) = memcpy(p)
"""
param_to_idx = dict()
param_to_fp16 = dict()
# recompute_var which need rename to fp16_param
fp16_param_to_recompute = dict()
recompute_to_fp16 = dict()
def remove_param(input_name):
param_to_idx.pop(input_name)
if input_name in param_to_fp16:
fp16_param = param_to_fp16.pop(input_name)
if fp16_param in fp16_param_to_recompute:
recompute = fp16_param_to_recompute.pop(fp16_param)
recompute_to_fp16.pop(recompute)
# step1: record param
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'):
param = op.desc.input("Param")[0]
param_to_idx[param] = idx
# step2: remove param which can't offload
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
break
for input_name in op.desc.input_arg_names():
if input_name not in param_to_idx:
continue
# param is real used by fp32 op
if op.type != 'cast':
remove_param(input_name)
continue
# param is only used by cast op,
# which to cast fp32_param to fp16_param
output_name = op.output_arg_names[0]
if 'cast_fp16' not in output_name:
remove_param(input_name)
continue
if 'subprog' not in output_name:
assert output_name == input_name + '.cast_fp16'
assert input_name not in param_to_fp16, \
"There must be only one cast op from fp32 param to fp16 param."
param_to_fp16[input_name] = output_name
else:
# fp16-->recompute_var
assert input_name in param_to_fp16, \
"param must first be cast to fp16"
fp16_param = param_to_fp16[input_name]
fp16_param_to_recompute[fp16_param] = output_name
recompute_to_fp16[output_name] = fp16_param
param_name_to_offload_name = dict()
# step3: main_block add offload, cast op
# change recompute to fp16, remove cast(param) to fp16
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'):
param = op.desc.input("Param")[0]
if param not in param_to_idx: continue
# step3.1: create offload_var
offload_var_name = self._get_offload_var_name(param)
param_name_to_offload_name[param] = offload_var_name
self._create_offload_var(param, offload_var_name,
[block, startup_block])
# step3.2: insert cast op and offload op
self._insert_offload_op(block, idx + 1, param, offload_var_name)
assert param in param_to_fp16
fp16_param_name = param_to_fp16[param]
fp16_param_var = block.var(fp16_param_name)
fp16_param_var.persistable = True
self._insert_cast_op(block, idx + 1, param,
param_to_fp16[param])
# step3.3: insert fetch op
self._insert_fetch_op(block, idx, offload_var_name, param)
continue
# step3.4: remove cast op
if op.type == 'cast':
input_name = op.desc.input_arg_names()[0]
if input_name in param_to_idx:
block._remove_op(idx, sync=False)
continue
# step3.5: change recompute_param to fp16_param
for input_name in op.desc.input_arg_names():
if input_name in recompute_to_fp16:
op._rename_input(input_name, recompute_to_fp16[input_name])
for output_name in op.desc.output_arg_names():
if output_name in recompute_to_fp16:
op._rename_output(output_name,
recompute_to_fp16[output_name])
# step4: remove recompute_param
for name in recompute_to_fp16.keys():
block._remove_var(name, sync=False)
# step5: startup_block add offload
visited_vars = set()
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
continue
if out_name in param_name_to_offload_name:
var_name = out_name
offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1, var_name,
offload_var_name)
self._insert_cast_op(startup_block, idx + 1, var_name,
param_to_fp16[var_name])
visited_vars.add(out_name)
block._sync_with_cpp()
startup_block._sync_with_cpp()
def offload(self, block, startup_block):
"""
(m1, m2) = prefetch(m1@offload, m2@offload)
(m1out, m2out, pout) = adam(m1, m2, p)
(m1@offload, m2@offload) = memcpy(m1, m2)
"""
vars_name_to_offload_name = dict()
# main_block add offload
for idx, op in reversed(list(enumerate(block.ops))):
if not is_optimizer_op(op):
break
vars_name = []
if op.type == "adam":
# {Moment1Out = [''], Moment2Out = [''], ParamOut = ['']} =
# adam(inputs={Moment1 = [''], Moment2 = [''], Param = ['']})
vars_name.append(op.desc.input("Moment1")[0])
vars_name.append(op.desc.input("Moment2")[0])
elif op.type == 'momentum':
pass
elif op.type == 'lars':
pass
elif op.type == 'lamb':
pass
# step1: create and init offload_var
for var_name in vars_name:
assert var_name not in vars_name_to_offload_name
offload_var_name = self._get_offload_var_name(var_name)
vars_name_to_offload_name[var_name] = offload_var_name
self._create_offload_var(var_name, offload_var_name,
[block, startup_block])
# step2: insert offload op
for var_name in vars_name:
offload_var_name = vars_name_to_offload_name[var_name]
self._insert_offload_op(block, idx + 1, var_name,
offload_var_name)
# step3: insert fetch op
for var_name in vars_name:
offload_var_name = vars_name_to_offload_name[var_name]
self._insert_fetch_op(block, idx, offload_var_name, var_name)
# startup_block add offload
visited_vars = set()
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
continue
if out_name in vars_name_to_offload_name:
var_name = out_name
offload_var_name = vars_name_to_offload_name[var_name]
# insert offload op after var is generated
self._insert_offload_op(startup_block, idx + 1, var_name,
offload_var_name)
visited_vars.add(out_name)
block._sync_with_cpp()
startup_block._sync_with_cpp()
......@@ -126,6 +126,9 @@ class ProgramDeps(object):
def should_remove_op(self, op_idx):
op = self._block.ops[op_idx]
# remove check_finite_and_unscale op if its input 'X' is empty
if op.type == 'check_finite_and_unscale' and len(op.input('X')) == 0:
return True
for output_name in op.desc.output_arg_names():
if output_name not in self._should_removed_var:
return False
......
......@@ -28,21 +28,24 @@ def check_broadcast(block):
if the broadcasted var has a fill_constant op, the fill_constant
op should stay forward before the broadcast op, and before a
sync_calc op. Otherwise, raise error.
should ignore and skip broadcast_op of inner_parallelism (e.g. Megatron)
"""
broadcast_vars = {}
for idx, op in enumerate(block.ops):
if op.type == "c_broadcast":
var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name:
if var_name in broadcast_vars:
raise ValueError("var_name areadly exist: {}"
"the old pos is {}, the new pos is {}".
format(var_name, broadcast_vars[var_name][
"broadcast_pos"], idx))
broadcast_vars[var_name] = {
"fill_constant_pos": -1,
"broadcast_pos": idx,
}
if op.all_attrs()["use_calc_stream"] == False:
var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name:
if var_name in broadcast_vars:
raise ValueError("var_name areadly exist: {}"
"the old pos is {}, the new pos is {}".
format(var_name, broadcast_vars[
var_name]["broadcast_pos"], idx))
broadcast_vars[var_name] = {
"fill_constant_pos": -1,
"broadcast_pos": idx,
}
for idx, op in enumerate(block.ops):
if op.type == "fill_constant":
......@@ -61,14 +64,15 @@ def check_broadcast(block):
last_sync_calc_op_idx = idx
continue
if op.type == "c_broadcast":
var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name:
if broadcast_vars[var_name]["fill_constant_pos"] != -1:
assert (last_sync_calc_op_idx != -1)
assert (broadcast_vars[var_name]["fill_constant_pos"] <
last_sync_calc_op_idx)
assert (last_sync_calc_op_idx < idx)
continue
if op.all_attrs()["use_calc_stream"] == False:
var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name:
if broadcast_vars[var_name]["fill_constant_pos"] != -1:
assert (last_sync_calc_op_idx != -1)
assert (broadcast_vars[var_name]["fill_constant_pos"] <
last_sync_calc_op_idx)
assert (last_sync_calc_op_idx < idx)
continue
for input_name in op.desc.input_arg_names():
if input_name in broadcast_vars:
assert (broadcast_vars[input_name]["broadcast_pos"] != -1)
......@@ -78,43 +82,47 @@ def check_broadcast(block):
return
def check_allreduce_sum(block, shard, dp_ring_id=-1):
def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1):
"""
the op order should be:
grad:
- 0: op that generate Var
- 1: sync_calc
- 2: allreduce_sum_sharding
- 2: reduce_sum_sharding (allreduce --> reduce)
- 3: sync_comm
- 4: allreuce_sum_dp (dp_grads)
- 5: sync_comm (dp_grads)
- 6: op that use Var (dp_grads & sum)
should ignore and skip allreduce_op of inner_parallelism (e.g. Megatron)
"""
vars_status = {}
dp_grads_status = {}
idx_last_grad_allreduce = -1
idx_amp_allreduce = -1
idx_gradient_clip_allreduce = -1
for idx, op in enumerate(block.ops):
if op.type == "c_allreduce_sum":
ring_id = op.desc.attr("ring_id")
var_name = op.desc.input_arg_names()[0]
param = var_name.split("@")[0]
if op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
if op.all_attrs()["use_calc_stream"] == False:
ring_id = op.desc.attr("ring_id")
var_name = op.desc.input_arg_names()[0]
param = var_name.split("@")[0]
assert 'sum' in var_name or ("@GRAD" in var_name)
if 'sum' in var_name or (not shard.has_param(param)):
vars_status[var_name] = -1
else:
dp_grads_status[var_name] = -1
assert 'sum' in var_name or ("@GRAD" in var_name)
if 'sum' in var_name or (not shard.has_param(param)):
vars_status[var_name] = -1
else:
dp_grads_status[var_name] = -1
if ring_id != 0:
assert shard.has_param(param)
assert ring_id == dp_ring_id
if ring_id != sharding_ring_id:
assert shard.has_param(param)
assert ring_id == dp_ring_id
if "sum" in var_name:
idx_amp_allreduce = idx
elif "@GRAD":
idx_last_grad_allreduce = idx
if "sum" in var_name:
idx_amp_allreduce = idx
elif "@GRAD":
idx_last_grad_allreduce = idx
if op.type == "c_allreduce_max":
idx_gradient_clip_allreduce = idx
......@@ -129,37 +137,40 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
var_name] == 0:
dp_grads_status[var_name] = 1
elif op.type == "c_allreduce_sum":
var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id")
if ring_id == 0:
if var_name in vars_status:
_status = vars_status[var_name]
else:
_status = dp_grads_status[var_name]
if _status == -1:
raise ValueError("{} is not generated, but you are"
"trying to all-reduce it".format(var_name))
if _status == 0:
raise ValueError("There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op".format(var_name))
assert (_status == 1)
if var_name in vars_status:
vars_status[var_name] = 2
elif op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
if op.all_attrs()["use_calc_stream"] == False:
var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id")
if ring_id == sharding_ring_id:
assert op.type == "c_reduce_sum", "Grad in Sharding group should be reduce rather than allreduce"
if var_name in vars_status:
_status = vars_status[var_name]
else:
_status = dp_grads_status[var_name]
if _status == -1:
raise ValueError("{} is not generated, but you are"
"trying to all-reduce it".format(
var_name))
if _status == 0:
raise ValueError("There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op".format(var_name))
assert (_status == 1)
if var_name in vars_status:
vars_status[var_name] = 2
else:
dp_grads_status[var_name] = 2
else:
dp_grads_status[var_name] = 2
else:
assert ring_id == dp_ring_id
param = var_name.split("@")[0]
assert shard.has_param(param)
assert dp_grads_status[var_name] == 3
dp_grads_status[var_name] = 4
assert ring_id == dp_ring_id
param = var_name.split("@")[0]
assert shard.has_param(param)
assert dp_grads_status[var_name] == 3
dp_grads_status[var_name] = 4
elif op.type == "c_sync_comm_stream":
var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id")
if ring_id == 0:
if ring_id == sharding_ring_id:
for var_name in op.desc.input_arg_names():
if var_name in vars_status:
assert vars_status[var_name] == 2
......@@ -181,6 +192,9 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".format(
input_name))
raise ValueError(
"The reduce output grad [{}] should NOT be be used in Non-root rank.".
format(input_name))
if input_name in dp_grads_status:
if dp_ring_id == -1:
if dp_grads_status[input_name] != 3:
......@@ -225,6 +239,13 @@ def get_valid_op_role(block, insert_idx):
return get_valid_op_role(block, insert_idx + 1)
# if insert_idx >= len(block.ops): return OpRole.Optimize
# if op_role == int(OpRole.Backward): return OpRole.Backward
# if op_role == int(OpRole.Optimize): return OpRole.Optimize
# if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
# return OpRole.Forward
# return get_valid_op_role(block, insert_idx + 1)
def insert_sync_calc_op(block, insert_idx, calc_dep_vars):
"""
......@@ -259,6 +280,9 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
"""
insert sync_comm_op for vars
"""
if len(comm_dep_vars) == 0:
return 0
op_role = get_valid_op_role(block, insert_idx)
block._insert_op_without_sync(
insert_idx,
......@@ -313,6 +337,9 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
"""
_add_allreduce_ops
"""
if len(allreduce_vars) == 0:
return
for var in allreduce_vars:
block._insert_op_without_sync(
insert_idx,
......@@ -325,6 +352,62 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
return
def get_grad_device(grad_name, shard):
assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
grad_name)
base_name = None
# mind the traversal order
possible_suffixes = [
'.cast_fp16@GRAD@MERGED', '.cast_fp16@GRAD', '@GRAD@MERGED', '@GRAD'
]
for suffix in possible_suffixes:
if suffix in grad_name:
base_name = re.sub(suffix, '', grad_name)
break
assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(
base_name)
return shard.global_param2device[base_name]
def insert_reduce_ops(block,
insert_idx,
ring_id,
reduce_vars,
shard,
op_role=OpRole.Backward,
use_calc_stream=False):
"""
_add_allreduce_ops
"""
for var in reduce_vars:
root_id = get_grad_device(var, shard)
assert root_id >= 0, "root id should be a positive int".format(var)
block._insert_op_without_sync(
insert_idx,
type='c_reduce_sum',
inputs={'X': var},
outputs={'Out': var},
attrs={
'ring_id': ring_id,
'root_id': root_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
return
def get_first_check_finite_and_unscale_op_idx(block):
for idx, op in enumerate(block.ops):
if op.type == "check_finite_and_unscale":
return idx
raise ValueError("check_finite_and_unscale does not exist in block")
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
"""
_add_broadcast_ops
......@@ -428,7 +511,7 @@ def comm_analyse(main_program):
count))
def add_sync_comm(program, dist_strategy):
def add_sync_comm(program, nccl_ids):
"""
When clone a test prog by clone from the sharding main prog,
part of the sync_comm op maybe be pruned by mistake, this function
......@@ -438,6 +521,9 @@ def add_sync_comm(program, dist_strategy):
#NOTE (liangjianzhong): only support one comm stream by now, use more than one
# comm streams will cause error. should be revise in future.
assert isinstance(
nccl_ids, list
), "the second argument of this function should be a list of nccl_ids"
block = program.global_block()
not_sync_vars = set([])
for op in block.ops:
......@@ -448,7 +534,7 @@ def add_sync_comm(program, dist_strategy):
for input_name in op.desc.input_arg_names():
not_sync_vars.remove(input_name)
if not_sync_vars:
for nccl_id in range(dist_strategy.nccl_comm_num):
for nccl_id in nccl_ids:
block.append_op(
type='c_sync_comm_stream',
inputs={'X': list(not_sync_vars)},
......@@ -467,6 +553,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
This function handles the model saving for sharding training.
"""
if main_program._pipeline_opt:
main_program = main_program._pipeline_opt['section_program']['program']
def is_opt_vars(var):
# NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer
# now only Momentum and adam are compatible with sharding
......
......@@ -16,14 +16,16 @@ from paddle.fluid import unique_name, core
import paddle.fluid as fluid
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op
from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase
from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper import WeightDecayHelper
from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper
from .sharding.offload_helper import OffloadHelper
from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
import logging
from functools import reduce
......@@ -31,6 +33,8 @@ __all__ = ["ShardingOptimizer"]
class ShardingOptimizer(MetaOptimizerBase):
"""Sharding Optimizer."""
def __init__(self, optimizer):
super(ShardingOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
......@@ -39,6 +43,8 @@ class ShardingOptimizer(MetaOptimizerBase):
"AMPOptimizer",
"LarsOptimizer",
"LambOptimizer",
# "ModelParallelOptimizer",
"PipelineOptimizer",
]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
self._main_program = None
......@@ -51,6 +57,10 @@ class ShardingOptimizer(MetaOptimizerBase):
self._reduced_grads_to_param = {}
self._shard = Shard()
# use sharding as outer parallelism (e.g. inner:Megatron & outer sharding)
self._as_outer_parallelism = False
self._inner_parallelism_size = None
def _can_apply(self):
if not self.role_maker._is_collective:
return False
......@@ -71,6 +81,7 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_program=None,
parameter_list=None,
no_grad_set=None):
"""Implementation of minimize."""
# TODO: (JZ-LIANG) support multiple comm in future
# self._nrings = self.user_defined_strategy.nccl_comm_num
self._nrings_sharding = 1
......@@ -79,20 +90,72 @@ class ShardingOptimizer(MetaOptimizerBase):
"fuse_broadcast_MB"]
self.hybrid_dp = self.user_defined_strategy.sharding_configs[
"hybrid_dp"]
self._as_outer_parallelism = self.user_defined_strategy.sharding_configs[
"as_outer_parallelism"]
self._inner_parallelism_size = int(
self.user_defined_strategy.sharding_configs["parallelism"])
self.use_pipeline = self.user_defined_strategy.sharding_configs[
"use_pipeline"]
self.acc_steps = self.user_defined_strategy.sharding_configs[
"acc_steps"]
self.schedule_mode = self.user_defined_strategy.sharding_configs[
"schedule_mode"]
self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"]
self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[
"pp_allreduce_in_optimize"]
if self.inner_opt is None:
raise ValueError(
"self.inner_opt of ShardingOptimizer should not be None.")
optimize_ops, params_grads = self.inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set)
if self.use_pipeline:
pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt,
self.acc_steps)
main_program = loss.block.program
main_program._pipeline_opt = dict()
main_program._pipeline_opt['schedule_mode'] = self.schedule_mode
main_program._pipeline_opt['pp_bz'] = self.pp_bz
pp_rank = self.role_maker._worker_index() // (
self.user_defined_strategy.sharding_configs[
'sharding_group_size'] * self._inner_parallelism_size)
main_program._pipeline_opt['local_rank'] = pp_rank
main_program._pipeline_opt[
'global_rank'] = self.role_maker._worker_index()
main_program._pipeline_opt['use_sharding'] = True
main_program._pipeline_opt['ring_id'] = 20
optimize_ops, params_grads, program_list, self.pipeline_pair, self.pp_ring_map = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set)
self.pipeline_nodes = len(program_list)
else:
optimize_ops, params_grads = self.inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set)
if startup_program is None:
startup_program = default_startup_program()
main_block = loss.block
if self.use_pipeline:
startup_program = startup_program._pipeline_opt['startup_program']
#main_program = main_program._pipeline_opt['section_program']['program']
print("pp_rank:", pp_rank)
main_program = program_list[pp_rank]['program']
with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
f.writelines(str(main_program))
main_block = main_program.global_block()
new_params_grads = []
for param, grad in params_grads:
if main_block.has_var(param.name):
new_params_grads.append((param, grad))
params_grads = new_params_grads
else:
main_block = loss.block
startup_block = startup_program.global_block()
self._main_program = main_block.program
self._startup_program = startup_program
if self.use_pipeline:
pp_optimizer._rename_gradient_var_name(main_block)
with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
f.writelines(str(main_program))
# step1: set_up
self._set_up(params_grads)
......@@ -105,17 +168,76 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_block._sync_with_cpp()
# step4: insert reduce_sum for grad
insert_scale_loss_grad_ops(
main_block, scale=1.0 / self.role_maker._worker_num())
# grad_scale_coeff = self.role_maker._worker_num()
# if self._as_outer_parallelism:
# grad_scale_coeff = grad_scale_coeff / self._inner_parallelism_size
# insert_scale_loss_grad_ops(main_block, scale=1.0 / grad_scale_coeff)
sharding_group_size = self.user_defined_strategy.sharding_configs[
'sharding_group_size']
insert_scale_loss_grad_ops(main_block, scale=1.0 / sharding_group_size)
main_block._sync_with_cpp()
# step5: remove unneeded ops and vars from block
self._prune_main_program(main_block)
self._prune_startup_program(startup_block)
if self.hybrid_dp:
self._initialization_broadcast(startup_program)
if self.use_pipeline:
# pp_optimizer._rename_gradient_var_name(main_block)
# crop ops
for idx, op in reversed(list(enumerate(main_block.ops))):
if is_update_op(op):
op_role_var = op.attr('op_role_var')
param_name = op_role_var[0]
if not self._shard.has_param(param_name):
main_block._remove_op(idx)
for idx, op in reversed(list(enumerate(main_block.ops))):
if op.type != 'cast': continue
in_name = op.input_arg_names[0]
if in_name not in self._params: continue
#if self._shard.has_param(param_name): continue
if in_name not in main_block.vars:
main_block._remove_op(idx)
accumulated_grad_names = pp_optimizer._accumulate_gradients(
main_block)
# accumulated_grad_names = sorted(accumulated_grad_names)
if self.pp_allreduce_in_optimize:
print("persistable FP32 grad: ")
print(accumulated_grad_names)
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block)
insert_reduce_ops(
main_block,
first_optimize_op_index,
self.sharding_ring_id,
accumulated_grad_names,
self._shard,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True)
main_block._sync_with_cpp()
# TODO(wangxi): add optimize offload
if self.optimize_offload:
logging.info("Sharding with optimize offload !")
offload_helper = OffloadHelper()
offload_helper.offload(main_block, startup_block)
offload_helper.offload_fp32param(main_block, startup_block)
with open("start_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(startup_block.program))
with open("main_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(main_block.program))
# check op dependecy
check_broadcast(main_block)
check_allreduce_sum(main_block, self._shard, self.dp_ring_id)
#check_allreduce_sum(main_block, self._shard, self.sharding_ring_id,
# self.dp_ring_id)
#check_allreduce_sum(main_block, self._shard, self.dp_ring_id)
self._wait()
return optimize_ops, params_grads
......@@ -129,16 +251,72 @@ class ShardingOptimizer(MetaOptimizerBase):
self._nrings_sharding)
# config sharding & dp groups
self._init_comm()
# global
if self._as_outer_parallelism:
print("global_group_endpoints:", self.global_group_endpoints)
print("global_rank:", self.global_rank)
print("global_ring_id:", self.global_group_id)
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.global_group_endpoints, self.global_rank,
self.global_group_id, False)
if self._as_outer_parallelism:
print("mp_group_endpoints:", self.mp_group_endpoints)
print("mp_rank:", self.mp_rank)
print("mp_ring_id:", self.mp_group_id)
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.mp_group_endpoints, self.mp_rank, self.mp_group_id, False)
# sharding
print("sharding_group_endpoints:", self.sharding_group_endpoints)
print("sharding_rank:", self.sharding_rank)
print("sharding_ring_id:", self.sharding_ring_id)
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.sharding_group_endpoints, self.sharding_rank,
self.sharding_ring_id, True)
self.sharding_ring_id, False)
# dp
if self.hybrid_dp:
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, True)
self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, False)
# pp
if self.use_pipeline:
print("pp_group_endpoints:", self.pp_group_endpoints)
print("pp_rank:", self.pp_rank)
print("pp_ring_id:", self.pp_ring_id)
if self.schedule_mode == 0: # GPipe
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.pp_group_endpoints, self.pp_rank, self.pp_ring_id,
False)
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.pp_group_endpoints, self.pp_rank, self.pp_ring_id + 2,
False)
else:
for pair in self.pipeline_pair:
pair_key = pair[0] * 1000 + pair[1]
ring_id = self.pp_ring_map[pair_key]
print("pp pair:{}, ring_id: {}".format(pair, ring_id))
if self.pp_rank not in pair: continue
pp_group_endpoints = [
self.pp_group_endpoints[pair[0]],
self.pp_group_endpoints[pair[1]],
]
if pair[0] < pair[1]:
start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1
else:
start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[
1] - 1
pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
pp_group_endpoints, pp_rank, ring_id, False, False)
startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp()
......@@ -153,10 +331,35 @@ class ShardingOptimizer(MetaOptimizerBase):
self._main_program.global_block())
def _wait(self, ):
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker._worker_index()]
if self.role_maker._worker_index() == 0:
self._collective_helper._wait(current_endpoint, endpoints)
# only the first parallelsm group that init nccl need to be wait.
if self._as_outer_parallelism:
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker._worker_index()]
else:
endpoints = self.sharding_group_endpoints[:]
current_endpoint = self.sharding_group_endpoints[self.sharding_rank]
if self._as_outer_parallelism:
if self.role_maker._worker_index() == 0:
self._collective_helper._wait(current_endpoint, endpoints)
else:
if self.sharding_rank == 0:
self._collective_helper._wait(current_endpoint, endpoints)
# def _wait(self, ):
# # only the first parallelsm group that init nccl need to be wait.
# if self._as_outer_parallelism:
# endpoints = self.role_maker._get_trainer_endpoints()
# else:
# endpoints = self.sharding_group_endpoints[:]
# current_endpoint = endpoints[self.role_maker._worker_index()]
# if self._as_outer_parallelism:
# if self.role_maker._worker_index() == 0:
# self._collective_helper._wait(current_endpoint, endpoints)
# else:
# if self.sharding_rank == 0:
# self._collective_helper._wait(current_endpoint, endpoints)
def _split_program(self, block):
for op_idx, op in reversed(list(enumerate(block.ops))):
......@@ -197,17 +400,22 @@ class ShardingOptimizer(MetaOptimizerBase):
self._main_program.global_block().var(input_name))
# find reduce vars
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if len(op_role_var) != 0:
assert len(op_role_var) % 2 == 0
for i in range(0, len(op_role_var), 2):
param, reduced_grad = op_role_var[i], op_role_var[i + 1]
segment._allreduce_vars.append(reduced_grad)
assert (
reduced_grad not in self._reduced_grads_to_param)
self._reduced_grads_to_param[reduced_grad] = param
if self.use_pipeline and self.pp_allreduce_in_optimize:
# place pipeline gradient allreduce in optimize
pass
else:
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if len(op_role_var) != 0:
assert len(op_role_var) % 2 == 0
for i in range(0, len(op_role_var), 2):
param, reduced_grad = op_role_var[i], op_role_var[
i + 1]
segment._allreduce_vars.append(reduced_grad)
#assert (
# reduced_grad not in self._reduced_grads_to_param)
self._reduced_grads_to_param[reduced_grad] = param
# find cast op
if FP16Utils.is_fp16_cast_op(block, op, self._params):
......@@ -234,9 +442,14 @@ class ShardingOptimizer(MetaOptimizerBase):
"""
weightdecay_helper = WeightDecayHelper()
weightdecay_helper.prune_weight_decay(block, self._shard)
# NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism
# group. and each Data Parallelism group should have its own sync of FoundInfinite
Model_Paramllelism_ring_id = self.sharding_ring_id
if self._as_outer_parallelism:
Model_Paramllelism_ring_id = self.global_group_id
FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param,
self.sharding_ring_id)
gradientclip_helper = GradientClipHelper(self.sharding_ring_id)
Model_Paramllelism_ring_id)
gradientclip_helper = GradientClipHelper(Model_Paramllelism_ring_id)
gradientclip_helper.prune_gradient_clip(block, self._shard)
# build prog deps
......@@ -264,8 +477,13 @@ class ShardingOptimizer(MetaOptimizerBase):
# Prune
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in [
"c_allreduce_sum", "c_sync_comm_stream",
"c_calc_comm_stream", "c_gen_nccl_id", "c_comm_init, c_comm_init_hcom"
"c_allreduce_sum",
"c_sync_comm_stream",
"c_calc_comm_stream",
"c_gen_nccl_id",
"c_comm_init",
'send_v2',
'recv_v2',
]:
pass
elif op.type == "conditional_block":
......@@ -303,15 +521,41 @@ class ShardingOptimizer(MetaOptimizerBase):
program_deps.remove_op(idx)
block._sync_with_cpp()
for idx, op in reversed(list(enumerate(block.ops))):
if op.type == 'concat' and is_optimizer_op(op):
# remove inputs that not on this card
reserved_x = []
for var_name in op.desc.input("X"):
if block.has_var(var_name): reserved_x.append(var_name)
op.desc.set_input('X', reserved_x)
block._sync_with_cpp()
return
def _add_broadcast_allreduce(self, block):
"""
_add_broadcast_allreduce
if combined with pipeline(grad accumulate),
the grad allreduce should be done in optimize role
"""
if len(self._segments) < 1:
return
# sharding
if self.use_pipeline and self.pp_allreduce_in_optimize:
for idx in range(len(self._segments)):
assert len(self._segments[idx]._allreduce_vars) == 0
# fix the _end_idx for segments[-1] if pp is used.
new_end_idx = self._segments[-1]._end_idx
for idx in range(self._segments[-1]._end_idx - 1,
self._segments[-1]._start_idx - 1, -1):
op = block.ops[idx]
if op.type == "fill_constant" or op.type == "sum":
if "MERGED" in op.output_arg_names[0]: new_end_idx = idx + 1
elif op.type == "cast":
if "@TMP" in op.output_arg_names[0]: new_end_idx = idx + 1
self._segments[-1]._end_idx = new_end_idx
if self._segments[-1]._allreduce_vars:
shard_allredue_vars = self._shard.filter_grads(self._segments[-1]
._allreduce_vars)
......@@ -323,9 +567,15 @@ class ShardingOptimizer(MetaOptimizerBase):
insert_sync_comm_ops(block, self._segments[-1]._end_idx,
self.sharding_ring_id,
self._segments[-1]._allreduce_vars)
insert_allreduce_ops(block, self._segments[-1]._end_idx,
self.sharding_ring_id,
self._segments[-1]._allreduce_vars)
# allreduce --> reduce
insert_reduce_ops(
block,
self._segments[-1]._end_idx,
self.sharding_ring_id,
self._segments[-1]._allreduce_vars,
self._shard,
op_role=OpRole.Backward,
use_calc_stream=False)
for idx, segment in reversed(list(enumerate(self._segments))):
allreduce_vars = self._segments[
......@@ -391,6 +641,7 @@ class ShardingOptimizer(MetaOptimizerBase):
fill_constant_vars)
# step4: add `cast` ops
print("cast_ops:", cast_ops)
insert_cast_ops(block, segment._end_idx, cast_ops)
# step5: add broadcast ops
......@@ -404,8 +655,15 @@ class ShardingOptimizer(MetaOptimizerBase):
insert_sync_comm_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars)
# sharding
insert_allreduce_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars)
# allreduce --> reduce
insert_reduce_ops(
block,
segment._start_idx,
self.sharding_ring_id,
allreduce_vars,
self._shard,
op_role=OpRole.Backward,
use_calc_stream=False)
block._sync_with_cpp()
......@@ -459,6 +717,7 @@ class ShardingOptimizer(MetaOptimizerBase):
def _init_comm(self):
if self.hybrid_dp:
assert self._as_outer_parallelism == False, "hybrid dp is conflict when using sharding as outer parallelism"
self.sharding_group_size = self.user_defined_strategy.sharding_configs[
"sharding_group_size"]
self.sharding_ring_id = 0
......@@ -476,6 +735,7 @@ class ShardingOptimizer(MetaOptimizerBase):
ep for idx, ep in enumerate(self.endpoints)
if (idx % self.sharding_group_size) == self.sharding_rank
]
assert self.global_word_size > self.sharding_group_size, \
"global_word_size: {} should be larger than sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size)
assert self.global_word_size % self.sharding_group_size == 0, \
......@@ -485,30 +745,215 @@ class ShardingOptimizer(MetaOptimizerBase):
self.global_word_size,
self.sharding_group_size,
self.dp_group_size)
self.pp_ring_id = -1
self.pp_rank = -1
self.pp_group_size = None
self.pp_group_endpoints = None
# sharding parallelism is the only model parallelism in the current setting
self.mp_group_id = self.sharding_ring_id
self.mp_rank = self.sharding_rank
self.mp_group_size = self.sharding_group_size
self.mp_group_endpoints = self.sharding_group_endpoints[:]
logging.info("Using Sharing&DP mode !")
else:
self.sharding_ring_id = 0
self.sharding_rank = self.global_rank
self.sharding_group_size = self.role_maker._worker_num()
self.sharding_group_endpoints = self.endpoints
if self._as_outer_parallelism and not self.use_pipeline:
self.sharding_ring_id = 1
assert self.global_word_size > self._inner_parallelism_size, \
"global_word_size: {} should be larger than inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size)
assert self.global_word_size % self._inner_parallelism_size == 0, \
"global_word_size: {} should be divisible to the inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size)
self.sharding_rank = self.global_rank // self._inner_parallelism_size
self.sharding_group_size = self.role_maker._worker_num(
) // self._inner_parallelism_size
_offset = self.global_rank % self._inner_parallelism_size
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if idx % self._inner_parallelism_size == _offset
]
# the current entire model parallelism group is the combination of innert & sharding parallelism
self.mp_group_id = 2
self.mp_rank = self.global_rank
self.mp_group_size = self.role_maker._worker_num()
self.mp_group_endpoints = self.endpoints[:]
logging.info("Using Sharing as Outer parallelism mode !")
# print(
# "init the nccl comm for megatron paramllelism, this should be done in Megatron Metaoptimizer"
# )
# partition_idx = self.global_rank // self._inner_parallelism_size
# magetron_endpoints = self.endpoints[
# partition_idx * self._inner_parallelism_size:partition_idx *
# self._inner_parallelism_size + self._inner_parallelism_size]
# magetron_rank = self.global_rank % self._inner_parallelism_size
# self._collective_helper._init_communicator(
# program=self._startup_program,
# current_endpoint=self.current_endpoint,
# endpoints=magetron_endpoints,
# rank=magetron_rank,
# ring_id=0,
# wait_port=True)
# logging.info("megatron group size: {}".format(
# self._inner_parallelism_size))
# logging.info("megatron rank: {}".format(magetron_rank))
# logging.info("megatron endpoints: {}".format(
# magetron_endpoints))
if self.use_pipeline:
if self._inner_parallelism_size == 1:
self.sharding_ring_id = 0
self.sharding_group_size = self.user_defined_strategy.sharding_configs[
'sharding_group_size']
self.sharding_rank = self.global_rank % self.sharding_group_size
assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num(
)
self.pp_ring_id = 20
self.pp_rank = self.global_rank // (
self.sharding_group_size * self._inner_parallelism_size)
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx // self.sharding_group_size) == self.pp_rank
]
self.pp_group_size = self.pipeline_nodes
self.pp_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx % self.sharding_group_size
) == self.sharding_rank
]
else:
self.mp_group_id = 0
self.sharding_ring_id = 1
self.pp_ring_id = 20
self.mp_rank = self.global_rank % self._inner_parallelism_size
self.mp_group = self.global_rank // self._inner_parallelism_size
self.mp_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if idx // self._inner_parallelism_size == self.mp_group
]
print("megatron_group_endpoints:", self.mp_group_endpoints)
print("megatron_rank:", self.mp_rank)
# self.cards_per_node = 8
self.sharding_group_size = self.user_defined_strategy.sharding_configs[
'sharding_group_size']
self.sharding_rank = (
self.global_rank //
self._inner_parallelism_size) % self.sharding_group_size
self.sharding_group_id = self.global_rank // (
self._inner_parallelism_size * self.sharding_group_size)
self.megatron_rank = self.global_rank % self._inner_parallelism_size
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx // (self._inner_parallelism_size *
self.sharding_group_size)
) == self.sharding_group_id and idx %
self._inner_parallelism_size == self.megatron_rank
]
print("sharding_endpoint:", self.sharding_group_endpoints)
print("sharding_rank:", self.sharding_rank)
assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num(
)
self.pp_rank = self.global_rank // (
self.sharding_group_size *
self._inner_parallelism_size) % self.pipeline_nodes
offset = self.sharding_group_size * self._inner_parallelism_size
# TODO: Adjust for dp
idx_with_pp_0 = self.global_rank % (
self.sharding_group_size * self._inner_parallelism_size)
self.pp_group_endpoints = []
for i in range(self.pipeline_nodes):
self.pp_group_endpoints.append(self.endpoints[
idx_with_pp_0])
idx_with_pp_0 += offset
print("pp_group_endpoints:", self.pp_group_endpoints)
print("pp_rank:", self.pp_rank)
#self.pp_group_endpoints = [
# ep for idx, ep in enumerate(self.endpoints)
# if (idx % self.sharding_group_size) == self.sharding_rank
#]
self.global_group_id = 3
self.global_rank = self.global_rank
self.global_group_size = self.role_maker._worker_num()
self.global_group_endpoints = self.endpoints[:]
logging.info("Using Sharing as Outer parallelism mode !")
self.dp_ring_id = -1
self.dp_rank = -1
self.dp_group_size = None
self.dp_group_endpoints = None
logging.info("Using Sharing with pipeline !")
#else:
# self.sharding_ring_id = 0
# self.sharding_rank = self.global_rank
# self.sharding_group_size = self.role_maker._worker_num()
# self.sharding_group_endpoints = self.endpoints
# # sharding parallelism is the only model parallelism in the current setting
# self.mp_group_id = self.sharding_ring_id
# self.mp_rank = self.sharding_rank
# self.mp_group_size = self.sharding_group_size
# self.mp_group_endpoints = self.sharding_group_endpoints[:]
# logging.info("Using Sharing alone mode !")
self.dp_ring_id = -1
self.dp_rank = -1
self.dp_group_size = None
self.dp_group_endpoints = None
#self.pp_ring_id = -1
#self.pp_rank = -1
#self.pp_group_size = None
#self.pp_group_endpoints = None
#self.dp_ring_id = -1
#self.dp_rank = -1
#self.dp_group_size = None
#self.dp_group_endpoints = None
logging.info("Using Sharing alone mode !")
logging.info("global word size: {}".format(self.global_word_size))
logging.info("global rank: {}".format(self.global_rank))
logging.info("sharding group_size: {}".format(self.sharding_group_size))
logging.info("sharding rank: {}".format(self.sharding_rank))
logging.info("dp group size: {}".format(self.dp_group_size))
logging.info("dp rank: {}".format(self.dp_rank))
logging.info("current endpoint: {}".format(self.current_endpoint))
logging.info("sharding group endpoints: {}".format(
self.sharding_group_endpoints))
logging.info("dp group endpoints: {}".format(self.dp_group_endpoints))
logging.info("global word endpoints: {}".format(self.endpoints))
#logging.info("global word size: {}".format(self.global_word_size))
#logging.info("global rank: {}".format(self.global_rank))
#logging.info("sharding group_size: {}".format(self.sharding_group_size))
#logging.info("sharding rank: {}".format(self.sharding_rank))
#logging.info("current model parallelism group_size: {}".format(
# self.mp_group_size))
#logging.info("current model parallelism rank: {}".format(self.mp_rank))
#logging.info("dp group size: {}".format(self.dp_group_size))
#logging.info("dp rank: {}".format(self.dp_rank))
#logging.info("current endpoint: {}".format(self.current_endpoint))
#logging.info("global word endpoints: {}".format(self.endpoints))
#logging.info("sharding group endpoints: {}".format(
# self.sharding_group_endpoints))
#logging.info("current model parallelism group endpoints: {}".format(
# self.mp_group_endpoints))
#logging.info("dp group endpoints: {}".format(self.dp_group_endpoints))
return
def _initialization_broadcast(self, startup_prog):
"""
this funtion is to ensure the initialization between dp group to be
identical when hybrid-dp is used.
"""
block = startup_prog.global_block()
params = []
for param in block.iter_parameters():
params.append(param)
block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': self.dp_ring_id,
'root': 0,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_sync_comm_stream',
inputs={'X': params},
outputs={'Out': params},
attrs={'ring_id': self.dp_ring_id,
OP_ROLE_KEY: OpRole.Forward})
......@@ -115,7 +115,7 @@ class ProgramStats(object):
updated_min_idx = min_idx
while idx_ > pre_segment_end_idx:
if is_amp_cast(self.ops[idx_]):
_logger.debug("found amp-cast op: {}, : {}".format(self.ops[
_logger.info("found amp-cast op: {}, : {}".format(self.ops[
idx_].desc.type(), self.ops[idx_].desc.input_arg_names()[
0]))
updated_min_idx = idx_
......@@ -155,7 +155,7 @@ class ProgramStats(object):
sorted_checkpoints = []
for name in checkpoints_name:
if name not in self.var_op_deps:
_logger.debug(
_logger.info(
"Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program."
% name)
elif self.var_op_deps[name]["var_as_output_ops"] == []:
......@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc)
return result_descs
......@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block):
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc)
return result_descs
......@@ -784,7 +788,6 @@ def _append_backward_ops_with_checkpoints_(
start_idx = 0
pre_segment_end_idx = -1
while True:
_logger.debug("FW op range[0] - [{}]".format(len(ops)))
if start_idx >= len(checkpoints_name) - 1:
break
# min_idx: checkpoint_1' s input op
......@@ -797,6 +800,9 @@ def _append_backward_ops_with_checkpoints_(
min_idx = program_stat._update_segment_start(
min_idx, pre_segment_end_idx)
segments.append([min_idx, max_idx + 1])
else:
_logger.info("Could not recompute op range [{}] - [{}] ".format(
min_idx, max_idx + 1))
start_idx += 1
......@@ -806,15 +812,15 @@ def _append_backward_ops_with_checkpoints_(
recompute_segments = segments
for i, (idx1, idx2) in enumerate(recompute_segments):
_logger.debug("recompute segment[{}]".format(i))
_logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
_logger.info("recompute segment[{}]".format(i))
_logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
), ops[idx1].desc.input_arg_names()))
_logger.debug("segment end op: [{}]: [{}]".format(ops[
_logger.info("segment end op: [{}]: [{}]".format(ops[
idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names()))
_logger.debug("recompute segment[{}]".format(i))
_logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
_logger.info("recompute segment[{}]".format(i))
_logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
), ops[idx1].desc.input_arg_names()))
_logger.debug("segment end op: [{}]: [{}]".format(ops[
_logger.info("segment end op: [{}]: [{}]".format(ops[
idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names()))
# 2) go through all forward ops and induct all variables that will be hold in memory
......@@ -825,9 +831,9 @@ def _append_backward_ops_with_checkpoints_(
program_stat.get_out_of_subgraph_vars(segment[0], segment[1]))
cross_vars = set(vars_should_be_hold) - set(checkpoints_name)
_logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
_logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
len(cross_vars), cross_vars))
_logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
_logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
len(cross_vars), cross_vars))
# b. output of seed op should be kept in memory
......@@ -843,6 +849,7 @@ def _append_backward_ops_with_checkpoints_(
vars_in_memory = vars_should_be_hold + checkpoints_name
max_calculated_op_position = len(ops)
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
if recompute_segments == []:
gap_ops = ops[0:max_calculated_op_position]
for op in reversed(gap_ops):
......@@ -852,6 +859,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var)
......@@ -866,6 +878,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var)
......@@ -888,6 +905,18 @@ def _append_backward_ops_with_checkpoints_(
continue
if name not in var_name_dict:
var_name_dict[name] = name + var_suffix
# we should create the rename var in subprog, otherwise its VarType will be BOOL
block.create_var(
name=var_name_dict[name],
shape=block.program.global_block().var(name).shape,
dtype=block.program.global_block().var(name).dtype,
type=block.program.global_block().var(name).type,
persistable=block.program.global_block().var(
name).persistable,
stop_gradient=block.program.global_block().var(name)
.stop_gradient)
# 3.a. add ops in current recompute_segment as forward recomputation ops
buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block,
vars_in_memory)
......
......@@ -489,9 +489,14 @@ class ClipGradByGlobalNorm(ClipGradBase):
continue
with p.block.program._optimized_guard([p, g]):
new_grad = layers.elementwise_mul(x=g, y=scale_var)
param_new_grad_name_dict[p.name] = new_grad.name
params_and_grads.append((p, new_grad))
p.block.append_op(
type='elementwise_mul',
inputs={'X': g,
'Y': scale_var},
outputs={'Out': g})
param_new_grad_name_dict[p.name] = p.name
params_and_grads.append((p, p))
_correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
return params_and_grads
......
......@@ -123,7 +123,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
outputs={"Out": out_var},
attrs={
"in_dtype": in_var.dtype,
"out_dtype": out_var.dtype
"out_dtype": out_var.dtype,
"op_device": op.attr("op_device")
})
num_cast_ops += 1
_rename_arg(op, in_var.name, out_var.name)
......@@ -171,8 +172,11 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
type="cast",
inputs={"X": target_var},
outputs={"Out": cast_var},
attrs={"in_dtype": target_var.dtype,
"out_dtype": cast_var.dtype})
attrs={
"in_dtype": target_var.dtype,
"out_dtype": cast_var.dtype,
"op_device": op.attr("op_device")
})
num_cast_ops += 1
op_var_rename_map[block.idx][target_var.name] = cast_var.name
......
......@@ -413,6 +413,9 @@ class Section(DeviceWorker):
section_param = trainer_desc.section_param
section_param.num_microbatches = pipeline_opt["num_microbatches"]
section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"]
section_param.pipeline_stage = pipeline_opt["pipeline_stage"]
section_param.num_pipeline_stages = pipeline_opt["num_pipeline_stages"]
section_param.schedule_mode = pipeline_opt["schedule_mode"]
cfg = section_param.section_config
program = pipeline_opt["section_program"]
cfg.program_desc.ParseFromString(program["program"]._get_desc()
......
......@@ -19,6 +19,7 @@ import six
import os
import logging
from collections import defaultdict
import time
import paddle
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
......@@ -3759,15 +3760,21 @@ class PipelineOptimizer(object):
def __init__(self, optimizer, num_microbatches=1, start_cpu_core_id=0):
if framework.in_dygraph_mode():
raise Exception("In dygraph, don't support PipelineOptimizer.")
if not isinstance(optimizer, Optimizer) and not isinstance(
optimizer, paddle.optimizer.Optimizer) and not isinstance(
optimizer, paddle.fluid.contrib.mixed_precision.decorator.
OptimizerWithMixedPrecision):
supported_opt_types = (Optimizer, paddle.fluid.contrib.mixed_precision.
decorator.OptimizerWithMixedPrecision)
if not isinstance(optimizer, supported_opt_types):
raise ValueError("The 'optimizer' parameter for "
"PipelineOptimizer must be an instance of "
"Optimizer, but the given type is {}.".format(
type(optimizer)))
"PipelineOptimizer must be an instance of one of "
"{}, but the type is {}.".format(
supported_opt_types, type(optimizer)))
self._optimizer = optimizer
# Get the original optimizer defined by users, such as SGD
self._origin_optimizer = self._optimizer
while hasattr(self._origin_optimizer, "inner_opt"):
self._origin_optimizer = self._origin_optimizer.inner_opt
assert num_microbatches >= 1, (
"num_microbatches must be a positive value.")
self._num_microbatches = num_microbatches
......@@ -3781,52 +3788,147 @@ class PipelineOptimizer(object):
self._op_role_var_key = op_maker.kOpRoleVarAttrName()
self._op_device_key = op_maker.kOpDeviceAttrName()
self._param_device_map = None
self._pipeline_pair = []
self._pp_ring_map = dict()
def _create_vars(self, block, ori_block):
# Create vars for block, copied from main_program's global block
# Create vars for block, copied from ori_block
used_var_set = set()
for op_idx in range(block.desc.op_size()):
op_desc = block.desc.op(op_idx)
vars = op_desc.input_arg_names() + op_desc.output_arg_names()
# Whether to insert allreduce_sum or allreduce_max op?
# For amp and global gradient clip strategies, we should
# get the global infomation, so allreduce op is needed.
should_insert = False
op = block.ops[op_idx]
# For op process vars on all devices, remove its input
# vars not in this block
reserved_x = []
if op.type == 'reduce_any' and self._is_optimize_op(op):
should_insert = True
if op.type == 'concat' and self._is_optimize_op(op):
for input_name in op.desc.input("X"):
if block._find_var_recursive(input_name):
reserved_x.append(input_name)
op.desc.set_input('X', reserved_x)
print('reserved_x:', reserved_x)
if op.type == 'update_loss_scaling':
for input_name in op.desc.input("X"):
if block._find_var_recursive(input_name):
reserved_x.append(input_name)
op.desc.set_input('X', reserved_x)
op.desc.set_output('Out', reserved_x)
if op.type == 'sum' and self._is_gradient_clip_op(op):
for input_name in op.desc.input("X"):
if block._find_var_recursive(input_name):
reserved_x.append(input_name)
op.desc.set_input('X', reserved_x)
should_insert = True
vars = op.desc.input_arg_names() + op.desc.output_arg_names()
for var in vars:
# a var whose name contains "blocking_queue"
# only exists in startup program
if var in used_var_set or "_blocking_queue" in var:
continue
if var in used_var_set or "_blocking_queue" in var: continue
used_var_set.add(var)
if block._find_var_recursive(str(var)): continue
source_var = ori_block._var_recursive(str(var))
if source_var.type == core.VarDesc.VarType.READER:
block.create_var(
dest_var = block.create_var(
name=var,
type=core.VarDesc.VarType.READER,
persistable=source_var.persistable)
else:
block._clone_variable(source_var, False)
dest_var = block._clone_variable(source_var, False)
dest_var.stop_gradient = source_var.stop_gradient
continue
# TODO add allreduce_max when without sharding
if not should_insert: continue
out_name = op.desc.output_arg_names()[0]
out_var = block.var(out_name)
offset = 0
if op.type == "reduce_any":
# cast the bool var to int32 to use allreduce op
temp_var_name = unique_name.generate(out_name + "_cast_int32")
temp_var = block.create_var(
name=temp_var_name, shape=[1], dtype="int32")
block._insert_op(
op_idx + 1 + offset,
type='cast',
inputs={'X': out_var},
outputs={'Out': temp_var},
attrs={
'in_dtype': out_var.dtype,
'out_dtype': temp_var.dtype,
self._op_role_key:
core.op_proto_and_checker_maker.OpRole.Optimize
})
offset += 1
# block._insert_op(
# op_idx + 1 + offset,
# type='c_sync_calc_stream',
# inputs={'X': temp_var if op.type == "reduce_any" else out_var},
# outputs={
# 'Out': temp_var if op.type == "reduce_any" else out_var
# },
# attrs={
# OP_ROLE_KEY:
# core.op_proto_and_checker_maker.OpRole.Optimize,
# })
# offset += 1
block._insert_op(
op_idx + 1 + offset,
type='c_allreduce_max'
if op.type == "reduce_any" else 'c_allreduce_sum',
inputs={'X': temp_var if op.type == "reduce_any" else out_var},
outputs={
'Out': temp_var if op.type == "reduce_any" else out_var
},
attrs={
'ring_id': self.ring_id,
self._op_role_key:
core.op_proto_and_checker_maker.OpRole.Optimize,
'use_calc_stream': True
})
offset += 1
# block._insert_op(
# # op_idx + 1 + extra_index,
# op_idx + 1 + offset,
# type='c_sync_comm_stream',
# inputs={'X': temp_var if op.type == "reduce_any" else out_var},
# outputs={
# 'Out': temp_var if op.type == "reduce_any" else out_var
# },
# attrs={
# 'ring_id': self.ring_id,
# OP_ROLE_KEY:
# core.op_proto_and_checker_maker.OpRole.Optimize,
# })
# offset += 1
if op.type == "reduce_any":
block._insert_op(
op_idx + 1 + offset,
type='cast',
inputs={'X': temp_var},
outputs={'Out': out_var},
attrs={
'in_dtype': temp_var.dtype,
'out_dtype': out_var.dtype,
self._op_role_key:
core.op_proto_and_checker_maker.OpRole.Optimize
})
def _is_loss_grad_op(self, op):
if self._op_role_key not in op.attr_names:
return False
op_role = int(op.all_attrs()[self._op_role_key])
assert self._op_role_key in op.attr_names
op_role = int(op.attr(self._op_role_key))
return op_role & int(self._op_role.Backward) and op_role & int(
self._op_role.Loss)
def _is_backward_op(self, op):
return self._op_role_key in op.attr_names and int(op.all_attrs()[
self._op_role_key]) & int(self._op_role.Backward)
def _is_optimize_op(self, op):
return self._op_role_key in op.attr_names and int(op.all_attrs()[
self._op_role_key]) & int(self._op_role.Optimize)
def _is_update_op(self, op):
return 'Param' in op.input_names and 'Grad' in op.input_names and (
"LearningRate" in op.input_names)
def _split_program(self, main_program, devices):
"""
Split a program into sections according to devices that ops run on.
The ops of the role LRSched are copied to all sections.
The op whose op_device attr is "gpu:all" is copied to all sections.
Args:
main_program (Program): the main program
......@@ -3842,27 +3944,20 @@ class PipelineOptimizer(object):
block = main_program.block(0)
for op in block.ops:
device = op.attr(self._op_device_key)
op_role = op.attr(self._op_role_key)
if int(op_role) & int(self._op_role.LRSched):
# Copy ops of the role LRSched to all sections.
for device in device_program_map.keys():
program = device_program_map[device]
op_desc = op.desc
ap_op = program["program"].block(0).desc.append_op()
ap_op.copy_from(op_desc)
# ap_op._set_attr(self._op_device_key, "")
elif op.type == "create_py_reader" or op.type == "read" or op.type == "create_double_buffer_reader":
# Copy read related ops to all section to make them exit after each epoch.
# Copy ops whose op_device set to "gpu:all" to all sections.
if device == "gpu:all":
for device in device_program_map.keys():
program = device_program_map[device]
op_desc = op.desc
ap_op = program["program"].block(0).desc.append_op()
ap_op.copy_from(op_desc)
ap_op._set_attr(self._op_device_key, "")
else:
program = device_program_map[device]
op_desc = op.desc
ap_op = program["program"].block(0).desc.append_op()
ap_op.copy_from(op_desc)
ap_op._set_attr(self._op_device_key, "")
for key in devices:
program = device_program_map[key]
......@@ -3921,6 +4016,11 @@ class PipelineOptimizer(object):
var_name as output.
var_name (string): Variable name.
"""
# To skip the cast op added by amp which has no op_device set
if '.cast_fp32' in var_name:
var_name = var_name.replace('.cast_fp32', '')
if '.cast_fp16' in var_name:
var_name = var_name.replace('.cast_fp16', '')
post_op = []
before = True
for op in ops:
......@@ -3949,7 +4049,7 @@ class PipelineOptimizer(object):
"""
prev_op = []
for op in ops:
if op.type == 'send_v2' or op.type == 'recv_v2':
if op.type == 'send_v2' or op.type == 'recv_v2' or op.type == 'c_broadcast':
continue
if op == cur_op:
break
......@@ -3964,11 +4064,8 @@ class PipelineOptimizer(object):
return None
def _rename_arg(self, op, old_name, new_name):
op_desc = op.desc
if isinstance(op_desc, tuple):
op_desc = op_desc[0]
op_desc._rename_input(old_name, new_name)
op_desc._rename_output(old_name, new_name)
op._rename_input(old_name, new_name)
op._rename_output(old_name, new_name)
def _create_var(self, block, ref_var, name):
"""
......@@ -3982,9 +4079,10 @@ class PipelineOptimizer(object):
dtype=ref_var.dtype,
type=ref_var.type,
lod_level=ref_var.lod_level,
persistable=False,
is_data=False,
persistable=ref_var.persistable,
is_data=ref_var.is_data,
need_check_feed=ref_var.desc.need_check_feed())
new_var.stop_gradient = ref_var.stop_gradient
return new_var
def _get_data_var_info(self, block):
......@@ -4037,6 +4135,7 @@ class PipelineOptimizer(object):
if not var_name in first_block.vars:
self._create_var(first_block, main_var, var_name)
dev_index = int(device.split(':')[1])
print("dev_index:", dev_index)
first_block._insert_op(
index=insert_index,
type='send_v2',
......@@ -4044,8 +4143,11 @@ class PipelineOptimizer(object):
attrs={
self._op_device_key: first_dev_spec,
self._op_role_key: self._op_role.Forward,
'use_calc_stream': True,
'use_calc_stream': False,
'peer': dev_index,
#'ring_id': self.ring_id,
'ring_id': self.ring_id
if dev_index > first_dev_index else self.ring_id + 2,
})
# Get the device that that data on
assert device in devices
......@@ -4070,6 +4172,21 @@ class PipelineOptimizer(object):
self._op_role_key: self._op_role.Forward,
'peer': first_dev_index,
'use_calc_stream': True,
#'ring_id': self.ring_id,
'ring_id': self.ring_id
if first_dev_index < dev_index else self.ring_id + 2,
})
block._insert_op(
index=index + 1,
type='c_sync_comm_stream',
inputs={'X': [new_var]},
outputs={'Out': [new_var]},
attrs={
self._op_device_key: device,
self._op_role_key: self._op_role.Forward,
#'ring_id': self.ring_id,
'ring_id': self.ring_id
if first_dev_index > dev_index else self.ring_id + 2,
})
def _strip_grad_suffix(self, name):
......@@ -4085,79 +4202,190 @@ class PipelineOptimizer(object):
"""
return name + core.grad_var_suffix()
def _add_opdevice_attr_for_regularization_clip(self, block):
def _is_forward_op(self, op):
"""
Add op_device attribute for regulization and clip ops.
Is the op_role attribute of a op is Forward.
"""
for op in block.ops:
# role for regularization and clip ops is optimize
if int(op.attr(self._op_role_key)) != int(self._op_role.Optimize):
continue
if op.has_attr(self._op_device_key) and (
op.attr(self._op_device_key) != ""):
continue
assert self._op_role_var_key in op.attr_names
op_role_var = op.all_attrs()[self._op_role_var_key]
assert len(op_role_var) == 2
param_name = op_role_var[0]
device = self._param_device_map[param_name]
op._set_attr(self._op_device_key, device)
assert self._op_role_key in op.attr_names
return int(op.attr(self._op_role_key)) == int(self._op_role.Forward)
def _is_backward_op(self, op):
"""
Is the op_role attribute of a op is Backward.
"""
assert self._op_role_key in op.attr_names
return int(op.attr(self._op_role_key)) == int(self._op_role.Backward)
def _add_default_opdevice_attr(self, block):
def _is_loss_op(self, op):
"""
1. Add default op_device attribute for lr-related ops.
The default value is the one that of the first place.
2. Add default op_device attribute for sum ops added during
backward. For these ops, we set the op_device attribute
as the one of its post op, i.e, which op has the output of the
sum op as an input.
Is the op_role attribute of a op is Loss.
"""
first_devcie = ""
assert self._op_role_key in op.attr_names
return int(op.attr(self._op_role_key)) == int(self._op_role.Loss)
# Get the device spec of the first place.
# device_spec: 'cpu' for cpu device and 'gpu:id' for gpu device,
# e.g. 'gpu:0', 'gpu:1', etc.
for op in block.ops:
if op.has_attr(self._op_device_key) and (
op.attr(self._op_device_key) != ""):
first_device = op.attr(self._op_device_key)
break
assert first_device
first_device_type = first_device.split(":")[0]
assert first_device_type == "gpu"
def _is_optimize_op(self, op):
"""
Is the op_role attribute of a op is Optimize.
"""
assert self._op_role_key in op.attr_names
return int(op.attr(self._op_role_key)) == int(self._op_role.Optimize)
# set op_device attr for lr-related ops
lrsched_role = int(self._op_role.LRSched)
for op in block.ops:
if not op.has_attr(self._op_device_key) or (
op.attr(self._op_device_key) == ""):
if op.type == "sum":
# For sum ops that compute the sum of @RENAMED@ vars
for name in op.desc.input_arg_names():
assert '@RENAME@' in name
assert len(op.desc.output_arg_names()) == 1
out_name = op.desc.output_arg_names()[0]
post_op = self._find_post_op(block.ops, op, out_name)
device = post_op.attr(self._op_device_key)
assert device
op._set_attr(self._op_device_key, device)
continue
def _is_lrsched_op(self, op):
"""
Is the op_role attribute of a op is LRSched.
"""
assert self._op_role_key in op.attr_names
return int(op.attr(self._op_role_key)) == int(self._op_role.LRSched)
def _is_update_op(self, op):
"""
Is the op updates the parameter using gradient.
"""
return 'Param' in op.input_names and 'Grad' in op.input_names and (
"LearningRate" in op.input_names)
def _get_op_device_attr(self, op):
"""
Get the op_device attribute of a op.
"""
device = op.attr(self._op_device_key) \
if op.has_attr(self._op_device_key) else None
if device:
assert device[0:3] == 'gpu', "Now, only gpu devices are " \
"supported in pipeline parallemism."
return device
def _add_op_device_attr_for_op(self, op, idx, block):
"""
Add op_device attrribute for ops that have not that attribute set.
assert op.attr(self._op_role_key) == lrsched_role, (
"Op whose op_device attr has not been set for pipeline"
" must be of the role LRSched.")
op._set_attr(self._op_device_key, first_device)
We use "gpu:all" to represent the op should be put on all
sub-programs, such as lr-related ops. Note that: "gpu:all"
is only used by pipeline as an indicator.
"""
lrsched_role = int(self._op_role.LRSched)
if op.attr(self._op_role_key) == lrsched_role:
# For LRSched ops, we should put them on all sub-programs to
# make sure each sub-program update the lr correctly
op._set_attr(self._op_device_key, "gpu:all")
elif op.type == "sum" and self._is_backward_op(op):
# For sum ops that compute the sum of @RENAMED@ vars
for name in op.desc.input_arg_names():
assert '@RENAME@' in name, \
"The op must be sum used to accumulate renamed vars."
assert len(op.desc.output_arg_names()) == 1
out_name = op.desc.output_arg_names()[0]
post_op = self._find_post_op(block.ops, op, out_name)
assert post_op.has_attr(
'op_device'), "{} has no op_device attr for var {}".format(
post_op.type, out_name)
device = post_op.attr(self._op_device_key)
assert device, "The post op must have op_device set."
op._set_attr(self._op_device_key, device)
elif (op.type == "cast" or
op.type == "scale") and self._is_backward_op(op):
prev_op = self._find_real_prev_op(block.ops, op,
op.desc.input("X")[0])
op._set_attr('op_device', prev_op.attr('op_device'))
elif op.type == "memcpy" and not self._is_optimize_op(op):
assert len(op.input_arg_names) == 1 and len(
op.output_arg_names) == 1
input_name = op.input_arg_names[0]
output_name = op.output_arg_names[0]
if '@Fetch' in output_name:
post_op = self._find_post_op(block.ops, op, output_name)
op._set_attr('op_device', post_op.attr('op_device'))
else:
prev_op = self._find_real_prev_op(block.ops, op,
op.desc.input("X")[0])
op._set_attr('op_device', prev_op.attr('op_device'))
elif self._is_loss_op(op):
# For loss * loss_scaling op added by AMP
offset = 1
while (not block.ops[idx + offset].has_attr(self._op_device_key) or
not block.ops[idx + offset].attr(self._op_device_key)):
offset += 1
# assert block.ops[idx + 1].type == "fill_constant"
# assert block.ops[idx + 2].type == "elementwise_mul_grad"
# assert block.ops[idx + 3].type == "elementwise_add_grad"
# assert block.ops[idx + 4].type == "mean_grad"
# device = block.ops[idx + 4].attr(self._op_device_key)
device = block.ops[idx + offset].attr(self._op_device_key)
assert device, "Please put you program within device_guard scope."
# op._set_attr(self._op_device_key, device)
# block.ops[idx + 1]._set_attr(self._op_device_key, device)
# block.ops[idx + 2]._set_attr(self._op_device_key, device)
# block.ops[idx + 2]._set_attr(self._op_device_key, device)
for i in range(offset):
block.ops[idx + i]._set_attr(self._op_device_key, device)
elif self._is_optimize_op(op) and op.type == "check_finite_and_unscale":
#op._set_attr(self._op_device_key, "gpu:all")
op_role_var = op.attr(self._op_role_var_key)
param_name = op_role_var[0]
device = self._param_device_map[param_name]
op._set_attr(self._op_device_key, device)
elif self._is_optimize_op(op) and op.type == "cast":
# For fp16-->fp32 cast added by AMP
grad_name = op.output('Out')
assert len(grad_name) == 1
param_name = grad_name[0].strip(core.grad_var_suffix())
device = self._param_device_map[param_name]
op._set_attr(self._op_device_key, device)
elif self._is_gradient_clip_op(op) or self._is_regularization_op(op):
# For gradient clip and regularization ops, we set their op_device
# attribute to the device where their corresponding parameters on.
assert self._op_role_var_key in op.attr_names, "gradient_clip " \
"and regularization ops must have op_role_var attribute."
op_role_var = op.attr(self._op_role_var_key)
assert len(op_role_var) == 2, "op_role_var for gradient_clip " \
"regularization ops must have two elements."
param_name = op_role_var[0]
device = self._param_device_map[param_name]
# For sum op added by global gradient clip, it must be
# put on all devices
if (op.type == 'sum' or op.type == 'sqrt' or
op.type == 'fill_constant' or
op.type == 'elementwise_max' or
op.type == 'elementwise_div'):
device = "gpu:all"
op._set_attr(self._op_device_key, device)
else:
other_known_ops = [
'update_loss_scaling', 'reduce_any', 'concat', 'sum'
]
assert op.type in other_known_ops, "For other ops without " \
"op_device set, they must be one of {}, but it " \
"is {}".format(other_known_ops, op.type)
assert self._is_optimize_op(op)
op._set_attr(self._op_device_key, "gpu:all")
def _add_op_device_attr(self, block):
"""
Add op_device attrribute for ops in block that have
not that attribute set.
"""
for idx, op in enumerate(list(block.ops)):
if (op.type == "create_py_reader" or op.type == "read" or
op.type == "create_double_buffer_reader"):
# Copy read related ops to all section to make them exit
# after each epoch.
# We use "gpu:all" to represent the op should be put on all
# sub-programs, such as lr-related ops. Note that: "gpu:all"
# is only used by pipeline as an indicator.
op._set_attr(self._op_device_key, "gpu:all")
continue
# op_device attribute has been set
if self._get_op_device_attr(op): continue
self._add_op_device_attr_for_op(op, idx, block)
def _check_validation(self, block):
"""
Check whether ops in a block are all validate (i.e., the
op_device attribute has been set).
Then, return all device specifications in order.
Check whether ops in a block have the op_device attribute set.
Then, return all devices in order.
"""
device_specs = []
device_list = []
for op in block.ops:
type = op.type
if not op._has_kernel(type):
if not op._has_kernel(op.type):
assert op.type == "conditional_block" and (
op.attr(self._op_role_key) == int(self._op_role.LRSched)), (
"Now, the only supported op without kernel is "
......@@ -4165,15 +4393,16 @@ class PipelineOptimizer(object):
assert op.has_attr(self._op_device_key), (
"op ({}) has no {} attribute.".format(op.type,
self._op_device_key))
dev_spec = op.attr(self._op_device_key)
assert dev_spec, ("op_device attribute for op "
"{} has not been set.".format(op.type))
dev_type = dev_spec.split(':')[0]
device = op.attr(self._op_device_key)
assert device, ("op_device attribute for op "
"{} has not been set.".format(op.type))
if device == "gpu:all": continue
dev_type = device.split(':')[0]
assert dev_type == "gpu", ("Now only gpu devices are supported "
"for pipeline parallelism.")
if not dev_spec in device_specs:
device_specs.append(dev_spec)
return device_specs
if not device in device_list:
device_list.append(device)
return device_list
def _insert_sendrecv_ops_for_boundaries(self, block):
"""
......@@ -4182,75 +4411,387 @@ class PipelineOptimizer(object):
"""
extra_index = 0
# A map from var to device spec where op takes it as input,
# A map from var to device where op takes it as input,
# avoiding multiple send and recv ops.
var_devspec = dict()
var_dev_map = dict()
for index, op in enumerate(list(block.ops)):
# skips lr-related ops and vars, as we will process them later.
if int(op.attr(self._op_role_key)) & int(self._op_role.LRSched):
continue
# skips update ops and vars, as we will process them later.
if self._is_update_op(op): continue
cur_device_spec = op.attr(self._op_device_key)
cur_device = op.attr(self._op_device_key)
if cur_device == "gpu:all": continue
for var_name in op.input_arg_names:
# i.e., lod_tensor_blocking_queue created by DataLoader,
# which only exists in startup program.
if not var_name in block.vars: continue
# if not var_name in block.vars: continue
var = block.var(var_name)
# skip data, because we will process it later
if var.is_data: continue
prev_device = None
if var_name in self._param_device_map:
prev_device = self._param_device_map[var_name]
prev_op = self._find_real_prev_op(block.ops, op, var_name)
if prev_op is None:
continue
prev_device_spec = prev_op.attr(self._op_device_key)
if not pre_device:
prev_device = prev_op.attr(self._op_device_key) \
if prev_op else None
if not prev_device or prev_device == 'gpu:all': continue
if prev_device_spec != cur_device_spec:
if var_name not in var_devspec:
var_devspec[var_name] = []
if cur_device_spec in var_devspec[var_name]: continue
var_devspec[var_name].append(cur_device_spec)
if prev_device != cur_device:
if var_name not in var_dev_map: var_dev_map[var_name] = []
if cur_device in var_dev_map[var_name]: continue
var_dev_map[var_name].append(cur_device)
op_role = op.all_attrs()[self._op_role_key]
var = block.vars[var_name]
prev_device_index = int(prev_device_spec.split(':')[1])
cur_device_index = int(cur_device_spec.split(':')[1])
prev_device_index = int(prev_device.split(':')[1])
cur_device_index = int(cur_device.split(':')[1])
pair = (prev_device_index, cur_device_index)
pair_key = prev_device_index * 1000 + cur_device_index
if cur_device_index > prev_device_index:
ring_id = self.ring_id + cur_device_index - prev_device_index - 1
else:
ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1
if self.schedule_mode == 0: # GPipe
block._insert_op(
index=index + extra_index,
type='send_v2',
inputs={'X': var},
attrs={
self._op_device_key: prev_device,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': cur_device_index,
'ring_id': self.ring_id
if cur_device_index > prev_device_index else
self.ring_id + 2,
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='recv_v2',
outputs={'Out': [var]},
attrs={
'out_shape': var.shape,
'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': prev_device_index,
'ring_id': self.ring_id
if cur_device_index > prev_device_index else
self.ring_id + 2,
})
extra_index += 1
continue
assert self.schedule_mode == 1
if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair)
self._pp_ring_map[pair_key] = self.ring_id
ring_id = self.ring_id
self.ring_id += 1
else:
ring_id = self._pp_ring_map[pair_key]
block._insert_op(
index=index + extra_index,
type='send_v2',
#type='send_v2',
type='c_broadcast',
inputs={'X': var},
outputs={'Out': var},
attrs={
self._op_device_key: prev_device_spec,
self._op_device_key: prev_device,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': cur_device_index,
'use_calc_stream': False,
#'peer': cur_device_index,
#'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2,
'ring_id': ring_id,
#'ring_id': self.ring_id,
#'root': prev_device_index,
'root': 0,
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: cur_device,
self._op_role_key:
core.op_proto_and_checker_maker.OpRole.Backward,
#self._op_role_key: op_role,
'ring_id': ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
})
extra_index += 1
#block._insert_op(
# index=index + extra_index,
# type='c_sync_comm_stream',
# inputs={'X': [var]},
# outputs={'Out': [var]},
# attrs={
# self._op_device_key: cur_device,
# self._op_role_key:
# core.op_proto_and_checker_maker.OpRole.Backward,
# 'ring_id': self.ring_id,
# #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
# })
#extra_index += 1
fill_shape = list(var.shape)
fill_shape[0] = self.pp_bz
block._insert_op(
index=index + extra_index,
#type='recv_v2',
type='fill_constant',
inputs={},
outputs={'Out': [var]},
attrs={
'shape': fill_shape,
'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'value': float(0.0),
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: cur_device,
#self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward,
self._op_role_key: op_role,
'ring_id': ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
})
extra_index += 1
block._insert_op(
index=index + extra_index,
#type='recv_v2',
type='c_broadcast',
inputs={'X': var},
outputs={'Out': var},
attrs={
#'out_shape': var.shape,
#'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'use_calc_stream': False,
#'peer': prev_device_index,
#'root': prev_device_index,
'root': 0,
#'ring_id': self.ring_id,
'ring_id': ring_id,
#'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2,
#'ring_id': self.ring_id if prev_device_index < cur_device_index else self.ring_id + 2,
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='recv_v2',
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: cur_device,
#self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward,
self._op_role_key: op_role,
'ring_id': ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
})
extra_index += 1
def _xx_insert_sendrecv_ops_for_boundaries(self, block):
"""
Insert a pair of send and recv ops for every two
consecutive ops on different devices.
"""
extra_index = 0
# A map from var to device where op takes it as input,
# avoiding multiple send and recv ops.
input_var_to_device = dict()
# A map from output var to op which generate it.
output_var_to_op = dict()
for index, op in enumerate(list(block.ops)):
for var_name in op.output_arg_names:
ops = output_var_to_op.setdefault(var_name, [])
ops.append([op, index])
for index, op in enumerate(list(block.ops)):
cur_device = op.attr(self._op_device_key)
if cur_device == "gpu:all": continue
for var_name in op.input_arg_names:
var = block.var(var_name)
if var.is_data: continue
#if var_name not in input_var_to_device:
# input_var_to_device[var_name] = []
#if cur_device in input_var_to_device[var_name]:
# continue
#input_var_to_device[var_name].append(cur_device)
prev_device = None
generate_ops = output_var_to_op.get(var_name)
if generate_ops is None:
if var_name not in self._param_device_map:
continue
prev_device = self._param_device_map[var_name]
prev_op = None
for gen_op, gen_idx in reversed(generate_ops):
if gen_idx < index:
prev_op = gen_op
break
if not prev_device:
prev_device = prev_op.attr(self._op_device_key) \
if prev_op else None
if prev_device is None or prev_device == 'gpu:all': continue
if prev_device == cur_device: continue
if var_name not in input_var_to_device:
input_var_to_device[var_name] = []
if (cur_device, prev_device) in input_var_to_device[var_name]:
continue
device_type = cur_device.split(':')[0] + ':'
def _insert_send_recv(cur_id, prev_id):
nonlocal extra_index
cur_dev = device_type + str(cur_id)
prev_dev = device_type + str(prev_id)
if (cur_dev, prev_dev) in input_var_to_device[var_name]:
return
if cur_id - prev_id > 1:
_insert_send_recv(cur_id - 1, prev_id)
_insert_send_recv(cur_id, cur_id - 1)
input_var_to_device[var_name].append(
(cur_dev, prev_dev))
return
elif cur_id - prev_id < -1:
_insert_send_recv(cur_id + 1, prev_id)
_insert_send_recv(cur_id, cur_id + 1)
input_var_to_device[var_name].append(
(cur_dev, prev_dev))
return
assert abs(cur_id - prev_id) == 1
input_var_to_device[var_name].append((cur_dev, prev_dev))
op_role = op.all_attrs()[self._op_role_key]
var = block.vars[var_name]
pair = (prev_id, cur_id)
pair_key = prev_id * 1000 + cur_id
if cur_id > prev_id:
ring_id = self.ring_id + cur_id - prev_id - 1
else:
ring_id = self.ring_id + 2 + prev_id - cur_id - 1
print("call xx_insert, schedule_mode:", self.schedule_mode)
assert self.schedule_mode == 1
if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair)
self._pp_ring_map[pair_key] = self.ring_id
ring_id = self.ring_id
self.ring_id += 1
else:
ring_id = self._pp_ring_map[pair_key]
print("opt: pp_pair: {}, ring_id: {}".format(pair, ring_id))
block._insert_op_without_sync(
index=index + extra_index,
type='c_sync_calc_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: prev_dev,
self._op_role_key: op_role,
})
extra_index += 1
block._insert_op_without_sync(
index=index + extra_index,
type="c_broadcast",
inputs={'X': var},
outputs={'Out': var},
attrs={
self._op_device_key: prev_dev,
self._op_role_key: op_role,
'use_calc_stream': False,
'ring_id': ring_id,
'root': 0,
})
extra_index += 1
fill_shape = list(var.shape)
fill_shape[0] = self.pp_bz
block._insert_op_without_sync(
index=index + extra_index,
type='fill_constant',
inputs={},
outputs={'Out': [var]},
attrs={
'out_shape': var.shape,
'shape': fill_shape,
'dtype': var.dtype,
self._op_device_key: cur_device_spec,
self._op_device_key: cur_dev,
self._op_role_key: op_role,
'value': float(0.0),
})
extra_index += 1
block._insert_op_without_sync(
index=index + extra_index,
type='c_broadcast',
inputs={'X': var},
outputs={'Out': var},
attrs={
self._op_device_key: cur_dev,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': prev_device_index,
'root': 0,
'ring_id': ring_id,
})
extra_index += 1
def _clear_gradients(self, main_block, dev_spec):
# block._insert_op_without_sync(
# index=index + extra_index,
# type='c_sync_comm_stream',
# inputs={'X': [var]},
# outputs={'Out': [var]},
# attrs={
# self._op_device_key: cur_dev,
# self._op_role_key: op_role,
# 'ring_id': ring_id,
# })
# extra_index += 1
_insert_send_recv(
int(cur_device.split(':')[1]),
int(prev_device.split(':')[1]))
block._sync_with_cpp()
def _clear_gradients(self, main_block, param_names):
"""
Clear gradients at the begining of each run of a minibatch.
"""
for param_name in self._param_device_map:
device = self._param_device_map[param_name]
if device != dev_spec: continue
# for param_name in self._param_device_map:
print("param_names:", param_names)
for param_name in param_names:
# device = self._param_device_map[param_name]
# if device != dev_spec: continue
grad_name = self._append_grad_suffix(param_name)
if not main_block.has_var(grad_name): continue
grad_var = main_block.vars[grad_name]
# if not main_block.has_var(grad_name): continue
assert main_block.has_var(grad_name)
grad_var = main_block.var(grad_name)
grad_var.persistable = True
main_block._insert_op(
index=0,
type='fill_constant',
......@@ -4260,21 +4801,20 @@ class PipelineOptimizer(object):
'shape': grad_var.shape,
'dtype': grad_var.dtype,
'value': float(0),
self._op_device_key: device,
# self._op_device_key: device,
# a trick to run this op once per mini-batch
self._op_role_key: self._op_role.Optimize.LRSched,
})
def _accumulate_gradients(self, block):
def _insert_loss_scale(self, block):
"""
Accumulate the gradients generated in microbatch to the one in mini-batch.
We also scale the loss corresponding to number of micro-batches as well.
"""
if self._num_microbatches == 1: return
for index, op in reversed(tuple(enumerate(list(block.ops)))):
offset = index
device = op.attr(self._op_device_key)
#device = op.attr(self._op_device_key)
# Backward pass
if self._is_loss_grad_op(op):
loss_grad_var = block.vars[op.output_arg_names[0]]
scale_factor = self._num_microbatches
......@@ -4285,36 +4825,130 @@ class PipelineOptimizer(object):
outputs={'Out': loss_grad_var},
attrs={
'scale': 1.0 / scale_factor,
self._op_device_key: device,
#self._op_device_key: device,
self._op_role_key: self._op_role.Backward
})
break
def _rename_gradient_var_name(self, block):
for index, op in enumerate(block.ops):
if not self._is_optimize_op(op): continue
input_names = op.input_arg_names
output_names = op.output_arg_names
in_out_names = input_names + output_names
if op.type == 'cast': continue
# append "MERGED" to the names of parameter gradients,
# and mofify the op_role_var attribute (by rename_arg func).
for name in in_out_names:
if not core.grad_var_suffix() in name: continue
param_name = name.strip(core.grad_var_suffix())
new_grad_name = name + "@MERGED"
self._rename_arg(op, name, new_grad_name)
def _accumulate_gradients(self, block, pp_allreduce_in_optimize=False):
"""
Create a new merged gradient for each parameter and accumulate the
corresponding gradient to it.
"""
merged_gradient_names = []
first_opt_op_idx = None
for index, op in reversed(tuple(enumerate(list(block.ops)))):
# remove the cast op of fp16 grad to fp32 grad
if self._is_optimize_op(op) and op.type == 'cast':
in_name = op.input_arg_names[0]
out_name = op.output_arg_names[0]
if out_name.strip('@GRAD') in self._param_device_map:
assert in_name.replace('.cast_fp16', '') == out_name
block._remove_op(index)
continue
if self._is_backward_op(op) and not first_opt_op_idx:
first_opt_op_idx = index + 1
if block.ops[first_opt_op_idx].type == "c_sync_comm_stream":
#block.ops[first_opt_op_idx]._set_attr(self._op_role_key, self._op_role.Backward)
first_opt_op_idx += 1
if self._is_backward_op(op) and (
self._op_role_var_key in op.attr_names):
op_role_var = op.all_attrs()[self._op_role_var_key]
op_role_var = op.attr(self._op_role_var_key)
if len(op_role_var) == 0:
continue
assert len(op_role_var) % 2 == 0
offset = index
#op._remove_attr(self._op_role_var_key)
for i in range(0, len(op_role_var), 2):
grad_name = op_role_var[i + 1]
grad_var = block.vars[grad_name]
new_grad_var_name = unique_name.generate(grad_name)
new_var = self._create_var(block, grad_var,
new_grad_var_name)
self._rename_arg(op, grad_name, new_grad_var_name)
offset = 0
param_name = op_role_var[i]
if not block.has_var(param_name): continue
if '@BroadCast' in param_name: continue
param_grad_name = param_name + core.grad_var_suffix()
merged_param_grad_name = param_grad_name + '@MERGED'
if not block.has_var(merged_param_grad_name):
self._create_var(block, block.vars[param_name],
merged_param_grad_name)
assert block.has_var(merged_param_grad_name)
param_grad_var = block.var(param_grad_name)
merged_param_grad_var = block.var(merged_param_grad_name)
merged_param_grad_var.persistable = True
block._insert_op(
index=offset + 1,
type='sum',
inputs={'X': [grad_var, new_var]},
outputs={'Out': grad_var},
index=first_opt_op_idx + offset,
type='fill_constant',
inputs={},
outputs={'Out': [merged_param_grad_var]},
attrs={
self._op_device_key: device,
self._op_role_key: self._op_role.Backward,
self._op_role_var_key: op_role_var
'shape': merged_param_grad_var.shape,
'dtype': merged_param_grad_var.dtype,
'value': float(0),
# a trick to run this op once per mini-batch
self._op_role_key: self._op_role.Optimize.LRSched,
})
offset += 1
grad_name = op_role_var[i + 1]
grad_var = block.vars[grad_name]
if not 'cast_fp16' in grad_name:
block._insert_op(
index=first_opt_op_idx + offset,
type='sum',
inputs={'X': [grad_var, merged_param_grad_var]},
outputs={'Out': merged_param_grad_var},
attrs={
self._op_role_key: self._op_role.Backward,
})
offset += 1
merged_gradient_names.append(merged_param_grad_name)
else:
# cast gradient to fp32 to accumulate to merged gradient
cast_grad_var_name = param_grad_name + '@TMP'
cast_grad_var = self._create_var(block, param_grad_var,
cast_grad_var_name)
cast_grad_var.persistable = False
block._insert_op(
index=first_opt_op_idx + offset,
type='cast',
inputs={'X': grad_var},
outputs={'Out': cast_grad_var},
attrs={
'in_dtype': grad_var.dtype,
'out_dtype': cast_grad_var.dtype,
self._op_role_key: self._op_role.Backward,
})
offset += 1
block._insert_op(
index=first_opt_op_idx + offset,
type='sum',
inputs={
'X': [merged_param_grad_var, cast_grad_var]
},
outputs={'Out': merged_param_grad_var},
attrs={
# self._op_device_key: device,
self._op_role_key: self._op_role.Backward,
#self._op_role_var_key: op_role_var
})
offset += 1
merged_gradient_names.append(merged_param_grad_name)
return merged_gradient_names
def _add_sub_blocks(self, main_block, program_list):
main_program = main_block.program
......@@ -4372,7 +5006,7 @@ class PipelineOptimizer(object):
block = prog.block(0)
for op in block.ops:
if op.type == "recv_v2" or op.type == "create_py_reader" or \
op.type == "read":
op.type == "read" or op.type == "update_loss_scaling":
continue
# We have processed lr related vars
if op.attr(self._op_role_key) == int(
......@@ -4407,11 +5041,14 @@ class PipelineOptimizer(object):
inputs={'X': write_block.var(var_name), },
attrs={
self._op_device_key: write_device,
'use_calc_stream': True,
'use_calc_stream': False,
# A trick to make the role LRSched to avoid copy every
# microbatch
self._op_role_key: self._op_role.LRSched,
'peer': read_dev_index,
#'ring_id': self.ring_id,
'ring_id': self.ring_id if
read_dev_index > write_dev_index else self.ring_id + 2,
})
read_block._insert_op(
index=0,
......@@ -4421,34 +5058,77 @@ class PipelineOptimizer(object):
'out_shape': read_block.var(var_name).shape,
'dtype': read_block.var(var_name).dtype,
self._op_device_key: read_device,
'use_calc_stream': True,
'use_calc_stream': False,
# A trick to make the role LRSched to avoid copy every
# microbatch
self._op_role_key: self._op_role.LRSched,
'peer': write_dev_index,
#'ring_id': self.ring_id,
'ring_id': self.ring_id if
write_dev_index < read_dev_index else self.ring_id + 2,
})
read_block._insert_op(
index=1,
type='c_sync_comm_stream',
inputs={'X': [read_block.var(var_name)]},
outputs={'Out': [read_block.var(var_name)]},
attrs={
self._op_device_key: read_device,
# A trick to make the role LRSched to avoid copy every
# microbatch
self._op_role_key: self._op_role.LRSched,
'peer': write_dev_index
#'ring_id': self.ring_id,
'ring_id': self.ring_id if
write_dev_index > read_dev_index else self.ring_id + 2,
})
def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip")
def _is_regularization_op(self, op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/regularization")
def minimize(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
main_block = loss.block
self.origin_main_block = main_block
if startup_program is None:
startup_program = default_startup_program()
optimize_ops, params_grads = self._optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set)
self._param_device_map = self._optimizer._param_device_map
# Step1: add default op_device attribute for regulization and clip ops
self._add_opdevice_attr_for_regularization_clip(main_block)
# Step2: add default op_device attribute for ops whose op_device
# attribute have not been set yet. Then check all ops have the
# op_device attribute.
self._add_default_opdevice_attr(main_block)
device_specs = self._check_validation(main_block)
self._param_device_map = self._origin_optimizer._param_device_map
assert main_block.program._pipeline_opt \
and 'local_rank' in main_block.program._pipeline_opt, \
'Please use pipeline with fleet.'
local_rank = main_block.program._pipeline_opt['local_rank']
schedule_mode = 0
if 'schedule_mode' in main_block.program._pipeline_opt:
schedule_mode = main_block.program._pipeline_opt['schedule_mode']
self.schedule_mode = schedule_mode
self.pp_bz = main_block.program._pipeline_opt['pp_bz']
self.use_sharding = False
if 'use_sharding' in main_block.program._pipeline_opt:
self.use_sharding = main_block.program._pipeline_opt['use_sharding']
self.ring_id = 0
if 'ring_id' in main_block.program._pipeline_opt:
self.ring_id = main_block.program._pipeline_opt['ring_id']
if main_block.program._pipeline_opt['global_rank'] == 0:
with open("startup_raw", 'w') as f:
f.writelines(str(startup_program))
with open("main_raw", 'w') as f:
f.writelines(str(main_block.program))
# Step1: add default op_device attribute for ops.
self._add_op_device_attr(main_block)
device_list = self._check_validation(main_block)
def device_cmp(device1, device2):
dev1_id = int(device1.split(':')[1])
......@@ -4460,66 +5140,169 @@ class PipelineOptimizer(object):
else:
return 0
sorted_device_spec = sorted(device_specs, key=cmp_to_key(device_cmp))
assert sorted_device_spec == device_specs, (
"With pipeline "
"parallelism, you must use gpu devices one after another "
"in the order of their ids.")
sorted_device_list = sorted(device_list, key=cmp_to_key(device_cmp))
assert sorted_device_list == device_list, (
"With pipeline parallelism, you must use gpu devices one after "
"another in the order of their ids.")
# Step3: add send and recv ops between section boundaries
self._insert_sendrecv_ops_for_boundaries(main_block)
# Step2: add send and recv ops between section boundaries
self._xx_insert_sendrecv_ops_for_boundaries(main_block)
# Step4: split program into sections and add pairs of
# Step3: split program into sections and add pairs of
# send and recv ops for data var.
main_program = main_block.program
program_list = self._split_program(main_program, device_specs)
program_list = self._split_program(main_program, device_list)
#cur_device_index = 0
#device_num = len(program_list)
for p in program_list:
self._create_vars(p["program"].block(0),
main_program.global_block())
self._insert_sendrecv_for_data_var(main_block, program_list,
startup_program, device_specs)
# Step5: Special Case: process persistable vars that exist in
self._create_vars(p["program"].block(0), main_block)
# # Add send/recv pair to sync the execution.
# block = p['program'].block(0)
# prev_device_index = cur_device_index - 1
# next_device_index = cur_device_index + 1
# add_send_for_forward = False
# add_send_for_backward = False
# add_recv_for_backward = False
# extra_index = 0
# new_var = block.create_var(
# name=unique_name.generate('sync'),
# shape=[1],
# dtype='float32',
# persistable=False,
# stop_gradient=True)
# block._insert_op(
# index=0,
# type='fill_constant',
# inputs={},
# outputs={'Out': [new_var]},
# attrs={
# 'shape': [1],
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Forward,
# 'value': float(0.0),
# })
# extra_index += 1
# for op_idx, op in enumerate(list(block.ops)):
# if op_idx == extra_index:
# if cur_device_index > 0:
# pair_key = prev_device_index * 1000 + cur_device_index
# ring_id = self._pp_ring_map[pair_key]
# block._insert_op(
# index=op_idx,
# type='recv_v2',
# outputs={'Out': [new_var]},
# attrs={
# 'out_shape': new_var.shape,
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Forward,
# 'peer': 0,
# 'use_calc_stream': True,
# 'ring_id': ring_id,
# })
# extra_index += 1
# continue
# if op.type == "send_v2" and self._is_forward_op(op) \
# and not add_send_for_forward \
# and cur_device_index < device_num - 1:
# add_send_for_forward = True
# pair_key = cur_device_index * 1000 + next_device_index
# ring_id = self._pp_ring_map[pair_key]
# block._insert_op(
# index=op_idx + extra_index,
# type='send_v2',
# inputs={'Out': new_var},
# attrs={
# 'out_shape': new_var.shape,
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Forward,
# 'peer': 1,
# 'use_calc_stream': True,
# 'ring_id': ring_id,
# })
# extra_index += 1
# if self._is_backward_op(op) and not add_recv_for_backward \
# and cur_device_index < device_num - 1:
# pair_key = next_device_index * 1000 + cur_device_index
# add_recv_for_backward = True
# ring_id = self._pp_ring_map[pair_key]
# block._insert_op(
# index=op_idx + extra_index,
# type='recv_v2',
# outputs={'Out': [new_var]},
# attrs={
# 'out_shape': new_var.shape,
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Backward,
# 'peer': 0,
# 'use_calc_stream': True,
# 'ring_id': ring_id,
# })
# if op.type == "send_v2" and self._is_backward_op(op) \
# and not add_send_for_backward \
# and cur_device_index > 0:
# pair_key = cur_device_index * 1000 + prev_device_index
# add_send_for_backward = True
# ring_id = self._pp_ring_map[pair_key]
# block._insert_op(
# index=op_idx + extra_index,
# type='send_v2',
# outputs={'Out': [new_var]},
# attrs={
# 'out_shape': new_var.shape,
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Backward,
# 'peer': 1,
# 'use_calc_stream': True,
# 'ring_id': ring_id,
# })
# cur_device_index += 1
#self._insert_sendrecv_for_data_var(main_block, program_list,
# startup_program, device_list)
# Step4: Special Case: process persistable vars that exist in
# multiple sections
self._process_persistable_vars_in_multi_sections(
main_program, startup_program, program_list)
#self._process_persistable_vars_in_multi_sections(
# main_program, startup_program, program_list)
# Step6: Add sub blocks for section programs
# Step5: Add sub blocks for section programs
self._add_sub_blocks(main_block, program_list)
assert (main_program._pipeline_opt and
isinstance(main_program._pipeline_opt, dict) and
'local_rank' in main_program._pipeline_opt), \
"You must use pipeline with fleet"
local_rank = main_program._pipeline_opt['local_rank'] % len(
device_specs)
local_rank = main_program._pipeline_opt['local_rank'] % len(device_list)
place_list = []
for dev_spec in device_specs:
dev_index = dev_spec.split(":")[1]
place_list.append(core.CUDAPlace(local_rank))
for dev in device_list:
dev_index = int(dev.split(":")[1])
place_list.append(core.CUDAPlace(dev_index % 8))
# Step7: Split startup program
# Step6: Split startup program
new_startup_program = self._split_startup_program(startup_program,
local_rank)
# Step8: clear gradients before each mini-batch and
# accumulate gradients during backward
self._clear_gradients(
program_list[local_rank]['program'].global_block(),
dev_spec=device_specs[local_rank])
self._accumulate_gradients(program_list[local_rank]['program']
.global_block())
startup_program._pipeline_opt = {
"startup_program": new_startup_program,
}
real_block = program_list[local_rank]['program'].global_block()
self._insert_loss_scale(real_block)
if not self.use_sharding:
# Step7: clear gradients before each mini-batch and
# accumulate gradients during backward
param_list = []
for param, grad in params_grads:
if real_block.has_var(param): param_list.append(param)
#self._clear_gradients(real_block, param_list)
self._rename_gradient_var_name(real_block)
real_block._sync_with_cpp()
self._accumulate_gradients(real_block)
real_block._sync_with_cpp()
place_id = int(os.getenv("FLAGS_selected_gpus", "0"))
main_program._pipeline_opt = {
"trainer": "PipelineTrainer",
"device_worker": "Section",
"inner_parallelism": len(device_specs),
"inner_parallelism": len(device_list),
"num_pipeline_stages": len(device_list),
"pipeline_stage": local_rank,
"schedule_mode": schedule_mode,
"section_program": program_list[local_rank],
"place": place_list[local_rank],
"place_id": place_id,
......@@ -4527,7 +5310,7 @@ class PipelineOptimizer(object):
"num_microbatches": self._num_microbatches,
"start_cpu_core_id": self._start_cpu_core_id,
}
return optimize_ops, params_grads, program_list
return optimize_ops, params_grads, program_list, self._pipeline_pair, self._pp_ring_map
class RecomputeOptimizer(Optimizer):
......@@ -4928,10 +5711,10 @@ class RecomputeOptimizer(Optimizer):
for output_var in output_vars:
if output_var in need_offload_checkpoint_names:
assert len(
output_vars
) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format(
output_var, op)
#assert len(
# output_vars
#) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format(
# output_var, op)
if output_var in self.un_offload_checkpoint_names:
# insert sync op if last checkpoint has not been sync
......@@ -4956,14 +5739,14 @@ class RecomputeOptimizer(Optimizer):
format(output_var))
# need to sync the last need to offload checkpoint before the last checkpoint as output op
if output_var == last_checkpoint:
assert len(
output_vars
) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format(
output_var, op)
assert last_offload_checkpoint == self.sorted_checkpoint_names[
-2], "the last offload chekpoint before [{}] is suppose to be [{}], but got [{}]".format(
last_checkpoint, self.sorted_checkpoint_names[-2],
last_offload_checkpoint)
#assert len(
# output_vars
#) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format(
# output_var, op)
#assert last_offload_checkpoint == self.sorted_checkpoint_names[
# -2], "the last offload chekpoint before [{}] is suppose to be [{}], but got [{}]".format(
# last_checkpoint, self.sorted_checkpoint_names[-2],
# last_offload_checkpoint)
# sync if last checkpoint has not been sync
if self.checkpoint_usage_count_and_idx[
last_offload_checkpoint]['idx'] == 0:
......@@ -5487,7 +6270,7 @@ class GradientMergeOptimizer(object):
def _is_the_backward_op(self, op):
op_maker = core.op_proto_and_checker_maker
backward = core.op_proto_and_checker_maker.OpRole.Backward
backward = core.op_proto_and_checker_maker.OpRole.Bcackward
if op_maker.kOpRoleVarAttrName() in op.attr_names and \
int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(backward):
return True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册