From cc18fffb9000d4b5b9352568f341844c72d14fe1 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 11 Sep 2018 12:05:25 +0800 Subject: [PATCH] add nest while_op --- paddle/fluid/operators/while_op.cc | 5 ++-- .../fluid/tests/unittests/test_while_op.py | 25 ++++++++++++++++--- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index 65a3bc928e4..791138a8c0e 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -63,7 +63,7 @@ class WhileOp : public framework::OperatorBase { while (cond.data()[0]) { auto ¤t_scope = scope.NewScope(); step_scopes->push_back(¤t_scope); - executor.RunPreparedContext(ctx.get(), ¤t_scope, false); + executor.RunPreparedContext(ctx.get(), ¤t_scope, false, true, true); if (is_test) { scope.DeleteScope(¤t_scope); } @@ -169,7 +169,8 @@ class WhileGradOp : public framework::OperatorBase { } } } - executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false); + executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false, true, + true); auto &pg_names = Outputs(kXGRAD); auto &p_names = Inputs(kX); diff --git a/python/paddle/fluid/tests/unittests/test_while_op.py b/python/paddle/fluid/tests/unittests/test_while_op.py index b75373cf24a..43fd9d425bf 100644 --- a/python/paddle/fluid/tests/unittests/test_while_op.py +++ b/python/paddle/fluid/tests/unittests/test_while_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,8 +30,10 @@ class TestWhileOp(unittest.TestCase): "d1", shape=[10], append_batch_size=False, dtype='float32') d2 = layers.data( "d2", shape=[10], append_batch_size=False, dtype='float32') + i = layers.zeros(shape=[1], dtype='int64') i.stop_gradient = True + init = layers.zeros(shape=[10], dtype='float32') mem_array = layers.array_write(x=init, i=i) data_array = layers.array_write(x=d0, i=i) @@ -45,11 +47,19 @@ class TestWhileOp(unittest.TestCase): i = layers.zeros(shape=[1], dtype='int64') i.stop_gradient = True - array_len = layers.fill_constant(shape=[1], dtype='int64', value=3) + array_len = layers.fill_constant(shape=[1], dtype='int64', value=1) array_len.stop_gradient = True cond = layers.less_than(x=i, y=array_len) + j = layers.fill_constant(shape=[1], dtype='int64', value=1) + j.stop_gradient = True + + array_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3) + array_len2.stop_gradient = True + cond2 = layers.less_than(x=j, y=array_len2) + while_op = layers.While(cond=cond) + while_op2 = layers.While(cond=cond2) with while_op.block(): d = layers.array_read(array=data_array, i=i) prev = layers.array_read(array=mem_array, i=i) @@ -59,7 +69,16 @@ class TestWhileOp(unittest.TestCase): layers.array_write(result, i=i, array=mem_array) layers.less_than(x=i, y=array_len, cond=cond) - sum_result = layers.array_read(array=mem_array, i=i) + with while_op2.block(): + d2 = layers.array_read(array=data_array, i=j) + prev2 = layers.array_read(array=mem_array, i=j) + result2 = layers.sums(input=[d2, prev2]) + + j = layers.increment(x=j, in_place=True) + layers.array_write(result2, i=j, array=mem_array) + layers.less_than(x=j, y=array_len2, cond=cond2) + + sum_result = layers.array_read(array=mem_array, i=j) loss = layers.mean(sum_result) append_backward(loss) -- GitLab