提交 a29cb4be 编写于 作者: Q Qiyang Min 提交者: gongweibao

Fix decay bug (#11520)

* Add sub_blocks of lr_decay_op to pserver_prog after distribute_transpiler

* Remove unused logs and logics

* 1. Add ops to new block (considering the nested block condition)
2. Follow the original hierarchy of blocks
3. Change the function's name and remove debug lines
上级 e8f5757d
......@@ -295,13 +295,14 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
const ProgramDesc& program, int block_id) {
auto* ctx = new ExecutorPrepareContext(program, block_id);
std::unique_ptr<ExecutorPrepareContext> ctx(
new ExecutorPrepareContext(program, block_id));
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
return std::unique_ptr<ExecutorPrepareContext>(ctx);
return ctx;
}
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
......
......@@ -644,6 +644,12 @@ class Operator(object):
def set_attr(self, name, val):
self.attrs[name] = val
if isinstance(val, Block):
self.desc.set_block_attr(name, val.desc)
elif isinstance(val, core.BlockDesc) or \
isinstance(val, core.ProgramDesc):
self.desc.set_serialized_attr(name, val.serialize_to_string())
else:
self.desc.set_attr(name, val)
@property
......
......@@ -44,7 +44,7 @@ import numpy as np
from ps_dispatcher import RoundRobin, HashName, PSDispatcher
from .. import core, framework
from ..framework import Program, default_main_program, \
default_startup_program, \
default_startup_program, Block, \
Variable, Parameter, grad_var_name
from details import *
......@@ -471,7 +471,7 @@ class DistributeTranspiler:
self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
self.origin_program, merged_var)
else:
self._append_pserver_non_opt_ops(block, op, endpoint)
self._append_pserver_non_opt_ops(block, op)
def __op_have_grad_input__(op):
for varname in op.input_arg_names:
......@@ -479,13 +479,39 @@ class DistributeTranspiler:
return varname
return ""
def __clone_lr_op_sub_block__(op, program, new_block):
if not op.has_attr('sub_block'):
return
origin_block_desc = op.attr('sub_block')
origin_block = self.origin_program.block(origin_block_desc.id)
assert isinstance(origin_block, Block)
# we put the new sub block to new block to follow the block
# hierarchy of the original blocks
new_sub_block = program.create_block(new_block.idx)
# clone vars
for var in origin_block.vars:
new_sub_block.clone_variable(var)
# clone ops
for op in origin_block.ops:
self._clone_lr_op(program, new_sub_block, op)
# clone sub_block of op
__clone_lr_op_sub_block__(op, program, new_sub_block)
# reset the block of op
op.set_attr('sub_block', new_sub_block)
# append lr decay ops to the child block if exists
lr_ops = self._get_lr_ops()
if len(lr_ops) > 0:
lr_decay_block = pserver_program.create_block(
pserver_program.num_blocks - 1)
for _, op in enumerate(lr_ops):
self._append_pserver_non_opt_ops(lr_decay_block, op, endpoint)
self._append_pserver_non_opt_ops(lr_decay_block, op)
# append sub blocks to pserver_program in lr_decay_op
__clone_lr_op_sub_block__(op, pserver_program, lr_decay_block)
# append op to the current block
grad_to_block_id = []
......@@ -1116,7 +1142,29 @@ class DistributeTranspiler:
break
return grad_block
def _append_pserver_non_opt_ops(self, optimize_block, opt_op, endpoint):
def _clone_lr_op(self, program, block, op):
inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, op)
for key, varlist in inputs.iteritems():
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
if var not in program.global_block().vars:
block.clone_variable(var)
outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, op)
for key, varlist in outputs.iteritems():
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
if var not in program.global_block().vars:
block.clone_variable(var)
block.append_op(
type=op.type, inputs=inputs, outputs=outputs, attrs=op.attrs)
def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
program = optimize_block.program
# Append the ops for parameters that do not need to be optimized/updated
inputs = self._get_input_map_from_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册