// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/fluid/eager/api/utils/tensor_utils.h" #include "paddle/fluid/eager/autograd_meta.h" #include "paddle/fluid/eager/eager_tensor.h" #include "paddle/fluid/eager/grad_node_info.h" #include "paddle/phi/api/all.h" namespace egr { class TensorWrapper; /** * EagerUtils is utils used to do some static conversion or autograd * members access, this class is desinged to be a full static functional * utils class * **/ template class IterHelper { virtual void visit(ElementType element) = 0; void visit(std::vector* elements) { for (auto element : *elements) visit(element); } template void apply() {} public: template void apply(T&& arg, Args&&... args) { visit(std::forward(arg)); return apply(std::forward(args)...); } virtual ~IterHelper() = default; }; class ComputeRequireGradIter : public IterHelper { public: bool RequireGrad() { return require_grad_; } private: void visit(AutogradMeta* element) override { // Dispensable Tensors feeds in nullptr autograd_meta if (!element) return; bool stop_gradient = element->StopGradient(); if (!stop_gradient) require_grad_ = true; } bool require_grad_ = false; }; class PassStopGradientIter : public IterHelper { public: void SetStopGradient(bool stop_gradient) { stop_gradient_ = stop_gradient; } private: void visit(AutogradMeta* element) override { if (!element) { // TODO(jiabin): Add Tensor name here when we supported. VLOG(2) << "Tensor is NULL"; return; } element->WeakSetStopGradient(stop_gradient_); } bool stop_gradient_ = true; }; class EagerUtils { public: /** * We have to use autograd_meta and multi_autograd_meta to initialize * autograd_meta for tensor, since we can't init it in * egr::EagerVariable's * constructor (it's abstract class there) * * **/ static AutogradMeta* autograd_meta(paddle::experimental::Tensor* target); static std::vector autograd_meta( std::vector* targets); static std::vector autograd_meta( std::vector* targets); static std::pair OutRankInfo( const paddle::experimental::Tensor& target); static std::shared_ptr grad_node( const paddle::experimental::Tensor& target); static paddle::experimental::Tensor* mutable_grad( const paddle::experimental::Tensor& target); // Set history is used to set backward info during forward process, it will // set forward var's autograd meta's grad node as current backward node. static void SetHistory(std::vector* autograd_metas, const std::shared_ptr& grad_node); static void SetHistory(AutogradMeta* autograd_meta, const std::shared_ptr& grad_node); // This is used for Set vector of tensors' rank static void SetOutRankWithSlot(std::vector* targets, size_t slot_id); static void SetOutRankWithSlot(AutogradMeta* target, size_t slot_id); // This method will return an AutogradMeta pointer unsafely. static AutogradMeta* nullable_autograd_meta( const paddle::experimental::Tensor& target); static AutogradMeta* nullable_autograd_meta( paddle::optional target); static std::vector nullable_autograd_meta( const std::vector& targets); static std::vector nullable_autograd_meta( const std::vector& targets); static AutogradMeta* unsafe_autograd_meta( const paddle::experimental::Tensor& target); static std::vector unsafe_autograd_meta( const std::vector& targets); template static bool ComputeRequireGrad(T trace_backward, Args&&... args) { if (!trace_backward) { VLOG(6) << "Do not require grad because trace_backward = false"; return false; } auto iter = ComputeRequireGradIter(); iter.apply(std::forward(args)...); return iter.RequireGrad(); } template static void PassStopGradient(T stop_gradient, Args&&... args) { auto iter = PassStopGradientIter(); iter.SetStopGradient(stop_gradient); iter.apply(std::forward(args)...); } static void CheckInplace(const paddle::experimental::Tensor& target, const AutogradMeta* autograd_meta, bool require_any_grad) { if (require_any_grad && autograd_meta) { PADDLE_ENFORCE_EQ(!autograd_meta->StopGradient() && egr::egr_utils_api::IsLeafTensor(target), false, paddle::platform::errors::InvalidArgument( "Leaf Var (%s) that doesn't stop gradient " "can't use inplace strategy.", target.name())); } } // TensorWrapper Utils static paddle::experimental::Tensor RecoverTensorWrapper( TensorWrapper* tw, const std::shared_ptr& grad_node); static std::vector RecoverTensorWrapper( std::vector* tw, const std::shared_ptr& grad_node); static paddle::optional RecoverOptionalTensorWrapper(TensorWrapper* tw, const std::shared_ptr& grad_node); // Intermidate needed remove this once we don't need legacy // Inner Method static std::shared_ptr TrySyncToVar( const paddle::experimental::Tensor& tensor); // Basic Input static std::vector> TrySyncToVars( const paddle::experimental::Tensor& tensor); // Basic Output static std::vector> TrySyncToVars( paddle::experimental::Tensor* tensor); // Multi Output static std::vector> TrySyncToVars( const std::vector& tensors); // Multi Input static std::vector> TrySyncToVars( const std::vector& tensors); // Construct empty output static std::vector> CreateVars( const size_t num); // Construct Tensor From var static void ModifyInplaceInput( const std::shared_ptr& inplace_variable, paddle::experimental::Tensor* inplace_tensor); static std::vector GetOutputs( const std::vector>& outs); static paddle::experimental::Tensor GetOutput( const std::shared_ptr& out); static void GetOutput(const std::shared_ptr& out, paddle::experimental::Tensor* out_var); static void GetOutputs( const std::vector>& outs, std::vector* result); static void GetOutputs( const std::vector>& outs, const std::vector& out_var); static void GetOutputs(const std::shared_ptr& out, std::vector* result); static void GetOutputs( const std::shared_ptr& out, const std::vector& out_var); static void Output2Result( const std::vector& out_var, std::vector* result); // end Intermidate needed static void CheckAndRetainGrad(const paddle::experimental::Tensor& tensor); static void CheckAndRetainGrad( const std::vector& tensors); static void CheckAndRetainGrad( const std::vector& tensors); static std::shared_ptr GetGradAccumulationNode( const paddle::experimental::Tensor& tensor); /** * Fill Zero * **/ static void FillZeroForEmptyGradInputs( std::vector>* out_grads, const std::vector>& grad_out_metas); }; } // namespace egr