提交 5cd2bfec 编写于 作者: W WangXi 提交者: sandyhouse

Add optimize_offload config

上级 47042a97
...@@ -38,7 +38,8 @@ message ShardingConfig { ...@@ -38,7 +38,8 @@ message ShardingConfig {
optional int32 acc_steps = 7 [ default = 1 ]; optional int32 acc_steps = 7 [ default = 1 ];
optional int32 schedule_mode = 8 [ default = 0 ]; optional int32 schedule_mode = 8 [ default = 0 ];
optional int32 pp_bz = 9 [ default = 1 ]; optional int32 pp_bz = 9 [ default = 1 ];
optional bool pp_allreduce_in_optimize = 10 [ default = true ]; optional bool pp_allreduce_in_optimize = 10 [ default = false ];
optional bool optimize_offload = 11 [ default = false ];
} }
message AMPConfig { message AMPConfig {
......
...@@ -103,6 +103,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -103,6 +103,8 @@ class ShardingOptimizer(MetaOptimizerBase):
self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"] self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"]
self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[ self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[
"pp_allreduce_in_optimize"] "pp_allreduce_in_optimize"]
self.optimize_offload = self.user_defined_strategy.sharding_configs[
"optimize_offload"]
if self.inner_opt is None: if self.inner_opt is None:
raise ValueError( raise ValueError(
...@@ -359,8 +361,10 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -359,8 +361,10 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block._sync_with_cpp() main_block._sync_with_cpp()
# TODO(wangxi): add optimize offload # TODO(wangxi): add optimize offload
offload_helper = OffloadHelper() if self.optimize_offload:
offload_helper.offload(main_block, startup_block) logging.info("Sharding with optimize offload !")
offload_helper = OffloadHelper()
offload_helper.offload(main_block, startup_block)
with open("start_sharding_%d" % self.role_maker._worker_index(), with open("start_sharding_%d" % self.role_maker._worker_index(),
'w') as f: 'w') as f:
...@@ -943,9 +947,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -943,9 +947,8 @@ class ShardingOptimizer(MetaOptimizerBase):
] ]
self.pp_group_size = self.pipeline_nodes self.pp_group_size = self.pipeline_nodes
self.pp_group_endpoints = [ self.pp_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints) ep for idx, ep in enumerate(self.endpoints) if
if (idx % self.sharding_group_size (idx % self.sharding_group_size) == self.sharding_rank
) == self.sharding_rank
] ]
else: else:
self.mp_group_id = 0 self.mp_group_id = 0
...@@ -969,11 +972,12 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -969,11 +972,12 @@ class ShardingOptimizer(MetaOptimizerBase):
self._inner_parallelism_size * self.sharding_group_size) self._inner_parallelism_size * self.sharding_group_size)
self.megatron_rank = self.global_rank % self._inner_parallelism_size self.megatron_rank = self.global_rank % self._inner_parallelism_size
self.sharding_group_endpoints = [ self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints) ep for idx, ep in enumerate(self.endpoints) if
if (idx // (self._inner_parallelism_size * (idx //
self.sharding_group_size) (self._inner_parallelism_size *
) == self.sharding_group_id and idx % self.sharding_group_size)) == self.sharding_group_id
self._inner_parallelism_size == self.megatron_rank and
idx % self._inner_parallelism_size == self.megatron_rank
] ]
print("sharding_endpoint:", self.sharding_group_endpoints) print("sharding_endpoint:", self.sharding_group_endpoints)
print("sharding_rank:", self.sharding_rank) print("sharding_rank:", self.sharding_rank)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册