提交 837378b5 编写于 作者: C chengduo 提交者: XiaoguangHu

Cherry pick Add sub-scope check in RecurrentOp and WhileOp (#20549)

* fix recurrent bug
test=develop

* Fix while op bug
test=release/1.6

* add unit test
test=release/1.6
上级 57173763
......@@ -62,6 +62,17 @@ class WhileOp : public framework::OperatorBase {
auto step_scopes =
scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>();
if (step_scopes->size() > 0) {
platform::DeviceContextPool::Instance().Get(dev_place)->Wait();
for (auto &s : *step_scopes) {
if (scope.HasKid(s)) {
scope.DeleteScope(s);
}
}
step_scopes->clear();
}
PADDLE_ENFORCE_EQ(step_scopes->size(), 0, "The StepScope should be empty.");
PADDLE_ENFORCE(platform::is_cpu_place(cond.place()),
"Condition of while op must in CPU memory.");
......
......@@ -48,7 +48,9 @@ static void ClearStepScopes(const platform::DeviceContext &dev_ctx,
dev_ctx.Wait();
for (auto *sub_scope : *step_scopes) {
parent_scope->DeleteScope(sub_scope);
if (parent_scope->HasKid(sub_scope)) {
parent_scope->DeleteScope(sub_scope);
}
}
step_scopes->clear();
......
......@@ -18,46 +18,38 @@ import unittest
import paddle.fluid.layers as layers
from paddle.fluid.executor import Executor
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid.backward import append_backward
import numpy
class TestWhileOp(unittest.TestCase):
def test_simple_forward(self):
def simple_net(self):
d0 = layers.data(
"d0", shape=[10], append_batch_size=False, dtype='float32')
d1 = layers.data(
"d1", shape=[10], append_batch_size=False, dtype='float32')
d2 = layers.data(
"d2", shape=[10], append_batch_size=False, dtype='float32')
i = layers.zeros(shape=[1], dtype='int64')
i.stop_gradient = True
init = layers.zeros(shape=[10], dtype='float32')
mem_array = layers.array_write(x=init, i=i)
data_array = layers.array_write(x=d0, i=i)
i = layers.increment(i)
layers.array_write(d1, i, array=data_array)
i = layers.increment(i)
layers.array_write(d2, i, array=data_array)
i = layers.zeros(shape=[1], dtype='int64')
i.stop_gradient = True
array_len = layers.fill_constant(shape=[1], dtype='int64', value=1)
array_len.stop_gradient = True
cond = layers.less_than(x=i, y=array_len)
j = layers.fill_constant(shape=[1], dtype='int64', value=1)
j.stop_gradient = True
array_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3)
array_len2.stop_gradient = True
cond2 = layers.less_than(x=j, y=array_len2)
while_op = layers.While(cond=cond)
while_op2 = layers.While(cond=cond2)
with while_op.block():
......@@ -77,24 +69,47 @@ class TestWhileOp(unittest.TestCase):
j = layers.increment(x=j, in_place=True)
layers.array_write(result2, i=j, array=mem_array)
layers.less_than(x=j, y=array_len2, cond=cond2)
sum_result = layers.array_read(array=mem_array, i=j)
loss = layers.mean(sum_result)
return loss, sum_result
def test_simple_net(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
loss, sum_result = self.simple_net()
append_backward(loss)
cpu = core.CPUPlace()
exe = Executor(cpu)
d = []
for i in range(3):
d.append(numpy.random.random(size=[10]).astype('float32'))
outs = exe.run(feed={'d0': d[0],
'd1': d[1],
'd2': d[2]},
fetch_list=[sum_result])
self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01)
append_backward(loss)
def test_simple_net_forward(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
self.simple_net()
binary = fluid.compiler.CompiledProgram(main_program)
cpu = core.CPUPlace()
exe = Executor(cpu)
d = []
cpu = core.CPUPlace()
exe = Executor(cpu)
d = []
for i in range(3):
d.append(numpy.random.random(size=[10]).astype('float32'))
for i in range(3):
d.append(numpy.random.random(size=[10]).astype('float32'))
outs = exe.run(feed={'d0': d[0],
'd1': d[1],
'd2': d[2]},
fetch_list=[sum_result])
self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01)
for _ in range(2):
exe.run(binary, feed={'d0': d[0], 'd1': d[1], 'd2': d[2]})
def test_exceptions(self):
i = layers.zeros(shape=[2], dtype='int64')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册