未验证 提交 6b86e966 编写于 作者: L lilong12 提交者: GitHub

Fix the bug in pipeline for dygraph mode (#32716) (#32728)

* update, test=develop
上级 4593597d
...@@ -108,7 +108,6 @@ class PipelineLayer(Layer): ...@@ -108,7 +108,6 @@ class PipelineLayer(Layer):
# construct layer # construct layer
self.run_function = [] self.run_function = []
self._build_layer() self._build_layer()
self.to(paddle.CUDAPlace(self.device_id))
def _segment_network(self, seg_method): def _segment_network(self, seg_method):
logger.info("start segment network..") logger.info("start segment network..")
......
...@@ -16,7 +16,21 @@ import abc ...@@ -16,7 +16,21 @@ import abc
import paddle import paddle
from ...utils import hybrid_parallel_util as hp_util from ...utils import hybrid_parallel_util as hp_util
__all__ = ['get_tensor_bytes', ] __all__ = [
'get_tensor_bytes',
'is_float_tensor',
]
FLOAT_TYPES = [
paddle.float16,
paddle.float32,
paddle.float64,
]
def is_float_tensor(tensor):
"""Is a float tensor"""
return tensor.dtype in FLOAT_TYPES
def get_tensor_bytes(tensor): def get_tensor_bytes(tensor):
...@@ -48,10 +62,6 @@ class Generator(): ...@@ -48,10 +62,6 @@ class Generator():
self.stage_id = stage_id self.stage_id = stage_id
self.prev_stage = self.stage_id - 1 self.prev_stage = self.stage_id - 1
self.next_stage = self.stage_id + 1 self.next_stage = self.stage_id + 1
assert self.micro_batches >= self.stages, (
"micro_batches {} "
"must be greater than or equal to {}".format(self.micro_batches,
self.stages))
@abc.abstractmethod @abc.abstractmethod
def generate(self): def generate(self):
...@@ -73,18 +83,25 @@ class TrainGenerator(Generator): ...@@ -73,18 +83,25 @@ class TrainGenerator(Generator):
cmds = [] cmds = []
forward_steps = 0 forward_steps = 0
backward_steps = 0 backward_steps = 0
while (forward_steps < startup_steps): #while (forward_steps < startup_steps):
cmds.append(Forward) # cmds.append(Forward(cache_id=forward_steps))
forward_steps += 1 # forward_steps += 1
#while (forward_steps < self.micro_batches):
# cmds.append(Forward(cache_id=forward_steps))
# forward_steps += 1
# cmds.append(Backward(cache_id=backward_steps))
# backward_steps += 1
#while (backward_steps < self.micro_batches):
# cmds.append(Backward(cache_id=backward_steps))
# backward_steps += 1
#cmds.append(Optimize())
while (forward_steps < self.micro_batches): while (forward_steps < self.micro_batches):
cmds.append(Forward) cmds.append(Forward(cache_id=forward_steps))
forward_steps += 1 forward_steps += 1
cmds.append(Backward)
backward_steps += 1
while (backward_steps < self.micro_batches): while (backward_steps < self.micro_batches):
cmds.append(Backward) cmds.append(Backward(cache_id=backward_steps))
backward_steps += 1 backward_steps += 1
cmds.append(Optimize) cmds.append(Optimize())
yield cmds yield cmds
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册