“848ae7dc34c84c09ac6df93e5cfd5c2031156cea”上不存在“paddle/phi/kernels/impl/tril_kernel_impl.h”
未验证 提交 81eaa97d 编写于 作者: Y Yuang Liu 提交者: GitHub

[dygraph hybrid pp for interleave] Virtual pipeline layer forward function (#45444)

上级 9eb4d89b
...@@ -172,17 +172,26 @@ class PipelineLayerChunk(Layer): ...@@ -172,17 +172,26 @@ class PipelineLayerChunk(Layer):
def __init__(self): def __init__(self):
super(PipelineLayerChunk, self).__init__() super(PipelineLayerChunk, self).__init__()
self.functions = [] self.run_function = []
def append(self, sublayer): def append(self, sublayer):
# This method is used to unify codes in _build_layer_impl. # This method is used to unify codes in _build_layer_impl.
# For 1f1b scheduler, it will call append method of a List. # For 1f1b scheduler, it will call append method of a List.
# For interleave scheduler, it will call append method of this class. # For interleave scheduler, it will call append method of this class.
if isinstance(sublayer, Layer): if isinstance(sublayer, Layer):
self.add_sublayer(str(len(self.functions)), sublayer) self.add_sublayer(str(len(self.run_function)), sublayer)
self.functions.append(sublayer) self.run_function.append(sublayer)
def get_run_function(self):
return self.run_function
# TODO (Yuang Liu) forward function implement def forward(self, *args, **kwargs):
# Users shouldn't call PipelineLayerChunk directly, since all logics relating with recompute
# are in the forward function of PipelineLayer. Any directly call will bring unexpected
# behavior under recompute circumstance.
raise NotImplementedError(
"The forward function of PipelineLayerChunk cannot be called directly. "
"Please call forward function of PipelineLayer.")
class PipelineLayer(Layer): class PipelineLayer(Layer):
...@@ -520,8 +529,22 @@ class PipelineLayer(Layer): ...@@ -520,8 +529,22 @@ class PipelineLayer(Layer):
return execute_func return execute_func
def forward(self, input): def forward(self, input, chunk_id=None):
# TODO(Yuang Liu): forward function for interleave scheduler if chunk_id is not None:
assert isinstance(chunk_id, int), "chunk_id should be an int"
assert self._num_virtual_pipeline_stages > 1, \
"chunk_id is only valid when using virtual pipeline stage"
assert chunk_id < len(self._model_chunks), \
"The virtual pipeline only has {} chunks, " \
"but received chunk_id {}.".format(len(self._model_chunks), chunk_id)
# Get the target model chunk.
model_chunk = self._model_chunks[chunk_id]
# Update the self.run_function to the target run functions.
# Runs for 1f1b and interleave are similar, just handle all functions in self.run_function.
# The only different is that, for 1f1b, self.run_function has already been inited during build_layer.
# But for interleave, self.run_function will keep updating to the target functions at every run.
self.run_function = model_chunk.get_run_function()
if self._recompute_interval == 0: if self._recompute_interval == 0:
input = self.forward_function(0, len(self.run_function))(input) input = self.forward_function(0, len(self.run_function))(input)
else: else:
......
...@@ -33,31 +33,22 @@ class ReshapeHelp(Layer): ...@@ -33,31 +33,22 @@ class ReshapeHelp(Layer):
return x.reshape(shape=self.shape) return x.reshape(shape=self.shape)
class FakeAlexNetPipeDesc(PipelineLayer): class MLPForVirtualStageLayerTest(PipelineLayer):
def __init__(self, num_classes=10, **kwargs): def __init__(self, num_classes=10, **kwargs):
self.num_classes = num_classes self.num_classes = num_classes
decs = [ decs = [
LayerDesc(nn.Conv2D, 1, 64, kernel_size=11, stride=4, padding=5), LayerDesc(nn.Linear, 2, self.num_classes),
LayerDesc(nn.Conv2D, 64, 64, kernel_size=11, stride=4, padding=5), LayerDesc(nn.Linear, self.num_classes, 2),
LayerDesc(nn.ReLU), LayerDesc(nn.Linear, 2, self.num_classes),
LayerDesc(nn.MaxPool2D, kernel_size=2, stride=2), LayerDesc(nn.Linear, self.num_classes, 2),
LayerDesc(nn.Conv2D, 64, 192, kernel_size=5, padding=2), LayerDesc(nn.Linear, 2, self.num_classes),
LayerDesc(nn.Conv2D, 192, 192, kernel_size=5, padding=2), LayerDesc(nn.Linear, self.num_classes, 2),
F.relu, LayerDesc(nn.Linear, 2, self.num_classes),
LayerDesc(nn.MaxPool2D, kernel_size=2, stride=2), LayerDesc(nn.Linear, self.num_classes, 2),
LayerDesc(nn.Conv2D, 192, 384, kernel_size=3, padding=1),
F.relu,
LayerDesc(nn.Conv2D, 384, 256, kernel_size=3, padding=1),
F.relu,
LayerDesc(nn.Conv2D, 256, 256, kernel_size=3, padding=1),
LayerDesc(nn.Conv2D, 256, 256, kernel_size=3, padding=1),
F.relu,
LayerDesc(nn.MaxPool2D, kernel_size=2, stride=2),
LayerDesc(ReshapeHelp, shape=[-1, 256]),
LayerDesc(nn.Linear, 256, self.num_classes), # classifier
] ]
super(FakeAlexNetPipeDesc, self).__init__(layers=decs, super(MLPForVirtualStageLayerTest,
self).__init__(layers=decs,
loss_fn=nn.CrossEntropyLoss(), loss_fn=nn.CrossEntropyLoss(),
**kwargs) **kwargs)
...@@ -73,16 +64,38 @@ class TestPipeLayerAPI(unittest.TestCase): ...@@ -73,16 +64,38 @@ class TestPipeLayerAPI(unittest.TestCase):
"pp_degree": self.pipeline_parallel_size "pp_degree": self.pipeline_parallel_size
} }
fleet.init(is_collective=True, strategy=strategy) fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group() self.hcg = fleet.get_hybrid_communicate_group()
def test_pipelayer_desc(self): def test_pipelayer_desc(self):
pipe_model = FakeAlexNetPipeDesc(seg_method="layer:Conv2D", pipe_model = MLPForVirtualStageLayerTest(
seg_method="layer:Linear",
num_stages=self.pipeline_parallel_size, num_stages=self.pipeline_parallel_size,
num_virtual_pipeline_stages=2) num_virtual_pipeline_stages=2,
recompute_interval=1)
assert len(pipe_model.parameters()) > 0 assert len(pipe_model.parameters()) > 0
model_chunks = pipe_model.get_model_chunks() model_chunks = pipe_model.get_model_chunks()
assert model_chunks is not None assert model_chunks is not None
assert len(model_chunks) == 2 assert len(model_chunks) == 2
optimizer = paddle.optimizer.SGD(parameters=pipe_model.parameters())
try:
model_chunks[0](paddle.to_tensor([1., 2.]))
except NotImplementedError:
pass
# fake call for the forward function of virtual pipeline layer
for i in range(len(model_chunks)):
out = pipe_model(paddle.to_tensor([1., 2.]), chunk_id=i)
assert list(out.shape) == [2]
out = F.relu(out)
loss = paddle.mean(out)
loss.backward()
optimizer.step()
# just make sure the model can be wrapped with distributed model
dist_model = fleet.distributed_model(pipe_model) dist_model = fleet.distributed_model(pipe_model)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册