/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/function_traits.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/kernels/hybird/general/elementwise_base.h" namespace pten { namespace kps = paddle::operators::kernel_primitives; enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 }; /* Packing scalar type T(float, int etc.) into Array type for supporting multiple-output feature in elementwise system.*/ template using OutType = typename std::conditional_t>; template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], OutT *result); }; template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], OutT *result) { kps::ElementwiseAny( result, args, func); } }; template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], OutT *result) { kps::ElementwiseUnary( result, args[0], func); } }; template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], OutT *result) { kps::ElementwiseBinary( result, args[0], args[1], func); } }; template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], OutT *result) { kps::ElementwiseTernary( result, args[0], args[1], args[2], func); } }; template struct ElementwiseWriteDataCaller { __device__ __forceinline__ void operator()( paddle::framework::Array outs, OutType src[VecSize], int block_offset, int num) { OutT dst[NumOuts][VecSize]; #pragma unroll for (int i = 0; i < VecSize; ++i) { #pragma unroll for (int j = 0; j < NumOuts; ++j) { dst[j][i] = (src[i])[j]; } } #pragma unroll for (int i = 0; i < NumOuts; ++i) { kps::WriteData( outs[i] + block_offset, dst[i], num); } } }; template struct ElementwiseWriteDataCaller { __device__ __forceinline__ void operator()( paddle::framework::Array outs, OutT src[VecSize], int block_offset, int num) { kps::WriteData( outs[0] + block_offset, src, num); } }; } // namespace pten