From a756360263d60ddb38fc077fe98385cca6e46232 Mon Sep 17 00:00:00 2001 From: gfwm0502 <63550764+gfwm0502@users.noreply.github.com> Date: Fri, 17 Apr 2020 14:12:56 +0800 Subject: [PATCH] OP/API (While/while_loop/DynamicRNN) : Error Message Enhancement (#23896) As the title --- .../fluid/operators/controlflow/while_op.cc | 99 +++++++++++-------- .../operators/controlflow/while_op_helper.cc | 34 +++++-- python/paddle/fluid/layers/control_flow.py | 47 ++++----- .../fluid/tests/unittests/test_dyn_rnn.py | 33 +++++++ 4 files changed, 137 insertions(+), 76 deletions(-) diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index f6aaa49eced..25b55586463 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -49,10 +49,17 @@ class WhileOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const override { - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition))); + PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition)), + platform::errors::NotFound( + "Input(Condition) of WhileOp is not found.")); auto &cond = scope.FindVar(Input(kCondition))->Get(); - PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1})); + PADDLE_ENFORCE_EQ( + cond.dims(), paddle::framework::make_ddim({1}), + platform::errors::InvalidArgument( + "The shape of Input(Condition) of WhileOp must be 1. But now " + "the Condition's shape is ", + cond.dims().to_str(), ".\n")); framework::Executor executor(dev_place); auto *block = Attr(kStepBlock); @@ -72,7 +79,9 @@ class WhileOp : public framework::OperatorBase { step_scopes->clear(); } - PADDLE_ENFORCE_EQ(step_scopes->size(), 0, "The StepScope should be empty."); + PADDLE_ENFORCE_EQ(step_scopes->size(), 0, + platform::errors::PreconditionNotMet( + "The Output(StepScope) of WhileOp should be empty.")); bool cond_data = GetCondData(cond); bool is_test = Attr("is_test"); @@ -160,8 +169,10 @@ class WhileGradOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const override { - PADDLE_ENFORCE(!Attr("is_test"), - "GradOp is only callable when is_test is false"); + PADDLE_ENFORCE_EQ( + Attr("is_test"), false, + platform::errors::InvalidArgument( + "WhileGradOp is only callable when is_test is false.")); // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place); @@ -180,7 +191,14 @@ class WhileGradOp : public framework::OperatorBase { auto inside_og_names = Attr>("original_output_grad"); - PADDLE_ENFORCE_EQ(outside_og_names.size(), inside_og_names.size()); + PADDLE_ENFORCE_EQ(outside_og_names.size(), inside_og_names.size(), + platform::errors::InvalidArgument( + "The number of original output gradient names " + "does not match the number of backward input " + "gradient names. The number of Backward input " + "names is %d and the numbers of original output " + "gradient names is %d.", + outside_og_names.size(), inside_og_names.size())); for (auto cur_scope_iter = step_scopes->rbegin(); cur_scope_iter != step_scopes->rend(); ++cur_scope_iter) { @@ -222,11 +240,18 @@ class WhileGradOp : public framework::OperatorBase { inside_array[j].set_lod(outside_array->at(j).lod()); inside_array[j].ShareDataWith(outside_array->at(j)); } else { - PADDLE_ENFORCE_EQ(inside_array[j].numel(), 0); + PADDLE_ENFORCE_EQ( + inside_array[j].numel(), 0, + platform::errors::InvalidArgument( + "The numel of %d-th element of var %s (LoDTensorArray) " + "in while block must be 0, but received its numel is %d.", + j, inside_og_name, inside_array[j].numel())); } } } else { - PADDLE_THROW("Currently only support LoDTensor and LoDTensorArray."); + PADDLE_THROW(platform::errors::Unimplemented( + "Currently only support LoDTensor and LoDTensorArray in " + "WhileGradOp.")); } } executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false, true, @@ -236,7 +261,13 @@ class WhileGradOp : public framework::OperatorBase { // and inputs. auto &pg_ig_names = Outputs(kXGRAD); auto &p_names = Inputs(kX); - PADDLE_ENFORCE_EQ(pg_ig_names.size(), p_names.size()); + PADDLE_ENFORCE_EQ(pg_ig_names.size(), p_names.size(), + platform::errors::PreconditionNotMet( + "The number of names in Outputs(X@GRAD) does not " + "match the number of names in Inputs(X). The " + "number of names in Outputs(X@GRAD) is %d and " + "the number of names in Inputs(X) is %d.", + pg_ig_names.size(), p_names.size())); for (size_t param_id = 0; param_id < pg_ig_names.size(); ++param_id) { if (pg_ig_names[param_id] == framework::kEmptyVarName) { continue; // parameter doesn't have gradient @@ -247,7 +278,9 @@ class WhileGradOp : public framework::OperatorBase { // for example lookup_table_grad_op, the input(Idx) doesn't have // gradient. auto pg_ig_var = cur_scope.FindVar(inside_grad_name); - PADDLE_ENFORCE(pg_ig_var != nullptr); + PADDLE_ENFORCE_NOT_NULL( + pg_ig_var, platform::errors::NotFound("Variable %s is not found.", + inside_grad_name)); if (pg_ig_var->IsType()) { auto pg_ig_lod_t_arr = pg_ig_var->GetMutable(); @@ -277,13 +310,16 @@ class WhileGradOp : public framework::OperatorBase { // zero gradient variable in step 0 if (cur_scope_iter == step_scopes->rbegin()) { auto *var = (*cur_scope_iter)->FindVar(inside_grad_name); - PADDLE_ENFORCE_NOT_NULL(var, "Can not find var %s", inside_grad_name); - PADDLE_ENFORCE( + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::NotFound("Variable %s is not found.", + inside_grad_name)); + PADDLE_ENFORCE_EQ( var->IsType() || var->IsType(), - "Currently the type of var only can be LoDTensorArray, " - "or LoDTensor, but the received var[%s] is %s.", - inside_grad_name, framework::ToTypeName(var->Type())); + true, platform::errors::InvalidArgument( + "Currently the type of var only can be LoDTensorArray, " + "or LoDTensor, but the received var[%s] is %s.", + inside_grad_name, framework::ToTypeName(var->Type()))); if (var->IsType()) { auto &inside_tensor = var->Get(); @@ -422,41 +458,24 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { ctx->HasOutputs(framework::GradVarName(kX)); ctx->HasInputs(kOutputs); ctx->HasInputs(framework::GradVarName(kOutputs)); - auto pg_ig_names = ctx->Outputs(kXGRAD); std::vector in_var_ptrs = ctx->GetInputVarPtrs(kX); std::vector out_var_ptrs = ctx->GetOutputVarPtrs(kXGRAD); - PADDLE_ENFORCE(in_var_ptrs.size() == out_var_ptrs.size()); + PADDLE_ENFORCE_EQ(in_var_ptrs.size(), out_var_ptrs.size(), + platform::errors::InvalidArgument( + "The size of Inputs(X) must be the same as " + "the size of Outputs(X@GRAD).")); for (size_t i = 0; i < in_var_ptrs.size(); ++i) { if (pg_ig_names[i] == framework::kEmptyVarName) { continue; } - if (ctx->IsRuntime()) { - framework::Variable *in_var = - boost::get(in_var_ptrs[i]); - framework::Variable *out_var = - boost::get(out_var_ptrs[i]); - - auto type = framework::ToVarType(in_var->Type()); - if (type == framework::proto::VarType::LOD_TENSOR) { - out_var->GetMutable()->Resize( - in_var->Get().dims()); - } else if (type == framework::proto::VarType::SELECTED_ROWS) { - out_var->GetMutable()->set_height( - in_var->Get().GetCompleteDims()[0]); - } else if (type == framework::proto::VarType::LOD_TENSOR_ARRAY) { - PADDLE_THROW("WhileGradOp doesn't support type %d", - static_cast(type)); - } - } else { - framework::VarDesc *in_var = - boost::get(in_var_ptrs[i]); - boost::get(out_var_ptrs[i]) - ->SetShape(in_var->GetShape()); - } + framework::VarDesc *in_var = + boost::get(in_var_ptrs[i]); + boost::get(out_var_ptrs[i]) + ->SetShape(in_var->GetShape()); } } }; diff --git a/paddle/fluid/operators/controlflow/while_op_helper.cc b/paddle/fluid/operators/controlflow/while_op_helper.cc index 6ac41af8326..a3fe71f3ec8 100644 --- a/paddle/fluid/operators/controlflow/while_op_helper.cc +++ b/paddle/fluid/operators/controlflow/while_op_helper.cc @@ -83,7 +83,11 @@ static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op, auto &in_grads = bwd_op.Outputs().at(framework::GradVarName(kX)); PADDLE_ENFORCE_EQ( fwd_input.size(), in_grads.size(), - "Backward input gradient number does not match forward input number."); + platform::errors::PreconditionNotMet( + "Backward output gradient number does not match forward input number." + "The number of forward input number is %d and the number of backward " + "output geadient number is %d.", + fwd_input.size(), in_grads.size())); std::unordered_set backward_skip_vars; for (size_t i = 0; i < in_grads.size(); ++i) { @@ -104,7 +108,13 @@ static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op, static void FindAllWhileAndWhileGradOp(const framework::ProgramDesc &program, std::vector *while_ops, std::vector *while_grad_ops) { - PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size()); + PADDLE_ENFORCE_GE( + while_ops->size(), while_grad_ops->size(), + platform::errors::PreconditionNotMet( + "There are more while_grad_ops than forward while_ops in the graph " + "or program, the number of while_ops is %d and the number of " + "while_grad_ops is %d.", + while_ops->size(), while_grad_ops->size())); for (size_t i = 1; i < program.Size(); ++i) { auto &block = program.Block(i); for (size_t j = 0; j < block.OpSize(); ++j) { @@ -117,8 +127,13 @@ static void FindAllWhileAndWhileGradOp(const framework::ProgramDesc &program, } } - PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size(), - "There are extra while_grad ops in the graph or program"); + PADDLE_ENFORCE_GE( + while_ops->size(), while_grad_ops->size(), + platform::errors::InvalidArgument( + "There are more while_grad_ops than forward while_ops in the graph " + "or program, the number of while_ops is %d and the number of " + "while_grad_ops is %d.", + while_ops->size(), while_grad_ops->size())); } static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl( @@ -140,13 +155,16 @@ static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl( const OpVariant *matched_fwd_op = nullptr; for (auto &fwd_op : while_op_set) { if (IsMatchedWhileOpAndWhileGradOp(fwd_op, bwd_op)) { - PADDLE_ENFORCE(matched_fwd_op == nullptr, - "Found multiple matched while ops"); + PADDLE_ENFORCE_EQ(matched_fwd_op, nullptr, + platform::errors::PreconditionNotMet( + "Found multiple while forward ops match while " + "grad ops.")); matched_fwd_op = &fwd_op; } } PADDLE_ENFORCE_NOT_NULL(matched_fwd_op, - "Cannot find matched forward while op."); + platform::errors::PreconditionNotMet( + "Cannot find matched forward while op.")); ModifyWhileOpAndWhileGradOpAttr(*matched_fwd_op, bwd_op); while_op_set.erase(*matched_fwd_op); } @@ -209,7 +227,7 @@ bool GetCondData(const framework::LoDTensor &cond) { #else PADDLE_THROW(platform::errors::PreconditionNotMet( "This version of PaddlePaddle does NOT support GPU but got GPU tensor " - "Cond in WhileOp. Please compile WITH_GPU option")); + "Cond in WhileOp. Please compile WITH_GPU option.")); #endif return cpu_cond->data()[0]; } diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 753b869202c..5727e5cb258 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -882,14 +882,10 @@ class While(object): def __init__(self, cond, is_test=False, name=None): self.helper = LayerHelper("while", name=name) self.status = While.BEFORE_WHILE_BLOCK - if not isinstance(cond, Variable): - raise TypeError("condition should be a variable") - assert isinstance(cond, Variable) - if cond.dtype != core.VarDesc.VarType.BOOL: - raise TypeError("condition should be a boolean variable") + check_variable_and_dtype(cond, 'cond', ['bool'], 'fluid.layers.While') if reduce(lambda a, b: a * b, cond.shape, 1) != 1: raise TypeError( - "condition expected shape as [], but given shape as {0}.". + "condition expected shape as [1], but given shape as {0}.". format(list(cond.shape))) self.cond_var = cond self.is_test = is_test @@ -999,19 +995,16 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): raise TypeError("cond in while_loop should be callable") if not callable(body): raise TypeError("body in while_loop should be callable") - if not isinstance(loop_vars, (list, tuple)): - raise TypeError("loop_vars in while_loop should be a list or tuple") + check_type(loop_vars, 'loop_vars', (list, tuple), 'fluid.layers.while_loop') if len(loop_vars) == 0: raise ValueError("loop_vars in while_loop should not be empty") pre_cond = cond(*loop_vars) - if not isinstance(pre_cond, Variable): - raise TypeError("cond in while_loop should return a variable") - if pre_cond.dtype != core.VarDesc.VarType.BOOL: - raise TypeError("cond in while_loop should return a boolean variable") + check_variable_and_dtype(pre_cond, 'var of cond returned', ['bool'], + 'fluid.layers.while_loop') if reduce(lambda a, b: a * b, pre_cond.shape, 1) != 1: raise TypeError( - "the shape of the variable returned by cond should be []," + "the shape of the variable returned by cond should be [1]," "but given shape as {0}.".format(list(pre_cond.shape))) if in_dygraph_mode(): @@ -2906,9 +2899,7 @@ class DynamicRNN(object): rnn_output = drnn() """ self._assert_in_rnn_block_("step_input") - if not isinstance(x, Variable): - raise TypeError( - "step_input() can only take a Variable as its input.") + check_type(x, 'x', Variable, 'fluid.layers.DynamicRNN.step_input()') parent_block = self._parent_block_() if self.lod_rank_table is None: self.lod_rank_table = parent_block.create_var( @@ -3075,9 +3066,7 @@ class DynamicRNN(object): rnn_output = drnn() """ self._assert_in_rnn_block_("static_input") - if not isinstance(x, Variable): - raise TypeError( - "static_input() can only take a Variable as its input") + check_type(x, 'x', Variable, 'fluid.layers.DynamicRNN.static_input()') if self.lod_rank_table is None: raise RuntimeError( "static_input() must be called after step_input().") @@ -3242,10 +3231,12 @@ class DynamicRNN(object): """ self._assert_in_rnn_block_('memory') self._init_zero_idx_() + if shape is not None: + check_type(shape, 'shape', (list, tuple), + 'fluid.layers.DynamicRNN.memory()') if init is not None: - if not isinstance(init, Variable): - raise TypeError( - "The input arg `init` of memory() must be a Variable") + check_type(init, 'init', Variable, + 'fluid.layers.DynamicRNN.memory()') parent_block = self._parent_block_() init_tensor = init if need_reorder == True: @@ -3326,12 +3317,10 @@ class DynamicRNN(object): ValueError: When :code:`update_memory()` is called before :code:`step_input()` . """ self._assert_in_rnn_block_('update_memory') - if not isinstance(ex_mem, Variable): - raise TypeError("The input arg `ex_mem` of update_memory() must " - "be a Variable") - if not isinstance(new_mem, Variable): - raise TypeError("The input arg `new_mem` of update_memory() must " - "be a Variable") + check_type(ex_mem, 'ex_mem', Variable, + 'fluid.layers.DynamicRNN.update_memory()') + check_type(new_mem, 'new_mem', Variable, + 'fluid.layers.DynamicRNN.update_memory()') mem_array = self.mem_dict.get(ex_mem.name, None) if mem_array is None: @@ -3358,6 +3347,8 @@ class DynamicRNN(object): self._assert_in_rnn_block_('output') parent_block = self._parent_block_() for each in outputs: + check_type(each, "outputs", Variable, + "fluid.layers.DynamicRNN.output") outside_array = parent_block.create_var( name=unique_name.generate_with_ignorable_key("_".join( [self.helper.name, "output_array", each.name])), diff --git a/python/paddle/fluid/tests/unittests/test_dyn_rnn.py b/python/paddle/fluid/tests/unittests/test_dyn_rnn.py index 24b54b288e5..78f4669f7a7 100644 --- a/python/paddle/fluid/tests/unittests/test_dyn_rnn.py +++ b/python/paddle/fluid/tests/unittests/test_dyn_rnn.py @@ -19,6 +19,7 @@ import paddle import unittest import numpy +from paddle.fluid.framework import Program, program_guard from paddle.fluid.layers.control_flow import lod_rank_table from paddle.fluid.layers.control_flow import max_sequence_len from paddle.fluid.layers.control_flow import lod_tensor_to_array @@ -299,5 +300,37 @@ class TestDynamicRNN(unittest.TestCase): self.train_data = train_data_orig +class TestDynamicRNNErrors(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + init = fluid.layers.zeros(shape=[1], dtype='float32') + shape = 'shape' + sentence = fluid.data( + name='sentence', shape=[None, 32], dtype='float32', lod_level=1) + + # The type of Input(shape) in API(memory) must be list or tuple + def input_shape_type_of_memory(): + drnn = fluid.layers.DynamicRNN() + with drnn.block(): + res = drnn.memory(init, shape) + + self.assertRaises(TypeError, input_shape_type_of_memory) + + # The type of element of Input(*outputs) in API(output) must be Variable. + def outputs_type_of_output(): + drnn = fluid.layers.DynamicRNN() + with drnn.block(): + word = drnn.step_input(sentence) + memory = drnn.memory(shape=[10], dtype='float32', value=0) + hidden = fluid.layers.fc(input=[word, memory], + size=10, + act='tanh') + out = np.ones(1).astype('float32') + drnn.update_memory(ex_mem=memory, new_mem=hidden) + drnn.output(hidden, out) + + self.assertRaises(TypeError, outputs_type_of_output) + + if __name__ == '__main__': unittest.main() -- GitLab