// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include "paddle/fluid/framework/op_call_stack.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/static/static_global_utils.h" #include "paddle/phi/core/enforce.h" namespace paddle { namespace prim { /* This functor class is responsible for creating the gradient ops for the given operator fwd_op. After it is called (through operator()), the pairs of (gradient variable, corresponding input variable of fwd_op) will be added to grad_to_var. If an input variable of fwd_op is contained in no_grad_set, its gradient variable will be ignored or kEmptyVarName depending on the template argument DropEmptyIG in the derived classes. */ class CompositeGradOpMakerBase { public: explicit CompositeGradOpMakerBase( const framework::OpDesc& fwd_op, const std::unordered_set& no_grad_set, std::unordered_map* grad_to_var, const framework::BlockDesc* original_block, const std::vector& grad_block = std::vector()) : fwd_op_(fwd_op), no_grad_set_(no_grad_set), grad_to_var_(grad_to_var), original_block_(original_block), acting_program_(framework::ProgramDesc()), grad_block_(grad_block) { // TODO(jiabin): This should always execute by one thread... StaticCompositeContext::Instance().SetBlock( acting_program_.MutableBlock(0)); } virtual ~CompositeGradOpMakerBase() = default; virtual std::vector> operator()() { this->Apply(); std::vector> ops; // TODO(jiabin): Support multiple blocks later for (auto* op : StaticCompositeContext::Instance().GetBlock()->AllOps()) { ops.emplace_back(new framework::OpDesc(*op)); ops.back()->ResetBlock(); } return ops; } virtual void Apply() = 0; paddle::experimental::Tensor GetSingleForwardOutput(const std::string& name) { framework::VarDesc* out_desc = this->SingleForwardOutput(name); paddle::experimental::Tensor out = paddle::experimental::Tensor(std::make_shared(out_desc)); return out; } paddle::experimental::Tensor GetSingleForwardInput(const std::string& name) { paddle::experimental::Tensor input = paddle::experimental::Tensor( std::make_shared(this->SingleForwardInput(name))); return input; } paddle::experimental::Tensor GetSingleOutputGrad(const std::string& name) { paddle::experimental::Tensor output_grad = paddle::experimental::Tensor( std::make_shared(this->SingleOutputGrad(name))); return output_grad; } paddle::experimental::Tensor GetSingleInputGrad(const std::string& name) { framework::VarDesc* input_grad_desc = this->SingleInputGrad(name); if (!input_grad_desc) return paddle::experimental::Tensor(); paddle::experimental::Tensor input_grad = paddle::experimental::Tensor( std::make_shared(input_grad_desc)); return input_grad; } paddle::optional GetOptionalSingleForwardOutput( const std::string& name) { paddle::optional output_opt; framework::VarDesc* output_desc = this->SingleForwardOutput(name); if (!output_desc) return output_opt; paddle::experimental::Tensor output = paddle::experimental::Tensor(std::make_shared(output_desc)); output_opt = paddle::make_optional(output); return output_opt; } paddle::optional GetOptionalSingleForwardInput( const std::string& name) { paddle::optional input_opt; framework::VarDesc* input_desc = this->SingleForwardInput(name); if (!input_desc) return input_opt; paddle::experimental::Tensor input = paddle::experimental::Tensor(std::make_shared(input_desc)); input_opt = paddle::make_optional(input); return input_opt; } paddle::optional GetOptionalSingleOutputGrad( const std::string& name) { paddle::optional output_grad_opt; framework::VarDesc* output_grad_desc = this->SingleOutputGrad(name); if (!output_grad_desc) return output_grad_opt; paddle::experimental::Tensor output_grad = paddle::experimental::Tensor( std::make_shared(output_grad_desc)); output_grad_opt = paddle::make_optional(output_grad); return output_grad_opt; } std::vector GetMultiForwardOutput( const std::string& name) { std::vector outputs; std::vector outputs_descs = this->MultiForwardOutput(name); outputs.reserve(outputs_descs.size()); for (const auto& output_desc : outputs_descs) { outputs.emplace_back(paddle::experimental::Tensor( std::make_shared(output_desc))); } return outputs; } std::vector GetMultiForwardInput( const std::string& name) { std::vector inputs; std::vector inputs_descs = this->MultiForwardInput(name); inputs.reserve(inputs_descs.size()); for (const auto& input_desc : inputs_descs) { inputs.emplace_back(paddle::experimental::Tensor( std::make_shared(input_desc))); } return inputs; } std::vector GetMultiOutputGrad( const std::string& name) { std::vector outputs_grads; std::vector outputs_grads_descs = this->MultiOutputGrad(name); outputs_grads.reserve(outputs_grads_descs.size()); for (const auto& output_grad_desc : outputs_grads_descs) { outputs_grads.emplace_back(paddle::experimental::Tensor( std::make_shared(output_grad_desc))); } return outputs_grads; } std::vector GetMultiInputGrad( const std::string& name) { std::vector inputs_grads; std::vector inputs_grads_descs = this->MultiInputGrad(name); inputs_grads.reserve(inputs_grads_descs.size()); for (const auto& input_grad_desc : inputs_grads_descs) { if (input_grad_desc) { inputs_grads.emplace_back(paddle::experimental::Tensor( std::make_shared(input_grad_desc))); } else { inputs_grads.emplace_back(paddle::experimental::Tensor()); } } return inputs_grads; } std::vector> GetOptionalMultiForwardOutput(const std::string& name) { std::vector> outputs_opt; std::vector outputs_descs = this->MultiForwardOutput(name); outputs_opt.reserve(outputs_descs.size()); for (const auto& output_desc : outputs_descs) { if (output_desc) { outputs_opt.emplace_back( paddle::make_optional( paddle::experimental::Tensor( std::make_shared(output_desc)))); } else { outputs_opt.emplace_back( paddle::make_optional( paddle::experimental::Tensor())); } } return outputs_opt; } std::vector> GetOptionalMultiForwardInput(const std::string& name) { std::vector> inputs_opt; std::vector inputs_descs = this->MultiForwardInput(name); inputs_opt.reserve(inputs_descs.size()); for (const auto& input_desc : inputs_descs) { if (input_desc) { inputs_opt.emplace_back( paddle::make_optional( paddle::experimental::Tensor( std::make_shared(input_desc)))); } else { inputs_opt.emplace_back( paddle::make_optional( paddle::experimental::Tensor())); } } return inputs_opt; } std::vector> GetOptionalMultiOutputGrad(const std::string& name) { std::vector> outputs_grads; std::vector outputs_grads_descs = this->MultiOutputGrad(name); outputs_grads.reserve(outputs_grads_descs.size()); for (const auto& output_grad_desc : outputs_grads_descs) { if (output_grad_desc) { outputs_grads.emplace_back( paddle::make_optional( paddle::experimental::Tensor( std::make_shared(output_grad_desc)))); } else { outputs_grads.emplace_back( paddle::make_optional( paddle::experimental::Tensor())); } } return outputs_grads; } paddle::experimental::Tensor* GetOutputPtr( paddle::experimental::Tensor* input) { if (input->defined()) return input; return nullptr; } std::vector GetOutputPtr( const std::vector& inputs) { std::vector output_ptrs; output_ptrs.reserve(inputs.size()); for (const auto& input : inputs) { if (input->defined()) output_ptrs.emplace_back(input); else output_ptrs.emplace_back(nullptr); } return output_ptrs; } std::string GetOutputName(const paddle::experimental::Tensor& output) { if (!output.defined()) return framework::kEmptyVarName; return static_cast(output.impl().get())->Name(); } std::vector GetOutputName( const std::vector& outputs) { std::vector out_names; out_names.reserve(outputs.size()); for (const auto& output : outputs) { if (!output.defined()) out_names.emplace_back(framework::kEmptyVarName); else out_names.emplace_back( static_cast(output.impl().get())->Name()); } return out_names; } protected: void CopyVarFromOrig(const std::string& name) const { VLOG(6) << "Copy Var: " << name << "from block: " << original_block_ << " to block: " << StaticCompositeContext::Instance().GetBlock(); framework::VarDesc* original_var = original_block_->FindVar(name); PADDLE_ENFORCE_NOT_NULL( original_var, phi::errors::InvalidArgument( "Can't find var: %s in block %s", name, original_block_)); *StaticCompositeContext::Instance().GetBlock()->Var(name) = *original_var; } framework::VarDesc* SingleInputGrad(const std::string& name, bool drop_empty_grad = true) const { auto var_name = this->SingleForwardInputVarName(name); auto grad_var_name = framework::GradVarName(var_name); if (no_grad_set_.empty() || !no_grad_set_.count(grad_var_name)) { (*this->grad_to_var_)[grad_var_name] = var_name; VLOG(8) << "Valid gradients: " << grad_var_name; } else { // TODO(jiabin): Will this cause fill zeros error? grad_var_name = framework::kEmptyVarName; if (drop_empty_grad) return nullptr; } if (original_block_->HasVar(grad_var_name)) { // Copy Var from original block to active block, or create a new one. CopyVarFromOrig(grad_var_name); return StaticCompositeContext::Instance().GetBlock()->FindVar( grad_var_name); } else { return StaticCompositeContext::Instance().GetBlock()->Var(grad_var_name); } } framework::VarDesc* SingleOutputGrad(const std::string& name) const { auto var_name = this->SingleForwardOutputVarName(name); auto grad_var_name = framework::GradVarName(var_name); (*this->grad_to_var_)[grad_var_name] = var_name; VLOG(8) << "Valid gradients: " << grad_var_name; if (original_block_->HasVar(grad_var_name)) { // Copy Var from original block to active block, or create a new one. CopyVarFromOrig(grad_var_name); return StaticCompositeContext::Instance().GetBlock()->FindVar( grad_var_name); } else { return StaticCompositeContext::Instance().GetBlock()->Var(grad_var_name); } } std::vector MultiInputGrad( const std::string& name, bool drop_empty_grad = true) const { std::vector ret_val; std::vector input_grads; auto var_names = this->MultiForwardInputVarName(name); ret_val.reserve(var_names.size()); std::transform(var_names.begin(), var_names.end(), std::back_inserter(ret_val), [this](const std::string& fwd_var_name) -> std::string { auto g_name = framework::GradVarName(fwd_var_name); if (no_grad_set_.empty() || !no_grad_set_.count(g_name)) { (*this->grad_to_var_)[g_name] = fwd_var_name; return g_name; } else { return framework::kEmptyVarName; } }); if (!drop_empty_grad) { for (const auto& name : ret_val) { if (original_block_->HasVar(name)) { // Copy Var from original block to active block, or create a new one. CopyVarFromOrig(name); input_grads.emplace_back( StaticCompositeContext::Instance().GetBlock()->FindVar(name)); } else { input_grads.emplace_back( StaticCompositeContext::Instance().GetBlock()->Var(name)); } } return input_grads; } PADDLE_ENFORCE_LE( var_names.size(), 1UL, platform::errors::Unavailable( "BUG from operator developer:" " for input argument with a list of variables, " " drop_empty_grad is not allowed because it makes" " the correspondence bewteen a variable and its gradient" " ambiguous.")); std::vector dropped_ret_val; dropped_ret_val.reserve(ret_val.size()); std::copy_if( ret_val.begin(), ret_val.end(), std::back_inserter(dropped_ret_val), [](const std::string& str) { return str != framework::kEmptyVarName; }); for (const auto& name : dropped_ret_val) { // TODO(jiabin): Will this cause fill zeros error? if (original_block_->HasVar(name)) { // Copy Var from original block to active block, or create a new one. CopyVarFromOrig(name); input_grads.emplace_back( StaticCompositeContext::Instance().GetBlock()->FindVar(name)); } else { input_grads.emplace_back( StaticCompositeContext::Instance().GetBlock()->Var(name)); } } return input_grads; } std::vector MultiOutputGrad( const std::string& name) const { std::vector ret_val; auto out_names = this->MultiForwardOutputVarName(name); ret_val.reserve(out_names.size()); std::transform(out_names.begin(), out_names.end(), std::back_inserter(ret_val), [this](const std::string& fwd_var_name) -> std::string { auto g_name = framework::GradVarName(fwd_var_name); (*this->grad_to_var_)[g_name] = fwd_var_name; return g_name; }); std::vector grad_out; for (const auto& name : ret_val) { // TODO(jiabin): Will this cause fill zeros error? if (original_block_->HasVar(name)) { // Copy Var from original block to active block, or create a new one. CopyVarFromOrig(name); grad_out.emplace_back( StaticCompositeContext::Instance().GetBlock()->FindVar(name)); } else { grad_out.emplace_back( StaticCompositeContext::Instance().GetBlock()->Var(name)); } } return grad_out; } framework::VarDesc* SingleForwardInput(const std::string& name) const { // Copy Var from original block to active block, or create a new one. CopyVarFromOrig(fwd_op_.Input(name).at(0)); return StaticCompositeContext::Instance().GetBlock()->FindVar( fwd_op_.Input(name).at(0)); } framework::VarDesc* SingleForwardOutput(const std::string& name) const { // Copy Var from original block to active block, or create a new one. CopyVarFromOrig(fwd_op_.Output(name).at(0)); return StaticCompositeContext::Instance().GetBlock()->FindVar( fwd_op_.Output(name).at(0)); } std::vector MultiForwardInput( const std::string& name) const { std::vector result; for (const auto& n : fwd_op_.Input(name)) { // Copy Var from original block to active block, or create a new one. CopyVarFromOrig(n); result.emplace_back( StaticCompositeContext::Instance().GetBlock()->FindVar(n)); } return result; } std::vector MultiForwardOutput( const std::string& name) const { std::vector result; for (const auto& n : fwd_op_.Output(name)) { // Copy Var from original block to active block, or create a new one. CopyVarFromOrig(n); result.emplace_back( StaticCompositeContext::Instance().GetBlock()->FindVar(n)); } return result; } void RecoverOutputName(const paddle::experimental::Tensor& output, const std::string& origin_name) { if (origin_name == framework::kEmptyVarName) return; VLOG(4) << "Recover: " << static_cast(output.impl().get())->Name() << " To: " << origin_name; prim::StaticCompositeContext::Instance().GetBlock()->RenameVar( static_cast(output.impl().get())->Name(), origin_name); } void RecoverOutputName( const std::vector& outputs, const std::vector& origin_names) { PADDLE_ENFORCE_EQ(outputs.size(), origin_names.size(), platform::errors::InvalidArgument( "The size of outputs must be equal to the size " "of the origin_names.", outputs.size(), origin_names.size())); for (size_t i = 0; i < outputs.size(); ++i) { if (origin_names[i] == framework::kEmptyVarName) continue; prim::StaticCompositeContext::Instance().GetBlock()->RenameVar( static_cast(outputs[i].impl().get())->Name(), origin_names[i]); } } std::string SingleForwardInputVarName(const std::string& name) const { return fwd_op_.Input(name).at(0); } std::string SingleForwardOutputVarName(const std::string& name) const { return fwd_op_.Output(name).at(0); } std::vector MultiForwardOutputVarName( const std::string& name) const { return fwd_op_.Output(name); } std::vector MultiForwardInputVarName( const std::string& name) const { return fwd_op_.Input(name); } static std::vector EmptyInput() { return {}; } static std::vector EmptyOutput() { return {}; } static std::vector EmptyInputGrad() { return {}; } static std::vector EmptyOutputGrad() { return {}; } std::vector InputNames() const { return this->fwd_op_.InputNames(); } std::vector OutputNames() const { return this->fwd_op_.OutputNames(); } const std::unordered_map& Attrs() const { return fwd_op_.GetAttrMap(); } const std::unordered_map& RuntimeAttrs() const { return fwd_op_.GetRuntimeAttrMap(); } const framework::Attribute& GetAttr(const std::string& name) const { auto& map = fwd_op_.GetAttrMap(); auto it = map.find(name); PADDLE_ENFORCE_NE( it, map.end(), platform::errors::NotFound("Cannot find attribute (%s).", name)); return it->second; } template inline const T& Attr(const std::string& name) const { return PADDLE_GET_CONST(T, GetAttr(name)); } std::string ForwardOpType() const { return this->fwd_op_.Type(); } const framework::BlockDesc* GetForwardOpBlock() const { return fwd_op_.Block(); } protected: bool HasInput(const std::string& name) const { return (fwd_op_.Inputs().count(name) > 0); } bool HasOutput(const std::string& name) const { return (fwd_op_.Outputs().count(name) > 0); } private: const framework::OpDesc& fwd_op_; const std::unordered_set& no_grad_set_; std::unordered_map* grad_to_var_; const framework::BlockDesc* original_block_; framework::ProgramDesc acting_program_; protected: std::vector grad_block_; }; } // namespace prim } // namespace paddle