未验证 提交 5c9c1a39 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] recovery annotation (#49665)

* recovery annotation

* bugfix
上级 cc3b2009
...@@ -939,7 +939,6 @@ class Completer: ...@@ -939,7 +939,6 @@ class Completer:
self._dist_context._serial_main_program = serial_main_program self._dist_context._serial_main_program = serial_main_program
if not is_naive_data_parallel(self._dist_context): if not is_naive_data_parallel(self._dist_context):
print("$$$$$$ here 0", flush=True)
self._dist_context.initialize(with_graph=True) self._dist_context.initialize(with_graph=True)
self._prepare() self._prepare()
self._update_process_mesh() self._update_process_mesh()
...@@ -947,7 +946,6 @@ class Completer: ...@@ -947,7 +946,6 @@ class Completer:
# Copy the corresponding distributed attribute from graph to serial_main_program # Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program() self._dist_context.copy_dist_attr_from_graph_to_program()
else: else:
print("$$$$$$ here 2", flush=True)
self._logger.info("Default distributed attributed will be set.") self._logger.info("Default distributed attributed will be set.")
self._dist_context.initialize(with_graph=False) self._dist_context.initialize(with_graph=False)
# A fast and special completion for data parallel # A fast and special completion for data parallel
......
...@@ -89,31 +89,31 @@ class TestAMPPass(unittest.TestCase): ...@@ -89,31 +89,31 @@ class TestAMPPass(unittest.TestCase):
) )
def test_amp_pass(self): def test_amp_pass(self):
# # mp2 training # mp2 training
# mp_engine = self.get_engine() mp_engine = self.get_engine()
# history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size) history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
# mp_losses = np.array(history.history["loss"]) mp_losses = np.array(history.history["loss"])
# mp2 amp-o1 training # mp2 amp-o1 training
amp_o1_engine = self.get_engine(True, "o1") amp_o1_engine = self.get_engine(True, "o1")
history = amp_o1_engine.fit(self.dataset, 3, batch_size=self.batch_size) history = amp_o1_engine.fit(self.dataset, 3, batch_size=self.batch_size)
amp_o1_losses = np.array(history.history["loss"]) amp_o1_losses = np.array(history.history["loss"])
amp_o1_engine.evaluate(self.dataset, 3, batch_size=self.batch_size) amp_o1_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# # self.check_results(mp_losses, amp_o1_losses) # self.check_results(mp_losses, amp_o1_losses)
# # mp2 amp-o2 training # mp2 amp-o2 training
# amp_o2_engine = self.get_engine(True, "o2") amp_o2_engine = self.get_engine(True, "o2")
# history = amp_o2_engine.fit(self.dataset, 3, batch_size=self.batch_size) history = amp_o2_engine.fit(self.dataset, 3, batch_size=self.batch_size)
# amp_o2_losses = np.array(history.history["loss"]) amp_o2_losses = np.array(history.history["loss"])
# amp_o2_engine.evaluate(self.dataset, 3, batch_size=self.batch_size) amp_o2_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# # self.check_results(mp_losses, amp_o2_losses) # self.check_results(mp_losses, amp_o2_losses)
# # mp2 amp-o3 training # mp2 amp-o3 training
# amp_o3_engine = self.get_engine(True, "o3") amp_o3_engine = self.get_engine(True, "o3")
# history = amp_o3_engine.fit(self.dataset, 3, batch_size=self.batch_size) history = amp_o3_engine.fit(self.dataset, 3, batch_size=self.batch_size)
# amp_o3_losses = np.array(history.history["loss"]) amp_o3_losses = np.array(history.history["loss"])
# amp_o3_engine.evaluate(self.dataset, 3, batch_size=self.batch_size) amp_o3_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# # self.check_results(mp_losses, amp_o3_losses) # self.check_results(mp_losses, amp_o3_losses)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -158,9 +158,9 @@ def train_high_level(fetch): ...@@ -158,9 +158,9 @@ def train_high_level(fetch):
eval_dataset2 = MyDataset(batch_size) eval_dataset2 = MyDataset(batch_size)
engine.evaluate(eval_dataset2, batch_size=batch_size) engine.evaluate(eval_dataset2, batch_size=batch_size)
# # predict # predict
# test_dataset = MyDataset(batch_size) test_dataset = MyDataset(batch_size)
# outputs = engine.predict(test_dataset, batch_size=batch_size) outputs = engine.predict(test_dataset, batch_size=batch_size)
# save # save
temp_dir = tempfile.TemporaryDirectory() temp_dir = tempfile.TemporaryDirectory()
...@@ -498,10 +498,10 @@ def get_cost_by_spec(): ...@@ -498,10 +498,10 @@ def get_cost_by_spec():
if __name__ == "__main__": if __name__ == "__main__":
train_high_level(fetch=True) train_high_level(fetch=True)
# train_high_level(fetch=False) train_high_level(fetch=False)
# train_low_level() train_low_level()
# train_builtin_data_vars() train_builtin_data_vars()
# train_non_builtin_data_vars() train_non_builtin_data_vars()
# get_cost() get_cost()
# get_cost_by_default_program() get_cost_by_default_program()
# get_cost_by_spec() get_cost_by_spec()
...@@ -38,8 +38,8 @@ class TestEngineAPI(unittest.TestCase): ...@@ -38,8 +38,8 @@ class TestEngineAPI(unittest.TestCase):
"paddle.distributed.launch", "paddle.distributed.launch",
"--devices", "--devices",
"0,1", "0,1",
# "--log_dir", "--log_dir",
# tmp_dir.name, tmp_dir.name,
launch_model_path, launch_model_path,
] ]
) )
......
...@@ -38,8 +38,8 @@ class TestAMPPass(unittest.TestCase): ...@@ -38,8 +38,8 @@ class TestAMPPass(unittest.TestCase):
"paddle.distributed.launch", "paddle.distributed.launch",
"--devices", "--devices",
"0,1", "0,1",
# "--log_dir", "--log_dir",
# tmp_dir.name, tmp_dir.name,
launch_model_path, launch_model_path,
] ]
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册