未验证 提交 b7841a2b 编写于 作者: W Wang Xin 提交者: GitHub

move "function_traits.h" from fluid to phi (#48065)

上级 ff44df18
...@@ -176,7 +176,7 @@ __device__ void VectorizedBroadcastKernelImpl( ...@@ -176,7 +176,7 @@ __device__ void VectorizedBroadcastKernelImpl(
#endif #endif
constexpr bool kCallElementwiseAny = constexpr bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args; phi::funcs::FunctionTraits<Functor>::has_pointer_args;
phi::funcs::ElementwisePrimitiveCaller<InT, phi::funcs::ElementwisePrimitiveCaller<InT,
ConditionalT<OutT, NumOuts>, ConditionalT<OutT, NumOuts>,
VecSize, VecSize,
...@@ -787,7 +787,7 @@ void BroadcastKernelForDifferentVecSize( ...@@ -787,7 +787,7 @@ void BroadcastKernelForDifferentVecSize(
std::vector<DenseTensor *> *outs, std::vector<DenseTensor *> *outs,
int axis, int axis,
Functor func) { Functor func) {
using Traits = paddle::platform::FunctionTraits<Functor>; using Traits = phi::funcs::FunctionTraits<Functor>;
const int kArity = const int kArity =
Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity; Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -23,9 +23,9 @@ limitations under the License. */ ...@@ -23,9 +23,9 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
#include "paddle/fluid/platform/function_traits.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/function_traits.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h"
#define HOSTDEVICE __host__ __device__ #define HOSTDEVICE __host__ __device__
...@@ -563,7 +563,7 @@ int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins, ...@@ -563,7 +563,7 @@ int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins,
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
int vec_size = 256; int vec_size = 256;
#else #else
using Traits = paddle::platform::FunctionTraits<Functor>; using Traits = phi::funcs::FunctionTraits<Functor>;
using ArgsT = typename Traits::ArgsTuple; using ArgsT = typename Traits::ArgsTuple;
const int Arity = Traits::arity; const int Arity = Traits::arity;
int vec_size = 4; int vec_size = 4;
...@@ -736,7 +736,7 @@ __device__ void VectorizedElementwiseKernelImpl( ...@@ -736,7 +736,7 @@ __device__ void VectorizedElementwiseKernelImpl(
int num, int num,
int read_lens, int read_lens,
Functor func) { Functor func) {
using Traits = paddle::platform::FunctionTraits<Functor>; using Traits = phi::funcs::FunctionTraits<Functor>;
using ArgsT = typename Traits::ArgsTuple; using ArgsT = typename Traits::ArgsTuple;
ArgsT args[VecSize]; ArgsT args[VecSize];
ConditionalT<OutT, NumOuts> result[VecSize]; ConditionalT<OutT, NumOuts> result[VecSize];
...@@ -831,7 +831,7 @@ void ElementwiseKernel(const KPDevice &ctx, ...@@ -831,7 +831,7 @@ void ElementwiseKernel(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins, const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs, std::vector<DenseTensor *> *outs,
Functor func) { Functor func) {
using Traits = paddle::platform::FunctionTraits<Functor>; using Traits = phi::funcs::FunctionTraits<Functor>;
const int kArity = Traits::arity; const int kArity = Traits::arity;
PADDLE_ENFORCE_EQ(ins.size(), PADDLE_ENFORCE_EQ(ins.size(),
kArity, kArity,
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.1 (the "License"); Licensed under the Apache License, Version 2.1 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -16,8 +16,8 @@ limitations under the License. */ ...@@ -16,8 +16,8 @@ limitations under the License. */
#include <tuple> #include <tuple>
namespace paddle { namespace phi {
namespace platform { namespace funcs {
template <int Arity, typename... Args> template <int Arity, typename... Args>
struct IsPointerArgs { struct IsPointerArgs {
static_assert(Arity == sizeof...(Args), "Arity and Args not match!"); static_assert(Arity == sizeof...(Args), "Arity and Args not match!");
...@@ -57,5 +57,5 @@ struct FunctionTraits<ReturnType(Args...)> { ...@@ -57,5 +57,5 @@ struct FunctionTraits<ReturnType(Args...)> {
using ArgsTuple = std::tuple<Args...>; using ArgsTuple = std::tuple<Args...>;
}; };
} // namespace platform } // namespace funcs
} // namespace paddle } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册