diff --git a/paddle/framework/details/op_registry.h b/paddle/framework/details/op_registry.h index d2516ccc1eec21682276c2fddf049984453404d4..daa474e8c5a223589018720da29a5c3363b5934d 100644 --- a/paddle/framework/details/op_registry.h +++ b/paddle/framework/details/op_registry.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/framework/grad_op_desc_maker.h" #include "paddle/framework/op_info.h" #include "paddle/framework/op_proto_maker.h" #include "paddle/framework/operator.h" @@ -96,7 +97,10 @@ struct OpInfoFiller { template struct OpInfoFiller { void operator()(const char* op_type, OpInfo* info) const { - info->grad_op_maker_ = new T(); + info->grad_op_maker_ = [](const OpDescBind& fwd_op) { + T maker(fwd_op); + return maker(); + }; } }; } // namespace details diff --git a/paddle/framework/grad_op_desc_maker.h b/paddle/framework/grad_op_desc_maker.h new file mode 100644 index 0000000000000000000000000000000000000000..cb4d160bd084f7739da9002cde9a73b7e3b477bb --- /dev/null +++ b/paddle/framework/grad_op_desc_maker.h @@ -0,0 +1,115 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/op_desc.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace framework { + +class GradOpDescMakerBase { + public: + explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {} + + virtual ~GradOpDescMakerBase() = default; + virtual std::vector operator()() const = 0; + + protected: + static std::vector ToGradNames( + const std::vector& var_names) { + std::vector ret_val; + ret_val.reserve(var_names.size()); + std::transform(var_names.begin(), var_names.end(), + std::back_inserter(ret_val), GradVarName); + return ret_val; + } + + std::vector InputGrad(const std::string& name) const { + return ToGradNames(fwd_op_.Input(name)); + } + + std::vector OutputGrad(const std::string& name) const { + return ToGradNames(fwd_op_.Output(name)); + } + + std::vector InputParamNames() const { + return this->fwd_op_.InputParamNames(); + } + + std::vector OutputParamNames() const { + return this->fwd_op_.OutputParamNames(); + } + + std::vector Input(const std::string& name) const { + return fwd_op_.Input(name); + } + + std::vector Output(const std::string& name) const { + return fwd_op_.Output(name); + } + + const std::unordered_map& Attrs() const { + return fwd_op_.GetAttrMap(); + } + + const Attribute& GetAttr(const std::string& name) const { + auto& map = fwd_op_.GetAttrMap(); + auto it = map.find(name); + PADDLE_ENFORCE(it != map.end(), "Cannot find attribute %s", name); + return it->second; + } + + std::string ForwardOpType() const { return this->fwd_op_.Type(); } + + private: + const OpDescBind& fwd_op_; +}; + +class SingleGradOpDescMaker : public GradOpDescMakerBase { + public: + std::vector operator()() const { return {this->Apply()}; } + + protected: + virtual OpDescBind Apply() const = 0; +}; + +class DefaultGradOpDescMaker : public SingleGradOpDescMaker { + protected: + virtual OpDescBind Apply() const { + OpDescBind grad; + grad.SetType(this->GradOpType()); + + for (auto& input_param : this->InputParamNames()) { + grad.SetInput(input_param, this->Input(input_param)); + grad.SetOutput(GradVarName(input_param), this->InputGrad(input_param)); + } + + for (auto& output_param : this->OutputParamNames()) { + grad.SetInput(output_param, this->Output(output_param)); + grad.SetInput(GradVarName(output_param), this->OutputGrad(output_param)); + } + + grad.SetAttrMap(this->Attrs()); + + return grad; + } + + virtual std::string GradOpType() const { + return this->ForwardOpType() + "_grad"; + } +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index 0cf7d13971675eb825bcd0c7636896f0862d6ebb..851a305061c1f402be5c903852762d89f15c267c 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -60,17 +60,31 @@ class OpDescBind { void SetBlockAttr(const std::string &name, BlockDescBind &block); - // Only be used in C++ - void SetAttrMap(const std::unordered_map &attr_map); - Attribute GetAttr(const std::string &name) const; int GetBlockAttr(const std::string &name) const; - // Only be used in C++ + // The following methods should only be used in C++ const std::unordered_map &GetAttrMap() const; + void SetAttrMap(const std::unordered_map &attr_map); + + std::vector InputParamNames() const { return MapKeys(inputs_); } + std::vector OutputParamNames() const { + return MapKeys(outputs_); + } + private: + template + static std::vector MapKeys(const MapType &map) { + std::vector ret_val; + ret_val.reserve(map.size()); + std::transform( + map.begin(), map.end(), ret_val.begin(), + [](const typename MapType::value_type &pair) { return pair.first; }); + return ret_val; + } + struct SetAttrDescVisitor : public boost::static_visitor { explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {} mutable OpDesc::Attr *attr_; diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index 6d1ee4dece14affa78ccf4ff1d9c7f09992f127f..8149c0061ac67ecfdc52013cc06252628d92168c 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -29,16 +29,10 @@ using OpCreator = std::function; -class GradOpDescMakerBase { - public: - virtual ~GradOpDescMakerBase() = default; - virtual std::vector operator()(const OpDescBind&) const = 0; -}; - struct OpInfo { OpCreator creator_; std::string grad_op_type_; - GradOpDescMakerBase* grad_op_maker_{nullptr}; + std::function(const OpDescBind&)> grad_op_maker_; OpProto* proto_{nullptr}; OpAttrChecker* checker_{nullptr};