未验证 提交 59bd1e6f 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] set learning rate for each run and set op_role for fetch var (#37945)

上级 f68b175f
......@@ -1999,6 +1999,14 @@ class Executor(object):
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name)
main_block = cached_program.block(0)
for op in main_block.ops:
# set the op_role of fetch op to Optimize to avoid
# erase the fetched vars by gc for pipeline
if op.type == 'fetch':
op._set_attr(
'op_role',
core.op_proto_and_checker_maker.OpRole.Optimize)
self._add_program_cache(cache_key, cached_program)
if cached_ctx is None:
fleet_opt = program._pipeline_opt["fleet_opt"]
......@@ -2007,6 +2015,18 @@ class Executor(object):
self._add_ctx_cache(cache_key, cached_ctx)
if feed:
self._feed_data(cached_program, feed, feed_var_name, cached_scope)
from paddle.optimizer.lr import LRScheduler
if hasattr(program, 'lr_sheduler'):
lr_sheduler = program.lr_sheduler
assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler"
lr_value = lr_sheduler()
lr_var = program.global_block().vars[lr_sheduler._var_name]
data = np.array([lr_value]).astype(convert_dtype(lr_var.dtype))
tensor = core.get_variable_tensor(cached_scope,
lr_sheduler._var_name)
tensor.set(data, self.place)
cached_ctx.run()
if fetch_list:
arr = cached_scope.find_var(fetch_var_name).get_fetch_list()
......
......@@ -47,6 +47,18 @@ class TestFleetExecutor(unittest.TestCase):
name='y', shape=y_data.shape, dtype=y_data.dtype)
z = x + y
a = 2 * x + 3 * y
loss = paddle.mean(a)
base_lr = 0.1
passes = [30, 60, 80, 90]
steps_per_pass = 10
bd = [steps_per_pass * p for p in passes]
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
lr_val = paddle.optimizer.lr.PiecewiseDecay(
boundaries=bd, values=lr)
opt = paddle.optimizer.AdamW(
learning_rate=lr_val,
grad_clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0))
opt.minimize(loss)
# TODO: section_program will be removed in the future
empty_program._pipeline_opt = {
"fleet_opt": self.fake_fleet_opt(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册