diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index f22f584213ce325d3fda17efb0cb48725ede7e53..be8d10c0506a91fe20252640c84dece634d951fe 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -74,6 +74,15 @@ class ConditionalBlockOp : public ConditionalOp { } }; +class ConditionalBlockInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE_EQ(context->HasInputs(ConditionalOp::kCondition), true, + platform::errors::InvalidArgument( + "conditional_block_op must have condition input")); + } +}; + class ConditionalBlockGradOp : public ConditionalOp { public: ConditionalBlockGradOp(const std::string &type, @@ -278,6 +287,7 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpMaker { namespace ops = paddle::operators; REGISTER_OPERATOR(conditional_block, ops::ConditionalBlockOp, + ops::ConditionalBlockInferShape, ops::ConditionalBlockOpProtoMaker, ops::ConditionalBlockGradMaker); REGISTER_OPERATOR(conditional_block_grad, ops::ConditionalBlockGradOp, diff --git a/python/paddle/fluid/tests/unittests/test_conditional_block.py b/python/paddle/fluid/tests/unittests/test_conditional_block.py index 5b2b71d050c42b4fea84bab89824d3f5c164b36e..6a71d396b48b089d7cf06df1d143f3957a61deb1 100644 --- a/python/paddle/fluid/tests/unittests/test_conditional_block.py +++ b/python/paddle/fluid/tests/unittests/test_conditional_block.py @@ -14,42 +14,70 @@ from __future__ import print_function +import numpy as np import unittest +import paddle.fluid as fluid import paddle.fluid.layers as layers import paddle.fluid.core as core -from paddle.fluid.framework import default_startup_program, default_main_program from paddle.fluid.executor import Executor from paddle.fluid.backward import append_backward from paddle.fluid.layers.control_flow import ConditionalBlock -import numpy class ConditionalBlockTest(unittest.TestCase): def test_forward(self): - data = layers.data(name='X', shape=[1], dtype='float32') - data.stop_gradient = False - cond = ConditionalBlock(inputs=[data]) - out = layers.create_tensor(dtype='float32') - with cond.block(): - hidden = layers.fc(input=data, size=10) - layers.assign(hidden, out) - - cpu = core.CPUPlace() - exe = Executor(cpu) - exe.run(default_startup_program()) - - x = numpy.random.random(size=(10, 1)).astype('float32') - - outs = exe.run(feed={'X': x}, fetch_list=[out])[0] - print(outs) - loss = layers.mean(out) - append_backward(loss=loss) - outs = exe.run( - feed={'X': x}, - fetch_list=[ - default_main_program().block(0).var(data.name + "@GRAD") - ])[0] - print(outs) + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + data = layers.data(name='X', shape=[1], dtype='float32') + data.stop_gradient = False + cond = ConditionalBlock(inputs=[data]) + out = layers.create_tensor(dtype='float32') + with cond.block(): + hidden = layers.fc(input=data, size=10) + layers.assign(hidden, out) + + cpu = core.CPUPlace() + exe = Executor(cpu) + exe.run(startup_program) + + x = np.random.random(size=(10, 1)).astype('float32') + + outs = exe.run(main_program, feed={'X': x}, fetch_list=[out])[0] + print(outs) + loss = layers.mean(out) + append_backward(loss=loss) + outs = exe.run( + main_program, + feed={'X': x}, + fetch_list=[main_program.block(0).var(data.name + "@GRAD")])[0] + print(outs) + + +class TestConditionalBlockOpInferShape(unittest.TestCase): + def test_infer_shape(self): + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + global_block = main_program.global_block() + sub_block = main_program._create_block() + main_program._rollback() + step_scope = global_block.create_var( + type=core.VarDesc.VarType.STEP_SCOPES) + cond_var = layers.fill_constant( + shape=[1], dtype='bool', value=False) + + op = global_block.append_op( + type='conditional_block', + inputs={ + 'Cond': [cond_var], + 'Input': [], + }, + outputs={'Out': [], + 'Scope': [step_scope]}, + attrs={'sub_block': sub_block, + 'is_scalar_condition': True}) + op.desc.infer_shape(global_block.desc) if __name__ == '__main__':