未验证 提交 6db6a347 编写于 作者: S ShenLiang 提交者: GitHub

add segment methods for pp (#53368)

add utest

fix utest
上级 27016144
......@@ -109,7 +109,37 @@ class SegmentLayers:
), "layer number should be greater than number of segments"
def do_segment(self):
if self.method == "uniform":
if isinstance(self.method, list):
seg_method = self.method[:]
source_num_parts = len(seg_method) - 1
def check_sanity():
assert seg_method[0] == 0, "seg_method[0] should be 0"
for part in seg_method:
assert isinstance(part, int), "part should be int"
assert part >= 0, f"part[{part}] should be greater than 0"
assert (
part <= self.num_items
), "part[{}] should be less than num_items[{}]".format(
part, self.num_items
)
check_sanity()
if self.num_parts == source_num_parts + 1:
seg_method.append(self.num_items)
return seg_method
elif self.num_parts == source_num_parts:
return seg_method
else:
raise ValueError(
"We set seg_method as {}, this length is {}, but the number of stages is {}".format(
seg_method, len(seg_method), self.num_parts
)
)
elif self.method == "uniform":
return self.uniform(self.num_items, self.num_parts)
elif self.method.startswith('layer:'):
......@@ -144,6 +174,8 @@ class SegmentLayers:
memory_counter = 0
result[actual_num_parts] = len(weights)
return result
else:
raise ValueError(f"method {self.method} is not supported")
def _gen_layer_weight(self, layername):
weight_idxs = []
......
......@@ -136,6 +136,20 @@ class TestPipeLayerAPI(unittest.TestCase):
np.testing.assert_array_equal(param_a.name, param_b.name)
np.testing.assert_allclose(param_a.numpy(), param_b.numpy())
def test_pipelayer_segment_method(self):
init_net = AlexNetPipe()
pipe_model = PipelineLayer(
layers=init_net.to_layers(),
num_stages=self.pipeline_parallel_size,
seg_method=[0, 4],
loss_fn=nn.CrossEntropyLoss(),
)
stage_id = self.hcg.get_stage_id()
if stage_id == 0:
np.testing.assert_array_equal(len(pipe_model.parameters()), 4)
elif stage_id == 1:
np.testing.assert_array_equal(len(pipe_model.parameters()), 8)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册