提交 9d26f1a3 编写于 作者: Y Yang Yang

callback to list of callbacks

上级 bea80b0d
...@@ -120,12 +120,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -120,12 +120,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
// VLOG(3) << op->DebugStringEx(local_scope); VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(op->Type(), pool.Get(place_)); platform::RecordEvent record_event(op->Type(), pool.Get(place_));
VLOG(3) << op->Type();
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: " VLOG(2) << "Memory used after operator " + op->Type() + " running: "
......
...@@ -234,7 +234,6 @@ def _callback_lookup_(op): ...@@ -234,7 +234,6 @@ def _callback_lookup_(op):
"ncclInit", "ncclInit",
{"parallel_scopes": self.parallel_scopes_name}, {"parallel_scopes": self.parallel_scopes_name},
{"Communicator": ['nccl_com__do_not_change_']}, {}) {"Communicator": ['nccl_com__do_not_change_']}, {})
print(serialize_op_decs(op_desc))
block.program.global_block().desc.append_op().copy_from( block.program.global_block().desc.append_op().copy_from(
op_desc) op_desc)
self.has_inserted_nccl_init = True self.has_inserted_nccl_init = True
...@@ -285,9 +284,7 @@ def _append_backward_ops_(block, ...@@ -285,9 +284,7 @@ def _append_backward_ops_(block,
val(str): corresponding forward variable name val(str): corresponding forward variable name
callback(callable object): a callable object used to decorate new generated grad ops callback(callable object): a callable object used to decorate new generated grad ops
""" """
if callbacks is None: if callbacks is not None:
callbacks = []
else:
assert (isinstance(callbacks, list)) assert (isinstance(callbacks, list))
for cb in callbacks: for cb in callbacks:
if not hasattr(cb, '__call__'): if not hasattr(cb, '__call__'):
...@@ -302,12 +299,17 @@ def _append_backward_ops_(block, ...@@ -302,12 +299,17 @@ def _append_backward_ops_(block,
if op.has_attr("sub_block"): if op.has_attr("sub_block"):
sub_block = program.block(op.block_attr("sub_block")) sub_block = program.block(op.block_attr("sub_block"))
grad_sub_block = program.create_block(parent_idx=sub_block.idx) grad_sub_block = program.create_block(parent_idx=sub_block.idx)
if callbacks is None: cb = _callback_lookup_(op)
callbacks = [_callback_lookup_(op)] if cb is not None:
if callbacks is None:
new_callbacks = [cb]
else:
new_callbacks = callbacks + [_callback_lookup_(op)]
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
no_grad_dict, grad_to_var, new_callbacks)
else: else:
callbacks.append(_callback_lookup_(op)) _append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block, no_grad_dict, grad_to_var, callbacks)
no_grad_dict, grad_to_var, callbacks)
grad_sub_block_list.append(grad_sub_block.desc) grad_sub_block_list.append(grad_sub_block.desc)
# Getting op's corresponding grad_op # Getting op's corresponding grad_op
...@@ -327,8 +329,10 @@ def _append_backward_ops_(block, ...@@ -327,8 +329,10 @@ def _append_backward_ops_(block,
new_op_desc = target_block.desc.append_op() new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc) new_op_desc.copy_from(op_desc)
grad_to_var["__current_op_desc__"] = new_op_desc grad_to_var["__current_op_desc__"] = new_op_desc
for cb in callbacks: if callbacks is not None:
cb(block=target_block, context=grad_to_var) assert (isinstance(callbacks, list))
for cb in callbacks:
cb(block=target_block, context=grad_to_var)
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
...@@ -408,7 +412,8 @@ def _get_stop_gradients_(program): ...@@ -408,7 +412,8 @@ def _get_stop_gradients_(program):
return no_grad_dict return no_grad_dict
def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None): def append_backward(loss, parameter_list=None, no_grad_set=None,
callbacks=None):
""" """
Append backward part to main_program Append backward part to main_program
...@@ -424,6 +429,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None): ...@@ -424,6 +429,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
(list[(Variable,Variable)]): list of (parameter, gradient) pair. (list[(Variable,Variable)]): list of (parameter, gradient) pair.
""" """
assert isinstance(loss, framework.Variable) assert isinstance(loss, framework.Variable)
if callbacks is not None:
isinstance(callbacks, list)
program = loss.block.program program = loss.block.program
if no_grad_set is None: if no_grad_set is None:
...@@ -451,7 +458,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None): ...@@ -451,7 +458,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
no_grad_dict[0].update(map(_append_grad_suffix_, block_no_grad_set)) no_grad_dict[0].update(map(_append_grad_suffix_, block_no_grad_set))
_append_backward_ops_(root_block, op_path, root_block, no_grad_dict, _append_backward_ops_(root_block, op_path, root_block, no_grad_dict,
grad_to_var, callback) grad_to_var, callbacks)
# Because calc_gradient may be called multiple times, # Because calc_gradient may be called multiple times,
# we need rename the internal gradient variables so that they have # we need rename the internal gradient variables so that they have
......
...@@ -225,7 +225,7 @@ class Optimizer(object): ...@@ -225,7 +225,7 @@ class Optimizer(object):
`create_optimization_pass()` into one. `create_optimization_pass()` into one.
""" """
params_grads = append_backward(loss, parameter_list, no_grad_set, params_grads = append_backward(loss, parameter_list, no_grad_set,
error_clip_callback) [error_clip_callback])
params_grads = append_gradient_clip_ops(params_grads) params_grads = append_gradient_clip_ops(params_grads)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册