提交 3f09620e 编写于 作者: Y Yang Yang

pass compile

上级 e021ad67
...@@ -55,7 +55,7 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) { ...@@ -55,7 +55,7 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
var->GetMutable<platform::PlaceList>(); var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarType::READER) { } else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>(); var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarDesc::NCCL_COM) { } else if (var_type == proto::VarType::NCCL_COM) {
// GetMutable will be called in ncclInit // GetMutable will be called in ncclInit
} else { } else {
PADDLE_THROW( PADDLE_THROW(
......
...@@ -65,7 +65,7 @@ class NCCLInitOpVarTypeInference : public framework::VarTypeInference { ...@@ -65,7 +65,7 @@ class NCCLInitOpVarTypeInference : public framework::VarTypeInference {
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Communicator").front(); auto out_var_name = op_desc.Output("Communicator").front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarDesc::NCCL_COM; auto var_type = framework::proto::VarType::NCCL_COM;
out_var.SetType(var_type); out_var.SetType(var_type);
} }
}; };
......
...@@ -269,7 +269,7 @@ def _append_backward_ops_(block, ...@@ -269,7 +269,7 @@ def _append_backward_ops_(block,
target_block, target_block,
no_grad_dict, no_grad_dict,
grad_to_var, grad_to_var,
callback=None): callbacks=None):
""" """
Create all grad ops, and insert them into given block Create all grad ops, and insert them into given block
...@@ -285,14 +285,13 @@ def _append_backward_ops_(block, ...@@ -285,14 +285,13 @@ def _append_backward_ops_(block,
val(str): corresponding forward variable name val(str): corresponding forward variable name
callback(callable object): a callable object used to decorate new generated grad ops callback(callable object): a callable object used to decorate new generated grad ops
""" """
if callback is None: if callbacks is None:
callbacks = []
def empty_callback(block, context): else:
pass assert (isinstance(callbacks, list))
for cb in callbacks:
callback = empty_callback if not hasattr(cb, '__call__'):
elif not hasattr(callback, '__call__'): raise ValueError("'callback' must be a callable object.")
raise ValueError("'callback' must be a callable object.")
# grad_op_descs holds created grad_op, and will be appended to target_block # grad_op_descs holds created grad_op, and will be appended to target_block
grad_op_descs = [] grad_op_descs = []
...@@ -303,9 +302,12 @@ def _append_backward_ops_(block, ...@@ -303,9 +302,12 @@ def _append_backward_ops_(block,
if op.has_attr("sub_block"): if op.has_attr("sub_block"):
sub_block = program.block(op.block_attr("sub_block")) sub_block = program.block(op.block_attr("sub_block"))
grad_sub_block = program.create_block(parent_idx=sub_block.idx) grad_sub_block = program.create_block(parent_idx=sub_block.idx)
if callbacks is None:
callbacks = [_callback_lookup_(op)]
else:
callbacks.append(_callback_lookup_(op))
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block, _append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
no_grad_dict, grad_to_var, no_grad_dict, grad_to_var, callbacks)
_callback_lookup_(op))
grad_sub_block_list.append(grad_sub_block.desc) grad_sub_block_list.append(grad_sub_block.desc)
# Getting op's corresponding grad_op # Getting op's corresponding grad_op
...@@ -325,7 +327,8 @@ def _append_backward_ops_(block, ...@@ -325,7 +327,8 @@ def _append_backward_ops_(block,
new_op_desc = target_block.desc.append_op() new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc) new_op_desc.copy_from(op_desc)
grad_to_var["__current_op_desc__"] = new_op_desc grad_to_var["__current_op_desc__"] = new_op_desc
callback(block=target_block, context=grad_to_var) for cb in callbacks:
cb(block=target_block, context=grad_to_var)
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册