/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include #include #include #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/imperative/dygraph_grad_maker.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/type_defs.h" namespace paddle { namespace framework { /* This functor class is responsible for creating the gradient ops for the given operator fwd_op. After it is called (through operator()), the pairs of (gradient variable, corresponding input variable of fwd_op) will be added to grad_to_var. If an input variable of fwd_op is contained in no_grad_set, its gradient varialbe will be ignored or kEmptyVarName depending on the template argument DropEmptyIG in the derived classes. */ class GradOpDescMakerBase { public: explicit GradOpDescMakerBase( const OpDesc& fwd_op, const std::unordered_set& no_grad_set, std::unordered_map* grad_to_var, const std::vector& grad_block = std::vector()) : fwd_op_(fwd_op), no_grad_set_(no_grad_set), grad_to_var_(grad_to_var), grad_block_(grad_block) {} virtual ~GradOpDescMakerBase() = default; virtual std::vector> operator()() const = 0; protected: std::vector InputGrad(const std::string& name, bool drop_empty_grad = true) const { std::vector ret_val; auto var_names = this->Input(name); ret_val.reserve(var_names.size()); std::transform(var_names.begin(), var_names.end(), std::back_inserter(ret_val), [this](const std::string& fwd_var_name) -> std::string { auto g_name = GradVarName(fwd_var_name); if (no_grad_set_.empty() || !no_grad_set_.count(g_name)) { (*this->grad_to_var_)[g_name] = fwd_var_name; return g_name; } else { return kEmptyVarName; } }); if (!drop_empty_grad) { return ret_val; } PADDLE_ENFORCE_LE(var_names.size(), 1UL, "BUG from operator developer:" " for input argument with a list of variables, " " drop_empty_grad is not allowed because it makes" " the correspondence bewteen a variable and its gradient" " ambiguous." " Op type %s", fwd_op_.Type()); std::vector dropped_ret_val; dropped_ret_val.reserve(ret_val.size()); std::copy_if(ret_val.begin(), ret_val.end(), std::back_inserter(dropped_ret_val), [](const std::string& str) { return str != kEmptyVarName; }); return dropped_ret_val; } std::vector OutputGrad(const std::string& name) const { std::vector ret_val; auto onames = this->Output(name); ret_val.reserve(onames.size()); std::transform(onames.begin(), onames.end(), std::back_inserter(ret_val), [this](const std::string& fwd_var_name) -> std::string { auto g_name = GradVarName(fwd_var_name); (*this->grad_to_var_)[g_name] = fwd_var_name; return g_name; }); return ret_val; } std::vector Empty() const { return {}; } std::vector InputNames() const { return this->fwd_op_.InputNames(); } std::vector OutputNames() const { return this->fwd_op_.OutputNames(); } std::vector Input(const std::string& name) const { return fwd_op_.Input(name); } std::vector Output(const std::string& name) const { return fwd_op_.Output(name); } const std::unordered_map& Attrs() const { return fwd_op_.GetAttrMap(); } const Attribute& GetAttr(const std::string& name) const { auto& map = fwd_op_.GetAttrMap(); auto it = map.find(name); PADDLE_ENFORCE(it != map.end(), "Cannot find attribute %s", name); return it->second; } template inline const T& Attr(const std::string& name) const { return boost::get(GetAttr(name)); } std::string ForwardOpType() const { return this->fwd_op_.Type(); } protected: bool HasInput(const std::string& name) const { return (fwd_op_.Inputs().count(name) > 0); } bool HasOutput(const std::string& name) const { return (fwd_op_.Outputs().count(name) > 0); } private: const OpDesc& fwd_op_; const std::unordered_set& no_grad_set_; std::unordered_map* grad_to_var_; protected: std::vector grad_block_; }; template class SingleGradOpMaker { public: std::vector> operator()() const { PADDLE_ENFORCE(false, "should not call this function"); return {}; } protected: virtual std::unique_ptr Apply() const = 0; }; template <> class SingleGradOpMaker : public GradOpDescMakerBase { public: using GradOpDescMakerBase::GradOpDescMakerBase; std::vector> operator()() const { std::vector> retv; retv.emplace_back(this->Apply()); return retv; } protected: virtual std::unique_ptr Apply() const = 0; }; template <> class SingleGradOpMaker : public imperative::GradOpBaseMakerBase { public: using GradOpBaseMakerBase::GradOpBaseMakerBase; public: std::vector> operator()() const { std::vector> retv; retv.emplace_back(this->Apply()); return retv; } protected: virtual std::unique_ptr Apply() const = 0; }; template class DefaultGradOpMaker final : public SingleGradOpMaker { public: using SingleGradOpMaker::SingleGradOpMaker; protected: std::unique_ptr Apply() const final { auto* grad = new T(); grad->SetType(this->ForwardOpType() + "_grad"); for (auto& input_param : this->InputNames()) { grad->SetInput(input_param, this->Input(input_param)); grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param, DropEmptyIG)); } for (auto& output_param : this->OutputNames()) { grad->SetInput(output_param, this->Output(output_param)); grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param)); } grad->SetAttrMap(this->Attrs()); return std::unique_ptr(grad); } }; template class EmptyGradOpMaker { public: virtual std::vector> operator()() const final { /* NOLINT */ return {}; } }; template <> class EmptyGradOpMaker final : public GradOpDescMakerBase { public: using GradOpDescMakerBase::GradOpDescMakerBase; std::vector> operator()() const final { return {}; } }; template <> class EmptyGradOpMaker final : public imperative::GradOpBaseMakerBase { public: using GradOpBaseMakerBase::GradOpBaseMakerBase; std::vector> operator()() const final { return {}; } }; } // namespace framework } // namespace paddle