// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" namespace egr { static inline bool NeedCast(const paddle::Tensor& tensor, const phi::DataType& dst_dtype) { auto place = tensor.place(); auto data_type = tensor.dtype(); // Except CPU judgment, other conditions should be consistent with // amp_utils.h's judgment if (paddle::platform::is_gpu_place(place) || paddle::platform::is_cuda_pinned_place(place) || paddle::platform::is_xpu_place(place) || paddle::platform::is_custom_place(place) || paddle::platform::is_cpu_place(place)) { // CudaPinndePlace is added for varbase created by dataloader // Cpu place is for differnt place tensor, when input1 is cpu and input2 is // gpu if ((data_type == phi::DataType::FLOAT32 || data_type == phi::DataType::FLOAT16 || data_type == phi::DataType::BFLOAT16) && (data_type != dst_dtype)) { return true; } } return false; } inline paddle::Tensor Cast(const paddle::Tensor& input, const phi::DataType& dst_dtype, const bool trace_backward = true) { if (input.is_sparse_coo_tensor() || input.is_sparse_csr_tensor()) { if (trace_backward) { return sparse::cast_ad_func(input, phi::DataType::UNDEFINED, dst_dtype); } else { return paddle::experimental::sparse::cast( input, phi::DataType::UNDEFINED, dst_dtype); } } else { if (trace_backward) { return cast_ad_func(input, dst_dtype); } else { return paddle::experimental::cast(input, dst_dtype); } } } inline std::vector EagerAmpAutoCasts( const std::string& inputs_name, const std::vector& inputs, const phi::DataType& dst_dtype, std::string op_name UNUSED, bool trace_backward UNUSED = true) { VLOG(6) << "AMP AmpAutoCasts:" << " inputs(" << inputs_name << ") dst_dtype(" << phi::DataTypeToString(dst_dtype) << ")."; std::vector inputs_casted; for (auto& input : inputs) { if (NeedCast(input, dst_dtype)) { inputs_casted.emplace_back(std::move(Cast(input, dst_dtype))); } else { inputs_casted.emplace_back(input); } } return inputs_casted; } inline paddle::Tensor EagerAmpAutoCast(const std::string& input_name, const paddle::Tensor& input, const phi::DataType& dst_dtype, const std::string& op_name, bool trace_backward = true) { VLOG(6) << "AMP AmpAutoCasts:" << " input(" << egr::EagerUtils::TensorStr(input) << " to dst_dtype(" << phi::DataTypeToString(dst_dtype) << ")."; if ((op_name == "batch_norm" || op_name == "layer_norm" || op_name == "sync_batch_norm") && input_name != "x") { return input; } if (dst_dtype == phi::DataType::FLOAT16) { if (op_name == "run_program") { return input; } if ((op_name == "fused_attention" || op_name == "fused_feedforward")) { if (input_name == "LnScale" || input_name == "LnBias" || input_name == "Ln2Scale" || input_name == "Ln2Bias" || input_name == "Ln1Scale" || input_name == "Ln1Bias") { return input; } } } if (NeedCast(input, dst_dtype)) { VLOG(6) << "Input : " << input.impl() << "NeedCast"; return Cast(input, dst_dtype, trace_backward); } return input; } inline paddle::optional EagerAmpAutoCast( const std::string& input_name, const paddle::optional& input, const phi::DataType& dst_dtype, const std::string& op_name, bool trace_backward = true) { if (input) { return EagerAmpAutoCast( input_name, *input, dst_dtype, op_name, trace_backward); } return paddle::none; } inline paddle::optional> EagerAmpAutoCasts( const std::string& inputs_name, const paddle::optional>& inputs, const phi::DataType& dst_dtype, std::string op_name, bool trace_backward = true) { if (inputs) { return EagerAmpAutoCasts( inputs_name, *inputs, dst_dtype, op_name, trace_backward); } return paddle::optional>(); } } // namespace egr