From 3f09620ef2b1924bbeff8b9915ca2a46aed1aa5c Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Tue, 13 Feb 2018 22:09:05 +0000 Subject: [PATCH] pass compile --- paddle/fluid/framework/executor.cc | 2 +- paddle/fluid/operators/nccl_op.cc | 2 +- python/paddle/v2/fluid/backward.py | 27 +++++++++++++++------------ 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 1d7eccbc65..92b32b04d6 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -55,7 +55,7 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) { var->GetMutable(); } else if (var_type == proto::VarType::READER) { var->GetMutable(); - } else if (var_type == proto::VarDesc::NCCL_COM) { + } else if (var_type == proto::VarType::NCCL_COM) { // GetMutable will be called in ncclInit } else { PADDLE_THROW( diff --git a/paddle/fluid/operators/nccl_op.cc b/paddle/fluid/operators/nccl_op.cc index f61b5003bd..0994bba782 100644 --- a/paddle/fluid/operators/nccl_op.cc +++ b/paddle/fluid/operators/nccl_op.cc @@ -65,7 +65,7 @@ class NCCLInitOpVarTypeInference : public framework::VarTypeInference { framework::BlockDesc *block) const override { auto out_var_name = op_desc.Output("Communicator").front(); auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarDesc::NCCL_COM; + auto var_type = framework::proto::VarType::NCCL_COM; out_var.SetType(var_type); } }; diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index cf32c6683b..682df3301b 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -269,7 +269,7 @@ def _append_backward_ops_(block, target_block, no_grad_dict, grad_to_var, - callback=None): + callbacks=None): """ Create all grad ops, and insert them into given block @@ -285,14 +285,13 @@ def _append_backward_ops_(block, val(str): corresponding forward variable name callback(callable object): a callable object used to decorate new generated grad ops """ - if callback is None: - - def empty_callback(block, context): - pass - - callback = empty_callback - elif not hasattr(callback, '__call__'): - raise ValueError("'callback' must be a callable object.") + if callbacks is None: + callbacks = [] + else: + assert (isinstance(callbacks, list)) + for cb in callbacks: + if not hasattr(cb, '__call__'): + raise ValueError("'callback' must be a callable object.") # grad_op_descs holds created grad_op, and will be appended to target_block grad_op_descs = [] @@ -303,9 +302,12 @@ def _append_backward_ops_(block, if op.has_attr("sub_block"): sub_block = program.block(op.block_attr("sub_block")) grad_sub_block = program.create_block(parent_idx=sub_block.idx) + if callbacks is None: + callbacks = [_callback_lookup_(op)] + else: + callbacks.append(_callback_lookup_(op)) _append_backward_ops_(sub_block, sub_block.ops, grad_sub_block, - no_grad_dict, grad_to_var, - _callback_lookup_(op)) + no_grad_dict, grad_to_var, callbacks) grad_sub_block_list.append(grad_sub_block.desc) # Getting op's corresponding grad_op @@ -325,7 +327,8 @@ def _append_backward_ops_(block, new_op_desc = target_block.desc.append_op() new_op_desc.copy_from(op_desc) grad_to_var["__current_op_desc__"] = new_op_desc - callback(block=target_block, context=grad_to_var) + for cb in callbacks: + cb(block=target_block, context=grad_to_var) def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): -- GitLab