提交 837378b5 编写于 作者: C chengduo 提交者: XiaoguangHu

Cherry pick Add sub-scope check in RecurrentOp and WhileOp (#20549)

* fix recurrent bug
test=develop

* Fix while op bug
test=release/1.6

* add unit test
test=release/1.6
上级 57173763
...@@ -62,6 +62,17 @@ class WhileOp : public framework::OperatorBase { ...@@ -62,6 +62,17 @@ class WhileOp : public framework::OperatorBase {
auto step_scopes = auto step_scopes =
scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>(); scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>();
if (step_scopes->size() > 0) {
platform::DeviceContextPool::Instance().Get(dev_place)->Wait();
for (auto &s : *step_scopes) {
if (scope.HasKid(s)) {
scope.DeleteScope(s);
}
}
step_scopes->clear();
}
PADDLE_ENFORCE_EQ(step_scopes->size(), 0, "The StepScope should be empty."); PADDLE_ENFORCE_EQ(step_scopes->size(), 0, "The StepScope should be empty.");
PADDLE_ENFORCE(platform::is_cpu_place(cond.place()), PADDLE_ENFORCE(platform::is_cpu_place(cond.place()),
"Condition of while op must in CPU memory."); "Condition of while op must in CPU memory.");
......
...@@ -48,8 +48,10 @@ static void ClearStepScopes(const platform::DeviceContext &dev_ctx, ...@@ -48,8 +48,10 @@ static void ClearStepScopes(const platform::DeviceContext &dev_ctx,
dev_ctx.Wait(); dev_ctx.Wait();
for (auto *sub_scope : *step_scopes) { for (auto *sub_scope : *step_scopes) {
if (parent_scope->HasKid(sub_scope)) {
parent_scope->DeleteScope(sub_scope); parent_scope->DeleteScope(sub_scope);
} }
}
step_scopes->clear(); step_scopes->clear();
} }
......
...@@ -18,46 +18,38 @@ import unittest ...@@ -18,46 +18,38 @@ import unittest
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
import numpy import numpy
class TestWhileOp(unittest.TestCase): class TestWhileOp(unittest.TestCase):
def test_simple_forward(self): def simple_net(self):
d0 = layers.data( d0 = layers.data(
"d0", shape=[10], append_batch_size=False, dtype='float32') "d0", shape=[10], append_batch_size=False, dtype='float32')
d1 = layers.data( d1 = layers.data(
"d1", shape=[10], append_batch_size=False, dtype='float32') "d1", shape=[10], append_batch_size=False, dtype='float32')
d2 = layers.data( d2 = layers.data(
"d2", shape=[10], append_batch_size=False, dtype='float32') "d2", shape=[10], append_batch_size=False, dtype='float32')
i = layers.zeros(shape=[1], dtype='int64') i = layers.zeros(shape=[1], dtype='int64')
i.stop_gradient = True i.stop_gradient = True
init = layers.zeros(shape=[10], dtype='float32') init = layers.zeros(shape=[10], dtype='float32')
mem_array = layers.array_write(x=init, i=i) mem_array = layers.array_write(x=init, i=i)
data_array = layers.array_write(x=d0, i=i) data_array = layers.array_write(x=d0, i=i)
i = layers.increment(i) i = layers.increment(i)
layers.array_write(d1, i, array=data_array) layers.array_write(d1, i, array=data_array)
i = layers.increment(i) i = layers.increment(i)
layers.array_write(d2, i, array=data_array) layers.array_write(d2, i, array=data_array)
i = layers.zeros(shape=[1], dtype='int64') i = layers.zeros(shape=[1], dtype='int64')
i.stop_gradient = True i.stop_gradient = True
array_len = layers.fill_constant(shape=[1], dtype='int64', value=1) array_len = layers.fill_constant(shape=[1], dtype='int64', value=1)
array_len.stop_gradient = True array_len.stop_gradient = True
cond = layers.less_than(x=i, y=array_len) cond = layers.less_than(x=i, y=array_len)
j = layers.fill_constant(shape=[1], dtype='int64', value=1) j = layers.fill_constant(shape=[1], dtype='int64', value=1)
j.stop_gradient = True j.stop_gradient = True
array_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3) array_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3)
array_len2.stop_gradient = True array_len2.stop_gradient = True
cond2 = layers.less_than(x=j, y=array_len2) cond2 = layers.less_than(x=j, y=array_len2)
while_op = layers.While(cond=cond) while_op = layers.While(cond=cond)
while_op2 = layers.While(cond=cond2) while_op2 = layers.While(cond=cond2)
with while_op.block(): with while_op.block():
...@@ -77,9 +69,15 @@ class TestWhileOp(unittest.TestCase): ...@@ -77,9 +69,15 @@ class TestWhileOp(unittest.TestCase):
j = layers.increment(x=j, in_place=True) j = layers.increment(x=j, in_place=True)
layers.array_write(result2, i=j, array=mem_array) layers.array_write(result2, i=j, array=mem_array)
layers.less_than(x=j, y=array_len2, cond=cond2) layers.less_than(x=j, y=array_len2, cond=cond2)
sum_result = layers.array_read(array=mem_array, i=j) sum_result = layers.array_read(array=mem_array, i=j)
loss = layers.mean(sum_result) loss = layers.mean(sum_result)
return loss, sum_result
def test_simple_net(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
loss, sum_result = self.simple_net()
append_backward(loss) append_backward(loss)
...@@ -96,6 +94,23 @@ class TestWhileOp(unittest.TestCase): ...@@ -96,6 +94,23 @@ class TestWhileOp(unittest.TestCase):
fetch_list=[sum_result]) fetch_list=[sum_result])
self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01) self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01)
def test_simple_net_forward(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
self.simple_net()
binary = fluid.compiler.CompiledProgram(main_program)
cpu = core.CPUPlace()
exe = Executor(cpu)
d = []
for i in range(3):
d.append(numpy.random.random(size=[10]).astype('float32'))
for _ in range(2):
exe.run(binary, feed={'d0': d[0], 'd1': d[1], 'd2': d[2]})
def test_exceptions(self): def test_exceptions(self):
i = layers.zeros(shape=[2], dtype='int64') i = layers.zeros(shape=[2], dtype='int64')
array_len = layers.fill_constant(shape=[2], dtype='int64', value=1) array_len = layers.fill_constant(shape=[2], dtype='int64', value=1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册