/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include "paddle/phi/kernels/funcs/elementwise_base.h" #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) #include "paddle/phi/kernels/funcs/dims_simplifier.h" namespace kps = phi::kps; #endif namespace phi { namespace funcs { #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) enum BroadcastLoadType { kMixed = 1, kBroadcast = 2, kElementwise = 3 }; template struct UseBroadcast { template static HOSTDEVICE void Apply( const std::vector &ins_tensor, const ArgsT &args, int64_t numel, Array1 *ins_data, Array2 *use_broadcast, int *broadcast_num, bool *all_elementwise) { (*ins_data)[Index] = (const _ptr_ char *)(ins_tensor[Index]->data()); bool is_same_dim = ins_tensor[Index]->numel() == numel; if (is_same_dim) { (*use_broadcast)[Index] = false; } else { (*use_broadcast)[Index] = true; (*broadcast_num)++; } *all_elementwise &= is_same_dim; } }; template struct LoaderTypeClassifier { public: int64_t numel{0}; int vec_size{4}; int broadcast_num{0}; bool all_elementwise{true}; phi::Array use_broadcast; phi::Array ins_data; LoaderTypeClassifier() {} LoaderTypeClassifier(const std::vector &ins, std::vector *outs) { using Traits = phi::funcs::FunctionTraits; using ArgsT = typename Traits::ArgsTuple; ArgsT arg; uint64_t out_addr = reinterpret_cast((*outs)[0]->data()); UnrollerWithoutVecSize::step(ins, arg, &vec_size); for (auto i = 1; i < outs->size(); ++i) { PADDLE_ENFORCE_EQ( (*outs)[i]->dims(), (*outs)[0]->dims(), phi::errors::InvalidArgument( "The shape of each output tensor shall be identical yet, but " "%d-th output tensor`s shape is not.", i)); out_addr = (out_addr | reinterpret_cast((*outs)[i]->data())); } vec_size = std::min( vec_size, phi::GetVectorizedSize(reinterpret_cast(out_addr))); numel = (*outs)[0]->numel(); UnrollerWithoutVecSize::step(ins, arg, numel, &ins_data, &use_broadcast, &broadcast_num, &all_elementwise); } }; // Common broadcast/elementwise Loader. template struct BroadcastDataLoader { template static __device__ __forceinline__ void Apply(const Array1 &ins, ArgsT *args, const Array2 &configs, const Array3 &use_broadcast, const int block_offset, const int num, const uint32_t numel, int read_lens) { using Type = std::tuple_element_t; #ifdef PADDLE_WITH_XPU_KP kps::Init( args, static_cast(1.0f), read_lens); if (use_broadcast[Index]) { kps::ReadDataBc( args, reinterpret_cast(ins[Index]), block_offset, configs[Index], numel, read_lens); } else { kps::ReadData( args, reinterpret_cast(ins[Index]) + block_offset, num, read_lens); } #else kps::Init(args, static_cast(1.0f)); if (use_broadcast[Index]) { kps::ReadDataBc( args, reinterpret_cast(ins[Index]), block_offset, configs[Index], numel, VecSize); } // NOTE: If use if...else... with condition `use_broadcast[Index]` here, // there will be some errs with clang12 while compiling in ROCm. // When the compiler is upgraded, if...else... may be used. if (!use_broadcast[Index]) { kps::ReadData( args, reinterpret_cast(ins[Index]) + block_offset, num, VecSize); } #endif } }; /* BroadcastDataLoaders Partial specialization */ #ifndef PADDLE_WITH_XPU_KP // Scalar elementwise Loader with consideration of IsBoundary. template struct BroadcastDataLoader { template static __device__ __forceinline__ void Apply(const Array1 &ins, ArgsT *args, const Array2 &configs, const Array3 &use_broadcast, const int block_offset, const int num, const uint32_t numel, int read_lens) { using Type = std::tuple_element_t; int thread_offset = threadIdx.x * VecSize + block_offset; #pragma unroll for (int idx = 0; idx < VecSize; ++idx) { std::get(args[idx]) = static_cast(1); int index = thread_offset + idx; if (index < numel) { std::get(args[idx]) = reinterpret_cast(ins[Index])[index]; } } } }; // Vectorized elementwise Loader without consideration of IsBoundary. template struct BroadcastDataLoader { template static __device__ __forceinline__ void Apply(const Array1 &ins, ArgsT *args, const Array2 &configs, const Array3 &use_broadcast, const int block_offset, const int num, const uint32_t numel, int read_lens) { using Type = std::tuple_element_t; using VecType = phi::kps::details::VectorType; VecType vec_temp; int thread_offset = threadIdx.x + blockIdx.x * blockDim.x; const VecType *__restrict__ vec_input = reinterpret_cast(ins[Index]); vec_temp = vec_input[thread_offset]; #pragma unroll for (int idx = 0; idx < VecSize; ++idx) { std::get(args[idx]) = vec_temp.val[idx]; } } }; template struct BroadcastDataInit { template static __device__ __forceinline__ void Apply(ArgsT *args) { using Type = std::tuple_element_t; #pragma unroll for (int k = 0; k < VecSize; ++k) { std::get(args[k]) = static_cast(1); } } }; template struct BroadcastDataSetter { template static __device__ __forceinline__ void Apply(const Array &ins, ArgsT *args, uint32_t index_bc[][VecSize]) { using Type = std::tuple_element_t; #pragma unroll for (int k = 0; k < VecSize; ++k) { std::get(args[k]) = reinterpret_cast(ins[Index])[index_bc[Index][k]]; } } }; #endif // static broadcast unroller template