提交 3b58574b 编写于 作者: Q qiaolongfei

add check in OPeratorContext Input/Output

上级 61ebacbc
...@@ -52,7 +52,8 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const { ...@@ -52,7 +52,8 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr"); PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr");
auto input_format = GetAttr<std::vector<int>>("input_format"); auto input_format = GetAttr<std::vector<int>>("input_format");
auto offset = in_out_idxs_->at(name); auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(input_format.at((size_t)offset + 1) <= inputs_.size(), PADDLE_ENFORCE(input_format.at(static_cast<size_t>(offset) + 1) <=
static_cast<int>(inputs_.size()),
"Input Out Of Range"); "Input Out Of Range");
return std::vector<std::string>{ return std::vector<std::string>{
...@@ -78,7 +79,8 @@ std::vector<std::string> OperatorBase::Outputs(const std::string& name) const { ...@@ -78,7 +79,8 @@ std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr"); PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr");
auto output_format = GetAttr<std::vector<int>>("output_format"); auto output_format = GetAttr<std::vector<int>>("output_format");
auto offset = in_out_idxs_->at(name); auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(output_format.at((size_t)offset + 1) <= outputs_.size(), PADDLE_ENFORCE(output_format.at(static_cast<size_t>(offset) + 1) <=
static_cast<int>(outputs_.size()),
"Output Out of Range"); "Output Out of Range");
return std::vector<std::string>{ return std::vector<std::string>{
outputs_.begin() + output_format.at(offset), outputs_.begin() + output_format.at(offset),
......
...@@ -108,11 +108,11 @@ class OperatorContext { ...@@ -108,11 +108,11 @@ class OperatorContext {
size_t OutputSize() const { return op_.outputs_.size(); } size_t OutputSize() const { return op_.outputs_.size(); }
const Variable* InputVar(const size_t& index) const { const Variable* InputVar(const size_t index) const {
return scope_->GetVariable(op_.inputs_.at(index)); return scope_->GetVariable(op_.inputs_.at(index));
} }
Variable* OutputVar(const size_t& index) const { Variable* OutputVar(const size_t index) const {
return scope_->GetVariable(op_.outputs_.at(index)); return scope_->GetVariable(op_.outputs_.at(index));
} }
...@@ -146,23 +146,31 @@ class OperatorContext { ...@@ -146,23 +146,31 @@ class OperatorContext {
} }
template <typename T> template <typename T>
const T* Input(const size_t& index) const { const T* Input(const size_t index) const {
return &(InputVar(index)->Get<T>()); auto var = InputVar(index);
PADDLE_ENFORCE(var != nullptr, "Input(%d) should not be nullptr", index);
return &var->Get<T>();
} }
template <typename T> template <typename T>
T* Output(const size_t& index) const { T* Output(const size_t index) const {
return OutputVar(index)->GetMutable<T>(); auto var = OutputVar(index);
PADDLE_ENFORCE(var != nullptr, "Output(%d) should not be nullptr", index);
return var->GetMutable<T>();
} }
template <typename T> template <typename T>
const T* Input(const std::string& name) const { const T* Input(const std::string& name) const {
return &(InputVar(name)->Get<T>()); auto var = InputVar(name);
PADDLE_ENFORCE(var != nullptr, "Input(%s) should not be nullptr", name);
return &var->Get<T>();
} }
template <typename T> template <typename T>
T* Output(const std::string& name) const { T* Output(const std::string& name) const {
return OutputVar(name)->GetMutable<T>(); auto var = OutputVar(name);
PADDLE_ENFORCE(var != nullptr, "Output(%s) should not be nullptr", name);
return var->GetMutable<T>();
} }
template <typename T> template <typename T>
...@@ -171,8 +179,12 @@ class OperatorContext { ...@@ -171,8 +179,12 @@ class OperatorContext {
std::vector<const T*> res; std::vector<const T*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) { [&](const std::string& sub_name) {
return &scope_->GetVariable(name)->Get<T>(); auto var = scope_->GetVariable(sub_name);
PADDLE_ENFORCE(var != nullptr,
"MultiInput(%s:%s) should not be nullptr",
name, sub_name);
return &var->Get<T>();
}); });
return res; return res;
} }
...@@ -183,8 +195,12 @@ class OperatorContext { ...@@ -183,8 +195,12 @@ class OperatorContext {
std::vector<const T*> res; std::vector<const T*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) { [&](const std::string& sub_name) {
return scope_->GetVariable(name)->GetMutable<T>(); auto var = scope_->GetVariable(sub_name);
PADDLE_ENFORCE(var != nullptr,
"MultiOutput(%s:%s) should not be nullptr",
name, sub_name);
return var->GetMutable<T>();
}); });
return res; return res;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册