未验证 提交 5c9c1a39 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] recovery annotation (#49665)

* recovery annotation

* bugfix
上级 cc3b2009
......@@ -939,7 +939,6 @@ class Completer:
self._dist_context._serial_main_program = serial_main_program
if not is_naive_data_parallel(self._dist_context):
print("$$$$$$ here 0", flush=True)
self._dist_context.initialize(with_graph=True)
self._prepare()
self._update_process_mesh()
......@@ -947,7 +946,6 @@ class Completer:
# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()
else:
print("$$$$$$ here 2", flush=True)
self._logger.info("Default distributed attributed will be set.")
self._dist_context.initialize(with_graph=False)
# A fast and special completion for data parallel
......
......@@ -89,31 +89,31 @@ class TestAMPPass(unittest.TestCase):
)
def test_amp_pass(self):
# # mp2 training
# mp_engine = self.get_engine()
# history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
# mp_losses = np.array(history.history["loss"])
# mp2 training
mp_engine = self.get_engine()
history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(history.history["loss"])
# mp2 amp-o1 training
amp_o1_engine = self.get_engine(True, "o1")
history = amp_o1_engine.fit(self.dataset, 3, batch_size=self.batch_size)
amp_o1_losses = np.array(history.history["loss"])
amp_o1_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# # self.check_results(mp_losses, amp_o1_losses)
# # mp2 amp-o2 training
# amp_o2_engine = self.get_engine(True, "o2")
# history = amp_o2_engine.fit(self.dataset, 3, batch_size=self.batch_size)
# amp_o2_losses = np.array(history.history["loss"])
# amp_o2_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# # self.check_results(mp_losses, amp_o2_losses)
# # mp2 amp-o3 training
# amp_o3_engine = self.get_engine(True, "o3")
# history = amp_o3_engine.fit(self.dataset, 3, batch_size=self.batch_size)
# amp_o3_losses = np.array(history.history["loss"])
# amp_o3_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# # self.check_results(mp_losses, amp_o3_losses)
# self.check_results(mp_losses, amp_o1_losses)
# mp2 amp-o2 training
amp_o2_engine = self.get_engine(True, "o2")
history = amp_o2_engine.fit(self.dataset, 3, batch_size=self.batch_size)
amp_o2_losses = np.array(history.history["loss"])
amp_o2_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o2_losses)
# mp2 amp-o3 training
amp_o3_engine = self.get_engine(True, "o3")
history = amp_o3_engine.fit(self.dataset, 3, batch_size=self.batch_size)
amp_o3_losses = np.array(history.history["loss"])
amp_o3_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o3_losses)
if __name__ == "__main__":
......
......@@ -158,9 +158,9 @@ def train_high_level(fetch):
eval_dataset2 = MyDataset(batch_size)
engine.evaluate(eval_dataset2, batch_size=batch_size)
# # predict
# test_dataset = MyDataset(batch_size)
# outputs = engine.predict(test_dataset, batch_size=batch_size)
# predict
test_dataset = MyDataset(batch_size)
outputs = engine.predict(test_dataset, batch_size=batch_size)
# save
temp_dir = tempfile.TemporaryDirectory()
......@@ -498,10 +498,10 @@ def get_cost_by_spec():
if __name__ == "__main__":
train_high_level(fetch=True)
# train_high_level(fetch=False)
# train_low_level()
# train_builtin_data_vars()
# train_non_builtin_data_vars()
# get_cost()
# get_cost_by_default_program()
# get_cost_by_spec()
train_high_level(fetch=False)
train_low_level()
train_builtin_data_vars()
train_non_builtin_data_vars()
get_cost()
get_cost_by_default_program()
get_cost_by_spec()
......@@ -38,8 +38,8 @@ class TestEngineAPI(unittest.TestCase):
"paddle.distributed.launch",
"--devices",
"0,1",
# "--log_dir",
# tmp_dir.name,
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
......
......@@ -38,8 +38,8 @@ class TestAMPPass(unittest.TestCase):
"paddle.distributed.launch",
"--devices",
"0,1",
# "--log_dir",
# tmp_dir.name,
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册