/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { // FORWARD CODE // Add template struct SameDimsAddFunctor { void operator()(const DevCtx& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* z); }; template struct SameDimsAddFunctor< DevCtx, T, typename std::enable_if::value>::type> { void operator()(const DevCtx& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* z) { auto blas = phi::funcs::GetBlas(dev_ctx); blas.VADD( x.numel(), x.data(), y.data(), dev_ctx.template Alloc(z)); } }; template struct SameDimsAddFunctor< DevCtx, T, typename std::enable_if::value>::type> { void operator()(const DevCtx& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* z) { dev_ctx.template Alloc(z); auto eigen_x = phi::EigenVector::Flatten(x); auto eigen_y = phi::EigenVector::Flatten(y); auto eigen_z = phi::EigenVector::Flatten(*z); auto& place = *dev_ctx.eigen_device(); eigen_z.device(place) = eigen_x + eigen_y; } }; // Subtract template struct SameDimsSubtractFunctor { void operator()(const DevCtx& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* z); }; template struct SameDimsSubtractFunctor< DevCtx, T, typename std::enable_if::value>::type> { void operator()(const DevCtx& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* z) { auto blas = phi::funcs::GetBlas(dev_ctx); blas.VSUB( x.numel(), x.data(), y.data(), dev_ctx.template Alloc(z)); } }; template struct SameDimsSubtractFunctor< DevCtx, T, typename std::enable_if::value>::type> { void operator()(const DevCtx& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* z) { auto eigen_x = phi::EigenVector::Flatten(x); auto eigen_y = phi::EigenVector::Flatten(y); auto eigen_z = phi::EigenVector::Flatten(*z); auto& place = *dev_ctx.eigen_device(); eigen_z.device(place) = eigen_x - eigen_y; } }; // Divide template struct SameDimsDivideFunctor { void operator()(const DevCtx& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* z); }; template struct SameDimsDivideFunctor< DevCtx, T, typename std::enable_if::value>::type> { void operator()(const DevCtx& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* z) { phi::errors::InvalidArgument( "If use SameDimsDivideFunctor, template args(T) must be floating " "point. "); } }; template struct SameDimsDivideFunctor< DevCtx, T, typename std::enable_if::value>::type> { void operator()(const DevCtx& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* z) { auto blas = phi::funcs::GetBlas(dev_ctx); blas.VDIV( x.numel(), x.data(), y.data(), dev_ctx.template Alloc(z)); } }; // Multiply template struct SameDimsMultiplyFunctor { void operator()(const DevCtx& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* z); }; template struct SameDimsMultiplyFunctor< DevCtx, T, typename std::enable_if::value>::type> { void operator()(const DevCtx& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* z) { auto blas = phi::funcs::GetBlas(dev_ctx); blas.VMUL( x.numel(), x.data(), y.data(), dev_ctx.template Alloc(z)); } }; template struct SameDimsMultiplyFunctor< DevCtx, T, typename std::enable_if::value>::type> { void operator()(const DevCtx& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* z) { auto eigen_x = phi::EigenVector::Flatten(x); auto eigen_y = phi::EigenVector::Flatten(y); auto eigen_z = phi::EigenVector::Flatten(*z); auto& place = *dev_ctx.eigen_device(); eigen_z.device(place) = eigen_x * eigen_y; } }; template struct SameDimsElementwiseCompute { void operator()(const CPUContext& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* z) { Functor()(dev_ctx, x, y, z); } }; } // namespace phi