提交 a29cb4be 编写于 作者: Q Qiyang Min 提交者: gongweibao

Fix decay bug (#11520)

* Add sub_blocks of lr_decay_op to pserver_prog after distribute_transpiler

* Remove unused logs and logics

* 1. Add ops to new block (considering the nested block condition)
2. Follow the original hierarchy of blocks
3. Change the function's name and remove debug lines
上级 e8f5757d
...@@ -295,13 +295,14 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -295,13 +295,14 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
const ProgramDesc& program, int block_id) { const ProgramDesc& program, int block_id) {
auto* ctx = new ExecutorPrepareContext(program, block_id); std::unique_ptr<ExecutorPrepareContext> ctx(
new ExecutorPrepareContext(program, block_id));
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size()); PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id); auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
} }
return std::unique_ptr<ExecutorPrepareContext>(ctx); return ctx;
} }
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare( std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
......
...@@ -644,7 +644,13 @@ class Operator(object): ...@@ -644,7 +644,13 @@ class Operator(object):
def set_attr(self, name, val): def set_attr(self, name, val):
self.attrs[name] = val self.attrs[name] = val
self.desc.set_attr(name, val) if isinstance(val, Block):
self.desc.set_block_attr(name, val.desc)
elif isinstance(val, core.BlockDesc) or \
isinstance(val, core.ProgramDesc):
self.desc.set_serialized_attr(name, val.serialize_to_string())
else:
self.desc.set_attr(name, val)
@property @property
def attr_names(self): def attr_names(self):
......
...@@ -24,7 +24,7 @@ Steps to transpile trainer: ...@@ -24,7 +24,7 @@ Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width). 1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. rename splited grad variables to add trainer_id suffix ".trainer_%d". 2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
3. modify trainer program add split_op to each grad variable. 3. modify trainer program add split_op to each grad variable.
4. append send_op to send splited variables to server and 4. append send_op to send splited variables to server and
5. add recv_op to fetch params(splited blocks or origin param) from server. 5. add recv_op to fetch params(splited blocks or origin param) from server.
6. append concat_op to merge splited blocks to update local weights. 6. append concat_op to merge splited blocks to update local weights.
...@@ -44,7 +44,7 @@ import numpy as np ...@@ -44,7 +44,7 @@ import numpy as np
from ps_dispatcher import RoundRobin, HashName, PSDispatcher from ps_dispatcher import RoundRobin, HashName, PSDispatcher
from .. import core, framework from .. import core, framework
from ..framework import Program, default_main_program, \ from ..framework import Program, default_main_program, \
default_startup_program, \ default_startup_program, Block, \
Variable, Parameter, grad_var_name Variable, Parameter, grad_var_name
from details import * from details import *
...@@ -471,7 +471,7 @@ class DistributeTranspiler: ...@@ -471,7 +471,7 @@ class DistributeTranspiler:
self._append_pserver_ops(block, op, endpoint, grad_to_block_id, self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
self.origin_program, merged_var) self.origin_program, merged_var)
else: else:
self._append_pserver_non_opt_ops(block, op, endpoint) self._append_pserver_non_opt_ops(block, op)
def __op_have_grad_input__(op): def __op_have_grad_input__(op):
for varname in op.input_arg_names: for varname in op.input_arg_names:
...@@ -479,13 +479,39 @@ class DistributeTranspiler: ...@@ -479,13 +479,39 @@ class DistributeTranspiler:
return varname return varname
return "" return ""
def __clone_lr_op_sub_block__(op, program, new_block):
if not op.has_attr('sub_block'):
return
origin_block_desc = op.attr('sub_block')
origin_block = self.origin_program.block(origin_block_desc.id)
assert isinstance(origin_block, Block)
# we put the new sub block to new block to follow the block
# hierarchy of the original blocks
new_sub_block = program.create_block(new_block.idx)
# clone vars
for var in origin_block.vars:
new_sub_block.clone_variable(var)
# clone ops
for op in origin_block.ops:
self._clone_lr_op(program, new_sub_block, op)
# clone sub_block of op
__clone_lr_op_sub_block__(op, program, new_sub_block)
# reset the block of op
op.set_attr('sub_block', new_sub_block)
# append lr decay ops to the child block if exists # append lr decay ops to the child block if exists
lr_ops = self._get_lr_ops() lr_ops = self._get_lr_ops()
if len(lr_ops) > 0: if len(lr_ops) > 0:
lr_decay_block = pserver_program.create_block( lr_decay_block = pserver_program.create_block(
pserver_program.num_blocks - 1) pserver_program.num_blocks - 1)
for _, op in enumerate(lr_ops): for _, op in enumerate(lr_ops):
self._append_pserver_non_opt_ops(lr_decay_block, op, endpoint) self._append_pserver_non_opt_ops(lr_decay_block, op)
# append sub blocks to pserver_program in lr_decay_op
__clone_lr_op_sub_block__(op, pserver_program, lr_decay_block)
# append op to the current block # append op to the current block
grad_to_block_id = [] grad_to_block_id = []
...@@ -1116,7 +1142,29 @@ class DistributeTranspiler: ...@@ -1116,7 +1142,29 @@ class DistributeTranspiler:
break break
return grad_block return grad_block
def _append_pserver_non_opt_ops(self, optimize_block, opt_op, endpoint): def _clone_lr_op(self, program, block, op):
inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, op)
for key, varlist in inputs.iteritems():
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
if var not in program.global_block().vars:
block.clone_variable(var)
outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, op)
for key, varlist in outputs.iteritems():
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
if var not in program.global_block().vars:
block.clone_variable(var)
block.append_op(
type=op.type, inputs=inputs, outputs=outputs, attrs=op.attrs)
def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
program = optimize_block.program program = optimize_block.program
# Append the ops for parameters that do not need to be optimized/updated # Append the ops for parameters that do not need to be optimized/updated
inputs = self._get_input_map_from_op( inputs = self._get_input_map_from_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册