未验证 提交 815f7a67 编写于 作者: B Baibaifan 提交者: GitHub

change_ASP_sharding_option (#40028)

上级 34d93bee
......@@ -1430,6 +1430,22 @@ class Fleet(object):
# cache original feed forward program
self.origin_main_program = loss.block.program
# add distributed attr
if not hasattr(self.origin_main_program, "distributed_info_"):
setattr(self.origin_main_program, "distributed_info_", dict())
self.origin_main_program.distributed_info_[
"dp_degree"] = self._user_defined_strategy.sharding_configs[
"dp_degree"]
self.origin_main_program.distributed_info_[
"mp_degree"] = self._user_defined_strategy.sharding_configs[
"mp_degree"]
self.origin_main_program.distributed_info_[
"pp_degree"] = self._user_defined_strategy.sharding_configs[
"pp_degree"]
self.origin_main_program.distributed_info_[
"sharding_degree"] = self._user_defined_strategy.sharding_configs[
"sharding_degree"]
context["origin_main_program"] = self.origin_main_program
context["loss"] = loss
if startup_program == None:
......
......@@ -155,8 +155,7 @@ def prune_model(main_program=None,
n=2,
m=4,
mask_algo='mask_1d',
with_mask=True,
sharding=False):
with_mask=True):
r"""
Pruning parameters of supported layers in :attr:`main_program` via
specified mask generation function given by :attr:`mask_algo`. This
......@@ -179,7 +178,6 @@ def prune_model(main_program=None,
mask_algo (string, optional): The function name to generate spase mask. Default is `mask_1d`.
The vaild inputs should be one of 'mask_1d', 'mask_2d_greedy' and 'mask_2d_best'.
with_mask (bool, optional): To prune mask Variables related to parameters or not. Ture is purning also, False is not. Defalut is True.
sharding (bool, optional): Whether to turn on sharding (model parallel) during training. Please consider turning it ON when encountering OOM using sharding. Default is False.
Returns:
dictionary: A dictionary with key: `parameter name` (string) and value: its corresponding mask Variable.
Examples:
......@@ -221,7 +219,10 @@ def prune_model(main_program=None,
# Must call `exe.run(startup_program)` first before calling `sparsity.prune_model`
sparsity.prune_model(main_program, mask_algo='mask_2d_best')
"""
if sharding:
if main_program is not None and hasattr(
main_program,
"distributed_info_") and main_program.distributed_info_[
"sharding_degree"] > 1 and paddle.fluid.is_compiled_with_cuda():
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = paddle.CUDAPlace(gpu_id)
else:
......
......@@ -98,7 +98,7 @@ class TestFleetWithASPSharding(unittest.TestCase):
feeder = fluid.DataFeeder(feed_list=[input_x, input_y], place=place)
exe.run(startup_prog)
sparsity.prune_model(train_prog, sharding=True)
sparsity.prune_model(train_prog)
data = (np.random.randn(64, 32), np.random.randint(2, size=(64, 1)))
exe.run(train_prog, feed=feeder.feed([data]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册