diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index f4d99f55b5ec00fcca61d81e7bf71f7791248fe9..bf8b11e5f5ae801621f84bdbeffb5c4cf2dd8905 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -127,8 +127,8 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker { public: FillZeroOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("x", "x"); - AddOutput("out", "out"); + AddInput("Src", "x"); + AddOutput("Dst", "out"); AddComment(""); } }; @@ -138,7 +138,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker { AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "x").AsDuplicable(); - AddOutput("Y", "y"); + AddOutput("Out", "out"); AddComment(""); } }; diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index 94245c6c44aca962b0db890947a9dc5550ac0799..b98d8f23a14cf6fbe787953ad16b5c9ab99222ad 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -80,9 +80,19 @@ class OpInfoMap { } const OpInfo& Get(const std::string& type) const { + auto op_info_ptr = GetNullable(type); + PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not been registered", + type); + return *op_info_ptr; + } + + const OpInfo* GetNullable(const std::string& type) const { auto it = map_.find(type); - PADDLE_ENFORCE(it != map_.end(), "Operator %s are not found", type); - return it->second; + if (it == map_.end()) { + return nullptr; + } else { + return &it->second; + } } template diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 7abbde610f1e9c530393b9a9cabe40b826712212..790cfc4746b1d34da413fa3c29a266f962c6dde6 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -33,12 +33,12 @@ ExecutionContext::GetEigenDevice() const { } #endif -const std::string& OperatorBase::Input(const std::string& name) const { +std::string OperatorBase::Input(const std::string& name) const { auto& ins = Inputs(name); - PADDLE_ENFORCE_EQ(ins.size(), 1UL, + PADDLE_ENFORCE_LE(ins.size(), 1UL, "Op %s input %s should contain only one variable", type_, name); - return ins[0]; + return ins.empty() ? kEmptyVarName : ins[0]; } const std::vector& OperatorBase::Inputs( @@ -49,12 +49,12 @@ const std::vector& OperatorBase::Inputs( return it->second; } -const std::string& OperatorBase::Output(const std::string& name) const { +std::string OperatorBase::Output(const std::string& name) const { auto& outs = Outputs(name); - PADDLE_ENFORCE_EQ(outs.size(), 1UL, + PADDLE_ENFORCE_LE(outs.size(), 1UL, "Op %s output %s should contain only one variable", type_, name); - return outs[0]; + return outs.empty() ? kEmptyVarName : outs[0]; } const std::vector& OperatorBase::Outputs( @@ -119,16 +119,8 @@ OperatorBase::OperatorBase(const std::string& type, const VariableNameMap& outputs, const AttributeMap& attrs) : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) { - static std::atomic gUniqId(0UL); - for (auto& output : outputs_) { - for (auto& output_name : output.second) { - if (output_name == kTempVarName) { - output_name += type_; - output_name += "@"; - output_name += std::to_string(gUniqId.fetch_add(1)); - } - } - } + GenerateTemporaryNames(); + CheckAllInputOutputSet(); } std::vector OperatorBase::OutputVars(bool has_intermediate) const { @@ -156,6 +148,35 @@ std::vector OperatorBase::OutputVars(bool has_intermediate) const { return ret_val; } +void OperatorBase::CheckAllInputOutputSet() const { + auto& info_map = OpInfoMap::Instance(); + auto* op_info = info_map.GetNullable(Type()); + if (op_info == nullptr || op_info->proto_ == nullptr) return; + + for (auto& in : op_info->Proto().inputs()) { + PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(), + "Type %s's input %s is not set", Type(), in.name()); + } + + for (auto& out : op_info->Proto().outputs()) { + PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(), + "Type %s's output %s is not set", Type(), out.name()); + } +} + +void OperatorBase::GenerateTemporaryNames() { + static std::atomic gUniqId(0UL); + for (auto& output : outputs_) { + for (auto& output_name : output.second) { + if (output_name == kTempVarName) { + output_name += type_; + output_name += "@"; + output_name += std::to_string(gUniqId.fetch_add(1)); + } + } + } +} + void OpProtoAndCheckerMaker::Validate() { validated_ = true; CheckNoDuplicatedInOutAttrs(); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 8397570d26f06f0238e9c5afc85d721df7679257..590e335fdc8843ed9edd01a09605163de93f52d9 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -95,12 +95,12 @@ class OperatorBase { const VariableNameMap& Inputs() const { return inputs_; } const VariableNameMap& Outputs() const { return outputs_; } //! Get a input with argument's name described in `op_proto` - const std::string& Input(const std::string& name) const; + std::string Input(const std::string& name) const; //! Get a input which has multiple variables. const std::vector& Inputs(const std::string& name) const; //! Get a output with argument's name described in `op_proto` - const std::string& Output(const std::string& name) const; + std::string Output(const std::string& name) const; //! Get an output which has multiple variables. //! TODO add a vector_view to prevent memory copy. const std::vector& Outputs(const std::string& name) const; @@ -127,6 +127,10 @@ class OperatorBase { // IG (Inputs Gradients) VariableNameMap outputs_; AttributeMap attrs_; + + private: + void GenerateTemporaryNames(); + void CheckAllInputOutputSet() const; }; // Macro for define a clone method. @@ -238,11 +242,13 @@ class InferShapeContext { } const Variable* InputVar(const std::string& name) const { - return scope_.FindVar(op_.Input(name)); + auto ipt = op_.Input(name); + return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); } Variable* OutputVar(const std::string& name) const { - return scope_.FindVar(op_.Output(name)); + auto opt = op_.Output(name); + return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt); } const std::vector MultiInputVar( @@ -250,9 +256,11 @@ class InferShapeContext { auto names = op_.Inputs(name); std::vector res; res.reserve(names.size()); - std::transform( - names.begin(), names.end(), std::back_inserter(res), - [this](const std::string& name) { return scope_.FindVar(name); }); + std::transform(names.begin(), names.end(), std::back_inserter(res), + [this](const std::string& name) { + return name == kEmptyVarName ? nullptr + : scope_.FindVar(name); + }); return res; } @@ -260,24 +268,24 @@ class InferShapeContext { auto names = op_.Outputs(name); std::vector res; res.reserve(names.size()); - std::transform( - names.begin(), names.end(), std::back_inserter(res), - [this](const std::string& name) { return scope_.FindVar(name); }); + std::transform(names.begin(), names.end(), std::back_inserter(res), + [this](const std::string& name) { + return name == kEmptyVarName ? nullptr + : scope_.FindVar(name); + }); return res; } template const T* Input(const std::string& name) const { auto* var = InputVar(name); - PADDLE_ENFORCE_NOT_NULL(var, "Input(%s) should not be nullptr", name); - return &var->Get(); + return var == nullptr ? nullptr : &var->Get(); } template T* Output(const std::string& name) const { auto var = OutputVar(name); - PADDLE_ENFORCE_NOT_NULL(var, "Output(%s) should not be nullptr", name); - return var->GetMutable(); + return var == nullptr ? nullptr : var->GetMutable(); } template @@ -288,10 +296,7 @@ class InferShapeContext { std::transform(names.begin(), names.end(), std::back_inserter(res), [&](const std::string& sub_name) { auto var = scope_.FindVar(sub_name); - PADDLE_ENFORCE_NOT_NULL( - var, "MultiInput(%s:%s) should not be nullptr", name, - sub_name); - return &var->Get(); + return var == nullptr ? nullptr : &var->Get(); }); return res; } @@ -304,10 +309,7 @@ class InferShapeContext { std::transform(names.begin(), names.end(), std::back_inserter(res), [&](const std::string& sub_name) { auto var = scope_.FindVar(sub_name); - PADDLE_ENFORCE_NOT_NULL( - var, "MultiOutput(%s:%s) should not be nullptr.", name, - sub_name); - return var->GetMutable(); + return var == nullptr ? nullptr : var->GetMutable(); }); return res; }