diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index de15c8be7a7fc82f43ca31d03fc3c92e51b39f36..5bdb27683ee68440c590e6ca7f141ea0e80d9ba9 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" +#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor_array.h" @@ -145,8 +146,10 @@ std::shared_ptr Executor::InitForDataset( VLOG(3) << "Start to RunFromDataset in executor"; TrainerDesc trainer_desc; bool success = trainer_desc.ParseFromString(trainer_desc_str); - PADDLE_ENFORCE_EQ(success, true, "Fail to parse TrainerDesc from string:\n%s", - trainer_desc_str.c_str()); + PADDLE_ENFORCE_EQ(success, true, + platform::errors::PreconditionNotMet( + "Fail to parse TrainerDesc from string:\n%s", + trainer_desc_str.c_str())); VLOG(3) << "Going to create trainer, trainer class is " << trainer_desc.class_name(); std::shared_ptr trainer; @@ -165,8 +168,9 @@ std::shared_ptr Executor::InitForDataset( } void Executor::RunFromDataset(std::shared_ptr trainer) { - PADDLE_ENFORCE_NE(trainer, nullptr, - "Trainer is nullptr, invoke InitForDataset first"); + PADDLE_ENFORCE_NOT_NULL( + trainer, platform::errors::InvalidArgument( + "Trainer is nullptr, invoke InitForDataset first")); // training and finalize training VLOG(3) << "Trainer starts to run"; trainer->Run(); @@ -203,29 +207,41 @@ static bool has_feed_operators( if (op->Type() == kFeedOpType) { feed_count++; // The input variable's name of feed_op should be feed_holder_name. - PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name, - "Input to feed op should be '%s'", feed_holder_name); + PADDLE_ENFORCE_EQ( + op->Input("X")[0], feed_holder_name, + platform::errors::PreconditionNotMet( + "Input to feed op should be '%s', but received '%s'.", + feed_holder_name, op->Input("X")[0])); std::string feed_target_name = op->Output("Out")[0]; - PADDLE_ENFORCE( - feed_targets.find(feed_target_name) != feed_targets.end(), - "Feed operator output name '%s' cannot be found in 'feed_targets'", - feed_target_name); + PADDLE_ENFORCE_NE(feed_targets.find(feed_target_name), feed_targets.end(), + platform::errors::PreconditionNotMet( + "Feed operator output name '%s' cannot be found in " + "'feed_targets'", + feed_target_name)); } } if (feed_count > 0) { PADDLE_ENFORCE_EQ( feed_count, feed_targets.size(), - "The number of feed operators should match 'feed_targets'"); + platform::errors::PreconditionNotMet( + "The number of feed operators should match 'feed_targets', but " + "received feed_count: %zu, required feed_targets.size(): %zu.", + feed_count, feed_targets.size())); if (!feed_holder_name.empty()) { // When feed operator are present, so should be feed_holder. auto var = block.FindVar(feed_holder_name); - PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", - feed_holder_name); - PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH, - "'%s' variable should be 'FEED_MINIBATCH' type", - feed_holder_name); + PADDLE_ENFORCE_NOT_NULL( + var, + platform::errors::PreconditionNotMet( + "Block should already have a '%s' variable", feed_holder_name)); + PADDLE_ENFORCE_EQ( + var->GetType(), proto::VarType::FEED_MINIBATCH, + platform::errors::PreconditionNotMet( + "'%s' variable should be 'FEED_MINIBATCH' type, but received " + "'%s'.", + feed_holder_name, DataTypeToString(var->GetType()))); } } @@ -247,29 +263,41 @@ static bool has_fetch_operators( if (op->Type() == kFetchOpType) { fetch_count++; // The output variable's name of fetch_op should be fetch_holder_name. - PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name, - "Output of fetch op should be '%s'", fetch_holder_name); + PADDLE_ENFORCE_EQ( + op->Output("Out")[0], fetch_holder_name, + platform::errors::PreconditionNotMet( + "Output of fetch op should be '%s', but received '%s'.", + fetch_holder_name, op->Output("Out")[0])); std::string fetch_target_name = op->Input("X")[0]; - PADDLE_ENFORCE( - fetch_targets.find(fetch_target_name) != fetch_targets.end(), - "Fetch operator input name '%s' cannot be found in 'fetch_targets'", - fetch_target_name); + PADDLE_ENFORCE_NE(fetch_targets.find(fetch_target_name), + fetch_targets.end(), + platform::errors::NotFound( + "Fetch operator input name '%s' cannot be found in " + "'fetch_targets'.", + fetch_target_name)); } } if (fetch_count > 0) { PADDLE_ENFORCE_EQ( fetch_count, fetch_targets.size(), - "The number of fetch operators should match 'fetch_targets'"); + platform::errors::PreconditionNotMet( + "The number of fetch operators should match 'fetch_targets', but " + "received fetch_count: %zu, required fetch_targets.size(): %zu.", + fetch_count, fetch_targets.size())); if (!fetch_holder_name.empty()) { // When fetch operator are present, so should be fetch_holder. auto var = block.FindVar(fetch_holder_name); - PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", - fetch_holder_name); - PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST, - "'%s' variable should be 'FETCH_LIST' type", - fetch_holder_name); + PADDLE_ENFORCE_NOT_NULL( + var, + platform::errors::PreconditionNotMet( + "Block should already have a '%s' variable.", fetch_holder_name)); + PADDLE_ENFORCE_EQ( + var->GetType(), proto::VarType::FETCH_LIST, + platform::errors::PreconditionNotMet( + "'%s' variable should be 'FETCH_LIST' type, but received '%s'.", + fetch_holder_name, DataTypeToString(var->GetType()))); } } @@ -354,7 +382,11 @@ std::unique_ptr Executor::Prepare( const std::vector& skip_ref_cnt_vars, bool force_disable_gc) { std::unique_ptr ctx( new ExecutorPrepareContext(program, block_id)); - PADDLE_ENFORCE_LT(static_cast(block_id), program.Size()); + PADDLE_ENFORCE_LT(static_cast(block_id), program.Size(), + platform::errors::InvalidArgument( + "Input block id = %d, but it should be less than " + "program.size() which is %d", + static_cast(block_id), program.Size())); auto& block = program.Block(block_id); for (auto& op_desc : block.AllOps()) { ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); @@ -367,14 +399,20 @@ std::vector> Executor::Prepare( const ProgramDesc& program, const std::vector& block_ids, const std::vector>& skip_ref_cnt_vars, bool force_disable_gc) { - PADDLE_ENFORCE( + PADDLE_ENFORCE_EQ( skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(), - "skip_ref_cnt_vars should be either empty or equals to block number %d", - block_ids.size()); + true, + platform::errors::InvalidArgument("skip_ref_cnt_vars should be either " + "empty or equals to block number %d", + block_ids.size())); std::vector> result; size_t idx = 0; for (auto& bid : block_ids) { - PADDLE_ENFORCE_LT(static_cast(bid), program.Size()); + PADDLE_ENFORCE_LT(static_cast(bid), program.Size(), + platform::errors::InvalidArgument( + "Input block id = %zu, but it should be less than " + "program.size() which is %zu", + static_cast(bid), program.Size())); auto* ctx = new ExecutorPrepareContext(program, bid); auto& block = program.Block(bid); for (auto& op_desc : block.AllOps()) { @@ -397,7 +435,8 @@ void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx, bool create_local_scope, bool create_vars, bool keep_kids) { platform::RecordBlock b(kProgramId); - PADDLE_ENFORCE_NOT_NULL(scope); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope shouldn't be null")); Scope* local_scope = scope; if (create_vars) { if (create_local_scope) { @@ -470,12 +509,14 @@ void Executor::RunPreparedContext( const std::string& fetch_holder_name) { auto& global_block = ctx->prog_.Block(ctx->block_id_); - PADDLE_ENFORCE( - has_feed_operators(global_block, *feed_targets, feed_holder_name), - "Program in ExecutorPrepareContext should has feed_ops."); - PADDLE_ENFORCE( + PADDLE_ENFORCE_EQ( + has_feed_operators(global_block, *feed_targets, feed_holder_name), true, + platform::errors::PreconditionNotMet( + "Program in ExecutorPrepareContext should has feed_ops.")); + PADDLE_ENFORCE_EQ( has_fetch_operators(global_block, *fetch_targets, fetch_holder_name), - "Program in the prepared context should has fetch_ops."); + true, platform::errors::PreconditionNotMet( + "Program in the prepared context should has fetch_ops.")); // map the data of feed_targets to feed_holder for (auto* op : global_block.AllOps()) { diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 851cce56dccbf9ce7d5ca7c04dabdf5716e8b8e1..6be9c5b4b7d50a8fd7ad7c4c0c5c9ebc9f3e73f6 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -155,7 +155,9 @@ class ParallelExecutorPrivate { VLOG(10) << "find nccl_id_var:" << var_name << ", nccl_id:" << nccl_id; } else { nccl_id = new ncclUniqueId(); - PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(nccl_id)); + PADDLE_ENFORCE_EQ( + platform::dynload::ncclGetUniqueId(nccl_id), ncclSuccess, + platform::errors::PreconditionNotMet("Get NCCL unique ID failed.")); VLOG(10) << "can't find nccl_id_var:" << var_name << ", nccl_id:" << nccl_id; } @@ -178,7 +180,9 @@ class ParallelExecutorPrivate { for (int i = 0; i < static_cast(bst.nccl_comm_num_); i++) { std::string var_name = platform::GetFlatNCCLVarName(i); auto nccl_id_var = scope->FindVar(var_name); - PADDLE_ENFORCE(nccl_id_var, "can't find %s nccl_id_var", var_name); + PADDLE_ENFORCE_NOT_NULL( + nccl_id_var, + platform::errors::NotFound("Can't find nccl_id_var '%s'.", var_name)); auto nccl_id = nccl_id_var->GetMutable(); flat_nccl_ids.push_back(nccl_id); } @@ -191,7 +195,9 @@ class ParallelExecutorPrivate { for (int i = 0; i < static_cast(bst.nccl_comm_num_); i++) { std::string var_name = platform::GetHierarchicalInterNCCLVarName(i); auto nccl_id_var = scope->FindVar(var_name); - PADDLE_ENFORCE(nccl_id_var, "can't find %s nccl_id_var", var_name); + PADDLE_ENFORCE_NOT_NULL(nccl_id_var, + platform::errors::NotFound( + "Can't find nccl_id_var '%s'.", var_name)); auto inter_nccl_id = nccl_id_var->GetMutable(); inter_nccl_ids.push_back(inter_nccl_id); } @@ -200,7 +206,9 @@ class ParallelExecutorPrivate { for (int i = 0; i < static_cast(bst.nccl_comm_num_); i++) { std::string var_name = platform::GetHierarchicalExterNCCLVarName(i); auto nccl_id_var = scope->FindVar(var_name); - PADDLE_ENFORCE(nccl_id_var, "can't find %s nccl_id_var", var_name); + PADDLE_ENFORCE_NOT_NULL(nccl_id_var, + platform::errors::NotFound( + "Can't find nccl_id_var '%s'.", var_name)); auto nccl_id = nccl_id_var->GetMutable(); exter_nccl_ids.push_back(nccl_id); } @@ -216,8 +224,9 @@ class ParallelExecutorPrivate { const std::string var_name = "NCCLCommunicator"; auto var = scope->FindVar(var_name); if (var != nullptr) { - PADDLE_ENFORCE(var->IsInitialized(), - "if %s exists, it must be initialized", var_name); + PADDLE_ENFORCE_EQ(var->IsInitialized(), true, + platform::errors::PreconditionNotMet( + "if %s exists, it must be initialized", var_name)); VLOG(1) << "find " << var_name << " in scope, so use it and does not recreate!"; nccl_ctxs_ = var->GetMutable(); @@ -225,15 +234,21 @@ class ParallelExecutorPrivate { } if (bst->use_hierarchical_allreduce_) { - PADDLE_ENFORCE(bst->num_trainers_ > 1, "num_trainers:%llu < 1", - bst->num_trainers_); - PADDLE_ENFORCE(bst->hierarchical_allreduce_inter_nranks_ > 1, - "inter_nranks:%d < 1", - bst->hierarchical_allreduce_inter_nranks_); - PADDLE_ENFORCE( - (bst->num_trainers_ % bst->hierarchical_allreduce_inter_nranks_ == 0), - "num_trainers:%llu mod inter_nranks:%d != 0", bst->num_trainers_, - bst->hierarchical_allreduce_inter_nranks_); + PADDLE_ENFORCE_GT( + bst->num_trainers_, 1, + platform::errors::PreconditionNotMet( + "The num_trainers should be greater than 1, but received %llu.", + bst->num_trainers_)); + PADDLE_ENFORCE_GT( + bst->hierarchical_allreduce_inter_nranks_, 1, + platform::errors::PreconditionNotMet( + "The inter_nranks should be greater than 1, but received %d.", + bst->hierarchical_allreduce_inter_nranks_)); + PADDLE_ENFORCE_EQ( + bst->num_trainers_ % bst->hierarchical_allreduce_inter_nranks_, 0, + platform::errors::PreconditionNotMet( + "num_trainers:%llu mod inter_nranks:%d != 0", bst->num_trainers_, + bst->hierarchical_allreduce_inter_nranks_)); bst->hierarchical_allreduce_exter_nranks_ = bst->num_trainers_ / bst->hierarchical_allreduce_inter_nranks_; @@ -369,7 +384,8 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) { max_memory_size)); VLOG(10) << "Created GarbageCollector at " << place; } else { - PADDLE_THROW("Unsupported place for garbage collection"); + PADDLE_THROW(platform::errors::PreconditionNotMet( + "Unsupported place for garbage collection")); } #ifdef PADDLE_WITH_CUDA } @@ -455,7 +471,9 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, } #if defined(PADDLE_WITH_CUDA) && defined(_WIN32) if (member_->use_cuda_) { - PADDLE_ENFORCE(places.size() == 1, "Windows can support Single GPU only."); + PADDLE_ENFORCE_EQ( + places.size(), 1, + platform::errors::Unavailable("Windows can support Single GPU only.")); } #endif @@ -484,7 +502,11 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, } } else { member_->own_local_scope_ = false; - PADDLE_ENFORCE_EQ(member_->places_.size(), local_scopes.size()); + PADDLE_ENFORCE_EQ(member_->places_.size(), local_scopes.size(), + platform::errors::PreconditionNotMet( + "member_->places_.size() = %d is not equal to " + "local_scopes.size() = %d", + member_->places_.size(), local_scopes.size())); for (size_t i = 0; i < member_->places_.size(); ++i) { member_->local_scopes_.emplace_back(&local_scopes[i]->NewScope()); } @@ -492,8 +514,9 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, std::vector graphs; if (member_->build_strategy_.async_mode_) { - PADDLE_ENFORCE(!member_->use_cuda_, - "gpu mode does not support async_mode_ now!"); + PADDLE_ENFORCE_EQ(member_->use_cuda_, false, + platform::errors::Unavailable( + "gpu mode does not support async_mode_ now!")); graphs.push_back(graph); for (size_t i = 1; i < places.size(); ++i) { auto *tmp_graph = new ir::Graph(graph->OriginProgram()); @@ -622,8 +645,12 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, scope_map.emplace(scope, &local_exec_scope); } - PADDLE_ENFORCE_EQ(member_->local_scopes_.size(), - member_->local_exec_scopes_.size()); + PADDLE_ENFORCE_EQ( + member_->local_scopes_.size(), member_->local_exec_scopes_.size(), + platform::errors::PreconditionNotMet( + "member_->local_scopes_.size() = %d is not equal to " + "member_->local_exec_scopes_.size() = %d", + member_->local_scopes_.size(), member_->local_exec_scopes_.size())); std::vector final_graphs; @@ -655,8 +682,8 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, } } #else - PADDLE_THROW( - "Paddle should be compiled with CUDA for ParallelGraph Execution."); + PADDLE_THROW(platform::errors::PreconditionNotMet( + "Paddle should be compiled with CUDA for ParallelGraph Execution.")); #endif } else { bool has_drop_last_read_op = details::HasDropLastReadOp(*graph); @@ -746,7 +773,10 @@ void ParallelExecutor::BCastParamsToDevices( } PADDLE_ENFORCE_EQ(member_->places_.size(), buffers.size(), - "variables' buffer size to bcast NOT equal to places"); + platform::errors::PreconditionNotMet( + "variables' buffer size to bcast is %d, which is " + "NOT equal to places size %d", + buffers.size(), member_->places_.size())); { auto *nccl_ctxs = member_->nccl_ctxs_->DefaultFlatCtx(); platform::NCCLGroupGuard guard; @@ -889,15 +919,17 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( "You should set the environment variable CPU_NUM in the system " "to determine the number of devices you need."; } - PADDLE_THROW(error_info); + PADDLE_THROW(platform::errors::PreconditionNotMet(error_info)); } else if (is_persistable) { if (lod_tensors.size() == 1) { lod_tensors.reserve(num_places); auto &tensor = lod_tensors.front(); - PADDLE_ENFORCE_EQ(tensor.dims(), pair.second.dims(), - "The dim doesn't match."); - PADDLE_ENFORCE_EQ(tensor.place(), member_->places_.at(0), - "The place doesn't match."); + PADDLE_ENFORCE_EQ( + tensor.dims(), pair.second.dims(), + platform::errors::PreconditionNotMet("The dim doesn't match.")); + PADDLE_ENFORCE_EQ( + tensor.place(), member_->places_.at(0), + platform::errors::PreconditionNotMet("The place doesn't match.")); for (size_t i = 1; i < num_places; ++i) { lod_tensors.emplace_back(); auto &tmp = lod_tensors.back(); @@ -914,7 +946,7 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( "value, you should feed %d samples.", lod_tensors.size(), pair.first, num_places, (is_cpu_place ? "CPU" : "GPU"), pair.first, num_places, num_places); - PADDLE_THROW(error_info); + PADDLE_THROW(platform::errors::PreconditionNotMet(error_info)); } } diff --git a/paddle/fluid/operators/split_lod_tensor_op.cc b/paddle/fluid/operators/split_lod_tensor_op.cc index 72faac4471db2d4421276ef136febca353c79c81..aa8c0b13dbb2e585e97cd47a0340336b01c2c408 100644 --- a/paddle/fluid/operators/split_lod_tensor_op.cc +++ b/paddle/fluid/operators/split_lod_tensor_op.cc @@ -58,7 +58,8 @@ class SplitLoDTensorOp : public framework::OperatorBase { framework::TensorCopy(mask, platform::CPUPlace(), dev_ctx, cpu_mask.get()); #else - PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option"); + PADDLE_THROW(paddle::platform::errors::Fatal( + "Not support GPU, Please compile WITH_GPU option")); #endif } auto *mask_data = cpu_mask->data(); @@ -146,34 +147,34 @@ class SplitLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker { class SplitLoDTensorInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *context) const override { - PADDLE_ENFORCE(context->HasInput("X"), - "SplitLoDTensorOp must have input X."); - PADDLE_ENFORCE(context->HasInput("Mask"), - "SplitLoDTensorOp must have input Mask."); - PADDLE_ENFORCE(context->HasOutput("OutTrue"), - "SplitLoDTensorOp must have output OutTrue."); - PADDLE_ENFORCE(context->HasOutput("OutFalse"), - "SplitLoDTensorOp must have output OutFalse."); + OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "SplitLoDTensor"); + OP_INOUT_CHECK(context->HasInput("Mask"), "Input", "Mask", + "SplitLoDTensor"); + OP_INOUT_CHECK(context->HasOutput("OutTrue"), "Output", "OutTrue", + "SplitLoDTensor"); + OP_INOUT_CHECK(context->HasOutput("OutFalse"), "Output", "OutFalse", + "SplitLoDTensor"); auto mask_dim = context->GetInputDim("Mask"); - PADDLE_ENFORCE_EQ(mask_dim.size(), 2, - "If you are using IfElse OP:" - "\n\nie = fluid.layers.IfElse(cond=cond)\nwith " - "ie.true_block():\n out_1 = ie.input(x)\n\n" - "Please ensure that the cond should be a 2-D tensor and " - "the second dim size of cond should be 1. " - "But now the cond's shape is [", - *mask_dim.Get(), "].\n"); - if (context->IsRuntime()) { - PADDLE_ENFORCE_EQ(mask_dim[1], 1, - "If you are using IfElse OP:" - "\n\nie = fluid.layers.IfElse(cond=cond)\nwith " - "ie.true_block():\n out_1 = ie.input(x)\n\n" - "Please ensure that the cond should be a 2-D tensor " - "and the second dim size of cond should be 1. " - "But now the cond's shape is [", - *mask_dim.Get(), "].\n"); - } + PADDLE_ENFORCE_EQ( + mask_dim.size(), 2, + platform::errors::InvalidArgument( + "If you are using IfElse OP:" + "\n\nie = fluid.layers.IfElse(cond=cond)\nwith " + "ie.true_block():\n out_1 = ie.input(x)\n\n" + "Please ensure that the cond should be a 2-D tensor and " + "the second dim size of cond should be 1. " + "But now the cond's shape is [", + *mask_dim.Get(), "].\n")); + PADDLE_ENFORCE_EQ(mask_dim[1], 1, + platform::errors::InvalidArgument( + "If you are using IfElse OP:" + "\n\nie = fluid.layers.IfElse(cond=cond)\nwith " + "ie.true_block():\n out_1 = ie.input(x)\n\n" + "Please ensure that the cond should be a 2-D tensor " + "and the second dim size of cond should be 1. " + "But now the cond's shape is [", + *mask_dim.Get(), "].\n")); context->SetOutputDim("OutTrue", context->GetInputDim("X")); context->SetOutputDim("OutFalse", context->GetInputDim("X")); diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index dd94171e07ab55a4518fa7eecdbc290befff3754..fb54bfa4092a3360ca5362bbd172e45aa53ae01b 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -110,9 +110,9 @@ def split_lod_tensor(input, mask, level=0): data into two parts. Args: - input(tuple|list|None): The input tensor that contains complete + input(Variable|tuple|list|None): The input tensor that contains complete lod information needed to construct the output. - mask(list): A bool column vector which masks the input. + mask(Variable|list): A bool column vector which masks the input. level(int): The specific lod level to split. Returns: @@ -135,6 +135,10 @@ def split_lod_tensor(input, mask, level=0): input=x, mask=y, level=level) """ + check_type(input, 'input', (Variable, list, tuple, type(None)), + 'fluid.layers.split_lod_tensor') + check_type(mask, 'mask', (Variable, list), 'fluid.layers.split_lod_tensor') + check_type(level, 'level', int, 'fluid.layers.split_lod_tensor') helper = LayerHelper('split_lod_tensor', **locals()) out_true = helper.create_variable_for_type_inference(dtype=input.dtype) out_false = helper.create_variable_for_type_inference(dtype=input.dtype) @@ -407,6 +411,7 @@ class StaticRNN(object): AFTER_RNN_BLOCK = 2 def __init__(self, name=None): + check_type(name, "name", (str, type(None)), "fluid.layers.StaticRNN") self.helper = LayerHelper("static_rnn", name=name) self.memories = {} # memory map, from pre_mem.name --> MemoryLink self.inputs = [] # input variable list in current block @@ -511,6 +516,12 @@ class StaticRNN(object): """ self._assert_in_rnn_block_('memory') + check_type(init, "init", (Variable, type(None)), + "fluid.layers.StaticRNN.memory") + check_type(shape, "shape", (list, tuple, type(None)), + "fluid.layers.StaticRNN.memory") + check_type(batch_ref, "batch_ref", (Variable, type(None)), + "fluid.layers.StaticRNN.memory") if init is None: if shape is None or batch_ref is None: raise ValueError( @@ -587,8 +598,7 @@ class StaticRNN(object): """ self._assert_in_rnn_block_('step_input') - if not isinstance(x, Variable): - raise TypeError("step input takes a Variable") + check_type(x, "x", Variable, "fluid.layers.StaticRNN.step_input") if self.seq_len is None: self.seq_len = x.shape[0] elif x.shape[0] != -1 and self.seq_len != x.shape[0]: @@ -641,8 +651,7 @@ class StaticRNN(object): """ self._assert_in_rnn_block_('step_output') - if not isinstance(o, Variable): - raise TypeError("step output takes a Variable") + check_type(o, "o", Variable, "fluid.layers.StaticRNN.step_output") tmp_o = self.helper.create_variable_for_type_inference(dtype=o.dtype) self.helper.append_op( @@ -715,8 +724,8 @@ class StaticRNN(object): None """ - if not isinstance(mem, Variable) or not isinstance(var, Variable): - raise TypeError("update memory should take variables") + check_type(mem, "mem", Variable, "fluid.layers.StaticRNN.update_memory") + check_type(var, "var", Variable, "fluid.layers.StaticRNN.update_memory") self.memories[mem.name].mem = var def _parent_block(self): @@ -2134,6 +2143,8 @@ def cond(pred, true_fn=None, false_fn=None, name=None): return false_fn() return None + check_variable_and_dtype(pred, "pred", ['bool'], "fluid.layers.cond") + check_type(name, "name", (str, type(None)), "fluid.layers.cond") helper = LayerHelper('cond', **locals()) true_output = None false_output = None @@ -2545,8 +2556,8 @@ class IfElse(object): IN_IF_ELSE_FALSE_BLOCKS = 2 def __init__(self, cond, name=None): - if not isinstance(cond, Variable): - raise TypeError("cond must be a Variable") + check_type(cond, "cond", Variable, "fluid.layers.IfElse") + check_type(name, "name", (str, type(None)), "fluid.layers.IfElse") self.helper = LayerHelper('ifelse', name=name) self.cond = cond self.input_table = {} @@ -2605,8 +2616,8 @@ class IfElse(object): self.IN_IF_ELSE_TRUE_BLOCKS else 0] parent_block = self._parent_block() for each_out in outs: - if not isinstance(each_out, Variable): - raise TypeError("Each output should be a variable") + check_type(each_out, "each output", Variable, + "fluid.layers.IfElse.output") # create outside tensor outside_out = parent_block.create_var( name=unique_name.generate_with_ignorable_key("_".join( diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 58380cf8e1440cab766f8a88c33341e22a0300ca..de69a1eea2eab2b2bbad0ffb379c3b416f225488 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -17,6 +17,7 @@ from . import core from . import framework from . import executor from . import compiler +from .data_feeder import check_type import sys __all__ = ['ParallelExecutor'] @@ -357,14 +358,14 @@ class ParallelExecutor(object): parallel_exe.drop_local_exe_scopes() """ - assert isinstance( - self._compiled_program._executor, - core.ParallelExecutor), "The Executor should be ParallelExecutor." + check_type(self._compiled_program._executor, + "the Executor of compiled program", core.ParallelExecutor, + "ParallelExecutor.drop_local_exe_scopes") self._compiled_program._executor.drop_local_exe_scopes() # This API is used to check whether DropLocalExeScopes can work. def _need_create_local_exe_scopes(self): - assert isinstance( - self._compiled_program._executor, - core.ParallelExecutor), "The Executor should be ParallelExecutor." + check_type(self._compiled_program._executor, + "the Executor of compiled program", core.ParallelExecutor, + "ParallelExecutor._need_create_local_exe_scopes") return self._compiled_program._executor._need_create_local_exe_scopes() diff --git a/python/paddle/fluid/tests/test_if_else_op.py b/python/paddle/fluid/tests/test_if_else_op.py index 0cadb747835ec33f996943676b332020d4451596..1c992b9d8cd38a3851f99b1fc78ef5639c7f6eef 100644 --- a/python/paddle/fluid/tests/test_if_else_op.py +++ b/python/paddle/fluid/tests/test_if_else_op.py @@ -221,5 +221,27 @@ class TestIfElseFalseBranch(TestIfElse): self.data = np.random.rand(25, 1).astype(np.float32) +class TestIfElseError(unittest.TestCase): + def test_input_type_error(self): + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + src = layers.data(name='data', shape=[1], dtype='float32') + const_value = layers.fill_constant( + [1], dtype='float32', value=123.0) + ifcond = layers.less_than(x=src, y=const_value) + with self.assertRaises(TypeError): + ie = layers.IfElse(set()) + with self.assertRaises(TypeError): + ie = layers.IfElse(ifcond, set()) + + with self.assertRaises(TypeError): + ie = layers.IfElse(ifcond) + with ie.true_block(): + true_target = ie.input(src) + true_target = fluid.layers.exp(true_target) + ie.output([]) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_cond.py b/python/paddle/fluid/tests/unittests/test_cond.py index 20c5d1af97c3cfdb04663acc02bb376f8e11b4b8..ac36329bcf629ec8f89eda43f29a6345042c180f 100644 --- a/python/paddle/fluid/tests/unittests/test_cond.py +++ b/python/paddle/fluid/tests/unittests/test_cond.py @@ -539,5 +539,28 @@ class TestCondBackward(unittest.TestCase): use_parallel_exe) +class TestCondWithError(unittest.TestCase): + def test_input_type_error(self): + main_program = framework.Program() + startup_program = framework.Program() + with framework.program_guard(main_program, startup_program): + pred = fluid.data(name='y', shape=[1], dtype='bool') + + def func(): + return pred + + with self.assertRaises(TypeError): + layers.cond(None, func, func) + + with self.assertRaises(TypeError): + layers.cond(pred, func, set()) + + with self.assertRaises(TypeError): + layers.cond(pred, set(), func) + + with self.assertRaises(TypeError): + layers.cond(pred, func, func, set()) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_split_and_merge_lod_tensor_op.py b/python/paddle/fluid/tests/unittests/test_split_and_merge_lod_tensor_op.py index f407eb1d8b75cc3c29fc798c19d4284881dcdd49..ffbdddf4b88a4660088d18d38fd6667842186b56 100644 --- a/python/paddle/fluid/tests/unittests/test_split_and_merge_lod_tensor_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_and_merge_lod_tensor_op.py @@ -221,5 +221,26 @@ class TestCPUSplitMergeLoDTensorGrad(unittest.TestCase): self.assertAlmostEqual(1.0, g_out_sum, delta=0.1) +class TestSplitLodTensorWithError(unittest.TestCase): + def test_error(self): + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + x = layers.data( + name='x', shape=[1], dtype='float32', stop_gradient=False) + y = layers.data( + name='y', shape=[1], dtype='bool', stop_gradient=False) + level = 0 + + with self.assertRaises(TypeError): + split_lod_tensor(input=set(), mask=y, level=level) + + with self.assertRaises(TypeError): + split_lod_tensor(input=x, mask=set(), level=level) + + with self.assertRaises(TypeError): + split_lod_tensor(input=x, mask=set(), level=None) + + if __name__ == '__main__': unittest.main()