提交 61a7df2e 编写于 作者: F fengjiayi

update

上级 590e6111
...@@ -157,7 +157,10 @@ void BindBlockDesc(py::module &m) { ...@@ -157,7 +157,10 @@ void BindBlockDesc(py::module &m) {
.def_property_readonly("parent", &BlockDescBind::Parent) .def_property_readonly("parent", &BlockDescBind::Parent)
.def("append_op", &BlockDescBind::AppendOp, .def("append_op", &BlockDescBind::AppendOp,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("append_allocated_op", &BlockDescBind::AppendAllocatedOp) .def("append_allocated_op",
[](BlockDescBind &self, OpDescBind *op_desc) {
self.AppendAllocatedOp(std::unique_ptr<OpDescBind>(op_desc));
})
.def("prepend_op", &BlockDescBind::PrependOp, .def("prepend_op", &BlockDescBind::PrependOp,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("var", .def("var",
......
from paddle.v2.fluid import framework as framework from paddle.v2.fluid import framework as framework
from . import core from . import core
import collections import collections
import pdb
__all__ = ['append_backward_ops'] __all__ = ['append_backward_ops']
...@@ -15,7 +16,8 @@ def rename_arg(op_desc_list, old_name, new_name, begin_idx=None, end_idx=None): ...@@ -15,7 +16,8 @@ def rename_arg(op_desc_list, old_name, new_name, begin_idx=None, end_idx=None):
op_desc_list[i].rename_output(old_name, new_name) op_desc_list[i].rename_output(old_name, new_name)
def backward_impl(block, def backward_impl(target,
block,
target_block, target_block,
no_grad_set, no_grad_set,
grad_info_map, grad_info_map,
...@@ -29,8 +31,8 @@ def backward_impl(block, ...@@ -29,8 +31,8 @@ def backward_impl(block,
sub_block_idx = each_op.block_attr("sub_block") sub_block_idx = each_op.block_attr("sub_block")
sub_block = program.block(sub_block_idx) sub_block = program.block(sub_block_idx)
grad_sub_block = program.create_block(parent_idx=sub_block_idx) grad_sub_block = program.create_block(parent_idx=sub_block_idx)
backward_impl(sub_block, grad_sub_block, no_grad_set, grad_info_map, backward_impl(target, sub_block, grad_sub_block, no_grad_set,
callback) grad_info_map, callback)
grad_sub_block_list.append(grad_sub_block) grad_sub_block_list.append(grad_sub_block)
grad_op_desc = core.get_grad_op_desc(each_op.desc, grad_op_desc = core.get_grad_op_desc(each_op.desc,
no_grad_set[block.idx], no_grad_set[block.idx],
...@@ -46,6 +48,7 @@ def backward_impl(block, ...@@ -46,6 +48,7 @@ def backward_impl(block,
for pos, op_desc in enumerate(grad_op_descs): for pos, op_desc in enumerate(grad_op_descs):
for var_name in op_desc.input_arg_names(): for var_name in op_desc.input_arg_names():
if len(var_inputs[var_name]) > 1: if len(var_inputs[var_name]) > 1:
pdb.set_trace()
pending_sum_ops.append((core.OpDesc( pending_sum_ops.append((core.OpDesc(
type="sum_op", type="sum_op",
inputs=var_inputs[var_name], inputs=var_inputs[var_name],
...@@ -55,7 +58,7 @@ def backward_impl(block, ...@@ -55,7 +58,7 @@ def backward_impl(block,
for var_name in op_desc.output_arg_names(): for var_name in op_desc.output_arg_names():
if len(var_inputs[var_name]) == 0: if len(var_inputs[var_name]) == 0:
# it's the first time we get the variable # it's the first time we get the variable
var_inputs[var_name] = var_name var_inputs[var_name] = [var_name]
else: else:
if len(var_inputs[var_name] == 1): if len(var_inputs[var_name] == 1):
new_name = var_name + "@RENAME@" + \ new_name = var_name + "@RENAME@" + \
...@@ -73,8 +76,9 @@ def backward_impl(block, ...@@ -73,8 +76,9 @@ def backward_impl(block,
var_inputs[var_name].append(new_name) var_inputs[var_name].append(new_name)
for var_name, inputs in var_inputs.iteritems(): for var_name, inputs in var_inputs.iteritems():
if len(inputs) > 1: if len(inputs) > 1:
pending_sum_ops.append((core.OpDesc( pdb.set_trace()
type="sum_op", inputs=inputs, outputs=var_name, attrs={}), pending_sum_ops.append((core.OpDesc("sum_op", {"X": inputs},
{"Out": var_name}, {}),
len(grad_op_descs))) len(grad_op_descs)))
# TODO: remove op in no grad set # TODO: remove op in no grad set
...@@ -84,6 +88,7 @@ def backward_impl(block, ...@@ -84,6 +88,7 @@ def backward_impl(block,
# create new gradient variables in the target block desc # create new gradient variables in the target block desc
for op_desc in grad_op_descs: for op_desc in grad_op_descs:
for grad_var_name in op_desc.output_arg_names(): for grad_var_name in op_desc.output_arg_names():
grad_var_name = grad_var_name.encode("ascii")
if target_block.desc.has_var( if target_block.desc.has_var(
grad_var_name) or grad_var_name == core.get_empty_var_name( grad_var_name) or grad_var_name == core.get_empty_var_name(
): ):
...@@ -93,6 +98,16 @@ def backward_impl(block, ...@@ -93,6 +98,16 @@ def backward_impl(block,
continue continue
grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name,
target_block) target_block)
if target_block.idx == 0:
grad_target_name = (target.name + "@GRAD")
target_block.desc.var(grad_target_name)
grad_op_descs.insert(
0,
core.OpDesc(u"fill_constant", {}, {
u"Out": [unicode(grad_target_name, "ascii")]
}, {u"shape": (1),
u"value": 1.0,
u"dtype": core.DataType.FP32}))
# insert backward operators to target_block # insert backward operators to target_block
for op_desc in grad_op_descs: for op_desc in grad_op_descs:
target_block.desc.append_allocated_op(op_desc) target_block.desc.append_allocated_op(op_desc)
...@@ -118,18 +133,22 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None): ...@@ -118,18 +133,22 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
assert isinstance(loss, framework.Variable) assert isinstance(loss, framework.Variable)
if no_grad_set is None: if no_grad_set is None:
no_grad_set = dict()
program = loss.block.program program = loss.block.program
assert isinstance(program, framework.Program) assert isinstance(program, framework.Program)
no_grad_set = list()
for block in program.blocks: for block in program.blocks:
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
block_no_grad_set = set()
for var in block.vars.itervalues(): for var in block.vars.itervalues():
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
if var.stop_gradient: if var.stop_gradient:
no_grad_set.append(var.name) block_no_grad_set.add(var.name)
no_grad_set = set(no_grad_set) no_grad_set[block.idx] = block_no_grad_set
param_grad_map = loss.block.program.append_backward(loss, no_grad_set) grad_info_map = dict()
root_block = loss.block.program.block(0)
backward_impl(loss, root_block, root_block, no_grad_set, grad_info_map)
pdb.set_trace()
if parameter_list is not None: if parameter_list is not None:
parameters = parameter_list parameters = parameter_list
else: else:
...@@ -137,9 +156,9 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None): ...@@ -137,9 +156,9 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
parameters = [param.name for param in params] parameters = [param.name for param in params]
params_and_grads = [] params_and_grads = []
for param in parameters: for param in parameters:
if param not in param_grad_map: if param not in grad_info_map:
raise ValueError("param %s is not in map" % param) raise ValueError("param %s is not in map" % param)
grad_info = param_grad_map[param] grad_info = grad_info_map[param]
grad_block = loss.block.program.block(grad_info[1]) grad_block = loss.block.program.block(grad_info[1])
if not grad_block.has_var(grad_info[0]): if not grad_block.has_var(grad_info[0]):
raise ValueError("grad block[{0}] did not have grad var {1}".format( raise ValueError("grad block[{0}] did not have grad var {1}".format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册