未验证 提交 8ad90558 编写于 作者: C chengduo 提交者: GitHub

Add is_test for while_op (#12874)

* add is_test for while_op

* Change API
上级 c6f212a3
...@@ -191,7 +191,7 @@ paddle.fluid.layers.argsort ArgSpec(args=['input', 'axis', 'name'], varargs=None ...@@ -191,7 +191,7 @@ paddle.fluid.layers.argsort ArgSpec(args=['input', 'axis', 'name'], varargs=None
paddle.fluid.layers.ones ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.ones ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'is_test', 'name'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.Switch.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.Switch.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.Switch.case ArgSpec(args=['self', 'condition'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.Switch.case ArgSpec(args=['self', 'condition'], varargs=None, keywords=None, defaults=None)
......
...@@ -58,11 +58,15 @@ class WhileOp : public framework::OperatorBase { ...@@ -58,11 +58,15 @@ class WhileOp : public framework::OperatorBase {
PADDLE_ENFORCE(platform::is_cpu_place(cond.place()), PADDLE_ENFORCE(platform::is_cpu_place(cond.place()),
"Condition of while op must in CPU memory."); "Condition of while op must in CPU memory.");
bool is_test = Attr<bool>("is_test");
auto ctx = executor.Prepare(*program, block->ID()); auto ctx = executor.Prepare(*program, block->ID());
while (cond.data<bool>()[0]) { while (cond.data<bool>()[0]) {
auto &current_scope = scope.NewScope(); auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope); step_scopes->push_back(&current_scope);
executor.RunPreparedContext(ctx.get(), &current_scope, false); executor.RunPreparedContext(ctx.get(), &current_scope, false);
if (is_test) {
scope.DeleteScope(&current_scope);
}
} }
} }
}; };
...@@ -88,6 +92,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -88,6 +92,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
"variables generated in the i'th step."); "variables generated in the i'th step.");
AddAttr<framework::BlockDesc *>(kStepBlock, AddAttr<framework::BlockDesc *>(kStepBlock,
"The step block inside WhileOp"); "The step block inside WhileOp");
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
)DOC"); )DOC");
} }
...@@ -103,6 +108,8 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -103,6 +108,8 @@ class WhileGradOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
PADDLE_ENFORCE(!Attr<bool>("is_test"),
"GradOp is only callable when is_test is false");
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
......
...@@ -661,6 +661,7 @@ class While(object): ...@@ -661,6 +661,7 @@ class While(object):
Args: Args:
cond (Variable): condition used to compare. cond (Variable): condition used to compare.
is_test(bool): A flag indicating whether execution is in test phase.
name (str): The name of this layer. name (str): The name of this layer.
Examples: Examples:
...@@ -683,7 +684,7 @@ class While(object): ...@@ -683,7 +684,7 @@ class While(object):
IN_WHILE_BLOCK = 1 IN_WHILE_BLOCK = 1
AFTER_WHILE_BLOCK = 2 AFTER_WHILE_BLOCK = 2
def __init__(self, cond, name=None): def __init__(self, cond, is_test=False, name=None):
self.helper = LayerHelper("while", name=name) self.helper = LayerHelper("while", name=name)
self.status = While.BEFORE_WHILE_BLOCK self.status = While.BEFORE_WHILE_BLOCK
if not isinstance(cond, Variable): if not isinstance(cond, Variable):
...@@ -694,6 +695,7 @@ class While(object): ...@@ -694,6 +695,7 @@ class While(object):
if reduce(lambda a, b: a * b, cond.shape, 1) != 1: if reduce(lambda a, b: a * b, cond.shape, 1) != 1:
raise TypeError("condition should be a bool scalar") raise TypeError("condition should be a bool scalar")
self.cond_var = cond self.cond_var = cond
self.is_test = is_test
def block(self): def block(self):
return WhileGuard(self) return WhileGuard(self)
...@@ -735,7 +737,8 @@ class While(object): ...@@ -735,7 +737,8 @@ class While(object):
}, },
outputs={'Out': out_vars, outputs={'Out': out_vars,
'StepScopes': [step_scope]}, 'StepScopes': [step_scope]},
attrs={'sub_block': while_block}) attrs={'sub_block': while_block,
"is_test": self.is_test})
def lod_rank_table(x, level=0): def lod_rank_table(x, level=0):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册