未验证 提交 4f403d3e 编写于 作者: Z zhaoyingli 提交者: GitHub

update strategy (#46138)

上级 28b4240b
...@@ -81,7 +81,7 @@ set_field_default_config(AMP, "use_optimizer_fp16", False) ...@@ -81,7 +81,7 @@ set_field_default_config(AMP, "use_optimizer_fp16", False)
SHARDING = "sharding" SHARDING = "sharding"
set_field_default_config(SHARDING, "enable", False) set_field_default_config(SHARDING, "enable", False)
set_field_default_config(SHARDING, "stage", 1) set_field_default_config(SHARDING, "stage", 1)
set_field_default_config(SHARDING, "sharding_degree", 8) set_field_default_config(SHARDING, "degree", 8)
set_field_default_config(SHARDING, "segment_broadcast_MB", 32.0) set_field_default_config(SHARDING, "segment_broadcast_MB", 32.0)
set_field_default_config(SHARDING, "enable_tuning", False) set_field_default_config(SHARDING, "enable_tuning", False)
set_field_default_config(SHARDING, "tuning_range", []) set_field_default_config(SHARDING, "tuning_range", [])
......
...@@ -564,9 +564,11 @@ class Engine: ...@@ -564,9 +564,11 @@ class Engine:
self._infer_sample_spec(train_data, batch_size, train_sample_split) self._infer_sample_spec(train_data, batch_size, train_sample_split)
if not self._mode_init_states[self.mode]: if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode) self._prepare_single_mode(self.mode)
else:
self._switch_mode("train")
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"train model is not ready, please call `engine.prepare()` first." "train model is not ready, please call `engine._prepare_single_mode('train')` first."
train_dataloader = self._create_dataloader(train_data, batch_size, train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch, epochs, steps_per_epoch,
collate_fn) collate_fn)
...@@ -621,8 +623,8 @@ class Engine: ...@@ -621,8 +623,8 @@ class Engine:
self.evaluate(valid_data, valid_sample_split, batch_size, self.evaluate(valid_data, valid_sample_split, batch_size,
valid_steps, collate_fn, callbacks) valid_steps, collate_fn, callbacks)
self._switch_mode("train") self._switch_mode("train")
else:
self._reset_metrics() self._reset_metrics()
return outputs return outputs
def evaluate(self, def evaluate(self,
...@@ -682,9 +684,11 @@ class Engine: ...@@ -682,9 +684,11 @@ class Engine:
self._infer_sample_spec(valid_data, batch_size, valid_sample_split) self._infer_sample_spec(valid_data, batch_size, valid_sample_split)
if not self._mode_init_states[self.mode]: if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode) self._prepare_single_mode(self.mode)
else:
self._switch_mode("eval")
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine.prepare()` first." "eval model is not ready, please call `engine._prepare_single_mode('eval')` first."
valid_dataloader = self._create_dataloader(valid_data, valid_dataloader = self._create_dataloader(valid_data,
batch_size, batch_size,
steps_per_epoch=steps, steps_per_epoch=steps,
...@@ -785,9 +789,11 @@ class Engine: ...@@ -785,9 +789,11 @@ class Engine:
self._infer_sample_spec(test_data, batch_size, test_sample_split) self._infer_sample_spec(test_data, batch_size, test_sample_split)
if not self._mode_init_states[self.mode]: if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode) self._prepare_single_mode(self.mode)
else:
self._switch_mode("predict")
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine.prepare()` first." "predict model is not ready, please call `engine._prepare_single_mode('predict')` first."
test_dataloader = self._create_dataloader(test_data, test_dataloader = self._create_dataloader(test_data,
batch_size, batch_size,
steps_per_epoch=steps, steps_per_epoch=steps,
...@@ -1059,7 +1065,7 @@ class Engine: ...@@ -1059,7 +1065,7 @@ class Engine:
""" """
if training: if training:
assert 'train' in self._serial_main_progs, \ assert 'train' in self._serial_main_progs, \
"training model is not ready, please call `engine.prepare()` first." "training model is not ready, please call `engine._prepare_single_mode('train')` first."
serial_program = self._serial_main_progs["train"] serial_program = self._serial_main_progs["train"]
dist_main_prog = self._dist_main_progs["train"][self._cur_rank] dist_main_prog = self._dist_main_progs["train"][self._cur_rank]
dist_context = self._dist_contexts["train"] dist_context = self._dist_contexts["train"]
......
...@@ -117,7 +117,7 @@ class TuningConfig(BaseConfig): ...@@ -117,7 +117,7 @@ class TuningConfig(BaseConfig):
class Strategy(BaseConfig): class Strategy(BaseConfig):
""" """
The `Strategy` object is used to configure the paralleization and optimization beheviors. The `Strategy` object is used to configure the paralleization and optimization beheviors.
Args: Args:
config (dict|string, optional): If this is None, the default configurations will used. config (dict|string, optional): If this is None, the default configurations will used.
...@@ -136,13 +136,13 @@ class Strategy(BaseConfig): ...@@ -136,13 +136,13 @@ class Strategy(BaseConfig):
sharding = strategy.sharding sharding = strategy.sharding
self.assertEqual(sharding.enabled, False) self.assertEqual(sharding.enabled, False)
self.assertEqual(sharding.stage, 1) self.assertEqual(sharding.stage, 1)
self.assertEqual(sharding.sharding_degree, 8) self.assertEqual(sharding.degree, 8)
sharding.enabled = True sharding.enabled = True
sharding.stage = 2 sharding.stage = 2
sharding.sharding_degree = 2 sharding.degree = 2
self.assertEqual(sharding.enabled, True) self.assertEqual(sharding.enabled, True)
self.assertEqual(sharding.stage, 2) self.assertEqual(sharding.stage, 2)
self.assertEqual(sharding.sharding_degree, 2) self.assertEqual(sharding.degree, 2)
""" """
......
...@@ -51,7 +51,8 @@ class ShardingPass(PassBase): ...@@ -51,7 +51,8 @@ class ShardingPass(PassBase):
super(ShardingPass, self).__init__() super(ShardingPass, self).__init__()
self.set_attr("dist_context", None) self.set_attr("dist_context", None)
self.set_attr("stage", None) self.set_attr("stage", None)
self.set_attr("sharding_degree", None) self.set_attr("sharding_degree", None) # for parallelizer
self.set_attr("degree", None) # for parallelizer_v2
self.set_attr("params_grads", []) self.set_attr("params_grads", [])
self.set_attr("global_rank", -1) self.set_attr("global_rank", -1)
self.dp_groups = set() self.dp_groups = set()
...@@ -67,8 +68,15 @@ class ShardingPass(PassBase): ...@@ -67,8 +68,15 @@ class ShardingPass(PassBase):
if self.get_attr("stage") not in [1, 2, 3]: if self.get_attr("stage") not in [1, 2, 3]:
return False return False
if (not isinstance(self.get_attr("sharding_degree"), if self.get_attr("sharding_degree") is not None:
int)) or self.get_attr("sharding_degree") <= 1: if (not isinstance(self.get_attr("sharding_degree"), int)) \
or self.get_attr("sharding_degree") <= 1:
return False
elif self.get_attr("degree") is not None:
if (not isinstance(self.get_attr("degree"), int)) \
or self.get_attr("degree") <= 1:
return False
else:
return False return False
if len(self.get_attr("params_grads")) <= 0: if len(self.get_attr("params_grads")) <= 0:
return False return False
...@@ -83,7 +91,8 @@ class ShardingPass(PassBase): ...@@ -83,7 +91,8 @@ class ShardingPass(PassBase):
def _apply_single_impl(self, main_program, startup_program, context): def _apply_single_impl(self, main_program, startup_program, context):
self._dist_context = self.get_attr("dist_context") self._dist_context = self.get_attr("dist_context")
self.sharding_world_size = int(self.get_attr("sharding_degree")) self.sharding_world_size = int(
self.get_attr("sharding_degree") or self.get_attr("degree"))
self.stage = int(self.get_attr("stage")) self.stage = int(self.get_attr("stage"))
self.global_rank = int(self.get_attr("global_rank")) self.global_rank = int(self.get_attr("global_rank"))
params_grads = self.get_attr("params_grads") params_grads = self.get_attr("params_grads")
......
...@@ -31,7 +31,7 @@ def apply_pass(use_sharding=False): ...@@ -31,7 +31,7 @@ def apply_pass(use_sharding=False):
strategy.reinit = True strategy.reinit = True
if use_sharding: if use_sharding:
sharding = strategy.sharding sharding = strategy.sharding
sharding.sharding_degree = 2 sharding.degree = 2
sharding.stage = 2 sharding.stage = 2
return strategy return strategy
......
...@@ -96,7 +96,7 @@ def train(fetch): ...@@ -96,7 +96,7 @@ def train(fetch):
# sharding config # sharding config
sharding = dist_strategy.sharding sharding = dist_strategy.sharding
sharding.enable = True sharding.enable = True
sharding.sharding_degree = 2 sharding.degree = 2
sharding.stage = 3 sharding.stage = 3
sharding.enable_tuning = True sharding.enable_tuning = True
sharding.tuning_range = [0, 1, 2, 3] sharding.tuning_range = [0, 1, 2, 3]
......
...@@ -32,7 +32,7 @@ def apply_pass(use_sharding=False, stage=None): ...@@ -32,7 +32,7 @@ def apply_pass(use_sharding=False, stage=None):
if use_sharding: if use_sharding:
sharding = strategy.sharding sharding = strategy.sharding
sharding.enable = True sharding.enable = True
sharding.sharding_degree = 2 sharding.degree = 2
sharding.stage = 1 sharding.stage = 1
return strategy return strategy
......
...@@ -45,7 +45,7 @@ class TestStrategy(unittest.TestCase): ...@@ -45,7 +45,7 @@ class TestStrategy(unittest.TestCase):
sharding = strategy.sharding sharding = strategy.sharding
self.assertEqual(sharding.enable, False) self.assertEqual(sharding.enable, False)
self.assertEqual(sharding.stage, 1) self.assertEqual(sharding.stage, 1)
self.assertEqual(sharding.sharding_degree, 8) self.assertEqual(sharding.degree, 8)
self.assertAlmostEqual(sharding.segment_broadcast_MB, 32.0) self.assertAlmostEqual(sharding.segment_broadcast_MB, 32.0)
self.assertEqual(sharding.enable_tuning, False) self.assertEqual(sharding.enable_tuning, False)
self.assertEqual(sharding.tuning_range, []) self.assertEqual(sharding.tuning_range, [])
...@@ -112,13 +112,13 @@ class TestStrategy(unittest.TestCase): ...@@ -112,13 +112,13 @@ class TestStrategy(unittest.TestCase):
sharding = strategy.sharding sharding = strategy.sharding
sharding.enable = True sharding.enable = True
sharding.stage = 2 sharding.stage = 2
sharding.sharding_degree = 2 sharding.degree = 2
sharding.segment_broadcast_MB = 64.0 sharding.segment_broadcast_MB = 64.0
sharding.enable_tuning = True sharding.enable_tuning = True
sharding.tuning_range = [1, 2, 3] sharding.tuning_range = [1, 2, 3]
self.assertEqual(sharding.enable, True) self.assertEqual(sharding.enable, True)
self.assertEqual(sharding.stage, 2) self.assertEqual(sharding.stage, 2)
self.assertEqual(sharding.sharding_degree, 2) self.assertEqual(sharding.degree, 2)
self.assertAlmostEqual(sharding.segment_broadcast_MB, 64.0) self.assertAlmostEqual(sharding.segment_broadcast_MB, 64.0)
self.assertEqual(sharding.enable_tuning, True) self.assertEqual(sharding.enable_tuning, True)
self.assertEqual(sharding.tuning_range, [1, 2, 3]) self.assertEqual(sharding.tuning_range, [1, 2, 3])
...@@ -175,7 +175,7 @@ class TestStrategy(unittest.TestCase): ...@@ -175,7 +175,7 @@ class TestStrategy(unittest.TestCase):
# enable: false # enable: false
# enable_tuning: true # enable_tuning: true
# segment_broadcast_MB: 64.0 # segment_broadcast_MB: 64.0
# sharding_degree: 8 # degree: 8
# stage: 2 # stage: 2
# tuning_range: None # tuning_range: None
# split_data: false # split_data: false
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册