未验证 提交 55098b97 编写于 作者: M MRXLT 提交者: GitHub

fleet support paddle.optimzier (#28026)

fleet support paddle.optimzier

* bug fix

* fix fleet_base

* bug fix

* fix coverage
上级 5bb348a1
......@@ -1084,17 +1084,11 @@ class Fleet(object):
loss_name=loss.name, share_vars_from=None)
loss.block.program._graph = compiled_program
return self.user_defined_optimizer.minimize(
loss,
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
loss, startup_program, parameter_list, no_grad_set=no_grad_set)
if meta_optimizer:
optimize_ops, params_grads = meta_optimizer.minimize(
loss,
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
loss, startup_program, parameter_list, no_grad_set=no_grad_set)
default_program = paddle.static.default_main_program()
......@@ -1103,20 +1097,14 @@ class Fleet(object):
else:
optimize_ops, params_grads = self.user_defined_optimizer.minimize(
loss,
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
loss, startup_program, parameter_list, no_grad_set=no_grad_set)
context["program_optimize_ops"] = optimize_ops
context["program_params_grads"] = params_grads
if graph_optimizer:
optimize_ops, params_grads = graph_optimizer.minimize(
loss,
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
loss, startup_program, parameter_list, no_grad_set=no_grad_set)
# since we do not encourage users to use graph operations
# if a graph optimizer takes effect, mostly
# optimizers_ops and params_grads are None
......
......@@ -19,6 +19,7 @@ import abc
import paddle.fluid as fluid
from paddle.fluid.executor import Executor
from paddle.fluid.optimizer import SGD
from paddle.optimizer import SGD as SGD_v2
from paddle.fluid.incubate.fleet.base.mode import Mode
from paddle.distributed.fleet.base.role_maker import RoleMakerBase
......@@ -291,7 +292,8 @@ class DistributedOptimizer(object):
def __init__(self, optimizer, strategy=None):
if not isinstance(optimizer, SGD.__bases__) \
and not isinstance(optimizer, OptimizerWithMixedPrecision):
and not isinstance(optimizer, OptimizerWithMixedPrecision) \
and not isinstance(optimizer, SGD_v2.__base__):
raise TypeError("optimizer must be an instance of Optimizer")
self._optimizer = optimizer
......
......@@ -28,6 +28,8 @@ from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
from paddle.fluid import compiler
from paddle.fluid.incubate.checkpoint.checkpoint_saver import PaddleModel, CheckpointSaver
import paddle
import os
import sys
import six
......@@ -505,10 +507,7 @@ class CollectiveOptimizer(DistributedOptimizer):
self._strategy)
optimize_ops, param_grads = self._optimizer.minimize(
loss,
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
loss, startup_program, parameter_list, no_grad_set=no_grad_set)
fleet._origin_program = main_program.clone(for_test=False)
fleet._transpiled_program = main_program
......
......@@ -60,7 +60,7 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.nccl_comm_num = 2
strategy.sync_nccl_allreduce = True
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
optimizer = paddle.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(
optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册