diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index bc59b87e2ffa5c653e89c759f951de5f520773ba..236322ccfca6aad442e76af6f57c6c5f83ca59bb 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -1430,6 +1430,22 @@ class Fleet(object): # cache original feed forward program self.origin_main_program = loss.block.program + # add distributed attr + if not hasattr(self.origin_main_program, "distributed_info_"): + setattr(self.origin_main_program, "distributed_info_", dict()) + self.origin_main_program.distributed_info_[ + "dp_degree"] = self._user_defined_strategy.sharding_configs[ + "dp_degree"] + self.origin_main_program.distributed_info_[ + "mp_degree"] = self._user_defined_strategy.sharding_configs[ + "mp_degree"] + self.origin_main_program.distributed_info_[ + "pp_degree"] = self._user_defined_strategy.sharding_configs[ + "pp_degree"] + self.origin_main_program.distributed_info_[ + "sharding_degree"] = self._user_defined_strategy.sharding_configs[ + "sharding_degree"] + context["origin_main_program"] = self.origin_main_program context["loss"] = loss if startup_program == None: diff --git a/python/paddle/fluid/contrib/sparsity/asp.py b/python/paddle/fluid/contrib/sparsity/asp.py index 937fcdf0463beed7d9116be1a4800a0d02238e7d..ffa12ac70460084fd49a14d0193be6e913495b9a 100644 --- a/python/paddle/fluid/contrib/sparsity/asp.py +++ b/python/paddle/fluid/contrib/sparsity/asp.py @@ -155,8 +155,7 @@ def prune_model(main_program=None, n=2, m=4, mask_algo='mask_1d', - with_mask=True, - sharding=False): + with_mask=True): r""" Pruning parameters of supported layers in :attr:`main_program` via specified mask generation function given by :attr:`mask_algo`. This @@ -179,7 +178,6 @@ def prune_model(main_program=None, mask_algo (string, optional): The function name to generate spase mask. Default is `mask_1d`. The vaild inputs should be one of 'mask_1d', 'mask_2d_greedy' and 'mask_2d_best'. with_mask (bool, optional): To prune mask Variables related to parameters or not. Ture is purning also, False is not. Defalut is True. - sharding (bool, optional): Whether to turn on sharding (model parallel) during training. Please consider turning it ON when encountering OOM using sharding. Default is False. Returns: dictionary: A dictionary with key: `parameter name` (string) and value: its corresponding mask Variable. Examples: @@ -221,7 +219,10 @@ def prune_model(main_program=None, # Must call `exe.run(startup_program)` first before calling `sparsity.prune_model` sparsity.prune_model(main_program, mask_algo='mask_2d_best') """ - if sharding: + if main_program is not None and hasattr( + main_program, + "distributed_info_") and main_program.distributed_info_[ + "sharding_degree"] > 1 and paddle.fluid.is_compiled_with_cuda(): gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) place = paddle.CUDAPlace(gpu_id) else: diff --git a/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_sharding.py b/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_sharding.py index 26170015ae8c249fb3a36d13285f5b34491acb3a..d9ddd6c88d727a4cca5e94cf19b122355f3ea6c5 100644 --- a/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_sharding.py +++ b/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_sharding.py @@ -98,7 +98,7 @@ class TestFleetWithASPSharding(unittest.TestCase): feeder = fluid.DataFeeder(feed_list=[input_x, input_y], place=place) exe.run(startup_prog) - sparsity.prune_model(train_prog, sharding=True) + sparsity.prune_model(train_prog) data = (np.random.randn(64, 32), np.random.randint(2, size=(64, 1))) exe.run(train_prog, feed=feeder.feed([data]))