未验证 提交 6084af47 编写于 作者: Y Yiqun Liu 提交者: GitHub

Fix the bug when a input variable of op is dispensable. (#10268)

* Fix the bug when a input variable of op is dispensable.

* Add HasInputs/Outputs interfaces to OperatorBase.

* Remove the unreferenced header file.
上级 8a0c7e2e
...@@ -108,7 +108,7 @@ paddle_error paddle_matrix_get_row(paddle_matrix mat, ...@@ -108,7 +108,7 @@ paddle_error paddle_matrix_get_row(paddle_matrix mat,
paddle_error paddle_matrix_get_shape(paddle_matrix mat, paddle_error paddle_matrix_get_shape(paddle_matrix mat,
uint64_t* height, uint64_t* height,
uint64_t* width) { uint64_t* width) {
if (mat == nullptr) return kPD_NULLPTR; if (mat == nullptr || cast(mat)->mat == nullptr) return kPD_NULLPTR;
if (height != nullptr) { if (height != nullptr) {
*height = cast(mat)->mat->getHeight(); *height = cast(mat)->mat->getHeight();
} }
......
...@@ -93,6 +93,14 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { ...@@ -93,6 +93,14 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
RunImpl(scope, place); RunImpl(scope, place);
} }
bool OperatorBase::HasInputs(const std::string& name) const {
if (inputs_.find(name) != inputs_.end()) {
return true;
} else {
return false;
}
}
std::string OperatorBase::Input(const std::string& name) const { std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name); auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL, PADDLE_ENFORCE_LE(ins.size(), 1UL,
...@@ -109,6 +117,14 @@ const std::vector<std::string>& OperatorBase::Inputs( ...@@ -109,6 +117,14 @@ const std::vector<std::string>& OperatorBase::Inputs(
return it->second; return it->second;
} }
bool OperatorBase::HasOutputs(const std::string& name) const {
if (outputs_.find(name) != outputs_.end()) {
return true;
} else {
return false;
}
}
std::string OperatorBase::Output(const std::string& name) const { std::string OperatorBase::Output(const std::string& name) const {
auto& outs = Outputs(name); auto& outs = Outputs(name);
PADDLE_ENFORCE_LE(outs.size(), 1UL, PADDLE_ENFORCE_LE(outs.size(), 1UL,
...@@ -220,13 +236,18 @@ void OperatorBase::CheckAllInputOutputSet() const { ...@@ -220,13 +236,18 @@ void OperatorBase::CheckAllInputOutputSet() const {
if (op_info == nullptr || op_info->proto_ == nullptr) return; if (op_info == nullptr || op_info->proto_ == nullptr) return;
for (auto& in : op_info->Proto().inputs()) { for (auto& in : op_info->Proto().inputs()) {
if (!in.dispensable()) {
PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(), PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
"Type %s's input %s is not set", Type(), in.name()); "Operator %s's input, %s, is not set", Type(), in.name());
}
} }
for (auto& out : op_info->Proto().outputs()) { for (auto& out : op_info->Proto().outputs()) {
if (!out.dispensable()) {
PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(), PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
"Type %s's output %s is not set", Type(), out.name()); "Operator %s's output, %s, is not set", Type(),
out.name());
}
} }
} }
...@@ -332,6 +353,9 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -332,6 +353,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
: op_(op), scope_(scope) {} : op_(op), scope_(scope) {}
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
if (!op_.HasInputs(name)) {
return false;
}
auto& ins = Inputs(name); auto& ins = Inputs(name);
size_t length = ins.size(); size_t length = ins.size();
if (length == 0) { if (length == 0) {
...@@ -345,6 +369,9 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -345,6 +369,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
bool HasOutput(const std::string& name) const override { bool HasOutput(const std::string& name) const override {
if (!op_.HasOutputs(name)) {
return false;
}
auto& outs = Outputs(name); auto& outs = Outputs(name);
size_t length = outs.size(); size_t length = outs.size();
if (length == 0) { if (length == 0) {
...@@ -358,6 +385,9 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -358,6 +385,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
bool HasInputs(const std::string& name) const override { bool HasInputs(const std::string& name) const override {
if (!op_.HasInputs(name)) {
return false;
}
auto inputs = op_.Inputs(name); auto inputs = op_.Inputs(name);
if (inputs.empty()) { if (inputs.empty()) {
return false; return false;
...@@ -371,6 +401,9 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -371,6 +401,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
bool HasOutputs(const std::string& name) const override { bool HasOutputs(const std::string& name) const override {
if (!op_.HasOutputs(name)) {
return false;
}
auto outputs = op_.Outputs(name); auto outputs = op_.Outputs(name);
if (outputs.empty()) { if (outputs.empty()) {
return false; return false;
......
...@@ -105,6 +105,7 @@ class OperatorBase { ...@@ -105,6 +105,7 @@ class OperatorBase {
const VariableNameMap& Inputs() const { return inputs_; } const VariableNameMap& Inputs() const { return inputs_; }
const VariableNameMap& Outputs() const { return outputs_; } const VariableNameMap& Outputs() const { return outputs_; }
bool HasInputs(const std::string& name) const;
//! Get a input with argument's name described in `op_proto` //! Get a input with argument's name described in `op_proto`
std::string Input(const std::string& name) const; std::string Input(const std::string& name) const;
//! Get a input which has multiple variables. //! Get a input which has multiple variables.
...@@ -112,6 +113,7 @@ class OperatorBase { ...@@ -112,6 +113,7 @@ class OperatorBase {
//! Get all inputs variable names //! Get all inputs variable names
std::vector<std::string> InputVars() const; std::vector<std::string> InputVars() const;
bool HasOutputs(const std::string& name) const;
//! Get a output with argument's name described in `op_proto` //! Get a output with argument's name described in `op_proto`
std::string Output(const std::string& name) const; std::string Output(const std::string& name) const;
//! Get an output which has multiple variables. //! Get an output which has multiple variables.
......
...@@ -18,7 +18,6 @@ limitations under the License. */ ...@@ -18,7 +18,6 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.pb.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册