From 16359da2c1b213564a3aef1fac4b98af0caf6ee2 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 31 Aug 2018 21:38:20 +0800 Subject: [PATCH] Refine While Op --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/while_op.cc | 8 ++++++++ python/paddle/fluid/layers/control_flow.py | 7 +++++-- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index e963902a5..5f90d43ca 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -190,7 +190,7 @@ paddle.fluid.layers.argsort ArgSpec(args=['input', 'axis', 'name'], varargs=None paddle.fluid.layers.ones ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None) -paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'is_test', 'name'], varargs=None, keywords=None, defaults=(False, None)) paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.Switch.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.Switch.case ArgSpec(args=['self', 'condition'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index 733157ea0..147b5b1d1 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -55,6 +55,7 @@ class WhileOp : public framework::OperatorBase { auto step_scopes = scope.FindVar(Output(kStepScopes))->GetMutable(); + bool is_test = Attr("is_test"); PADDLE_ENFORCE(platform::is_cpu_place(cond.place()), "Condition of while op must in CPU memory."); while (cond.data()[0]) { @@ -63,6 +64,10 @@ class WhileOp : public framework::OperatorBase { executor.Run(*program, ¤t_scope, block->ID(), false /*create_local_scope*/); + + if (is_test) { + scope.DeleteScope(¤t_scope); + } } } }; @@ -88,6 +93,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker { "variables generated in the i'th step."); AddAttr(kStepBlock, "The step block inside WhileOp"); + AddAttr("is_test", "True if in test phase.").SetDefault(false); AddComment(R"DOC( )DOC"); } @@ -103,6 +109,8 @@ class WhileGradOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const override { + PADDLE_ENFORCE(!Attr("is_test"), + "GradOp is only callable when is_test is false"); // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place); diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 173567a0a..3d9183e7c 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -661,6 +661,7 @@ class While(object): Args: cond (Variable): condition used to compare. + is_test(bool): A flag indicating whether execution is in test phase. name (str): The name of this layer. Examples: @@ -683,7 +684,7 @@ class While(object): IN_WHILE_BLOCK = 1 AFTER_WHILE_BLOCK = 2 - def __init__(self, cond, name=None): + def __init__(self, cond, is_test=False, name=None): self.helper = LayerHelper("while", name=name) self.status = While.BEFORE_WHILE_BLOCK if not isinstance(cond, Variable): @@ -694,6 +695,7 @@ class While(object): if reduce(lambda a, b: a * b, cond.shape, 1) != 1: raise TypeError("condition should be a bool scalar") self.cond_var = cond + self.is_test = is_test def block(self): return WhileGuard(self) @@ -735,7 +737,8 @@ class While(object): }, outputs={'Out': out_vars, 'StepScopes': [step_scope]}, - attrs={'sub_block': while_block}) + attrs={'sub_block': while_block, + "is_test": self.is_test}) def lod_rank_table(x, level=0): -- GitLab