未验证 提交 55098b97 编写于 作者: M MRXLT 提交者: GitHub

fleet support paddle.optimzier (#28026)

fleet support paddle.optimzier

* bug fix

* fix fleet_base

* bug fix

* fix coverage
上级 5bb348a1
...@@ -1084,17 +1084,11 @@ class Fleet(object): ...@@ -1084,17 +1084,11 @@ class Fleet(object):
loss_name=loss.name, share_vars_from=None) loss_name=loss.name, share_vars_from=None)
loss.block.program._graph = compiled_program loss.block.program._graph = compiled_program
return self.user_defined_optimizer.minimize( return self.user_defined_optimizer.minimize(
loss, loss, startup_program, parameter_list, no_grad_set=no_grad_set)
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
if meta_optimizer: if meta_optimizer:
optimize_ops, params_grads = meta_optimizer.minimize( optimize_ops, params_grads = meta_optimizer.minimize(
loss, loss, startup_program, parameter_list, no_grad_set=no_grad_set)
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
default_program = paddle.static.default_main_program() default_program = paddle.static.default_main_program()
...@@ -1103,20 +1097,14 @@ class Fleet(object): ...@@ -1103,20 +1097,14 @@ class Fleet(object):
else: else:
optimize_ops, params_grads = self.user_defined_optimizer.minimize( optimize_ops, params_grads = self.user_defined_optimizer.minimize(
loss, loss, startup_program, parameter_list, no_grad_set=no_grad_set)
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
context["program_optimize_ops"] = optimize_ops context["program_optimize_ops"] = optimize_ops
context["program_params_grads"] = params_grads context["program_params_grads"] = params_grads
if graph_optimizer: if graph_optimizer:
optimize_ops, params_grads = graph_optimizer.minimize( optimize_ops, params_grads = graph_optimizer.minimize(
loss, loss, startup_program, parameter_list, no_grad_set=no_grad_set)
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
# since we do not encourage users to use graph operations # since we do not encourage users to use graph operations
# if a graph optimizer takes effect, mostly # if a graph optimizer takes effect, mostly
# optimizers_ops and params_grads are None # optimizers_ops and params_grads are None
......
...@@ -19,6 +19,7 @@ import abc ...@@ -19,6 +19,7 @@ import abc
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid.optimizer import SGD from paddle.fluid.optimizer import SGD
from paddle.optimizer import SGD as SGD_v2
from paddle.fluid.incubate.fleet.base.mode import Mode from paddle.fluid.incubate.fleet.base.mode import Mode
from paddle.distributed.fleet.base.role_maker import RoleMakerBase from paddle.distributed.fleet.base.role_maker import RoleMakerBase
...@@ -291,7 +292,8 @@ class DistributedOptimizer(object): ...@@ -291,7 +292,8 @@ class DistributedOptimizer(object):
def __init__(self, optimizer, strategy=None): def __init__(self, optimizer, strategy=None):
if not isinstance(optimizer, SGD.__bases__) \ if not isinstance(optimizer, SGD.__bases__) \
and not isinstance(optimizer, OptimizerWithMixedPrecision): and not isinstance(optimizer, OptimizerWithMixedPrecision) \
and not isinstance(optimizer, SGD_v2.__base__):
raise TypeError("optimizer must be an instance of Optimizer") raise TypeError("optimizer must be an instance of Optimizer")
self._optimizer = optimizer self._optimizer = optimizer
......
...@@ -28,6 +28,8 @@ from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer ...@@ -28,6 +28,8 @@ from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
from paddle.fluid import compiler from paddle.fluid import compiler
from paddle.fluid.incubate.checkpoint.checkpoint_saver import PaddleModel, CheckpointSaver from paddle.fluid.incubate.checkpoint.checkpoint_saver import PaddleModel, CheckpointSaver
import paddle
import os import os
import sys import sys
import six import six
...@@ -505,10 +507,7 @@ class CollectiveOptimizer(DistributedOptimizer): ...@@ -505,10 +507,7 @@ class CollectiveOptimizer(DistributedOptimizer):
self._strategy) self._strategy)
optimize_ops, param_grads = self._optimizer.minimize( optimize_ops, param_grads = self._optimizer.minimize(
loss, loss, startup_program, parameter_list, no_grad_set=no_grad_set)
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
fleet._origin_program = main_program.clone(for_test=False) fleet._origin_program = main_program.clone(for_test=False)
fleet._transpiled_program = main_program fleet._transpiled_program = main_program
......
...@@ -60,7 +60,7 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): ...@@ -60,7 +60,7 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase):
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.nccl_comm_num = 2 strategy.nccl_comm_num = 2
strategy.sync_nccl_allreduce = True strategy.sync_nccl_allreduce = True
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) optimizer = paddle.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer( optimizer = fleet.distributed_optimizer(
optimizer, strategy=strategy) optimizer, strategy=strategy)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册