未验证 提交 4f403d3e 编写于 作者: Z zhaoyingli 提交者: GitHub

update strategy (#46138)

上级 28b4240b
......@@ -81,7 +81,7 @@ set_field_default_config(AMP, "use_optimizer_fp16", False)
SHARDING = "sharding"
set_field_default_config(SHARDING, "enable", False)
set_field_default_config(SHARDING, "stage", 1)
set_field_default_config(SHARDING, "sharding_degree", 8)
set_field_default_config(SHARDING, "degree", 8)
set_field_default_config(SHARDING, "segment_broadcast_MB", 32.0)
set_field_default_config(SHARDING, "enable_tuning", False)
set_field_default_config(SHARDING, "tuning_range", [])
......
......@@ -564,9 +564,11 @@ class Engine:
self._infer_sample_spec(train_data, batch_size, train_sample_split)
if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode)
else:
self._switch_mode("train")
assert self.mode in self._dist_main_progs, \
"train model is not ready, please call `engine.prepare()` first."
"train model is not ready, please call `engine._prepare_single_mode('train')` first."
train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch,
collate_fn)
......@@ -621,7 +623,7 @@ class Engine:
self.evaluate(valid_data, valid_sample_split, batch_size,
valid_steps, collate_fn, callbacks)
self._switch_mode("train")
else:
self._reset_metrics()
return outputs
......@@ -682,9 +684,11 @@ class Engine:
self._infer_sample_spec(valid_data, batch_size, valid_sample_split)
if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode)
else:
self._switch_mode("eval")
assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine.prepare()` first."
"eval model is not ready, please call `engine._prepare_single_mode('eval')` first."
valid_dataloader = self._create_dataloader(valid_data,
batch_size,
steps_per_epoch=steps,
......@@ -785,9 +789,11 @@ class Engine:
self._infer_sample_spec(test_data, batch_size, test_sample_split)
if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode)
else:
self._switch_mode("predict")
assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine.prepare()` first."
"predict model is not ready, please call `engine._prepare_single_mode('predict')` first."
test_dataloader = self._create_dataloader(test_data,
batch_size,
steps_per_epoch=steps,
......@@ -1059,7 +1065,7 @@ class Engine:
"""
if training:
assert 'train' in self._serial_main_progs, \
"training model is not ready, please call `engine.prepare()` first."
"training model is not ready, please call `engine._prepare_single_mode('train')` first."
serial_program = self._serial_main_progs["train"]
dist_main_prog = self._dist_main_progs["train"][self._cur_rank]
dist_context = self._dist_contexts["train"]
......
......@@ -136,13 +136,13 @@ class Strategy(BaseConfig):
sharding = strategy.sharding
self.assertEqual(sharding.enabled, False)
self.assertEqual(sharding.stage, 1)
self.assertEqual(sharding.sharding_degree, 8)
self.assertEqual(sharding.degree, 8)
sharding.enabled = True
sharding.stage = 2
sharding.sharding_degree = 2
sharding.degree = 2
self.assertEqual(sharding.enabled, True)
self.assertEqual(sharding.stage, 2)
self.assertEqual(sharding.sharding_degree, 2)
self.assertEqual(sharding.degree, 2)
"""
......
......@@ -51,7 +51,8 @@ class ShardingPass(PassBase):
super(ShardingPass, self).__init__()
self.set_attr("dist_context", None)
self.set_attr("stage", None)
self.set_attr("sharding_degree", None)
self.set_attr("sharding_degree", None) # for parallelizer
self.set_attr("degree", None) # for parallelizer_v2
self.set_attr("params_grads", [])
self.set_attr("global_rank", -1)
self.dp_groups = set()
......@@ -67,8 +68,15 @@ class ShardingPass(PassBase):
if self.get_attr("stage") not in [1, 2, 3]:
return False
if (not isinstance(self.get_attr("sharding_degree"),
int)) or self.get_attr("sharding_degree") <= 1:
if self.get_attr("sharding_degree") is not None:
if (not isinstance(self.get_attr("sharding_degree"), int)) \
or self.get_attr("sharding_degree") <= 1:
return False
elif self.get_attr("degree") is not None:
if (not isinstance(self.get_attr("degree"), int)) \
or self.get_attr("degree") <= 1:
return False
else:
return False
if len(self.get_attr("params_grads")) <= 0:
return False
......@@ -83,7 +91,8 @@ class ShardingPass(PassBase):
def _apply_single_impl(self, main_program, startup_program, context):
self._dist_context = self.get_attr("dist_context")
self.sharding_world_size = int(self.get_attr("sharding_degree"))
self.sharding_world_size = int(
self.get_attr("sharding_degree") or self.get_attr("degree"))
self.stage = int(self.get_attr("stage"))
self.global_rank = int(self.get_attr("global_rank"))
params_grads = self.get_attr("params_grads")
......
......@@ -31,7 +31,7 @@ def apply_pass(use_sharding=False):
strategy.reinit = True
if use_sharding:
sharding = strategy.sharding
sharding.sharding_degree = 2
sharding.degree = 2
sharding.stage = 2
return strategy
......
......@@ -96,7 +96,7 @@ def train(fetch):
# sharding config
sharding = dist_strategy.sharding
sharding.enable = True
sharding.sharding_degree = 2
sharding.degree = 2
sharding.stage = 3
sharding.enable_tuning = True
sharding.tuning_range = [0, 1, 2, 3]
......
......@@ -32,7 +32,7 @@ def apply_pass(use_sharding=False, stage=None):
if use_sharding:
sharding = strategy.sharding
sharding.enable = True
sharding.sharding_degree = 2
sharding.degree = 2
sharding.stage = 1
return strategy
......
......@@ -45,7 +45,7 @@ class TestStrategy(unittest.TestCase):
sharding = strategy.sharding
self.assertEqual(sharding.enable, False)
self.assertEqual(sharding.stage, 1)
self.assertEqual(sharding.sharding_degree, 8)
self.assertEqual(sharding.degree, 8)
self.assertAlmostEqual(sharding.segment_broadcast_MB, 32.0)
self.assertEqual(sharding.enable_tuning, False)
self.assertEqual(sharding.tuning_range, [])
......@@ -112,13 +112,13 @@ class TestStrategy(unittest.TestCase):
sharding = strategy.sharding
sharding.enable = True
sharding.stage = 2
sharding.sharding_degree = 2
sharding.degree = 2
sharding.segment_broadcast_MB = 64.0
sharding.enable_tuning = True
sharding.tuning_range = [1, 2, 3]
self.assertEqual(sharding.enable, True)
self.assertEqual(sharding.stage, 2)
self.assertEqual(sharding.sharding_degree, 2)
self.assertEqual(sharding.degree, 2)
self.assertAlmostEqual(sharding.segment_broadcast_MB, 64.0)
self.assertEqual(sharding.enable_tuning, True)
self.assertEqual(sharding.tuning_range, [1, 2, 3])
......@@ -175,7 +175,7 @@ class TestStrategy(unittest.TestCase):
# enable: false
# enable_tuning: true
# segment_broadcast_MB: 64.0
# sharding_degree: 8
# degree: 8
# stage: 2
# tuning_range: None
# split_data: false
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册