// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/activation_functor.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/fluid/platform/device_context.h" namespace phi { #define ToString(x) #x template void ActivationImpl(const Context& dev_ctx, const DenseTensor& X, DenseTensor* Out, const Functor& functor) { PADDLE_ENFORCE_NOT_NULL(Out, errors::NotFound("Output Out should not be nullptr")); dev_ctx.template Alloc(Out); auto x = phi::EigenVector::Flatten( GET_DATA_SAFELY(&X, "Input", "X", "Activation")); auto out = phi::EigenVector::Flatten( GET_DATA_SAFELY(Out, "Output", "Out", "Activation")); auto* place = dev_ctx.eigen_device(); // use 32bit index to speed up computation bool use_32bit_index = out.size() < Eigen::NumTraits::highest(); bool is_gpu_place = paddle::platform::is_gpu_place(dev_ctx.GetPlace()); if (use_32bit_index && is_gpu_place) { functor(*place, To32BitIndex(x), To32BitIndex(out)); } else { functor(*place, x, out); } } template void PowKernel(const Context& dev_ctx, const DenseTensor& x, const Scalar& factor, DenseTensor* out) { PADDLE_ENFORCE_NOT_NULL(out, errors::NotFound("Output Out should not be nullptr")); dev_ctx.template Alloc(out); auto x_flatten = phi::EigenVector::Flatten( GET_DATA_SAFELY(&x, "Input", "X", "Activation")); auto out_flatten = phi::EigenVector::Flatten( GET_DATA_SAFELY(out, "Output", "Out", "Activation")); auto* place = dev_ctx.eigen_device(); phi::funcs::PowFunctor functor; auto attrs = functor.GetAttrs(); *(attrs[0].second) = factor.to(); functor(*place, x_flatten, out_flatten); } } // namespace phi