diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index e3c510b70346a2baf6ccd756eaf689c146efee5f..cfe9cba308556475ef64b45e7178dfc418761598 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -52,7 +52,8 @@ std::vector OperatorBase::Inputs(const std::string& name) const { PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr"); auto input_format = GetAttr>("input_format"); auto offset = in_out_idxs_->at(name); - PADDLE_ENFORCE(input_format.at((size_t)offset + 1) <= (int)inputs_.size(), + PADDLE_ENFORCE(input_format.at(static_cast(offset) + 1) <= + static_cast(inputs_.size()), "Input Out Of Range"); return std::vector{ @@ -78,7 +79,8 @@ std::vector OperatorBase::Outputs(const std::string& name) const { PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr"); auto output_format = GetAttr>("output_format"); auto offset = in_out_idxs_->at(name); - PADDLE_ENFORCE(output_format.at((size_t)offset + 1) <= (int)outputs_.size(), + PADDLE_ENFORCE(output_format.at(static_cast(offset) + 1) <= + static_cast(outputs_.size()), "Output Out of Range"); return std::vector{ outputs_.begin() + output_format.at(offset), diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 6a9fe19b9b61333cf9db1cca3e34c72f3f9c99c5..0832a663dd01fe2921366d70599bc867e73af47c 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -161,22 +161,30 @@ class OperatorContext { template const T* Input(const size_t index) const { - return &(InputVar(index)->Get()); + auto var = InputVar(index); + PADDLE_ENFORCE(var != nullptr, "Input(%d) should not be nullptr", index); + return &var->Get(); } template T* Output(const size_t index) const { - return OutputVar(index)->GetMutable(); + auto var = OutputVar(index); + PADDLE_ENFORCE(var != nullptr, "Output(%d) should not be nullptr", index); + return var->GetMutable(); } template const T* Input(const std::string& name) const { - return &(InputVar(name)->Get()); + auto var = InputVar(name); + PADDLE_ENFORCE(var != nullptr, "Input(%s) should not be nullptr", name); + return &var->Get(); } template T* Output(const std::string& name) const { - return OutputVar(name)->GetMutable(); + auto var = OutputVar(name); + PADDLE_ENFORCE(var != nullptr, "Output(%s) should not be nullptr", name); + return var->GetMutable(); } template @@ -185,8 +193,12 @@ class OperatorContext { std::vector res; res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), - [this](const std::string& name) { - return &scope_.FindVar(name)->Get(); + [&](const std::string& sub_name) { + auto var = scope_.FindVar(sub_name); + PADDLE_ENFORCE(var != nullptr, + "MultiInput(%s:%s) should not be nullptr", + name, sub_name); + return &var->Get(); }); return res; } @@ -197,8 +209,12 @@ class OperatorContext { std::vector res; res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), - [this](const std::string& name) { - return scope_.FindVar(name)->GetMutable(); + [&](const std::string& sub_name) { + auto var = scope_.FindVar(sub_name); + PADDLE_ENFORCE(var != nullptr, + "MultiOutput(%s:%s) should not be nullptr", + name, sub_name); + return var->GetMutable(); }); return res; }