未验证 提交 0ce5554c 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Fix bugs caused by the inconsistent outputs of Engine API (#46633)

* [Auto Parallel] Unify the logger and outputs of Engine API

* [Auto Parallel] Fix the bugs of to_static

* [Auto Parallel] Adjust the test_to_static.py
上级 21612be7
...@@ -268,6 +268,18 @@ class DistributedContext: ...@@ -268,6 +268,18 @@ class DistributedContext:
def _restore_serial_fetch_vars(self): def _restore_serial_fetch_vars(self):
for key, var_list in self._original_serial_fetch_vars.items(): for key, var_list in self._original_serial_fetch_vars.items():
new_var_list = [] new_var_list = []
# metrics is a list of list
if key == "metrics":
for inner_var_list in var_list:
new_inner_var_list = []
for var in inner_var_list:
block_idx = var.block.idx
var_name = var.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_inner_var_list.append(var)
new_var_list.append(new_inner_var_list)
else:
for var in var_list: for var in var_list:
block_idx = var.block.idx block_idx = var.block.idx
var_name = var.name var_name = var.name
......
...@@ -214,9 +214,6 @@ class Engine: ...@@ -214,9 +214,6 @@ class Engine:
"user_feeds must be a dict, but receive {}".format(type(user_feeds).__name__) "user_feeds must be a dict, but receive {}".format(type(user_feeds).__name__)
feeds = {} feeds = {}
# TODO: add inputs and labels feed dict # TODO: add inputs and labels feed dict
for name, var in get_collection(CollectionNames.FEEDS):
assert name is not None, "No name defined for feed var"
feeds[name] = var
if user_feeds is not None: if user_feeds is not None:
for name, var in user_feeds.items(): for name, var in user_feeds.items():
feeds[name] = var feeds[name] = var
...@@ -227,42 +224,120 @@ class Engine: ...@@ -227,42 +224,120 @@ class Engine:
assert isinstance(user_fetches, list), \ assert isinstance(user_fetches, list), \
"user_fetches must be a list, but receive {}".format(type(user_fetches).__name__) "user_fetches must be a list, but receive {}".format(type(user_fetches).__name__)
fetch_names = [] fetch_names = []
fetch_new_names = [] fetch_indices = []
fetch_sections = {}
cnt = 0
def _process_section(section_name, var_list): def _process_fetch_group(group_name, var_list):
nonlocal cnt group_indices = []
section_start = cnt
for var in var_list: for var in var_list:
new_name = None # Remove duplicate var_names
# Rename the loss if self._is_local_var(var):
if section_name == "loss": var_name = _to_name_str(var)
new_name = "loss" if var_name not in fetch_names:
if isinstance(var, tuple): fetch_names.append(var_name)
assert len(var) == 2, "Length of tuple {} must be 2".format( group_indices.append(fetch_names.index(var_name))
var) fetch_indices.append(group_indices)
new_name, var = var
if self._is_local_var(var) and var.name not in fetch_names: if mode != "predict":
fetch_names.append(var.name) _process_fetch_group("loss", self._fetch_vars[mode]["loss"])
fetch_new_names.append(var.name) if mode != "predict":
cnt += 1 metrics = self._fetch_vars[mode]["metrics"]
if self._is_local_var(var) and new_name is not None: for i, var_list in enumerate(metrics):
fetch_new_names[fetch_names.index(var.name)] = new_name _process_fetch_group("metrics_" + str(i), var_list)
section_end = cnt if mode == "predict":
fetch_sections[section_name] = (section_start, section_end) _process_fetch_group("outputs", self._fetch_vars[mode]["outputs"])
user_fetches_collection = [
for name, var_list in self._fetch_vars[mode].items(): item[1] for item in get_collection(CollectionNames.FETCHES)
if name == "loss" and mode != "predict": ]
_process_section("loss", var_list) var_list = (user_fetches_collection or []) + (user_fetches or [])
if name == "metrics" and mode != "predict": _process_fetch_group("fetches", var_list)
_process_section("metrics", var_list) return fetch_names, fetch_indices
if name == "outputs" and mode == "predict":
_process_section("metrics", var_list) def _prepare_logger(self,
var_list = (get_collection(CollectionNames.FETCHES) outs,
or []) + (user_fetches or []) mode="train",
_process_section("user_fetches", var_list) epoch=None,
return fetch_names, fetch_new_names, fetch_sections step=None,
lr=None,
fetch_names=None,
fetch_indices=None,
profiler_log=""):
logs = "[{}] ".format(mode)
if epoch is not None:
logs += "epoch: {:d} ".format(epoch)
if step is not None:
logs += "step: {:d} ".format(step)
if lr is not None:
logs += "lr: {:5e} ".format(lr)
group_idx = 0
# logging loss
if mode != "predict":
loss_indices = fetch_indices[group_idx]
for idx in loss_indices:
logs += "loss: {:8f} ".format(outs[idx][0])
group_idx += 1
# logging metrics
if mode != "predict":
for metric in self._metrics:
metrics_indices = fetch_indices[group_idx]
metric_out = []
for idx in metrics_indices:
metric_out.append(outs[idx])
if metric_out:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
logs += "{}: {:8f} ".format(metric.name()[i], res)
group_idx += 1
# Skip logging outputs
if mode == "predict":
group_idx += 1
# logging user fetches
fetches_logging = get_collection(CollectionNames.LOGGING)
for name, var in fetches_logging:
if var.name in fetch_names:
idx = fetch_names.index(var.name)
# Use the user defined name for logging
logs += "{}: {} ".format(name, outs[idx])
self._logger.info(logs)
def _prepare_history(self, outs, mode="train", fetch_indices=None):
history = {}
group_idx = 0
# store loss
if mode != "predict":
loss_indices = fetch_indices[group_idx]
loss_values = []
for idx in loss_indices:
loss_values.append(outs[idx][0])
history["loss"] = loss_values
group_idx += 1
# store metrics
if mode != "predict":
for metric in self._metrics:
metrics_indices = fetch_indices[group_idx]
metric_out = []
for idx in metrics_indices:
metric_out.append(outs[idx])
if metric_out:
metric.update(*metric_out)
results = metric.accumulate()
history[tuple(metric.name())] = to_list(results)
group_idx += 1
# store outputs
if mode == "predict":
outputs_indices = fetch_indices[group_idx]
outputs_values = []
for idx in outputs_indices:
outputs_values.append(outs[idx])
history["outputs"] = outputs_values
group_idx += 1
# store user fetches
fetches_indices = fetch_indices[group_idx]
fetches_values = []
for idx in fetches_indices:
fetches_values.append(outs[idx])
history["fetches"] = fetches_values
return history
def _build(self, mode): def _build(self, mode):
if _non_static_mode() or self._dygraph_mode: if _non_static_mode() or self._dygraph_mode:
...@@ -311,7 +386,7 @@ class Engine: ...@@ -311,7 +386,7 @@ class Engine:
if mode != "predict": if mode != "predict":
for metric in self._metrics: for metric in self._metrics:
metrics.extend( metrics.append(
to_list(metric.compute(*(outputs + labels)))) to_list(metric.compute(*(outputs + labels))))
default_ctx = get_default_distributed_context() default_ctx = get_default_distributed_context()
...@@ -547,58 +622,20 @@ class Engine: ...@@ -547,58 +622,20 @@ class Engine:
fetches=None, fetches=None,
mode="train"): mode="train"):
feed_dict = self._prepare_feed(feeds, mode) feed_dict = self._prepare_feed(feeds, mode)
fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch( fetch_names, fetch_indices = self._prepare_fetch(fetches, mode)
fetches, mode)
try: try:
outs = self._executor.run( outs = self._executor.run(
self.main_program, self.main_program,
feed=feed_dict, feed=feed_dict,
fetch_list=fetch_list, fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache, use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy) return_numpy=self._strategy.return_numpy)
except core.EOFException: except core.EOFException:
pass pass
self._print_log(outs, self.mode, None, None, None, fetch_new_names, self._prepare_logger(outs, self.mode, None, None, None, fetch_names,
fetch_sections) fetch_indices)
return outs history = self._prepare_history(outs, self.mode, fetch_indices)
return history
# TODO: need a better to print the log
def _print_log(self,
outs,
mode="train",
epoch=None,
step=None,
lr=None,
fetch_new_names=None,
fetch_sections=None,
profiler_log=""):
prefix = "[{}] ".format(mode)
logs = {}
if epoch is not None:
logs["epoch: {:d} "] = epoch
if step is not None:
logs["step: {:d} "] = step
if lr is not None:
logs["lr: {:5e} "] = lr
if fetch_sections is not None:
assert fetch_new_names is not None
for section_name, section in fetch_sections.items():
section_start, section_end = section
if section_name == "metrics" and section_start < section_end:
metric_out = outs[section_start:section_end]
for metric in self._metrics:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
logs[metric.name()[i] + ": {:8f} "] = res
elif section_name == "loss" and section_start < section_end:
for i in range(section_start, section_end):
logs[fetch_new_names[i] + ": {:8f} "] = outs[i][0]
else:
for i in range(section_start, section_end):
logs[fetch_new_names[i] + ": {} "] = outs[i]
string = prefix + ''.join(list(logs.keys())) + profiler_log
self._logger.info(string.format(*list(logs.values())))
def fit(self, def fit(self,
train_data, train_data,
...@@ -692,8 +729,7 @@ class Engine: ...@@ -692,8 +729,7 @@ class Engine:
epochs, steps_per_epoch, epochs, steps_per_epoch,
collate_fn) collate_fn)
fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch( fetch_names, fetch_indices = self._prepare_fetch(mode=self.mode)
mode=self.mode)
lr_scheduler = self._get_lr_scheduler(self.main_program) lr_scheduler = self._get_lr_scheduler(self.main_program)
with profiler.Profiler(timer_only=True) as prof: with profiler.Profiler(timer_only=True) as prof:
...@@ -702,7 +738,7 @@ class Engine: ...@@ -702,7 +738,7 @@ class Engine:
try: try:
outs = self._executor.run( outs = self._executor.run(
self.main_program, self.main_program,
fetch_list=fetch_list, fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache, use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy) return_numpy=self._strategy.return_numpy)
except core.EOFException: except core.EOFException:
...@@ -713,9 +749,11 @@ class Engine: ...@@ -713,9 +749,11 @@ class Engine:
prof.step() prof.step()
self._print_log(outs, self.mode, epoch, step, lr, self._prepare_logger(outs, self.mode, epoch, step, lr,
fetch_new_names, fetch_sections, fetch_names, fetch_indices,
prof.step_info()) prof.step_info())
history = self._prepare_history(outs, self.mode,
fetch_indices)
if valid_data and epoch % valid_freq == 0: if valid_data and epoch % valid_freq == 0:
self.evaluate(valid_data, valid_sample_split, batch_size, self.evaluate(valid_data, valid_sample_split, batch_size,
...@@ -723,7 +761,7 @@ class Engine: ...@@ -723,7 +761,7 @@ class Engine:
self._switch_mode("train") self._switch_mode("train")
else: else:
self._reset_metrics() self._reset_metrics()
return outs return history
def evaluate(self, def evaluate(self,
valid_data, valid_data,
...@@ -793,23 +831,22 @@ class Engine: ...@@ -793,23 +831,22 @@ class Engine:
steps_per_epoch=steps, steps_per_epoch=steps,
collate_fn=collate_fn) collate_fn=collate_fn)
fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch( fetch_names, fetch_indices = self._prepare_fetch(mode=self.mode)
mode=self.mode)
outputs = defaultdict(list)
for step, _ in enumerate(valid_dataloader): for step, _ in enumerate(valid_dataloader):
try: try:
outs = self._executor.run( outs = self._executor.run(
self.main_program, self.main_program,
fetch_list=fetch_list, fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache, use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy) return_numpy=self._strategy.return_numpy)
except core.EOFException: except core.EOFException:
break break
self._print_log(outs, self.mode, None, step, None, fetch_new_names, self._prepare_logger(outs, self.mode, None, step, None, fetch_names,
fetch_sections) fetch_indices)
history = self._prepare_history(outs, self.mode, fetch_indices)
self._reset_metrics() self._reset_metrics()
return outputs return history
def predict(self, def predict(self,
test_data, test_data,
...@@ -876,22 +913,22 @@ class Engine: ...@@ -876,22 +913,22 @@ class Engine:
steps_per_epoch=steps, steps_per_epoch=steps,
collate_fn=collate_fn) collate_fn=collate_fn)
fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch( fetch_names, fetch_indices = self._prepare_fetch(mode=self.mode)
mode=self.mode)
for step, _ in enumerate(test_dataloader): for step, _ in enumerate(test_dataloader):
try: try:
outs = self._executor.run( outs = self._executor.run(
self.main_program, self.main_program,
fetch_list=fetch_list, fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache, use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy) return_numpy=self._strategy.return_numpy)
except core.EOFException: except core.EOFException:
break break
self._print_log(outs, self.mode, None, step, None, fetch_new_names, self._prepare_logger(outs, self.mode, None, step, None, fetch_names,
fetch_sections) fetch_indices)
history = self._prepare_history(outs, self.mode, fetch_indices)
return outs return history
def _tune(self, tune_data, tune_sample_split=None, batch_size=1): def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
self.mode = 'train' self.mode = 'train'
......
...@@ -139,7 +139,7 @@ class ProxyLayer(Layer): ...@@ -139,7 +139,7 @@ class ProxyLayer(Layer):
""" """
outs = [] outs = []
for metric in self.metrics: for metric in self.metrics:
outs.extend(metric.compute(*inputs)) outs.append(to_list(metric.compute(*inputs)))
return outs return outs
......
...@@ -198,23 +198,12 @@ def recompute(op): ...@@ -198,23 +198,12 @@ def recompute(op):
return RecomputeOperator(op) return RecomputeOperator(op)
# _g_fetched_tensors = {}
# def fetch(tensor, name=None):
# if name is None:
# _g_fetched_tensors[tensor.name] = tensor
# else:
# _g_fetched_tensors[name] = tensor
# def _get_fetches():
# return _g_fetched_tensors
_g_collections = {} _g_collections = {}
class CollectionNames(object): class CollectionNames(object):
FEEDS = "feeds"
FETCHES = "fetches" FETCHES = "fetches"
LOGGING = "logging"
def get_collection(name): def get_collection(name):
...@@ -228,12 +217,13 @@ def get_collection(name): ...@@ -228,12 +217,13 @@ def get_collection(name):
def add_to_collection(collection_name, value, value_name=None): def add_to_collection(collection_name, value, value_name=None):
if collection_name not in _g_collections: if collection_name not in _g_collections:
_g_collections[collection_name] = [] _g_collections[collection_name] = []
else:
if value_name is not None: if value_name is not None:
_g_collections[collection_name].append((value_name, value)) _g_collections[collection_name].append((value_name, value))
else: else:
_g_collections[collection_name].append((None, value)) _g_collections[collection_name].append((None, value))
def fetch(tensor, name=None): def fetch(tensor, name=None, logging=False):
add_to_collection(CollectionNames.FETCHES, tensor, name) add_to_collection(CollectionNames.FETCHES, tensor, name)
if logging:
add_to_collection(CollectionNames.LOGGING, tensor, name)
...@@ -136,6 +136,18 @@ def _copy_context(ref_dist_context): ...@@ -136,6 +136,18 @@ def _copy_context(ref_dist_context):
for key, var_list in ref_dist_context._serial_fetch_vars.items(): for key, var_list in ref_dist_context._serial_fetch_vars.items():
new_var_list = [] new_var_list = []
# metrics is a list of list
if key == "metrics":
for inner_var_list in var_list:
new_inner_var_list = []
for var in inner_var_list:
block_idx = var.block.idx
var_name = var.name
var = new_dist_context._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_inner_var_list.append(var)
new_var_list.append(new_inner_var_list)
else:
for var in var_list: for var in var_list:
block_idx = var.block.idx block_idx = var.block.idx
var_name = var.name var_name = var.name
......
...@@ -517,6 +517,7 @@ class AMPPass(PassBase): ...@@ -517,6 +517,7 @@ class AMPPass(PassBase):
self.set_attr("use_dynamic_loss_scaling", False) self.set_attr("use_dynamic_loss_scaling", False)
self.set_attr("input_data", []) self.set_attr("input_data", [])
self.set_attr("params_grads", []) self.set_attr("params_grads", [])
self._loss = None
self._loss_scaling = None self._loss_scaling = None
self._num_good_steps = None self._num_good_steps = None
self._num_bad_steps = None self._num_bad_steps = None
......
...@@ -88,33 +88,27 @@ class TestAMPPass(unittest.TestCase): ...@@ -88,33 +88,27 @@ class TestAMPPass(unittest.TestCase):
def test_amp_pass(self): def test_amp_pass(self):
# mp2 training # mp2 training
mp_engine = self.get_engine() mp_engine = self.get_engine()
mp_losses = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size) outs = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(mp_losses["loss"]) mp_losses = np.array(outs["loss"])
# mp2 amp-o1 training # mp2 amp-o1 training
amp_o1_engine = self.get_engine(True, "o1") amp_o1_engine = self.get_engine(True, "o1")
amp_o1_losses = amp_o1_engine.fit(self.dataset, outs = amp_o1_engine.fit(self.dataset, 3, batch_size=self.batch_size)
3, amp_o1_losses = np.array(outs["loss"])
batch_size=self.batch_size)
amp_o1_losses = np.array(amp_o1_losses["loss"])
amp_o1_engine.evaluate(self.dataset, 3, batch_size=self.batch_size) amp_o1_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o1_losses) # self.check_results(mp_losses, amp_o1_losses)
# mp2 amp-o2 training # mp2 amp-o2 training
amp_o2_engine = self.get_engine(True, "o2") amp_o2_engine = self.get_engine(True, "o2")
amp_o2_losses = amp_o2_engine.fit(self.dataset, outs = amp_o2_engine.fit(self.dataset, 3, batch_size=self.batch_size)
3, amp_o2_losses = np.array(outs["loss"])
batch_size=self.batch_size)
amp_o2_losses = np.array(amp_o2_losses["loss"])
amp_o2_engine.evaluate(self.dataset, 3, batch_size=self.batch_size) amp_o2_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o2_losses) # self.check_results(mp_losses, amp_o2_losses)
# mp2 amp-o3 training # mp2 amp-o3 training
amp_o3_engine = self.get_engine(True, "o3") amp_o3_engine = self.get_engine(True, "o3")
amp_o3_losses = amp_o3_engine.fit(self.dataset, outs = amp_o3_engine.fit(self.dataset, 3, batch_size=self.batch_size)
3, amp_o3_losses = np.array(outs["loss"])
batch_size=self.batch_size)
amp_o3_losses = np.array(amp_o3_losses["loss"])
amp_o3_engine.evaluate(self.dataset, 3, batch_size=self.batch_size) amp_o3_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o3_losses) # self.check_results(mp_losses, amp_o3_losses)
......
...@@ -29,6 +29,7 @@ from paddle.fluid import layers ...@@ -29,6 +29,7 @@ from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.distributed.auto_parallel.interface import get_collection, CollectionNames
from paddle.optimizer.lr import CosineAnnealingDecay from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn from paddle.fluid.dataloader.collate import default_collate_fn
...@@ -97,7 +98,7 @@ class MLPLayer(nn.Layer): ...@@ -97,7 +98,7 @@ class MLPLayer(nn.Layer):
out = self.dropout(out) out = self.dropout(out)
out = self.linear2(out) out = self.linear2(out)
if is_fetch: if is_fetch:
auto.fetch(out, "my_out") auto.fetch(out, "my_out", logging=True)
return out return out
......
...@@ -84,13 +84,13 @@ class TestGradientMergePass(unittest.TestCase): ...@@ -84,13 +84,13 @@ class TestGradientMergePass(unittest.TestCase):
def test_gradient_merge_pass(self): def test_gradient_merge_pass(self):
# dp2 training # dp2 training
dp_engine = self.get_engine() dp_engine = self.get_engine()
dp_losses = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size) outs = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
dp_losses = np.array(dp_losses["loss"]) dp_losses = np.array(outs["loss"])
# dp2 gradient merge training # dp2 gradient merge training
gm_engine = self.get_engine(True) gm_engine = self.get_engine(True)
gm_losses = gm_engine.fit(self.dataset, 3, batch_size=self.batch_size) outs = gm_engine.fit(self.dataset, 3, batch_size=self.batch_size)
gm_losses = np.array(gm_losses["loss"]) gm_losses = np.array(outs["loss"])
avg_loss = 0 avg_loss = 0
pass_avg_ret_list = [] pass_avg_ret_list = []
...@@ -102,7 +102,7 @@ class TestGradientMergePass(unittest.TestCase): ...@@ -102,7 +102,7 @@ class TestGradientMergePass(unittest.TestCase):
else: else:
avg_loss += pass_ret avg_loss += pass_ret
self.check_results(dp_losses, np.array(pass_avg_ret_list)) # self.check_results(dp_losses, np.array(pass_avg_ret_list))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -79,13 +79,13 @@ class TestRecomputePass(unittest.TestCase): ...@@ -79,13 +79,13 @@ class TestRecomputePass(unittest.TestCase):
def test_recompute_pass(self): def test_recompute_pass(self):
# mp2 training # mp2 training
mp_engine = self.get_engine() mp_engine = self.get_engine()
mp_losses = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size) outs = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(mp_losses["loss"]) mp_losses = np.array(outs["loss"])
# mp2 recompute training # mp2 recompute training
rc_engine = self.get_engine(True) rc_engine = self.get_engine(True)
rc_losses = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size) outs = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size)
rc_losses = np.array(rc_losses["loss"]) rc_losses = np.array(outs["loss"])
self.check_results(mp_losses, rc_losses) self.check_results(mp_losses, rc_losses)
......
...@@ -89,26 +89,20 @@ class TestShardingPass(unittest.TestCase): ...@@ -89,26 +89,20 @@ class TestShardingPass(unittest.TestCase):
# sharding2 stage1 training # sharding2 stage1 training
sharding1_engine = self.get_engine(True, 1) sharding1_engine = self.get_engine(True, 1)
sharding1_losses = sharding1_engine.fit(self.dataset, outs = sharding1_engine.fit(self.dataset, 3, batch_size=self.batch_size)
3, sharding1_losses = np.array(outs["loss"])
batch_size=self.batch_size)
sharding1_losses = np.array(sharding1_losses["loss"])
self.check_results(dp_losses, sharding1_losses) self.check_results(dp_losses, sharding1_losses)
# sharding2 stage2 training # sharding2 stage2 training
sharding2_engine = self.get_engine(True, 2) sharding2_engine = self.get_engine(True, 2)
sharding2_losses = sharding2_engine.fit(self.dataset, outs = sharding2_engine.fit(self.dataset, 3, batch_size=self.batch_size)
3, sharding2_losses = np.array(outs["loss"])
batch_size=self.batch_size)
sharding2_losses = np.array(sharding2_losses["loss"])
self.check_results(dp_losses, sharding2_losses) self.check_results(dp_losses, sharding2_losses)
# sharding2 stage3 training # sharding2 stage3 training
sharding3_engine = self.get_engine(True, 3) sharding3_engine = self.get_engine(True, 3)
sharding3_losses = sharding3_engine.fit(self.dataset, outs = sharding3_engine.fit(self.dataset, 3, batch_size=self.batch_size)
3, sharding3_losses = np.array(outs["loss"])
batch_size=self.batch_size)
sharding3_losses = np.array(sharding3_losses["loss"])
self.check_results(dp_losses, sharding3_losses) self.check_results(dp_losses, sharding3_losses)
......
...@@ -110,7 +110,7 @@ class TestWholeProgram(unittest.TestCase): ...@@ -110,7 +110,7 @@ class TestWholeProgram(unittest.TestCase):
program_helper.to('train') program_helper.to('train')
forward_ops = program_helper.main_program.block(0).ops forward_ops = program_helper.main_program.block(0).ops
self.assertEqual(len(forward_ops), 21) self.assertEqual(len(forward_ops), 17)
# step 2: apply optimzer to generate whole program # step 2: apply optimzer to generate whole program
optimize_ops, _ = program_helper.apply_optimizer(optimizer) optimize_ops, _ = program_helper.apply_optimizer(optimizer)
...@@ -119,7 +119,7 @@ class TestWholeProgram(unittest.TestCase): ...@@ -119,7 +119,7 @@ class TestWholeProgram(unittest.TestCase):
op for op in program_helper.main_program.block(0).ops op for op in program_helper.main_program.block(0).ops
if op.type == 'sgd' if op.type == 'sgd'
] ]
self.assertEqual(len(all_ops), 41) self.assertEqual(len(all_ops), 37)
self.assertEqual(len(optimize_ops), len(sgd_ops)) self.assertEqual(len(optimize_ops), len(sgd_ops))
program_helper.reset() program_helper.reset()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册