提交 5903151f 编写于 作者: C Chen Weihang

move apply in minimize

上级 28dac4ec
...@@ -19,7 +19,7 @@ import paddle ...@@ -19,7 +19,7 @@ import paddle
from .. import framework from .. import framework
from .. import core from .. import core
from ..framework import Variable, Parameter, ParamBase from ..framework import Variable, Parameter, ParamBase
from .base import switch_to_static_graph, to_variable from .base import switch_to_static_graph
from .math_op_patch import monkey_patch_math_varbase from .math_op_patch import monkey_patch_math_varbase
from .parallel import scale_loss from .parallel import scale_loss
......
...@@ -22,7 +22,7 @@ from collections import defaultdict ...@@ -22,7 +22,7 @@ from collections import defaultdict
import paddle import paddle
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard
from paddle.fluid.dygraph.parallel import scale_loss, apply_collective_grads from paddle.fluid.dygraph.parallel import apply_collective_grads
from . import framework from . import framework
from . import layers from . import layers
...@@ -772,8 +772,14 @@ class Optimizer(object): ...@@ -772,8 +772,14 @@ class Optimizer(object):
self._dtype = loss.dtype self._dtype = loss.dtype
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
parameter_list = parameter_list if parameter_list \
else self._parameter_list
if paddle.distributed.get_world_size() > 1:
apply_collective_grads(parameter_list)
params_grads = [] params_grads = []
for param in self._parameter_list: for param in parameter_list:
if not param.trainable: if not param.trainable:
continue continue
if param._grad_ivar() is not None: if param._grad_ivar() is not None:
...@@ -941,10 +947,6 @@ class Optimizer(object): ...@@ -941,10 +947,6 @@ class Optimizer(object):
parameter_list = parameter_list if parameter_list \ parameter_list = parameter_list if parameter_list \
else self._parameter_list else self._parameter_list
if paddle.distributed.get_world_size() > 1:
loss = scale_loss(loss)
apply_collective_grads(parameter_list)
params_grads = self.backward( params_grads = self.backward(
loss, loss,
startup_program=startup_program, startup_program=startup_program,
......
...@@ -16,7 +16,7 @@ from .optimizer import Optimizer ...@@ -16,7 +16,7 @@ from .optimizer import Optimizer
from .adam import Adam from .adam import Adam
from ..fluid import framework from ..fluid import framework
import paddle import paddle
from paddle.fluid.dygraph.parallel import scale_loss, apply_collective_grads from paddle.fluid.dygraph.parallel import apply_collective_grads
__all__ = ['AdamW'] __all__ = ['AdamW']
...@@ -189,10 +189,6 @@ class AdamW(Adam): ...@@ -189,10 +189,6 @@ class AdamW(Adam):
parameters = parameters if parameters \ parameters = parameters if parameters \
else self._parameter_list else self._parameter_list
if paddle.distributed.get_world_size() > 1:
loss = scale_loss(loss)
apply_collective_grads(parameter_list)
params_grads = self.backward( params_grads = self.backward(
loss=loss, loss=loss,
startup_program=startup_program, startup_program=startup_program,
......
...@@ -22,7 +22,7 @@ from collections import defaultdict ...@@ -22,7 +22,7 @@ from collections import defaultdict
import paddle import paddle
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard
from paddle.fluid.dygraph.parallel import scale_loss, apply_collective_grads from paddle.fluid.dygraph.parallel import apply_collective_grads
from ..fluid import framework from ..fluid import framework
from ..fluid import layers from ..fluid import layers
...@@ -676,8 +676,14 @@ class Optimizer(object): ...@@ -676,8 +676,14 @@ class Optimizer(object):
self._dtype = loss.dtype self._dtype = loss.dtype
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
parameter_list = parameters if parameters \
else self._parameter_list
if paddle.distributed.get_world_size() > 1:
apply_collective_grads(parameter_list)
params_grads = [] params_grads = []
for param in self._parameter_list: for param in parameter_list:
if not param.trainable: if not param.trainable:
continue continue
if param._grad_ivar() is not None: if param._grad_ivar() is not None:
...@@ -873,10 +879,6 @@ class Optimizer(object): ...@@ -873,10 +879,6 @@ class Optimizer(object):
parameter_list = parameters if parameters \ parameter_list = parameters if parameters \
else self._parameter_list else self._parameter_list
if paddle.distributed.get_world_size() > 1:
loss = scale_loss(loss)
apply_collective_grads(parameter_list)
params_grads = self.backward( params_grads = self.backward(
loss, loss,
startup_program=startup_program, startup_program=startup_program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册