未验证 提交 6b86e966 编写于 作者: L lilong12 提交者: GitHub

Fix the bug in pipeline for dygraph mode (#32716) (#32728)

* update, test=develop
上级 4593597d
......@@ -108,7 +108,6 @@ class PipelineLayer(Layer):
# construct layer
self.run_function = []
self._build_layer()
self.to(paddle.CUDAPlace(self.device_id))
def _segment_network(self, seg_method):
logger.info("start segment network..")
......
......@@ -16,7 +16,21 @@ import abc
import paddle
from ...utils import hybrid_parallel_util as hp_util
__all__ = ['get_tensor_bytes', ]
__all__ = [
'get_tensor_bytes',
'is_float_tensor',
]
FLOAT_TYPES = [
paddle.float16,
paddle.float32,
paddle.float64,
]
def is_float_tensor(tensor):
"""Is a float tensor"""
return tensor.dtype in FLOAT_TYPES
def get_tensor_bytes(tensor):
......@@ -48,10 +62,6 @@ class Generator():
self.stage_id = stage_id
self.prev_stage = self.stage_id - 1
self.next_stage = self.stage_id + 1
assert self.micro_batches >= self.stages, (
"micro_batches {} "
"must be greater than or equal to {}".format(self.micro_batches,
self.stages))
@abc.abstractmethod
def generate(self):
......@@ -73,18 +83,25 @@ class TrainGenerator(Generator):
cmds = []
forward_steps = 0
backward_steps = 0
while (forward_steps < startup_steps):
cmds.append(Forward)
forward_steps += 1
#while (forward_steps < startup_steps):
# cmds.append(Forward(cache_id=forward_steps))
# forward_steps += 1
#while (forward_steps < self.micro_batches):
# cmds.append(Forward(cache_id=forward_steps))
# forward_steps += 1
# cmds.append(Backward(cache_id=backward_steps))
# backward_steps += 1
#while (backward_steps < self.micro_batches):
# cmds.append(Backward(cache_id=backward_steps))
# backward_steps += 1
#cmds.append(Optimize())
while (forward_steps < self.micro_batches):
cmds.append(Forward)
cmds.append(Forward(cache_id=forward_steps))
forward_steps += 1
cmds.append(Backward)
backward_steps += 1
while (backward_steps < self.micro_batches):
cmds.append(Backward)
cmds.append(Backward(cache_id=backward_steps))
backward_steps += 1
cmds.append(Optimize)
cmds.append(Optimize())
yield cmds
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册