未验证 提交 6084af47 编写于 作者: Y Yiqun Liu 提交者: GitHub

Fix the bug when a input variable of op is dispensable. (#10268)

* Fix the bug when a input variable of op is dispensable.

* Add HasInputs/Outputs interfaces to OperatorBase.

* Remove the unreferenced header file.
上级 8a0c7e2e
......@@ -108,7 +108,7 @@ paddle_error paddle_matrix_get_row(paddle_matrix mat,
paddle_error paddle_matrix_get_shape(paddle_matrix mat,
uint64_t* height,
uint64_t* width) {
if (mat == nullptr) return kPD_NULLPTR;
if (mat == nullptr || cast(mat)->mat == nullptr) return kPD_NULLPTR;
if (height != nullptr) {
*height = cast(mat)->mat->getHeight();
}
......
......@@ -93,6 +93,14 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
RunImpl(scope, place);
}
bool OperatorBase::HasInputs(const std::string& name) const {
if (inputs_.find(name) != inputs_.end()) {
return true;
} else {
return false;
}
}
std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL,
......@@ -109,6 +117,14 @@ const std::vector<std::string>& OperatorBase::Inputs(
return it->second;
}
bool OperatorBase::HasOutputs(const std::string& name) const {
if (outputs_.find(name) != outputs_.end()) {
return true;
} else {
return false;
}
}
std::string OperatorBase::Output(const std::string& name) const {
auto& outs = Outputs(name);
PADDLE_ENFORCE_LE(outs.size(), 1UL,
......@@ -220,13 +236,18 @@ void OperatorBase::CheckAllInputOutputSet() const {
if (op_info == nullptr || op_info->proto_ == nullptr) return;
for (auto& in : op_info->Proto().inputs()) {
PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
"Type %s's input %s is not set", Type(), in.name());
if (!in.dispensable()) {
PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
"Operator %s's input, %s, is not set", Type(), in.name());
}
}
for (auto& out : op_info->Proto().outputs()) {
PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
"Type %s's output %s is not set", Type(), out.name());
if (!out.dispensable()) {
PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
"Operator %s's output, %s, is not set", Type(),
out.name());
}
}
}
......@@ -332,6 +353,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
: op_(op), scope_(scope) {}
bool HasInput(const std::string& name) const override {
if (!op_.HasInputs(name)) {
return false;
}
auto& ins = Inputs(name);
size_t length = ins.size();
if (length == 0) {
......@@ -345,6 +369,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
bool HasOutput(const std::string& name) const override {
if (!op_.HasOutputs(name)) {
return false;
}
auto& outs = Outputs(name);
size_t length = outs.size();
if (length == 0) {
......@@ -358,6 +385,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
bool HasInputs(const std::string& name) const override {
if (!op_.HasInputs(name)) {
return false;
}
auto inputs = op_.Inputs(name);
if (inputs.empty()) {
return false;
......@@ -371,6 +401,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
bool HasOutputs(const std::string& name) const override {
if (!op_.HasOutputs(name)) {
return false;
}
auto outputs = op_.Outputs(name);
if (outputs.empty()) {
return false;
......
......@@ -105,6 +105,7 @@ class OperatorBase {
const VariableNameMap& Inputs() const { return inputs_; }
const VariableNameMap& Outputs() const { return outputs_; }
bool HasInputs(const std::string& name) const;
//! Get a input with argument's name described in `op_proto`
std::string Input(const std::string& name) const;
//! Get a input which has multiple variables.
......@@ -112,6 +113,7 @@ class OperatorBase {
//! Get all inputs variable names
std::vector<std::string> InputVars() const;
bool HasOutputs(const std::string& name) const;
//! Get a output with argument's name described in `op_proto`
std::string Output(const std::string& name) const;
//! Get an output which has multiple variables.
......
......@@ -18,7 +18,6 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.pb.h"
namespace paddle {
namespace platform {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册