diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 9bf60b7b11636df97031d111031ef782a173006b..c08c6bba592fb490cc033ed22698b238b9b5477c 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -52,7 +52,8 @@ std::vector OperatorBase::Inputs(const std::string& name) const { PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr"); auto input_format = GetAttr>("input_format"); auto offset = in_out_idxs_->at(name); - PADDLE_ENFORCE(input_format.at((size_t)offset + 1) <= inputs_.size(), + PADDLE_ENFORCE(input_format.at(static_cast(offset) + 1) <= + static_cast(inputs_.size()), "Input Out Of Range"); return std::vector{ @@ -78,7 +79,8 @@ std::vector OperatorBase::Outputs(const std::string& name) const { PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr"); auto output_format = GetAttr>("output_format"); auto offset = in_out_idxs_->at(name); - PADDLE_ENFORCE(output_format.at((size_t)offset + 1) <= outputs_.size(), + PADDLE_ENFORCE(output_format.at(static_cast(offset) + 1) <= + static_cast(outputs_.size()), "Output Out of Range"); return std::vector{ outputs_.begin() + output_format.at(offset), diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index ef1521b83bb50774d7b4f710a5deff879373ba35..ff518265a45382115b53d61cab52b85a73de5f70 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -108,11 +108,11 @@ class OperatorContext { size_t OutputSize() const { return op_.outputs_.size(); } - const Variable* InputVar(const size_t& index) const { + const Variable* InputVar(const size_t index) const { return scope_->GetVariable(op_.inputs_.at(index)); } - Variable* OutputVar(const size_t& index) const { + Variable* OutputVar(const size_t index) const { return scope_->GetVariable(op_.outputs_.at(index)); } @@ -146,23 +146,31 @@ class OperatorContext { } template - const T* Input(const size_t& index) const { - return &(InputVar(index)->Get()); + const T* Input(const size_t index) const { + auto var = InputVar(index); + PADDLE_ENFORCE(var != nullptr, "Input(%d) should not be nullptr", index); + return &var->Get(); } template - T* Output(const size_t& index) const { - return OutputVar(index)->GetMutable(); + T* Output(const size_t index) const { + auto var = OutputVar(index); + PADDLE_ENFORCE(var != nullptr, "Output(%d) should not be nullptr", index); + return var->GetMutable(); } template const T* Input(const std::string& name) const { - return &(InputVar(name)->Get()); + auto var = InputVar(name); + PADDLE_ENFORCE(var != nullptr, "Input(%s) should not be nullptr", name); + return &var->Get(); } template T* Output(const std::string& name) const { - return OutputVar(name)->GetMutable(); + auto var = OutputVar(name); + PADDLE_ENFORCE(var != nullptr, "Output(%s) should not be nullptr", name); + return var->GetMutable(); } template @@ -171,8 +179,12 @@ class OperatorContext { std::vector res; res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), - [this](const std::string& name) { - return &scope_->GetVariable(name)->Get(); + [&](const std::string& sub_name) { + auto var = scope_->GetVariable(sub_name); + PADDLE_ENFORCE(var != nullptr, + "MultiInput(%s:%s) should not be nullptr", + name, sub_name); + return &var->Get(); }); return res; } @@ -183,8 +195,12 @@ class OperatorContext { std::vector res; res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), - [this](const std::string& name) { - return scope_->GetVariable(name)->GetMutable(); + [&](const std::string& sub_name) { + auto var = scope_->GetVariable(sub_name); + PADDLE_ENFORCE(var != nullptr, + "MultiOutput(%s:%s) should not be nullptr", + name, sub_name); + return var->GetMutable(); }); return res; }