未验证 提交 8ad90558 编写于 作者: C chengduo 提交者: GitHub

Add is_test for while_op (#12874)

* add is_test for while_op

* Change API
上级 c6f212a3
......@@ -191,7 +191,7 @@ paddle.fluid.layers.argsort ArgSpec(args=['input', 'axis', 'name'], varargs=None
paddle.fluid.layers.ones ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'is_test', 'name'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.Switch.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.Switch.case ArgSpec(args=['self', 'condition'], varargs=None, keywords=None, defaults=None)
......
......@@ -58,11 +58,15 @@ class WhileOp : public framework::OperatorBase {
PADDLE_ENFORCE(platform::is_cpu_place(cond.place()),
"Condition of while op must in CPU memory.");
bool is_test = Attr<bool>("is_test");
auto ctx = executor.Prepare(*program, block->ID());
while (cond.data<bool>()[0]) {
auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope);
executor.RunPreparedContext(ctx.get(), &current_scope, false);
if (is_test) {
scope.DeleteScope(&current_scope);
}
}
}
};
......@@ -88,6 +92,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
"variables generated in the i'th step.");
AddAttr<framework::BlockDesc *>(kStepBlock,
"The step block inside WhileOp");
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
AddComment(R"DOC(
)DOC");
}
......@@ -103,6 +108,8 @@ class WhileGradOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_ENFORCE(!Attr<bool>("is_test"),
"GradOp is only callable when is_test is false");
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
......
......@@ -661,6 +661,7 @@ class While(object):
Args:
cond (Variable): condition used to compare.
is_test(bool): A flag indicating whether execution is in test phase.
name (str): The name of this layer.
Examples:
......@@ -683,7 +684,7 @@ class While(object):
IN_WHILE_BLOCK = 1
AFTER_WHILE_BLOCK = 2
def __init__(self, cond, name=None):
def __init__(self, cond, is_test=False, name=None):
self.helper = LayerHelper("while", name=name)
self.status = While.BEFORE_WHILE_BLOCK
if not isinstance(cond, Variable):
......@@ -694,6 +695,7 @@ class While(object):
if reduce(lambda a, b: a * b, cond.shape, 1) != 1:
raise TypeError("condition should be a bool scalar")
self.cond_var = cond
self.is_test = is_test
def block(self):
return WhileGuard(self)
......@@ -735,7 +737,8 @@ class While(object):
},
outputs={'Out': out_vars,
'StepScopes': [step_scope]},
attrs={'sub_block': while_block})
attrs={'sub_block': while_block,
"is_test": self.is_test})
def lod_rank_table(x, level=0):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册