From f2e76008d37d5a0203fdd42edf4fa1ac2907c400 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 15 Dec 2017 15:42:26 +0800 Subject: [PATCH] update --- python/paddle/v2/fluid/backward.py | 19 +++++++++++++++++++ python/paddle/v2/fluid/framework.py | 6 ++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index f18858217..3a128b8e6 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -1,8 +1,27 @@ from paddle.v2.fluid import framework as framework +from . import core __all__ = ['append_backward_ops'] +def backward_impl(block, target_block, no_grad_set, grad_to_var, callback): + grad_op_descs = [] + program = block.program + for each_op in block.ops: + grad_sub_block_list = [] + if each_op.has_attr("sub_block"): + sub_block_idx = each_op.block_attr("sub_block") + sub_block = program.block(sub_block_idx) + grad_sub_block = program.create_block(parent_idx=sub_block_idx) + backward_impl(sub_block, grad_sub_block, no_grad_set, grad_to_var, + callback) + grad_sub_block_list.append(grad_sub_block) + grad_op_desc = core.get_grad_op_desc(each_op.desc, + no_grad_set[block.idx], + grad_to_var, grad_sub_block_list) + grad_op_descs.append(grad_op_desc) + + def append_backward_ops(loss, parameter_list=None, no_grad_set=None): """ Create and add gradient Operators in BlockDesc to compute diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index bf0cd275b..244a96393 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -806,9 +806,11 @@ class Program(object): self.sync_with_cpp() return param_to_grad_info - def create_block(self): + def create_block(self, parent_idx=None): new_block_idx = len(self.blocks) - self.desc.append_block(self.current_block().desc) + parent = self.current_block() if parent_idx is None else self.block( + parent_idx) + self.desc.append_block(parent.desc) self.current_block_idx = new_block_idx self.blocks.append(Block(self, self.current_block_idx)) return self.current_block() -- GitLab