diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 2f5c42a69e362fdcbf063aa336257a10bc0af1ab..f3be9894a9cfeb48983b1317bf035296b0e59f50 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -109,7 +109,37 @@ class SegmentLayers: ), "layer number should be greater than number of segments" def do_segment(self): - if self.method == "uniform": + + if isinstance(self.method, list): + seg_method = self.method[:] + source_num_parts = len(seg_method) - 1 + + def check_sanity(): + assert seg_method[0] == 0, "seg_method[0] should be 0" + for part in seg_method: + assert isinstance(part, int), "part should be int" + assert part >= 0, f"part[{part}] should be greater than 0" + assert ( + part <= self.num_items + ), "part[{}] should be less than num_items[{}]".format( + part, self.num_items + ) + + check_sanity() + + if self.num_parts == source_num_parts + 1: + seg_method.append(self.num_items) + return seg_method + elif self.num_parts == source_num_parts: + return seg_method + else: + raise ValueError( + "We set seg_method as {}, this length is {}, but the number of stages is {}".format( + seg_method, len(seg_method), self.num_parts + ) + ) + + elif self.method == "uniform": return self.uniform(self.num_items, self.num_parts) elif self.method.startswith('layer:'): @@ -144,6 +174,8 @@ class SegmentLayers: memory_counter = 0 result[actual_num_parts] = len(weights) return result + else: + raise ValueError(f"method {self.method} is not supported") def _gen_layer_weight(self, layername): weight_idxs = [] diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer.py index 5d15e79d64b755b32afd61fea4c202d0731ddb7a..cf4c20e550ab388aaf36856e09bad2ad6b40b4c2 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer.py @@ -136,6 +136,20 @@ class TestPipeLayerAPI(unittest.TestCase): np.testing.assert_array_equal(param_a.name, param_b.name) np.testing.assert_allclose(param_a.numpy(), param_b.numpy()) + def test_pipelayer_segment_method(self): + init_net = AlexNetPipe() + pipe_model = PipelineLayer( + layers=init_net.to_layers(), + num_stages=self.pipeline_parallel_size, + seg_method=[0, 4], + loss_fn=nn.CrossEntropyLoss(), + ) + stage_id = self.hcg.get_stage_id() + if stage_id == 0: + np.testing.assert_array_equal(len(pipe_model.parameters()), 4) + elif stage_id == 1: + np.testing.assert_array_equal(len(pipe_model.parameters()), 8) + if __name__ == '__main__': unittest.main()