提交 c1feb27f 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #3831 from reyoung/feature/fix_empty_input_and_output

Make operator Input/Output can return nullptr
...@@ -127,8 +127,8 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker { ...@@ -127,8 +127,8 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker {
public: public:
FillZeroOpMaker(OpProto *proto, OpAttrChecker *op_checker) FillZeroOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("x", "x"); AddInput("Src", "x");
AddOutput("out", "out"); AddOutput("Dst", "out");
AddComment(""); AddComment("");
} }
}; };
...@@ -138,7 +138,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker { ...@@ -138,7 +138,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "x").AsDuplicable(); AddInput("X", "x").AsDuplicable();
AddOutput("Y", "y"); AddOutput("Out", "out");
AddComment(""); AddComment("");
} }
}; };
......
...@@ -80,9 +80,19 @@ class OpInfoMap { ...@@ -80,9 +80,19 @@ class OpInfoMap {
} }
const OpInfo& Get(const std::string& type) const { const OpInfo& Get(const std::string& type) const {
auto op_info_ptr = GetNullable(type);
PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not been registered",
type);
return *op_info_ptr;
}
const OpInfo* GetNullable(const std::string& type) const {
auto it = map_.find(type); auto it = map_.find(type);
PADDLE_ENFORCE(it != map_.end(), "Operator %s are not found", type); if (it == map_.end()) {
return it->second; return nullptr;
} else {
return &it->second;
}
} }
template <typename Callback> template <typename Callback>
......
...@@ -33,12 +33,12 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ...@@ -33,12 +33,12 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
} }
#endif #endif
const std::string& OperatorBase::Input(const std::string& name) const { std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name); auto& ins = Inputs(name);
PADDLE_ENFORCE_EQ(ins.size(), 1UL, PADDLE_ENFORCE_LE(ins.size(), 1UL,
"Op %s input %s should contain only one variable", type_, "Op %s input %s should contain only one variable", type_,
name); name);
return ins[0]; return ins.empty() ? kEmptyVarName : ins[0];
} }
const std::vector<std::string>& OperatorBase::Inputs( const std::vector<std::string>& OperatorBase::Inputs(
...@@ -49,12 +49,12 @@ const std::vector<std::string>& OperatorBase::Inputs( ...@@ -49,12 +49,12 @@ const std::vector<std::string>& OperatorBase::Inputs(
return it->second; return it->second;
} }
const std::string& OperatorBase::Output(const std::string& name) const { std::string OperatorBase::Output(const std::string& name) const {
auto& outs = Outputs(name); auto& outs = Outputs(name);
PADDLE_ENFORCE_EQ(outs.size(), 1UL, PADDLE_ENFORCE_LE(outs.size(), 1UL,
"Op %s output %s should contain only one variable", type_, "Op %s output %s should contain only one variable", type_,
name); name);
return outs[0]; return outs.empty() ? kEmptyVarName : outs[0];
} }
const std::vector<std::string>& OperatorBase::Outputs( const std::vector<std::string>& OperatorBase::Outputs(
...@@ -119,16 +119,8 @@ OperatorBase::OperatorBase(const std::string& type, ...@@ -119,16 +119,8 @@ OperatorBase::OperatorBase(const std::string& type,
const VariableNameMap& outputs, const VariableNameMap& outputs,
const AttributeMap& attrs) const AttributeMap& attrs)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) { : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
static std::atomic<size_t> gUniqId(0UL); GenerateTemporaryNames();
for (auto& output : outputs_) { CheckAllInputOutputSet();
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
}
}
} }
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const { std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
...@@ -156,6 +148,35 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const { ...@@ -156,6 +148,35 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
return ret_val; return ret_val;
} }
void OperatorBase::CheckAllInputOutputSet() const {
auto& info_map = OpInfoMap::Instance();
auto* op_info = info_map.GetNullable(Type());
if (op_info == nullptr || op_info->proto_ == nullptr) return;
for (auto& in : op_info->Proto().inputs()) {
PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
"Type %s's input %s is not set", Type(), in.name());
}
for (auto& out : op_info->Proto().outputs()) {
PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
"Type %s's output %s is not set", Type(), out.name());
}
}
void OperatorBase::GenerateTemporaryNames() {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : outputs_) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
}
}
}
void OpProtoAndCheckerMaker::Validate() { void OpProtoAndCheckerMaker::Validate() {
validated_ = true; validated_ = true;
CheckNoDuplicatedInOutAttrs(); CheckNoDuplicatedInOutAttrs();
......
...@@ -95,12 +95,12 @@ class OperatorBase { ...@@ -95,12 +95,12 @@ class OperatorBase {
const VariableNameMap& Inputs() const { return inputs_; } const VariableNameMap& Inputs() const { return inputs_; }
const VariableNameMap& Outputs() const { return outputs_; } const VariableNameMap& Outputs() const { return outputs_; }
//! Get a input with argument's name described in `op_proto` //! Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const; std::string Input(const std::string& name) const;
//! Get a input which has multiple variables. //! Get a input which has multiple variables.
const std::vector<std::string>& Inputs(const std::string& name) const; const std::vector<std::string>& Inputs(const std::string& name) const;
//! Get a output with argument's name described in `op_proto` //! Get a output with argument's name described in `op_proto`
const std::string& Output(const std::string& name) const; std::string Output(const std::string& name) const;
//! Get an output which has multiple variables. //! Get an output which has multiple variables.
//! TODO add a vector_view to prevent memory copy. //! TODO add a vector_view to prevent memory copy.
const std::vector<std::string>& Outputs(const std::string& name) const; const std::vector<std::string>& Outputs(const std::string& name) const;
...@@ -127,6 +127,10 @@ class OperatorBase { ...@@ -127,6 +127,10 @@ class OperatorBase {
// IG (Inputs Gradients) // IG (Inputs Gradients)
VariableNameMap outputs_; VariableNameMap outputs_;
AttributeMap attrs_; AttributeMap attrs_;
private:
void GenerateTemporaryNames();
void CheckAllInputOutputSet() const;
}; };
// Macro for define a clone method. // Macro for define a clone method.
...@@ -238,11 +242,13 @@ class InferShapeContext { ...@@ -238,11 +242,13 @@ class InferShapeContext {
} }
const Variable* InputVar(const std::string& name) const { const Variable* InputVar(const std::string& name) const {
return scope_.FindVar(op_.Input(name)); auto ipt = op_.Input(name);
return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
} }
Variable* OutputVar(const std::string& name) const { Variable* OutputVar(const std::string& name) const {
return scope_.FindVar(op_.Output(name)); auto opt = op_.Output(name);
return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
} }
const std::vector<const Variable*> MultiInputVar( const std::vector<const Variable*> MultiInputVar(
...@@ -250,9 +256,11 @@ class InferShapeContext { ...@@ -250,9 +256,11 @@ class InferShapeContext {
auto names = op_.Inputs(name); auto names = op_.Inputs(name);
std::vector<const Variable*> res; std::vector<const Variable*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform( std::transform(names.begin(), names.end(), std::back_inserter(res),
names.begin(), names.end(), std::back_inserter(res), [this](const std::string& name) {
[this](const std::string& name) { return scope_.FindVar(name); }); return name == kEmptyVarName ? nullptr
: scope_.FindVar(name);
});
return res; return res;
} }
...@@ -260,24 +268,24 @@ class InferShapeContext { ...@@ -260,24 +268,24 @@ class InferShapeContext {
auto names = op_.Outputs(name); auto names = op_.Outputs(name);
std::vector<const Variable*> res; std::vector<const Variable*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform( std::transform(names.begin(), names.end(), std::back_inserter(res),
names.begin(), names.end(), std::back_inserter(res), [this](const std::string& name) {
[this](const std::string& name) { return scope_.FindVar(name); }); return name == kEmptyVarName ? nullptr
: scope_.FindVar(name);
});
return res; return res;
} }
template <typename T> template <typename T>
const T* Input(const std::string& name) const { const T* Input(const std::string& name) const {
auto* var = InputVar(name); auto* var = InputVar(name);
PADDLE_ENFORCE_NOT_NULL(var, "Input(%s) should not be nullptr", name); return var == nullptr ? nullptr : &var->Get<T>();
return &var->Get<T>();
} }
template <typename T> template <typename T>
T* Output(const std::string& name) const { T* Output(const std::string& name) const {
auto var = OutputVar(name); auto var = OutputVar(name);
PADDLE_ENFORCE_NOT_NULL(var, "Output(%s) should not be nullptr", name); return var == nullptr ? nullptr : var->GetMutable<T>();
return var->GetMutable<T>();
} }
template <typename T> template <typename T>
...@@ -288,10 +296,7 @@ class InferShapeContext { ...@@ -288,10 +296,7 @@ class InferShapeContext {
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { [&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
PADDLE_ENFORCE_NOT_NULL( return var == nullptr ? nullptr : &var->Get<T>();
var, "MultiInput(%s:%s) should not be nullptr", name,
sub_name);
return &var->Get<T>();
}); });
return res; return res;
} }
...@@ -304,10 +309,7 @@ class InferShapeContext { ...@@ -304,10 +309,7 @@ class InferShapeContext {
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { [&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
PADDLE_ENFORCE_NOT_NULL( return var == nullptr ? nullptr : var->GetMutable<T>();
var, "MultiOutput(%s:%s) should not be nullptr.", name,
sub_name);
return var->GetMutable<T>();
}); });
return res; return res;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册