提交 5cd2bfec 编写于 作者: W WangXi 提交者: sandyhouse

Add optimize_offload config

上级 47042a97
......@@ -38,7 +38,8 @@ message ShardingConfig {
optional int32 acc_steps = 7 [ default = 1 ];
optional int32 schedule_mode = 8 [ default = 0 ];
optional int32 pp_bz = 9 [ default = 1 ];
optional bool pp_allreduce_in_optimize = 10 [ default = true ];
optional bool pp_allreduce_in_optimize = 10 [ default = false ];
optional bool optimize_offload = 11 [ default = false ];
}
message AMPConfig {
......
......@@ -103,6 +103,8 @@ class ShardingOptimizer(MetaOptimizerBase):
self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"]
self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[
"pp_allreduce_in_optimize"]
self.optimize_offload = self.user_defined_strategy.sharding_configs[
"optimize_offload"]
if self.inner_opt is None:
raise ValueError(
......@@ -359,8 +361,10 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block._sync_with_cpp()
# TODO(wangxi): add optimize offload
offload_helper = OffloadHelper()
offload_helper.offload(main_block, startup_block)
if self.optimize_offload:
logging.info("Sharding with optimize offload !")
offload_helper = OffloadHelper()
offload_helper.offload(main_block, startup_block)
with open("start_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
......@@ -943,9 +947,8 @@ class ShardingOptimizer(MetaOptimizerBase):
]
self.pp_group_size = self.pipeline_nodes
self.pp_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx % self.sharding_group_size
) == self.sharding_rank
ep for idx, ep in enumerate(self.endpoints) if
(idx % self.sharding_group_size) == self.sharding_rank
]
else:
self.mp_group_id = 0
......@@ -969,11 +972,12 @@ class ShardingOptimizer(MetaOptimizerBase):
self._inner_parallelism_size * self.sharding_group_size)
self.megatron_rank = self.global_rank % self._inner_parallelism_size
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx // (self._inner_parallelism_size *
self.sharding_group_size)
) == self.sharding_group_id and idx %
self._inner_parallelism_size == self.megatron_rank
ep for idx, ep in enumerate(self.endpoints) if
(idx //
(self._inner_parallelism_size *
self.sharding_group_size)) == self.sharding_group_id
and
idx % self._inner_parallelism_size == self.megatron_rank
]
print("sharding_endpoint:", self.sharding_group_endpoints)
print("sharding_rank:", self.sharding_rank)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册