提交 5903151f 编写于 作者: C Chen Weihang

move apply in minimize

上级 28dac4ec
......@@ -19,7 +19,7 @@ import paddle
from .. import framework
from .. import core
from ..framework import Variable, Parameter, ParamBase
from .base import switch_to_static_graph, to_variable
from .base import switch_to_static_graph
from .math_op_patch import monkey_patch_math_varbase
from .parallel import scale_loss
......
......@@ -22,7 +22,7 @@ from collections import defaultdict
import paddle
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard
from paddle.fluid.dygraph.parallel import scale_loss, apply_collective_grads
from paddle.fluid.dygraph.parallel import apply_collective_grads
from . import framework
from . import layers
......@@ -772,8 +772,14 @@ class Optimizer(object):
self._dtype = loss.dtype
if framework.in_dygraph_mode():
parameter_list = parameter_list if parameter_list \
else self._parameter_list
if paddle.distributed.get_world_size() > 1:
apply_collective_grads(parameter_list)
params_grads = []
for param in self._parameter_list:
for param in parameter_list:
if not param.trainable:
continue
if param._grad_ivar() is not None:
......@@ -941,10 +947,6 @@ class Optimizer(object):
parameter_list = parameter_list if parameter_list \
else self._parameter_list
if paddle.distributed.get_world_size() > 1:
loss = scale_loss(loss)
apply_collective_grads(parameter_list)
params_grads = self.backward(
loss,
startup_program=startup_program,
......
......@@ -16,7 +16,7 @@ from .optimizer import Optimizer
from .adam import Adam
from ..fluid import framework
import paddle
from paddle.fluid.dygraph.parallel import scale_loss, apply_collective_grads
from paddle.fluid.dygraph.parallel import apply_collective_grads
__all__ = ['AdamW']
......@@ -189,10 +189,6 @@ class AdamW(Adam):
parameters = parameters if parameters \
else self._parameter_list
if paddle.distributed.get_world_size() > 1:
loss = scale_loss(loss)
apply_collective_grads(parameter_list)
params_grads = self.backward(
loss=loss,
startup_program=startup_program,
......
......@@ -22,7 +22,7 @@ from collections import defaultdict
import paddle
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard
from paddle.fluid.dygraph.parallel import scale_loss, apply_collective_grads
from paddle.fluid.dygraph.parallel import apply_collective_grads
from ..fluid import framework
from ..fluid import layers
......@@ -676,8 +676,14 @@ class Optimizer(object):
self._dtype = loss.dtype
if framework.in_dygraph_mode():
parameter_list = parameters if parameters \
else self._parameter_list
if paddle.distributed.get_world_size() > 1:
apply_collective_grads(parameter_list)
params_grads = []
for param in self._parameter_list:
for param in parameter_list:
if not param.trainable:
continue
if param._grad_ivar() is not None:
......@@ -873,10 +879,6 @@ class Optimizer(object):
parameter_list = parameters if parameters \
else self._parameter_list
if paddle.distributed.get_world_size() > 1:
loss = scale_loss(loss)
apply_collective_grads(parameter_list)
params_grads = self.backward(
loss,
startup_program=startup_program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册