/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include "paddle/framework/op_desc.h" #include "paddle/framework/operator.h" namespace paddle { namespace framework { /* This functor class is responsible for creating the gradient ops for the given operator fwd_op. After it is called (through operator()), the pairs of (gradient variable, corresponding input variable of fwd_op) will be added to grad_to_var. If an input variable of fwd_op is contained in no_grad_set, its gradient varialbe will be ignored or kEmptyVarName depending on the template argument DropEmptyIG in the derived classes. */ class GradOpDescMakerBase { public: explicit GradOpDescMakerBase( const OpDesc& fwd_op, const std::unordered_set& no_grad_set, std::unordered_map* grad_to_var, const std::vector& grad_block = std::vector()) : fwd_op_(fwd_op), no_grad_set_(no_grad_set), grad_to_var_(grad_to_var), grad_block_(grad_block) {} virtual ~GradOpDescMakerBase() = default; virtual std::vector> operator()() const = 0; protected: std::vector InputGrad(const std::string& name, bool drop_empty_grad = true) const { std::vector ret_val; auto var_names = this->Input(name); ret_val.reserve(var_names.size()); std::transform(var_names.begin(), var_names.end(), std::back_inserter(ret_val), [this](const std::string& fwd_var_name) -> std::string { auto g_name = GradVarName(fwd_var_name); if (no_grad_set_.count(g_name)) { return kEmptyVarName; } else { (*this->grad_to_var_)[g_name] = fwd_var_name; return g_name; } }); if (!drop_empty_grad) { return ret_val; } PADDLE_ENFORCE_LE(var_names.size(), 1UL, "BUG from operator developer:" " for input argument with a list of variables, " " drop_empty_grad is not allowed because it makes" " the correspondence bewteen a variable and its gradient" " ambiguous. Use REGISTER_OP_EX to register the op" " or call InputGrad(?,false) in GradOpDescMaker." " Op type %s", fwd_op_.Type()); std::vector dropped_ret_val; dropped_ret_val.reserve(ret_val.size()); std::copy_if(ret_val.begin(), ret_val.end(), std::back_inserter(dropped_ret_val), [](const std::string& str) { return str != kEmptyVarName; }); return dropped_ret_val; } std::vector OutputGrad(const std::string& name) const { std::vector ret_val; auto onames = this->Output(name); ret_val.reserve(onames.size()); std::transform(onames.begin(), onames.end(), std::back_inserter(ret_val), [this](const std::string& fwd_var_name) -> std::string { auto g_name = GradVarName(fwd_var_name); (*this->grad_to_var_)[g_name] = fwd_var_name; return g_name; }); return ret_val; } std::vector InputNames() const { return this->fwd_op_.InputNames(); } std::vector OutputNames() const { return this->fwd_op_.OutputNames(); } std::vector Input(const std::string& name) const { return fwd_op_.Input(name); } std::vector Output(const std::string& name) const { return fwd_op_.Output(name); } const std::unordered_map& Attrs() const { return fwd_op_.GetAttrMap(); } const Attribute& GetAttr(const std::string& name) const { auto& map = fwd_op_.GetAttrMap(); auto it = map.find(name); PADDLE_ENFORCE(it != map.end(), "Cannot find attribute %s", name); return it->second; } std::string ForwardOpType() const { return this->fwd_op_.Type(); } private: const OpDesc& fwd_op_; const std::unordered_set& no_grad_set_; std::unordered_map* grad_to_var_; protected: std::vector grad_block_; }; class SingleGradOpDescMaker : public GradOpDescMakerBase { public: using GradOpDescMakerBase::GradOpDescMakerBase; std::vector> operator()() const { std::vector> retv; retv.emplace_back(this->Apply()); return retv; } protected: virtual std::unique_ptr Apply() const = 0; }; template class DefaultGradOpDescMaker : public SingleGradOpDescMaker { public: using SingleGradOpDescMaker::SingleGradOpDescMaker; protected: virtual std::unique_ptr Apply() const { auto* grad = new OpDesc(); grad->SetType(this->GradOpType()); for (auto& input_param : this->InputNames()) { grad->SetInput(input_param, this->Input(input_param)); grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param, DropEmptyIG)); } for (auto& output_param : this->OutputNames()) { grad->SetInput(output_param, this->Output(output_param)); grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param)); } grad->SetAttrMap(this->Attrs()); return std::unique_ptr(grad); } virtual std::string GradOpType() const { return this->ForwardOpType() + "_grad"; } }; class EmptyGradOpMaker : public GradOpDescMakerBase { public: using GradOpDescMakerBase::GradOpDescMakerBase; std::vector> operator()() const override { return {}; } }; } // namespace framework } // namespace paddle