未验证 提交 09482dde 编写于 作者: C Chengmo 提交者: GitHub

【Paddle.Fleet】Fix one ps gradient clip (#31664)

* fix one ps gradient clip
上级 740359ed
......@@ -150,7 +150,8 @@ class CommonAccessor:
oop = None
for op in optimizer_ops:
if op.input("Param")[0] == param_name:
if ("Param" in op.input_names) and (
op.input("Param")[0] == param_name):
oop = op
break
......
......@@ -31,7 +31,7 @@ from paddle.fluid.incubate.fleet.parameter_server.ir.ps_dispatcher import RoundR
from paddle.fluid.transpiler.details.program_utils import delete_ops
OP_NAME_SCOPE = "op_namescope"
CLIP_OP_NAME_SCOPE = "@CLIP"
CLIP_OP_NAME_SCOPE = "gradient_clip"
STEP_COUNTER = "@PS_STEP_COUNTER@"
LEARNING_RATE_DECAY_COUNTER = "@LR_DECAY_COUNTER@"
......
......@@ -32,7 +32,7 @@ from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_ta
from paddle.fluid.incubate.fleet.parameter_server.mode import DistributedMode
OP_NAME_SCOPE = "op_namescope"
CLIP_OP_NAME_SCOPE = "@CLIP"
CLIP_OP_NAME_SCOPE = "gradient_clip"
STEP_COUNTER = "@PS_STEP_COUNTER@"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleAttrName()
......
......@@ -18,6 +18,7 @@ from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distribu
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
import paddle.fluid as fluid
import paddle
"""
high level unit test for distribute fleet.
"""
......@@ -112,23 +113,21 @@ class FleetDistRunnerBase(object):
def build_optimizer(self, avg_cost, strategy):
use_grad_clip = int(os.getenv('GRAD_CLIP', 0))
grad_clip = None
if use_grad_clip:
# 1: clip_by_value; 2: clip_by_norm; 3:clip_by_global_norm
if use_grad_clip == 1:
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByValue(2.0))
grad_clip = paddle.nn.ClipGradByValue(min=-5.0, max=5.0)
elif use_grad_clip == 2:
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByNorm(2.0))
grad_clip = paddle.nn.ClipGradByNorm(2.0)
elif use_grad_clip == 3:
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(2.0))
grad_clip = paddle.nn.ClipGradByGlobalNorm(2.0)
use_decay = int(os.getenv("USE_DECAY", "0"))
if use_decay:
scheduler = paddle.optimizer.lr.ExponentialDecay(
learning_rate=LEARNING_RATE, gamma=0.999, verbose=True)
optimizer = fluid.optimizer.SGD(scheduler)
optimizer = fluid.optimizer.SGD(scheduler, grad_clip=grad_clip)
"""
# learning rate decay method before 2.0
optimizer = fluid.optimizer.SGD(
......@@ -139,7 +138,7 @@ class FleetDistRunnerBase(object):
staircase=True))
"""
else:
optimizer = fluid.optimizer.SGD(LEARNING_RATE)
optimizer = fluid.optimizer.SGD(LEARNING_RATE, grad_clip=grad_clip)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
......
......@@ -16,53 +16,66 @@ from __future__ import print_function
import os
import unittest
import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
from test_dist_fleet_base import TestFleetBase
from dist_fleet_simnet_bow import train_network
@unittest.skip(reason="Skip unstable ut, add it after PR 22957 merged")
class TestDistGeoClipByGlobalNormTranspiler(unittest.TestCase):
def test_pserver(self):
role = role_maker.UserDefinedRoleMaker(
current_id=0,
role=role_maker.Role.SERVER,
worker_num=2,
server_endpoints=["127.0.0.1:36011", "127.0.0.1:36012"])
class TestDistGeoClipByGlobalNorm(TestFleetBase):
def _setup_config(self):
self._mode = "geo"
self._reader = "dataset"
self._geo_sgd_need_push_nums = 5
self._grad_clip_mode = 3
fleet.init(role)
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": ""
}
required_envs.update(need_envs)
batch_size = 128
is_sparse = True
is_distribute = False
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
strategy = DistributeTranspilerConfig()
strategy.sync_mode = False
strategy.geo_sgd_mode = True
strategy.geo_sgd_need_push_nums = 5
def test_dist_train(self):
self.check_with_place(
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
avg_cost, _, _, _ = train_network(batch_size, is_distribute, is_sparse)
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(2.0))
def _setup_config(self):
self._sync_mode = False
self._grad_clip_mode = 2
optimizer = fluid.optimizer.SGD(0.1)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": ""
}
required_envs.update(need_envs)
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
pserver_startup_program = fleet.startup_program
pserver_mian_program = fleet.main_program
def test_dist_train(self):
self.check_with_place(
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
@unittest.skip(reason="Skip unstable ut, add it after PR 22957 merged")
class TestDistGeoClipByGlobalNorm(TestFleetBase):
class TestDistASyncClipByValue(TestFleetBase):
def _setup_config(self):
self._mode = "geo"
self._mode = "async"
self._reader = "dataset"
self._geo_sgd_need_push_nums = 5
self._grad_clip_mode = 3
self._grad_clip_mode = 1
def check_with_place(self,
model_file,
......@@ -84,8 +97,11 @@ class TestDistGeoClipByGlobalNorm(TestFleetBase):
self.check_with_place(
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
class TestDistASyncClipByNorm(TestFleetBase):
def _setup_config(self):
self._sync_mode = False
self._mode = "async"
self._reader = "dataset"
self._grad_clip_mode = 2
def check_with_place(self,
......@@ -109,7 +125,6 @@ class TestDistGeoClipByGlobalNorm(TestFleetBase):
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
@unittest.skip(reason="Skip unstable ut, add it after PR 22957 merged")
class TestDistASyncClipByGlobalNorm(TestFleetBase):
def _setup_config(self):
self._mode = "async"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册