未验证 提交 9adca1e7 编写于 作者: W Wang Xin 提交者: GitHub

move "gpu_primitives.h" to phi (#48015)

上级 e4ebf383
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
#include <stdio.h>
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
template <typename T>
using complex = phi::dtype::complex<T>;
using float16 = phi::dtype::float16;
using bfloat16 = phi::dtype::bfloat16;
namespace phi {
#define CUDA_ATOMIC_WRAPPER(op, T) \
__device__ __forceinline__ T CudaAtomic##op(T *address, const T val)
#define USE_CUDA_ATOMIC(op, T) \
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
// Default thread count per block(or block size).
// TODO(typhoonzero): need to benchmark against setting this value
// to 1024.
constexpr int PADDLE_CUDA_NUM_THREADS = 512;
// For atomicAdd.
USE_CUDA_ATOMIC(Add, float);
USE_CUDA_ATOMIC(Add, int);
USE_CUDA_ATOMIC(Add, unsigned int);
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
USE_CUDA_ATOMIC(Add, unsigned long long int); // NOLINT
CUDA_ATOMIC_WRAPPER(Add, int64_t) {
// Here, we check long long int must be int64_t.
static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT
"long long should be int64");
return CudaAtomicAdd(
reinterpret_cast<unsigned long long int *>(address), // NOLINT
static_cast<unsigned long long int>(val)); // NOLINT
}
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600)
USE_CUDA_ATOMIC(Add, double);
#else
CUDA_ATOMIC_WRAPPER(Add, double) {
unsigned long long int *address_as_ull = // NOLINT
reinterpret_cast<unsigned long long int *>(address); // NOLINT
unsigned long long int old = *address_as_ull, assumed; // NOLINT
do {
assumed = old;
old = atomicCAS(address_as_ull,
assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
// Note: uses integer comparison to avoid hang in case of NaN
} while (assumed != old);
return __longlong_as_double(old);
}
#endif
#ifdef PADDLE_CUDA_FP16
// NOTE(dzhwinter): cuda do not have atomicCAS for half.
// Just use the half address as a unsigned value address and
// do the atomicCAS. According to the value store at high 16 bits
// or low 16 bits, then do a different sum and CAS.
// Given most warp-threads will failed on the atomicCAS, so this
// implemented should be avoided in high concurrency. It's will be
// slower than the way convert value into 32bits and do a full atomicCAS.
// convert the value into float and do the add arithmetic.
// then store the result into a uint32.
inline static __device__ uint32_t add_to_low_half(uint32_t val, float x) {
float16 low_half;
// the float16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half = static_cast<float16>(static_cast<float>(low_half) + x);
return (val & 0xFFFF0000u) | low_half.x;
}
inline static __device__ uint32_t add_to_high_half(uint32_t val, float x) {
float16 high_half;
// the float16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half = static_cast<float16>(static_cast<float>(high_half) + x);
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
}
#if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
static __device__ __forceinline__ float16 CUDAFP16ToPDFP16(__half x) {
return *reinterpret_cast<float16 *>(&x);
}
static __device__ __forceinline__ __half PDFP16ToCUDAFP16(float16 x) {
return *reinterpret_cast<__half *>(&x);
}
CUDA_ATOMIC_WRAPPER(Add, float16) {
return CUDAFP16ToPDFP16(
atomicAdd(reinterpret_cast<__half *>(address), PDFP16ToCUDAFP16(val)));
}
#else
CUDA_ATOMIC_WRAPPER(Add, float16) {
// concrete packed float16 value may exsits in lower or higher 16bits
// of the 32bits address.
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
reinterpret_cast<char *>(address) -
(reinterpret_cast<uintptr_t>(address) & 0x02));
float val_f = static_cast<float>(val);
uint32_t old = *address_as_ui;
uint32_t sum;
uint32_t newval;
uint32_t assumed;
if (((uintptr_t)address & 0x02) == 0) {
// the float16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, add_to_low_half(assumed, val_f));
} while (old != assumed);
float16 ret;
ret.x = old & 0xFFFFu;
return ret;
} else {
// the float16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, add_to_high_half(assumed, val_f));
} while (old != assumed);
float16 ret;
ret.x = old >> 16;
return ret;
}
}
#endif
// The performance of "atomicAdd(half* )" is bad, but for "atomicAdd(half2* )"
// is good. So for fp16 type, we can use "atomicAdd(half2* )" to speed up.
template <
typename T,
typename std::enable_if<std::is_same<float16, T>::value>::type * = nullptr>
__device__ __forceinline__ void fastAtomicAdd(T *tensor,
size_t index,
const size_t numel,
T value) {
#if ((CUDA_VERSION < 10000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
CudaAtomicAdd(reinterpret_cast<float16 *>(tensor) + index,
static_cast<float16>(value));
#else
// whether the address is 32-byte aligned.
__half *target_addr = reinterpret_cast<__half *>(tensor + index);
bool aligned_half2 =
(reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(__half2) == 0);
if (aligned_half2 && index < (numel - 1)) {
__half2 value2;
value2.x = *reinterpret_cast<__half *>(&value);
value2.y = __int2half_rz(0);
atomicAdd(reinterpret_cast<__half2 *>(target_addr), value2);
} else if (!aligned_half2 && index > 0) {
__half2 value2;
value2.x = __int2half_rz(0);
value2.y = *reinterpret_cast<__half *>(&value);
atomicAdd(reinterpret_cast<__half2 *>(target_addr - 1), value2);
} else {
atomicAdd(reinterpret_cast<__half *>(tensor) + index,
*reinterpret_cast<__half *>(&value));
}
#endif
}
template <
typename T,
typename std::enable_if<!std::is_same<float16, T>::value>::type * = nullptr>
__device__ __forceinline__ void fastAtomicAdd(T *arr,
size_t index,
const size_t numel,
T value) {
CudaAtomicAdd(arr + index, value);
}
#endif
// NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16.
inline static __device__ uint32_t bf16_add_to_low_half(uint32_t val, float x) {
bfloat16 low_half;
// the bfloat16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half = static_cast<bfloat16>(static_cast<float>(low_half) + x);
return (val & 0xFFFF0000u) | low_half.x;
}
inline static __device__ uint32_t bf16_add_to_high_half(uint32_t val, float x) {
bfloat16 high_half;
// the bfloat16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half = static_cast<bfloat16>(static_cast<float>(high_half) + x);
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
}
#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ __forceinline__ bfloat16 CUDABF16ToPDBF16(__nv_bfloat16 x) {
return *reinterpret_cast<bfloat16 *>(&x);
}
static __device__ __forceinline__ __nv_bfloat16 PDBF16ToCUDABF16(bfloat16 x) {
return *reinterpret_cast<__nv_bfloat16 *>(&x);
}
CUDA_ATOMIC_WRAPPER(Add, bfloat16) {
return CUDABF16ToPDBF16(atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address),
PDBF16ToCUDABF16(val)));
}
#else
CUDA_ATOMIC_WRAPPER(Add, bfloat16) {
// concrete packed bfloat16 value may exsits in lower or higher 16bits
// of the 32bits address.
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
reinterpret_cast<char *>(address) -
(reinterpret_cast<uintptr_t>(address) & 0x02));
float val_f = static_cast<float>(val);
uint32_t old = *address_as_ui;
uint32_t sum;
uint32_t newval;
uint32_t assumed;
if (((uintptr_t)address & 0x02) == 0) {
// the bfloat16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(
address_as_ui, assumed, bf16_add_to_low_half(assumed, val_f));
} while (old != assumed);
bfloat16 ret;
ret.x = old & 0xFFFFu;
return ret;
} else {
// the bfloat16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(
address_as_ui, assumed, bf16_add_to_high_half(assumed, val_f));
} while (old != assumed);
bfloat16 ret;
ret.x = old >> 16;
return ret;
}
}
#endif
CUDA_ATOMIC_WRAPPER(Add, complex<float>) {
float *real = reinterpret_cast<float *>(address);
float *imag = real + 1;
return complex<float>(CudaAtomicAdd(real, val.real),
CudaAtomicAdd(imag, val.imag));
}
CUDA_ATOMIC_WRAPPER(Add, complex<double>) {
double *real = reinterpret_cast<double *>(address);
double *imag = real + 1;
return complex<double>(CudaAtomicAdd(real, val.real),
CudaAtomicAdd(imag, val.imag));
}
// For atomicMax
USE_CUDA_ATOMIC(Max, int);
USE_CUDA_ATOMIC(Max, unsigned int);
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350)
USE_CUDA_ATOMIC(Max, unsigned long long int); // NOLINT
#else
CUDA_ATOMIC_WRAPPER(Max, unsigned long long int) { // NOLINT
if (*address >= val) {
return *address;
}
unsigned long long int old = *address, assumed; // NOLINT
do {
assumed = old;
if (assumed >= val) {
break;
}
old = atomicCAS(address, assumed, val);
} while (assumed != old);
}
#endif
CUDA_ATOMIC_WRAPPER(Max, int64_t) {
// Here, we check long long int must be int64_t.
static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT
"long long should be int64");
long long int res = *address; // NOLINT
while (val > res) {
long long int old = res; // NOLINT
res = (long long int)atomicCAS((unsigned long long int *)address, // NOLINT
(unsigned long long int)old, // NOLINT
(unsigned long long int)val); // NOLINT
if (res == old) {
break;
}
}
return res;
}
CUDA_ATOMIC_WRAPPER(Max, float) {
if (*address >= val) {
return *address;
}
int *const address_as_i = reinterpret_cast<int *>(address);
int old = *address_as_i, assumed;
do {
assumed = old;
if (__int_as_float(assumed) >= val) {
break;
}
old = atomicCAS(address_as_i, assumed, __float_as_int(val));
} while (assumed != old);
return __int_as_float(old);
}
CUDA_ATOMIC_WRAPPER(Max, double) {
if (*address >= val) {
return *address;
}
unsigned long long int *const address_as_ull = // NOLINT
reinterpret_cast<unsigned long long int *>(address); // NOLINT
unsigned long long int old = *address_as_ull, assumed; // NOLINT
do {
assumed = old;
if (__longlong_as_double(assumed) >= val) {
break;
}
old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val));
} while (assumed != old);
return __longlong_as_double(old);
}
#ifdef PADDLE_CUDA_FP16
inline static __device__ uint32_t max_to_low_half(uint32_t val, float x) {
float16 low_half;
// The float16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half = static_cast<float16>(max(static_cast<float>(low_half), x));
return (val & 0xFFFF0000u) | low_half.x;
}
inline static __device__ uint32_t max_to_high_half(uint32_t val, float x) {
float16 high_half;
// The float16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half = static_cast<float16>(max(static_cast<float>(high_half), x));
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
}
CUDA_ATOMIC_WRAPPER(Max, float16) {
if (*address >= val) {
return *address;
}
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
reinterpret_cast<char *>(address) -
(reinterpret_cast<uintptr_t>(address) & 0x02));
float val_f = static_cast<float>(val);
uint32_t old = *address_as_ui;
uint32_t assumed;
if (((uintptr_t)address & 0x02) == 0) {
// The float16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, max_to_low_half(assumed, val_f));
} while (old != assumed);
float16 ret;
ret.x = old & 0xFFFFu;
return ret;
} else {
// The float16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, max_to_high_half(assumed, val_f));
} while (old != assumed);
float16 ret;
ret.x = old >> 16;
return ret;
}
}
#endif
// For atomicMin
USE_CUDA_ATOMIC(Min, int);
USE_CUDA_ATOMIC(Min, unsigned int);
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350)
USE_CUDA_ATOMIC(Min, unsigned long long int); // NOLINT
#else
CUDA_ATOMIC_WRAPPER(Min, unsigned long long int) { // NOLINT
if (*address <= val) {
return *address;
}
unsigned long long int old = *address, assumed; // NOLINT
do {
assumed = old;
if (assumed <= val) {
break;
}
old = atomicCAS(address, assumed, val);
} while (assumed != old);
}
#endif
CUDA_ATOMIC_WRAPPER(Min, int64_t) {
// Here, we check long long int must be int64_t.
static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT
"long long should be int64");
long long int res = *address; // NOLINT
while (val < res) {
long long int old = res; // NOLINT
res = (long long int)atomicCAS((unsigned long long int *)address, // NOLINT
(unsigned long long int)old, // NOLINT
(unsigned long long int)val); // NOLINT
if (res == old) {
break;
}
}
return res;
}
CUDA_ATOMIC_WRAPPER(Min, float) {
if (*address <= val) {
return *address;
}
int *const address_as_i = reinterpret_cast<int *>(address);
int old = *address_as_i, assumed;
do {
assumed = old;
if (__int_as_float(assumed) <= val) {
break;
}
old = atomicCAS(address_as_i, assumed, __float_as_int(val));
} while (assumed != old);
return __int_as_float(old);
}
CUDA_ATOMIC_WRAPPER(Min, double) {
if (*address <= val) {
return *address;
}
unsigned long long int *const address_as_ull = // NOLINT
reinterpret_cast<unsigned long long int *>(address); // NOLINT
unsigned long long int old = *address_as_ull, assumed; // NOLINT
do {
assumed = old;
if (__longlong_as_double(assumed) <= val) {
break;
}
old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val));
} while (assumed != old);
return __longlong_as_double(old);
}
#ifdef PADDLE_CUDA_FP16
inline static __device__ uint32_t min_to_low_half(uint32_t val, float x) {
float16 low_half;
// The float16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half = static_cast<float16>(min(static_cast<float>(low_half), x));
return (val & 0xFFFF0000u) | low_half.x;
}
inline static __device__ uint32_t min_to_high_half(uint32_t val, float x) {
float16 high_half;
// The float16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half = static_cast<float16>(min(static_cast<float>(high_half), x));
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
}
CUDA_ATOMIC_WRAPPER(Min, float16) {
if (*address <= val) {
return *address;
}
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
reinterpret_cast<char *>(address) -
(reinterpret_cast<uintptr_t>(address) & 0x02));
float val_f = static_cast<float>(val);
uint32_t old = *address_as_ui;
uint32_t assumed;
if (((uintptr_t)address & 0x02) == 0) {
// The float16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, min_to_low_half(assumed, val_f));
} while (old != assumed);
float16 ret;
ret.x = old & 0xFFFFu;
return ret;
} else {
// The float16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, min_to_high_half(assumed, val_f));
} while (old != assumed);
float16 ret;
ret.x = old >> 16;
return ret;
}
}
#endif
#ifdef PADDLE_CUDA_FP16
#ifdef PADDLE_WITH_CUDA
/*
* One thead block deals with elementwise atomicAdd for vector of len.
* @in: [x1, x2, x3, ...]
* @out:[y1+x1, y2+x2, y3+x3, ...]
* */
template <
typename T,
typename std::enable_if<!std::is_same<float16, T>::value>::type * = nullptr>
__device__ __forceinline__ void VectorizedAtomicAddPerBlock(
const int64_t len, int tid, int threads_per_block, const T *in, T *out) {
for (int i = tid; i < len; i += threads_per_block) {
CudaAtomicAdd(&out[i], in[i]);
}
}
// Note: assume that len is even. If len is odd, call fastAtomicAdd directly.
template <
typename T,
typename std::enable_if<std::is_same<float16, T>::value>::type * = nullptr>
__device__ __forceinline__ void VectorizedAtomicAddPerBlock(
const int64_t len, int tid, int threads_per_block, const T *in, T *out) {
#if ((CUDA_VERSION < 10000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
for (int i = tid; i < len; i += threads_per_block) {
CudaAtomicAdd(&out[i], in[i]);
}
#else
int i = 0;
int loops = len / 2 * 2;
bool aligned_half2 =
(reinterpret_cast<std::uintptr_t>(out) % sizeof(__half2) == 0);
if (aligned_half2) {
for (i = tid * 2; i < loops; i += threads_per_block * 2) {
__half2 value2;
T value_1 = in[i];
T value_2 = in[i + 1];
value2.x = *reinterpret_cast<__half *>(&value_1);
value2.y = *reinterpret_cast<__half *>(&value_2);
atomicAdd(reinterpret_cast<__half2 *>(&out[i]), value2);
}
for (; i < len; i += threads_per_block) {
fastAtomicAdd(out, i, len, in[i]);
}
} else {
for (int i = tid; i < len; i += threads_per_block) {
fastAtomicAdd(out, i, len, in[i]);
}
}
#endif
}
#endif
#endif
} // namespace phi
......@@ -15,8 +15,8 @@ limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/phi/kernels/funcs/gru_compute.h"
......
......@@ -15,8 +15,8 @@ limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/phi/kernels/funcs/lstm_compute.h"
......@@ -202,15 +202,12 @@ __global__ void KeLstmBackward(Op op,
if (is_batch) {
if (value.prev_state_value) {
if (grad.check_ig_grad)
paddle::platform::CudaAtomicAdd(grad.check_ig_grad + frame_idx,
r_checkIGrad);
phi::CudaAtomicAdd(grad.check_ig_grad + frame_idx, r_checkIGrad);
if (grad.check_fg_grad)
paddle::platform::CudaAtomicAdd(grad.check_fg_grad + frame_idx,
r_checkFGrad);
phi::CudaAtomicAdd(grad.check_fg_grad + frame_idx, r_checkFGrad);
}
if (grad.check_og_grad)
paddle::platform::CudaAtomicAdd(grad.check_og_grad + frame_idx,
r_checkOGrad);
phi::CudaAtomicAdd(grad.check_og_grad + frame_idx, r_checkOGrad);
} else {
if (value.prev_state_value) {
if (grad.check_ig_grad) grad.check_ig_grad[frame_idx] += r_checkIGrad;
......
......@@ -18,8 +18,8 @@ limitations under the License. */
#include "paddle/fluid/memory/memcpy.h"
// TODO(paddle-dev): move gpu_primitives.h to phi
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......@@ -217,7 +217,7 @@ __global__ void GatherGradGPUKernel(const T* input,
int64_t out_index =
inner_dim_index * (outer_dim_size * out_index_dim_size) +
index[index_dim_index] * outer_dim_size + out_dim_index;
paddle::platform::CudaAtomicAdd(out + out_index, *(input + idx));
phi::CudaAtomicAdd(out + out_index, *(input + idx));
}
}
......
......@@ -15,8 +15,8 @@ limitations under the License. */
#include <algorithm>
#include <vector>
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/primitive/datamover_primitives.h"
......@@ -428,8 +428,7 @@ __global__ void KernelMaxPool2DGrad(const int nthreads,
if (maxIndex != -1) {
// atomic add
paddle::platform::CudaAtomicAdd(input_grad + maxIndex,
output_grad[index]);
phi::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]);
}
}
}
......@@ -1330,7 +1329,7 @@ __global__ void KernelMaxPool3DGrad(const int nthreads,
}
if (maxIdx != -1) {
// atomic add
paddle::platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]);
phi::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]);
}
}
}
......@@ -2359,7 +2358,7 @@ __global__ void KernelMaxPool3DWithIdxGrad(
w_offset;
int max_index = mask[output_index];
if (max_index != -1) {
paddle::platform::CudaAtomicAdd(
phi::CudaAtomicAdd(
&input_grad[nc_offset * input_depth * input_height * input_width +
max_index],
output_grad[output_index]);
......
......@@ -16,8 +16,8 @@ limitations under the License. */
#include <unordered_set>
#include <vector>
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......@@ -70,7 +70,7 @@ __global__ void ScatterCUDAKernel(const T* params,
if (overwrite) {
*(output + out_i) = *(params + i);
} else {
paddle::platform::CudaAtomicAdd(output + out_i, *(params + i));
phi::CudaAtomicAdd(output + out_i, *(params + i));
}
}
}
......@@ -104,7 +104,7 @@ __global__ void ScatterNdCUDAKernel(const T* update,
temp *= output_dims[j];
}
int64_t output_i = gather_i + slice_i;
paddle::platform::CudaAtomicAdd(output + output_i, *(update + i));
phi::CudaAtomicAdd(output + output_i, *(update + i));
}
}
......
......@@ -14,9 +14,9 @@ limitations under the License. */
#include <algorithm>
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/segment_pooling.h"
......@@ -60,7 +60,7 @@ __global__ void SegmentSumIdsKernel(const Index* segment_ids,
}
if (j > 0) {
if (last_segment_id == first_segment_id) {
paddle::platform::CudaAtomicAdd(summed_ids + last_segment_id, sum);
phi::CudaAtomicAdd(summed_ids + last_segment_id, sum);
} else {
*(summed_ids + last_segment_id) = sum;
}
......@@ -70,7 +70,7 @@ __global__ void SegmentSumIdsKernel(const Index* segment_ids,
sum += T(1);
last_segment_id = current_segment_id;
}
paddle::platform::CudaAtomicAdd(summed_ids + last_segment_id, sum);
phi::CudaAtomicAdd(summed_ids + last_segment_id, sum);
}
}
......@@ -111,8 +111,8 @@ __global__ void SegmentMeanKernel(const Index* segment_ids,
last_segment_id * inner_dim_size + segment_offset;
if (last_segment_id == first_segment_id) {
paddle::platform::CudaAtomicAdd(
output + output_index, sum / *(summed_ids + last_segment_id));
phi::CudaAtomicAdd(output + output_index,
sum / *(summed_ids + last_segment_id));
} else {
*(output + output_index) = sum / *(summed_ids + last_segment_id);
}
......@@ -123,8 +123,8 @@ __global__ void SegmentMeanKernel(const Index* segment_ids,
last_segment_id = current_segment_id;
}
Index output_index = last_segment_id * inner_dim_size + segment_offset;
paddle::platform::CudaAtomicAdd(output + output_index,
sum / *(summed_ids + last_segment_id));
phi::CudaAtomicAdd(output + output_index,
sum / *(summed_ids + last_segment_id));
}
}
......@@ -215,7 +215,7 @@ class MaxPool {
DEVICE inline T initial() { return static_cast<T>(-FLT_MAX); }
DEVICE inline void compute(const T& x, T* y) { *y = *y > x ? *y : x; }
DEVICE inline T atomic(T* address, const T val) {
return paddle::platform::CudaAtomicMax(address, val);
return phi::CudaAtomicMax(address, val);
}
};
......@@ -225,7 +225,7 @@ class MinPool {
DEVICE inline T initial() { return static_cast<T>(FLT_MAX); }
DEVICE inline void compute(const T& x, T* y) { *y = *y < x ? *y : x; }
DEVICE inline T atomic(T* address, const T val) {
return paddle::platform::CudaAtomicMin(address, val);
return phi::CudaAtomicMin(address, val);
}
};
......@@ -235,7 +235,7 @@ class SumPool {
DEVICE inline T initial() { return static_cast<T>(0); }
DEVICE inline void compute(const T& x, T* y) { *y = *y + x; }
DEVICE inline T atomic(T* address, const T val) {
return paddle::platform::CudaAtomicAdd(address, val);
return phi::CudaAtomicAdd(address, val);
}
};
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include <set>
#include <vector>
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......@@ -127,7 +127,7 @@ __global__ void SelectedRowsAddTensorKernel(const T* selected_rows,
// Since index in rows of SelectedRows can be duplicate, we can not use
// tensor_out[index] += selected_rows[index]; Instead, we have to use
// AtomicAdd to avoid concurrent write error.
paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
phi::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
}
}
} // namespace
......@@ -279,7 +279,7 @@ __global__ void SelectedRowsAddToTensorKernel(const T* selected_rows,
for (int index = tid; index < row_numel; index += block_size) {
// Since index in rows of SelectedRows can be duplicate, we have to use
// Atomic Operation to avoid concurrent write error.
paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
phi::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
}
}
} // namespace
......@@ -360,7 +360,7 @@ __global__ void MergeAddKernel(const T* input,
input += ty * row_numel;
out += out_idx * row_numel;
for (int index = tid; index < row_numel; index += block_size) {
paddle::platform::CudaAtomicAdd(out + index, input[index]);
phi::CudaAtomicAdd(out + index, input[index]);
}
}
......@@ -623,9 +623,9 @@ struct UpdateToTensor<phi::GPUContext, T> {
auto* in1_data = in1_value.template data<T>();
auto* in2_data = input2->data<T>();
dim3 threads(paddle::platform::PADDLE_CUDA_NUM_THREADS, 1);
dim3 threads(phi::PADDLE_CUDA_NUM_THREADS, 1);
dim3 grid(in1_rows.size(), 1);
UpdateToTensorKernel<T, paddle::platform::PADDLE_CUDA_NUM_THREADS>
UpdateToTensorKernel<T, phi::PADDLE_CUDA_NUM_THREADS>
<<<grid, threads, 0, context.stream()>>>(
in1_data, in1_rows.cuda_data(), op, in2_data, in1_row_numel);
}
......
......@@ -17,14 +17,14 @@
#include <thrust/execution_policy.h>
#include <thrust/reduce.h>
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <int BlockSize>
__global__ void AccuracyCudaKernel(const int N,
......
......@@ -14,8 +14,8 @@
#include "paddle/phi/kernels/adagrad_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
......@@ -47,7 +47,7 @@ __global__ void MergeGradKernel(const T* grad,
grad += ty * row_numel;
grad_merge += grad_merge_idx * row_numel;
for (int index = tid; index < row_numel; index += block_size) {
paddle::platform::CudaAtomicAdd(grad_merge + index, grad[index]);
phi::CudaAtomicAdd(grad_merge + index, grad[index]);
}
}
......@@ -69,9 +69,9 @@ __global__ void SparseAdagradFunctorKernel(const T* grad,
for (int index = tid; index < row_numel; index += block_size) {
// Since index in rows of SelectedRows can be duplicate, we have to use
// Atomic Operation to avoid concurrent write error.
paddle::platform::CudaAtomicAdd(param + index,
-1.0 * learning_rate[0] * grad[index] /
(sqrt(moment[index]) + epsilon));
phi::CudaAtomicAdd(param + index,
-1.0 * learning_rate[0] * grad[index] /
(sqrt(moment[index]) + epsilon));
}
}
......
......@@ -18,9 +18,9 @@
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/affine_grid_utils.h"
......@@ -75,18 +75,14 @@ __global__ void affine_grid_grad_kernel_4d(const int count,
int theta_offset = n * 6; // 2 * 3;
T out_grad_x = out_grad[index * 2];
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset,
out_grad_x * w_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 1,
out_grad_x * h_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 2, out_grad_x);
phi::CudaAtomicAdd(theta_grad + theta_offset, out_grad_x * w_coor);
phi::CudaAtomicAdd(theta_grad + theta_offset + 1, out_grad_x * h_coor);
phi::CudaAtomicAdd(theta_grad + theta_offset + 2, out_grad_x);
T out_grad_y = out_grad[index * 2 + 1];
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 3,
out_grad_y * w_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 4,
out_grad_y * h_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 5, out_grad_y);
phi::CudaAtomicAdd(theta_grad + theta_offset + 3, out_grad_y * w_coor);
phi::CudaAtomicAdd(theta_grad + theta_offset + 4, out_grad_y * h_coor);
phi::CudaAtomicAdd(theta_grad + theta_offset + 5, out_grad_y);
}
}
......@@ -116,31 +112,22 @@ __global__ void affine_grid_grad_kernel_5d(const int count,
int theta_offset = n * 12; // 3 * 4;
T out_grad_x = out_grad[index * 3];
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset,
out_grad_x * w_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 1,
out_grad_x * h_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 2,
out_grad_x * d_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 3, out_grad_x);
phi::CudaAtomicAdd(theta_grad + theta_offset, out_grad_x * w_coor);
phi::CudaAtomicAdd(theta_grad + theta_offset + 1, out_grad_x * h_coor);
phi::CudaAtomicAdd(theta_grad + theta_offset + 2, out_grad_x * d_coor);
phi::CudaAtomicAdd(theta_grad + theta_offset + 3, out_grad_x);
T out_grad_y = out_grad[index * 3 + 1];
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 4,
out_grad_y * w_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 5,
out_grad_y * h_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 6,
out_grad_y * d_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 7, out_grad_y);
phi::CudaAtomicAdd(theta_grad + theta_offset + 4, out_grad_y * w_coor);
phi::CudaAtomicAdd(theta_grad + theta_offset + 5, out_grad_y * h_coor);
phi::CudaAtomicAdd(theta_grad + theta_offset + 6, out_grad_y * d_coor);
phi::CudaAtomicAdd(theta_grad + theta_offset + 7, out_grad_y);
T out_grad_z = out_grad[index * 3 + 2];
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 8,
out_grad_z * w_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 9,
out_grad_z * h_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 10,
out_grad_z * d_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 11, out_grad_z);
phi::CudaAtomicAdd(theta_grad + theta_offset + 8, out_grad_z * w_coor);
phi::CudaAtomicAdd(theta_grad + theta_offset + 9, out_grad_z * h_coor);
phi::CudaAtomicAdd(theta_grad + theta_offset + 10, out_grad_z * d_coor);
phi::CudaAtomicAdd(theta_grad + theta_offset + 11, out_grad_z);
}
}
......
......@@ -18,9 +18,9 @@
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/affine_grid_utils.h"
......
......@@ -14,13 +14,13 @@
#include "paddle/phi/kernels/auc_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
__global__ void ClearObsoleteDataKernel(int64_t *pos,
int64_t *neg,
......@@ -74,9 +74,9 @@ __global__ void AddDataKernel(const int64_t *label_data,
"The predict data must gather or equal 0.");
uint32_t binIdx = static_cast<uint32_t>(predict_data * num_thresholds);
if (label_data[i]) {
paddle::platform::CudaAtomicAdd(pos + cur_step_begin + binIdx, 1);
phi::CudaAtomicAdd(pos + cur_step_begin + binIdx, 1);
} else {
paddle::platform::CudaAtomicAdd(neg + cur_step_begin + binIdx, 1);
phi::CudaAtomicAdd(neg + cur_step_begin + binIdx, 1);
}
}
}
......
......@@ -14,15 +14,15 @@
#include "paddle/phi/kernels/bincount_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
inline int GET_BLOCKS(const int N) {
return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
......@@ -36,12 +36,11 @@ __global__ void KernelBincount(const InputT* input,
OutT* output) {
if (!has_weights) {
for (int i = threadIdx.x; i < total_elements; i += blockDim.x) {
paddle::platform::CudaAtomicAdd(&output[input[i]], 1L);
phi::CudaAtomicAdd(&output[input[i]], 1L);
}
} else {
for (int i = threadIdx.x; i < total_elements; i += blockDim.x) {
paddle::platform::CudaAtomicAdd(&output[input[i]],
static_cast<OutT>(weights[i]));
phi::CudaAtomicAdd(&output[input[i]], static_cast<OutT>(weights[i]));
}
}
}
......
......@@ -18,8 +18,8 @@
#include <thrust/host_vector.h>
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/box_coder.h"
......
......@@ -14,8 +14,8 @@
#include "paddle/phi/kernels/deformable_conv_grad_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h"
......@@ -107,8 +107,8 @@ __global__ void ModulatedDeformableCol2imGpuKernel(
height,
width);
paddle::platform::CudaAtomicAdd(grad_im + cur_bottom_grad_pos,
weight * cur_top_grad);
phi::CudaAtomicAdd(grad_im + cur_bottom_grad_pos,
weight * cur_top_grad);
}
}
}
......
......@@ -28,7 +28,7 @@ namespace cub = hipcub;
#endif
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
......@@ -981,7 +981,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC(
}
#undef gaid
}
platform::CudaAtomicAdd(&filter_grad_data[gbid], s);
phi::CudaAtomicAdd(&filter_grad_data[gbid], s);
}
}
......@@ -1057,7 +1057,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC(
}
for (int i = 0; i < c_filter * c_filter; ++i) {
T* weight = filter_grad_data + i * output_channels + kernel_id;
platform::CudaAtomicAdd(&weight[0], r_weight[i]);
phi::CudaAtomicAdd(&weight[0], r_weight[i]);
}
}
}
......
......@@ -15,13 +15,13 @@
#include "paddle/phi/kernels/diagonal_grad_kernel.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/diagonal.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename Context>
void DiagonalGradKernel(const Context& dev_ctx,
......
......@@ -15,12 +15,12 @@
#include "paddle/phi/kernels/diagonal_kernel.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/diagonal.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename Context>
void DiagonalKernel(const Context& dev_ctx,
const DenseTensor& x,
......
......@@ -32,7 +32,7 @@ namespace cub = hipcub;
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
namespace phi {
......@@ -69,7 +69,7 @@ __global__ void GPUDistFpnProposalsHelper(const int nthreads,
tgt_lvl = min(max_level, max(tgt_lvl, min_level));
target_lvls[i] = tgt_lvl;
// compute number of rois in the same batch and same target level
paddle::platform::CudaAtomicAdd(
phi::CudaAtomicAdd(
sub_lod_list + (tgt_lvl - min_level) * lod_size + roi_batch_ind, 1);
}
}
......
......@@ -18,14 +18,14 @@
#include <vector>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T>
__global__ void FillFirstRow(T* dist, const int N) {
......
......@@ -16,8 +16,8 @@
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
......@@ -51,10 +51,10 @@ __global__ void EmbeddingGrad(T* table,
const T* out = output + idy * D;
T* tab = table + id * D;
#ifdef PADDLE_WITH_CUDA
paddle::platform::VectorizedAtomicAddPerBlock(D, idx, blockDim.x, out, tab);
phi::VectorizedAtomicAddPerBlock(D, idx, blockDim.x, out, tab);
#else
for (int i = idx; i < D; i += blockDim.x) {
paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
phi::CudaAtomicAdd(&tab[i], out[i]);
}
#endif
idy += blockDim.y * gridDim.x;
......
......@@ -14,8 +14,8 @@
#pragma once
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/graph_reindex_kernel.h"
......
......@@ -19,8 +19,8 @@
#include <algorithm>
#include <vector>
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/send_u_recv_kernel.h"
......@@ -32,7 +32,7 @@ struct GraphSendRecvSumCUDAFunctor {
T* output,
const IndexT& in_i,
const IndexT& out_i) {
paddle::platform::CudaAtomicAdd(output + out_i, *(params + in_i));
phi::CudaAtomicAdd(output + out_i, *(params + in_i));
}
};
......@@ -42,7 +42,7 @@ struct GraphSendRecvMaxCUDAFunctor {
T* output,
const IndexT& in_i,
const IndexT& out_i) {
paddle::platform::CudaAtomicMax(output + out_i, *(params + in_i));
phi::CudaAtomicMax(output + out_i, *(params + in_i));
}
};
......@@ -52,7 +52,7 @@ struct GraphSendRecvMinCUDAFunctor {
T* output,
const IndexT& in_i,
const IndexT& out_i) {
paddle::platform::CudaAtomicMin(output + out_i, *(params + in_i));
phi::CudaAtomicMin(output + out_i, *(params + in_i));
}
};
......@@ -106,7 +106,7 @@ __global__ void ComputeCountCUDAKernel(int32_t* count,
size_t index_size) {
CUDA_KERNEL_LOOP_TYPE(i, index_size, int64_t) {
IndexT dst_i = dst_indices[i];
paddle::platform::CudaAtomicAdd(count + dst_i, 1);
phi::CudaAtomicAdd(count + dst_i, 1);
}
}
......@@ -140,8 +140,8 @@ __global__ void ManipulateMeanGradCUDAKernel(const T* params,
IndexT dst_i = dst_indices[indices_i];
int64_t in_i = src_i * slice_size + slice_i;
int64_t out_i = dst_i * slice_size + slice_i;
paddle::platform::CudaAtomicAdd(
output + out_i, *(params + in_i) / static_cast<T>(dst_count[src_i]));
phi::CudaAtomicAdd(output + out_i,
*(params + in_i) / static_cast<T>(dst_count[src_i]));
}
}
......@@ -162,10 +162,9 @@ __global__ void ManipulateMinMaxGradCUDAKernel(const T* params,
IndexT dst_i = dst_indices[indices_i];
int64_t in_i = src_i * slice_size + slice_i;
int64_t out_i = dst_i * slice_size + slice_i;
paddle::platform::CudaAtomicAdd(
output + out_i,
*(params + in_i) *
static_cast<T>(*(ptr_input + out_i) == *(ptr_output + in_i)));
phi::CudaAtomicAdd(output + out_i,
*(params + in_i) * static_cast<T>(*(ptr_input + out_i) ==
*(ptr_output + in_i)));
}
}
......
......@@ -17,8 +17,8 @@
#include <thrust/device_vector.h>
#include <thrust/fill.h>
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
......@@ -29,25 +29,25 @@ namespace phi {
#define CUDA_MAX_NUM_BLOCKS_Z 0xFFFF
inline void CopyBCastOff(const BroadCastInfo& bcast_info,
thrust::device_vector<int64_t>& l_bcastoff,
thrust::device_vector<int64_t>& r_bcastoff) {
l_bcastoff.resize(bcast_info.out_len);
r_bcastoff.resize(bcast_info.out_len);
thrust::device_vector<int64_t>* l_bcastoff,
thrust::device_vector<int64_t>* r_bcastoff) {
l_bcastoff->resize(bcast_info.out_len);
r_bcastoff->resize(bcast_info.out_len);
#ifdef PADDLE_WITH_HIP
hipMemcpy(thrust::raw_pointer_cast(l_bcastoff.data()),
hipMemcpy(thrust::raw_pointer_cast(l_bcastoff->data()),
bcast_info.l_offset.data(),
sizeof(int64_t) * bcast_info.out_len,
hipMemcpyHostToDevice);
hipMemcpy(thrust::raw_pointer_cast(r_bcastoff.data()),
hipMemcpy(thrust::raw_pointer_cast(r_bcastoff->data()),
bcast_info.r_offset.data(),
sizeof(int64_t) * bcast_info.out_len,
hipMemcpyHostToDevice);
#else
cudaMemcpy(thrust::raw_pointer_cast(l_bcastoff.data()),
cudaMemcpy(thrust::raw_pointer_cast(l_bcastoff->data()),
bcast_info.l_offset.data(),
sizeof(int64_t) * bcast_info.out_len,
cudaMemcpyHostToDevice);
cudaMemcpy(thrust::raw_pointer_cast(r_bcastoff.data()),
cudaMemcpy(thrust::raw_pointer_cast(r_bcastoff->data()),
bcast_info.r_offset.data(),
sizeof(int64_t) * bcast_info.out_len,
cudaMemcpyHostToDevice);
......@@ -102,21 +102,21 @@ inline int FindNumBlocks(char axis, int nblocks, int max_num_blocks = -1) {
template <typename T>
struct GraphSendUERecvSumCUDAFunctor {
DEVICE inline void operator()(T* output, T val) {
paddle::platform::CudaAtomicAdd(output, val);
phi::CudaAtomicAdd(output, val);
}
};
template <typename T>
struct GraphSendUERecvMaxCUDAFunctor {
DEVICE inline void operator()(T* output, T val) {
paddle::platform::CudaAtomicMax(output, val);
phi::CudaAtomicMax(output, val);
}
};
template <typename T>
struct GraphSendUERecvMinCUDAFunctor {
DEVICE inline void operator()(T* output, T val) {
paddle::platform::CudaAtomicMin(output, val);
phi::CudaAtomicMin(output, val);
}
};
......@@ -192,8 +192,7 @@ __global__ void ManipulateMeanGradCUDAKernelForMulX(const T* out_grad_data,
int64_t o_add = use_bcast ? l_bcastoff[tx] : tx;
int64_t e_add = use_bcast ? r_bcastoff[tx] : tx;
T val = out_grad_off[o_add] * e_off[e_add];
paddle::platform::CudaAtomicAdd(x_grad_off + tx,
val / static_cast<T>(dst_count[src]));
phi::CudaAtomicAdd(x_grad_off + tx, val / static_cast<T>(dst_count[src]));
tx += stride_x;
}
ty += stride_y;
......@@ -222,7 +221,7 @@ __global__ void ManipulateSumGradCUDAKernelForAddE(const T* out_grad_data,
const T* out_grad_off = out_grad_data + dst * out_len;
while (tx < out_len) {
int64_t e_add = use_bcast ? r_bcastoff[tx] : tx;
paddle::platform::CudaAtomicAdd(e_grad_off + e_add, out_grad_off[tx]);
phi::CudaAtomicAdd(e_grad_off + e_add, out_grad_off[tx]);
tx += stride_x;
}
ty += stride_y;
......@@ -258,8 +257,7 @@ __global__ void ManipulateSumGradCUDAKernelForMulE(const T* x_data,
while (tx < out_len) {
int64_t x_add = use_bcast ? l_bcastoff[tx] : tx;
int64_t e_add = use_bcast ? r_bcastoff[tx] : tx;
paddle::platform::CudaAtomicAdd(e_grad_off + e_add,
out_grad_off[tx] * x_off[x_add]);
phi::CudaAtomicAdd(e_grad_off + e_add, out_grad_off[tx] * x_off[x_add]);
tx += stride_x;
}
ty += stride_y;
......@@ -289,9 +287,8 @@ __global__ void ManipulateMeanGradCUDAKernelForAddE(const T* out_grad_data,
const T* out_grad_off = out_grad_data + dst * out_len;
while (tx < out_len) {
int64_t e_add = use_bcast ? r_bcastoff[tx] : tx;
paddle::platform::CudaAtomicAdd(
e_grad_off + e_add,
out_grad_off[tx] / static_cast<T>(dst_count[dst]));
phi::CudaAtomicAdd(e_grad_off + e_add,
out_grad_off[tx] / static_cast<T>(dst_count[dst]));
tx += stride_x;
}
ty += stride_y;
......@@ -328,7 +325,7 @@ __global__ void ManipulateMeanGradCUDAKernelForMulE(const T* x_data,
while (tx < out_len) {
int64_t x_add = use_bcast ? l_bcastoff[tx] : tx;
int64_t e_add = use_bcast ? r_bcastoff[tx] : tx;
paddle::platform::CudaAtomicAdd(
phi::CudaAtomicAdd(
e_grad_off + e_add,
out_grad_off[tx] * x_off[x_add] / static_cast<T>(dst_count[dst]));
tx += stride_x;
......@@ -373,12 +370,10 @@ __global__ void ManipulateMinMaxGradCUDAKernelForAdd(const T* x_data,
int64_t x_add = use_bcast ? xbcast_off[tx] : tx;
int64_t e_add = use_bcast ? ebcast_off[tx] : tx;
T val = x_off[x_add] + e_off[e_add];
paddle::platform::CudaAtomicAdd(
x_grad_off + x_add,
out_grad_off[tx] * static_cast<T>(val == out_off[tx]));
paddle::platform::CudaAtomicAdd(
e_grad_off + e_add,
out_grad_off[tx] * static_cast<T>(val == out_off[tx]));
phi::CudaAtomicAdd(x_grad_off + x_add,
out_grad_off[tx] * static_cast<T>(val == out_off[tx]));
phi::CudaAtomicAdd(e_grad_off + e_add,
out_grad_off[tx] * static_cast<T>(val == out_off[tx]));
tx += stride_x;
}
ty += stride_y;
......@@ -421,10 +416,10 @@ __global__ void ManipulateMinMaxGradCUDAKernelForMul(const T* x_data,
int64_t x_add = use_bcast ? xbcast_off[tx] : tx;
int64_t e_add = use_bcast ? ebcast_off[tx] : tx;
T val = x_off[x_add] * e_off[e_add];
paddle::platform::CudaAtomicAdd(
phi::CudaAtomicAdd(
x_grad_off + x_add,
out_grad_off[tx] * static_cast<T>(val == out_off[tx]) * e_off[e_add]);
paddle::platform::CudaAtomicAdd(
phi::CudaAtomicAdd(
e_grad_off + e_add,
out_grad_off[tx] * static_cast<T>(val == out_off[tx]) * x_off[x_add]);
tx += stride_x;
......
......@@ -15,9 +15,9 @@
#include "paddle/phi/kernels/grid_sample_grad_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpu/grid_sample_utils.h"
......@@ -28,7 +28,7 @@ template <typename T>
static __forceinline__ __device__ void AtomicAdd(
T* data, int h, int w, int sH, int sW, int H, int W, T delta) {
if (InBounds(h, w, H, W)) {
paddle::platform::CudaAtomicAdd(data + h * sH + w * sW, delta);
phi::CudaAtomicAdd(data + h * sH + w * sW, delta);
}
}
......@@ -45,7 +45,7 @@ static __forceinline__ __device__ void AtomicAdd3D(T* data,
int W,
T delta) {
if (InBounds3D(d, h, w, D, H, W)) {
paddle::platform::CudaAtomicAdd(data + d * sD + h * sH + w * sW, delta);
phi::CudaAtomicAdd(data + d * sD + h * sH + w * sW, delta);
}
}
......
......@@ -71,14 +71,14 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x,
if (flags & kHasScale) {
#if CUDA_VERSION >= 11070
paddle::platform::CudaAtomicAdd(&(d_scale[ccid]), d_scale_data);
phi::CudaAtomicAdd(&(d_scale[ccid]), d_scale_data);
#else
CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data);
#endif
}
if (flags & kHasBias) {
#if CUDA_VERSION >= 11070
paddle::platform::CudaAtomicAdd(&(d_bias[ccid]), d_bias_data);
phi::CudaAtomicAdd(&(d_bias[ccid]), d_bias_data);
#else
CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data);
#endif
......
......@@ -23,7 +23,7 @@ namespace cub = hipcub;
#endif
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
namespace phi {
......@@ -51,7 +51,7 @@ __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
typedef cub::WarpReduce<T> WarpReduce;
typename WarpReduce::TempStorage temp_storage;
value = WarpReduce(temp_storage).Sum(value);
if (cub::LaneId() == 0) paddle::platform::CudaAtomicAdd(sum, value);
if (cub::LaneId() == 0) phi::CudaAtomicAdd(sum, value);
}
template <typename T, typename AccT, int VecSize, int Num>
......
......@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/histogram_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
......@@ -25,7 +25,7 @@
namespace phi {
using IndexType = int64_t;
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
inline int GET_BLOCKS(const int N) {
return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
......@@ -61,13 +61,13 @@ __global__ void KernelHistogram(const T* input,
if (input_value >= min_value && input_value <= max_value) {
const IndexType output_index =
GetBin<T, IndexType>(input_value, min_value, max_value, nbins);
paddle::platform::CudaAtomicAdd(&buf_hist[output_index], 1);
phi::CudaAtomicAdd(&buf_hist[output_index], 1);
}
}
__syncthreads();
for (int i = threadIdx.x; i < nbins; i += blockDim.x) {
paddle::platform::CudaAtomicAdd(&output[i], buf_hist[i]);
phi::CudaAtomicAdd(&output[i], buf_hist[i]);
}
}
......
......@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/index_add_grad_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......@@ -24,7 +24,7 @@
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename Context>
void IndexAddGradKernel(const Context& ctx,
......
......@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/index_add_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
......@@ -24,7 +24,7 @@ DECLARE_bool(cudnn_deterministic);
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename IndexT>
__global__ void index_add_cuda_kernel(const T* input,
......@@ -41,7 +41,7 @@ __global__ void index_add_cuda_kernel(const T* input,
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx =
idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
paddle::platform::CudaAtomicAdd(&output[input_idx], add_value[idx]);
phi::CudaAtomicAdd(&output[input_idx], add_value[idx]);
}
}
......
......@@ -18,9 +18,9 @@
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......@@ -50,8 +50,8 @@ __global__ void IndexSampleGrad(const IndexT* index,
unsigned int in_idx = index_j * input_length + index_i;
IndexT sample_idx = index[index_idx];
if (same_data_in_row) {
paddle::platform::CudaAtomicAdd(
&(in_grad[in_idx - index_i + sample_idx]), out_grad[sample_idx]);
phi::CudaAtomicAdd(&(in_grad[in_idx - index_i + sample_idx]),
out_grad[sample_idx]);
} else {
in_grad[in_idx - index_i + sample_idx] = out_grad[index_idx];
}
......
......@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/index_select_grad_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......@@ -25,7 +25,7 @@ DECLARE_bool(cudnn_deterministic);
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename IndexT>
__global__ void index_select_grad_cuda_kernel(const T* output_grad,
......@@ -42,7 +42,7 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad,
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx =
idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
phi::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
}
}
......
......@@ -14,15 +14,15 @@
#pragma once
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename IndexT>
__global__ void index_select_cuda_kernel(const T* input,
......
......@@ -14,16 +14,16 @@
#include "paddle/phi/kernels/index_select_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/gpu/index_select_impl.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename Context>
void IndexSelectKernel(const Context& ctx,
......
......@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/interpolate_grad_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
......@@ -96,12 +96,11 @@ __global__ void KeLinearInterpBw(T* in,
const T* out_pos = &out[out_id_w];
if (data_layout == DataLayout::kNCHW) {
paddle::platform::CudaAtomicAdd(&in_pos[0], w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(&in_pos[w_id], w1lambda * out_pos[0]);
phi::CudaAtomicAdd(&in_pos[0], w2lambda * out_pos[0]);
phi::CudaAtomicAdd(&in_pos[w_id], w1lambda * out_pos[0]);
} else {
paddle::platform::CudaAtomicAdd(&in_pos[0], w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(&in_pos[w_id * num_channels],
w1lambda * out_pos[0]);
phi::CudaAtomicAdd(&in_pos[0], w2lambda * out_pos[0]);
phi::CudaAtomicAdd(&in_pos[w_id * num_channels], w1lambda * out_pos[0]);
}
}
}
......@@ -141,7 +140,7 @@ __global__ void KeNearestNeighborInterpNCHWBw(T* in,
while (nc_id < nc) {
T* in_pos = &in[in_index];
const T out_pos = out[out_index];
paddle::platform::CudaAtomicAdd(in_pos, out_pos);
phi::CudaAtomicAdd(in_pos, out_pos);
in_index += in_index_stride;
out_index += out_index_stride;
nc_id += nc_stride;
......@@ -194,7 +193,7 @@ __global__ void KeNearestNeighborInterpBw(
in_img_idx * num_channels + channel_id];
const T out_pos = out[tid];
paddle::platform::CudaAtomicAdd(in_pos, out_pos);
phi::CudaAtomicAdd(in_pos, out_pos);
}
}
......@@ -218,7 +217,7 @@ __inline__ __device__ T PartialBlockMin(T val,
}
} else {
shared_last_val = std::numeric_limits<T>::max();
paddle::platform::CudaAtomicMin(&shared_last_val, val);
phi::CudaAtomicMin(&shared_last_val, val);
shared[wid] = shared_last_val;
shared_last_idx = wid;
}
......@@ -308,33 +307,27 @@ __global__ void KeBilinearInterpBwShareMemory(T* in,
? (in_top_max_index - in_top_min_index)
: (in_bot_max_index - in_bot_min_index);
if (h_id != 0) {
paddle::platform::CudaAtomicAdd(
&s_data[0][input_index - in_top_min_index],
h2lambda * w2lambda * value);
paddle::platform::CudaAtomicAdd(
&s_data[0][top_right_index - in_top_min_index],
h2lambda * w1lambda * value);
paddle::platform::CudaAtomicAdd(
&s_data[1][bot_left_index - in_bot_min_index],
h1lambda * w2lambda * value);
paddle::platform::CudaAtomicAdd(
&s_data[1][bot_right_index - in_bot_min_index],
h1lambda * w1lambda * value);
phi::CudaAtomicAdd(&s_data[0][input_index - in_top_min_index],
h2lambda * w2lambda * value);
phi::CudaAtomicAdd(&s_data[0][top_right_index - in_top_min_index],
h2lambda * w1lambda * value);
phi::CudaAtomicAdd(&s_data[1][bot_left_index - in_bot_min_index],
h1lambda * w2lambda * value);
phi::CudaAtomicAdd(&s_data[1][bot_right_index - in_bot_min_index],
h1lambda * w1lambda * value);
} else {
paddle::platform::CudaAtomicAdd(
&s_data[0][top_right_index - in_top_min_index],
(h2lambda + h1lambda) * w1lambda * value);
paddle::platform::CudaAtomicAdd(
&s_data[1][bot_left_index - in_bot_min_index],
(h1lambda + h2lambda) * w2lambda * value);
phi::CudaAtomicAdd(&s_data[0][top_right_index - in_top_min_index],
(h2lambda + h1lambda) * w1lambda * value);
phi::CudaAtomicAdd(&s_data[1][bot_left_index - in_bot_min_index],
(h1lambda + h2lambda) * w2lambda * value);
}
__syncthreads();
if (threadIdx.x <= upper_limit_share_idx) {
paddle::platform::CudaAtomicAdd(&in[in_top_min_index + threadIdx.x],
s_data[0][threadIdx.x]);
paddle::platform::CudaAtomicAdd(&in[in_bot_min_index + threadIdx.x],
s_data[1][threadIdx.x]);
phi::CudaAtomicAdd(&in[in_top_min_index + threadIdx.x],
s_data[0][threadIdx.x]);
phi::CudaAtomicAdd(&in[in_bot_min_index + threadIdx.x],
s_data[1][threadIdx.x]);
}
}
}
......@@ -387,17 +380,14 @@ __global__ void KeBilinearInterpNCHWBw(T* in,
T d2val = out[index];
paddle::platform::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1),
h0lambda * w0lambda * d2val);
paddle::platform::CudaAtomicAdd(
in + GetInputIndex(nc, in_h, in_w, h1, w1 + x_id),
h0lambda * w1lambda * d2val);
paddle::platform::CudaAtomicAdd(
in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1),
h1lambda * w0lambda * d2val);
paddle::platform::CudaAtomicAdd(
in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1 + x_id),
h1lambda * w1lambda * d2val);
phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1),
h0lambda * w0lambda * d2val);
phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1 + x_id),
h0lambda * w1lambda * d2val);
phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1),
h1lambda * w0lambda * d2val);
phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1 + x_id),
h1lambda * w1lambda * d2val);
}
}
......@@ -446,12 +436,12 @@ __global__ void KeBilinearInterpBw(T* in,
T value = out[tid];
T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels +
in_img_idx * num_channels + channel_id];
paddle::platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value);
paddle::platform::CudaAtomicAdd(&in_pos[w_id * num_channels],
h2lambda * w1lambda * value);
paddle::platform::CudaAtomicAdd(&in_pos[h_id * in_w * num_channels],
h1lambda * w2lambda * value);
paddle::platform::CudaAtomicAdd(
phi::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value);
phi::CudaAtomicAdd(&in_pos[w_id * num_channels],
h2lambda * w1lambda * value);
phi::CudaAtomicAdd(&in_pos[h_id * in_w * num_channels],
h1lambda * w2lambda * value);
phi::CudaAtomicAdd(
&in_pos[h_id * in_w * num_channels + w_id * num_channels],
h1lambda * w1lambda * value);
}
......@@ -530,8 +520,8 @@ __global__ void KeBicubicInterpBw(T* in,
in_pos = &in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x * num_channels + channel_id];
}
paddle::platform::CudaAtomicAdd(
&in_pos[0], (out_pos[0] * y_coeffs[j] * x_coeffs[i]));
phi::CudaAtomicAdd(&in_pos[0],
(out_pos[0] * y_coeffs[j] * x_coeffs[i]));
}
}
}
......@@ -629,26 +619,22 @@ __global__ void KeTrilinearInterpBw(T* in,
const T* out_pos = &out[out_id_h * output_w + out_id_w];
// trilinear interpolation grad
paddle::platform::CudaAtomicAdd(
&in_pos1[0], d2lambda * h2lambda * w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos1[w_id], d2lambda * h2lambda * w1lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos1[h_id * in_img_w],
d2lambda * h1lambda * w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos1[h_id * in_img_w + w_id],
d2lambda * h1lambda * w1lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos2[0], d1lambda * h2lambda * w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos2[w_id], d1lambda * h2lambda * w1lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos2[h_id * in_img_w],
d1lambda * h1lambda * w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos2[h_id * in_img_w + w_id],
d1lambda * h1lambda * w1lambda * out_pos[0]);
phi::CudaAtomicAdd(&in_pos1[0],
d2lambda * h2lambda * w2lambda * out_pos[0]);
phi::CudaAtomicAdd(&in_pos1[w_id],
d2lambda * h2lambda * w1lambda * out_pos[0]);
phi::CudaAtomicAdd(&in_pos1[h_id * in_img_w],
d2lambda * h1lambda * w2lambda * out_pos[0]);
phi::CudaAtomicAdd(&in_pos1[h_id * in_img_w + w_id],
d2lambda * h1lambda * w1lambda * out_pos[0]);
phi::CudaAtomicAdd(&in_pos2[0],
d1lambda * h2lambda * w2lambda * out_pos[0]);
phi::CudaAtomicAdd(&in_pos2[w_id],
d1lambda * h2lambda * w1lambda * out_pos[0]);
phi::CudaAtomicAdd(&in_pos2[h_id * in_img_w],
d1lambda * h1lambda * w2lambda * out_pos[0]);
phi::CudaAtomicAdd(&in_pos2[h_id * in_img_w + w_id],
d1lambda * h1lambda * w1lambda * out_pos[0]);
} else {
int in_pos1_idx = out_id_h * input_w +
in_img_idt * in_img_h * in_img_w * num_channels +
......@@ -661,26 +647,22 @@ __global__ void KeTrilinearInterpBw(T* in,
const T* out_pos = &out[out_id_h * output_w + out_id_w];
// trilinear interpolation grad
paddle::platform::CudaAtomicAdd(
&in_pos1[0], d2lambda * h2lambda * w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos1[w_id * num_channels],
d2lambda * h2lambda * w1lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos1[h_id * in_img_w * num_channels],
d2lambda * h1lambda * w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
phi::CudaAtomicAdd(&in_pos1[0],
d2lambda * h2lambda * w2lambda * out_pos[0]);
phi::CudaAtomicAdd(&in_pos1[w_id * num_channels],
d2lambda * h2lambda * w1lambda * out_pos[0]);
phi::CudaAtomicAdd(&in_pos1[h_id * in_img_w * num_channels],
d2lambda * h1lambda * w2lambda * out_pos[0]);
phi::CudaAtomicAdd(
&in_pos1[h_id * in_img_w * num_channels + w_id * num_channels],
d2lambda * h1lambda * w1lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos2[0], d1lambda * h2lambda * w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos2[w_id * num_channels],
d1lambda * h2lambda * w1lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos2[h_id * in_img_w * num_channels],
d1lambda * h1lambda * w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
phi::CudaAtomicAdd(&in_pos2[0],
d1lambda * h2lambda * w2lambda * out_pos[0]);
phi::CudaAtomicAdd(&in_pos2[w_id * num_channels],
d1lambda * h2lambda * w1lambda * out_pos[0]);
phi::CudaAtomicAdd(&in_pos2[h_id * in_img_w * num_channels],
d1lambda * h1lambda * w2lambda * out_pos[0]);
phi::CudaAtomicAdd(
&in_pos2[h_id * in_img_w * num_channels + w_id * num_channels],
d1lambda * h1lambda * w1lambda * out_pos[0]);
}
......@@ -751,7 +733,7 @@ __global__ void KeNearestNeighbor3DInterpBw(T* in,
in_img_idx * num_channels + channel_id];
}
const T out_pos = out[out_id_h * output_w + out_id_w];
paddle::platform::CudaAtomicAdd(in_pos, out_pos);
phi::CudaAtomicAdd(in_pos, out_pos);
}
}
......
......@@ -15,9 +15,9 @@
#include "paddle/phi/kernels/interpolate_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h"
......
......@@ -14,8 +14,8 @@
#include "paddle/phi/kernels/linspace_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/nanmedian_grad_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......@@ -24,7 +24,7 @@
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
inline int GET_BLOCKS(const int N) {
return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
}
......
......@@ -15,9 +15,9 @@
#include "paddle/phi/kernels/nanmedian_kernel.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/impl/nanmedian_kernel_impl.h"
......@@ -25,7 +25,7 @@
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
inline int GET_BLOCKS(const int N) {
return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
......@@ -56,15 +56,15 @@ __global__ void KernelNanCounts(const T* input,
const T x = input[index];
if (isnan(static_cast<float>(x))) {
auto bin = static_cast<int64_t>(index / stride);
paddle::platform::CudaAtomicAdd(&buf[bin], 1);
phi::CudaAtomicAdd(&buf[bin], 1);
}
}
__syncthreads();
for (int i = threadIdx.x; i < pre_dim; i += blockDim.x) {
paddle::platform::CudaAtomicAdd(&nan_counts[i], buf[i]);
paddle::platform::CudaAtomicAdd(&nan_total[0], buf[i]);
paddle::platform::CudaAtomicMax(&nan_total[1], stride - buf[i]);
phi::CudaAtomicAdd(&nan_counts[i], buf[i]);
phi::CudaAtomicAdd(&nan_total[0], buf[i]);
phi::CudaAtomicMax(&nan_total[1], stride - buf[i]);
}
}
......
......@@ -20,7 +20,7 @@
#include <string>
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h"
......@@ -270,8 +270,8 @@ __global__ void GPUNLLLossForward2D_with_reduce(T* out_data,
partial_sums, blockDim.x, acc_weight, thrust::plus<T>(), (T)0);
if (threadIdx.x == 0) {
paddle::platform::CudaAtomicAdd(total_weight_data, acc_weight);
paddle::platform::CudaAtomicAdd(out_data, input_sum);
phi::CudaAtomicAdd(total_weight_data, acc_weight);
phi::CudaAtomicAdd(out_data, input_sum);
}
}
......
......@@ -16,8 +16,8 @@
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -15,14 +15,14 @@
#include "paddle/phi/kernels/one_hot_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename InT, typename OutT>
__global__ void FillOutputKernel(const InT* p_in_data,
......
......@@ -15,14 +15,14 @@
#include "paddle/phi/kernels/pad3d_grad_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T>
__global__ void Pad3DGradConstNCDHW(const int in_size,
......@@ -133,7 +133,7 @@ __global__ void Pad3DGradReflectNCDHW(const int out_size,
in_h = min(in_h, 2 * in_height - in_h - 2);
in_w = min(in_w, 2 * in_width - in_w - 2);
paddle::platform::CudaAtomicAdd(
phi::CudaAtomicAdd(
&d_in_data[nc * in_depth * in_height * in_width +
in_d * in_height * in_width + in_h * in_width + in_w],
d_out_data[out_index]);
......@@ -176,7 +176,7 @@ __global__ void Pad3DGradReflectNDHWC(const int out_size,
in_d = min(in_d, in_depth * 2 - in_d - 2);
in_h = min(in_h, in_height * 2 - in_h - 2);
in_w = min(in_w, in_width * 2 - in_w - 2);
paddle::platform::CudaAtomicAdd(
phi::CudaAtomicAdd(
&d_in_data[n * in_depth * in_height * in_width * channels +
in_d * in_height * in_width * channels +
in_h * in_width * channels + in_w * channels + c],
......@@ -211,7 +211,7 @@ __global__ void Pad3DGradReplicateNCDHW(const int out_size,
const int in_h = min(in_height - 1, max(out_h - pad_top, 0));
const int in_w = min(in_width - 1, max(out_w - pad_left, 0));
paddle::platform::CudaAtomicAdd(
phi::CudaAtomicAdd(
&d_in_data[nc * in_depth * in_height * in_width +
in_d * in_height * in_width + in_h * in_width + in_w],
d_out_data[out_index]);
......@@ -247,7 +247,7 @@ __global__ void Pad3DGradReplicateNDHWC(const int out_size,
const int in_h = min(in_height - 1, max(out_h - pad_top, 0));
const int in_w = min(in_width - 1, max(out_w - pad_left, 0));
paddle::platform::CudaAtomicAdd(
phi::CudaAtomicAdd(
&d_in_data[n * in_depth * in_height * in_width * channels +
in_d * in_height * in_width * channels +
in_h * in_width * channels + in_w * channels + c],
......@@ -282,7 +282,7 @@ __global__ void Pad3DGradCircularNCDHW(const int out_size,
int in_h = ((out_h - pad_top) % in_height + in_height) % in_height;
int in_w = ((out_w - pad_left) % in_width + in_width) % in_width;
paddle::platform::CudaAtomicAdd(
phi::CudaAtomicAdd(
&d_in_data[nc * in_depth * in_height * in_width +
in_d * in_height * in_width + in_h * in_width + in_w],
d_out_data[out_index]);
......@@ -318,7 +318,7 @@ __global__ void Pad3DGradCircularNDHWC(const int out_size,
int in_h = ((out_h - pad_top) % in_height + in_height) % in_height;
int in_w = ((out_w - pad_left) % in_width + in_width) % in_width;
paddle::platform::CudaAtomicAdd(
phi::CudaAtomicAdd(
&d_in_data[n * in_depth * in_height * in_width * channels +
in_d * in_height * in_width * channels +
in_h * in_width * channels + in_w * channels + c],
......
......@@ -17,14 +17,14 @@
#include <algorithm>
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T>
__global__ void Pad3DConstNCDHW(const int nthreads,
......
......@@ -16,7 +16,7 @@
#include <vector>
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
......@@ -97,7 +97,7 @@ __global__ void GPUPSROIPoolBackward(const int nthreads,
for (int ih = hstart; ih < hend; ++ih) {
for (int iw = wstart; iw < wend; ++iw) {
int input_index = ih * width + iw;
paddle::platform::CudaAtomicAdd(offset_dx_data + input_index, diff_val);
phi::CudaAtomicAdd(offset_dx_data + input_index, diff_val);
}
}
}
......
......@@ -15,9 +15,9 @@
#include "paddle/phi/kernels/roi_align_grad_kernel.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
......@@ -153,14 +153,11 @@ __global__ void GPURoiAlignBackward(const int nthreads,
T diff3 = out_grad_this_bin * w3 / count;
T diff4 = out_grad_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
paddle::platform::CudaAtomicAdd(
offset_input_grad + y_low * width + x_low, diff1);
paddle::platform::CudaAtomicAdd(
offset_input_grad + y_low * width + x_high, diff2);
paddle::platform::CudaAtomicAdd(
offset_input_grad + y_high * width + x_low, diff3);
paddle::platform::CudaAtomicAdd(
offset_input_grad + y_high * width + x_high, diff4);
phi::CudaAtomicAdd(offset_input_grad + y_low * width + x_low, diff1);
phi::CudaAtomicAdd(offset_input_grad + y_low * width + x_high, diff2);
phi::CudaAtomicAdd(offset_input_grad + y_high * width + x_low, diff3);
phi::CudaAtomicAdd(offset_input_grad + y_high * width + x_high,
diff4);
}
}
}
......
......@@ -15,9 +15,9 @@
#include "paddle/phi/kernels/roi_pool_grad_kernel.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......@@ -63,7 +63,7 @@ __global__ void GPURoiPoolBackward(const int nthreads,
int arg_max = offset_arg_max_data[ph * pooled_width + pw];
if (arg_max != -1) {
paddle::platform::CudaAtomicAdd(
phi::CudaAtomicAdd(
offset_input_grad + arg_max,
static_cast<T>(offset_output_grad[ph * pooled_width + pw]));
}
......
......@@ -20,7 +20,7 @@
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename Context>
void RollGradKernel(const Context& dev_ctx,
......
......@@ -21,7 +21,7 @@
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename Context>
void RollKernel(const Context& dev_ctx,
......
......@@ -14,13 +14,13 @@
#pragma once
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/utils/array.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T, size_t Rank>
__global__ void RollCudaKernel(const T* input,
......
......@@ -45,7 +45,7 @@ void CalculateXEGradForMinMax(const Context& ctx,
const auto& bcast_info = phi::CalcBCastInfo(x_dims, e_dims);
thrust::device_vector<int64_t> l_bcastoff, r_bcastoff;
if (bcast_info.use_bcast) {
CopyBCastOff(bcast_info, l_bcastoff, r_bcastoff);
CopyBCastOff(bcast_info, &l_bcastoff, &r_bcastoff);
}
int64_t out_len = bcast_info.out_len;
......@@ -177,7 +177,7 @@ void CalculateXGrad(const Context& ctx,
const auto& bcast_info = phi::CalcBCastInfo(out_grad_dims, e_dims);
thrust::device_vector<int64_t> l_bcastoff, r_bcastoff;
if (bcast_info.use_bcast) {
CopyBCastOff(bcast_info, l_bcastoff, r_bcastoff);
CopyBCastOff(bcast_info, &l_bcastoff, &r_bcastoff);
}
int64_t out_len = bcast_info.out_len;
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
......@@ -300,7 +300,7 @@ void CalculateXGrad(const Context& ctx,
const auto& bcast_info = phi::CalcBCastInfo(out_grad_dims, e_dims);
thrust::device_vector<int64_t> l_bcastoff, r_bcastoff;
if (bcast_info.use_bcast) {
CopyBCastOff(bcast_info, l_bcastoff, r_bcastoff);
CopyBCastOff(bcast_info, &l_bcastoff, &r_bcastoff);
}
int64_t out_len = bcast_info.out_len;
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
......@@ -386,7 +386,7 @@ void CalculateEGrad(const Context& ctx,
const auto& bcast_info = phi::CalcBCastInfo(x_dims, e_dims);
thrust::device_vector<int64_t> l_bcastoff, r_bcastoff;
if (bcast_info.use_bcast) {
CopyBCastOff(bcast_info, l_bcastoff, r_bcastoff);
CopyBCastOff(bcast_info, &l_bcastoff, &r_bcastoff);
}
int64_t out_len = bcast_info.out_len;
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
......
......@@ -89,7 +89,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx,
thrust::device_vector<int64_t> x_bcastoff, e_bcastoff;
if (bcast_info.use_bcast) {
CopyBCastOff(bcast_info, x_bcastoff, e_bcastoff);
CopyBCastOff(bcast_info, &x_bcastoff, &e_bcastoff);
}
int64_t out_len = bcast_info.out_len;
......
......@@ -45,7 +45,7 @@ __global__ void GraphSendUVGradCUDAKernel(const T* out_grad,
const T* out_grad_off = out_grad + ty * slice_size;
T* x_grad_off = x_grad + dst * slice_size;
while (tx < slice_size) {
paddle::platform::CudaAtomicAdd(x_grad_off + tx, out_grad_off[tx]);
phi::CudaAtomicAdd(x_grad_off + tx, out_grad_off[tx]);
tx += stride_x;
}
ty += stride_y;
......@@ -127,7 +127,7 @@ void CalculateGrad(const Context& ctx,
const auto& bcast_info = phi::CalcBCastInfo(y.dims(), out_grad_dims);
thrust::device_vector<int64_t> l_bcastoff, r_bcastoff;
if (bcast_info.use_bcast) {
CopyBCastOff(bcast_info, l_bcastoff, r_bcastoff);
CopyBCastOff(bcast_info, &l_bcastoff, &r_bcastoff);
}
int64_t out_len = bcast_info.out_len;
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
......
......@@ -94,7 +94,7 @@ void GraphSendUVOpCUDAKernelLaunchHelper(const Context& ctx,
thrust::device_vector<int64_t> x_bcastoff, y_bcastoff;
if (bcast_info.use_bcast) {
CopyBCastOff(bcast_info, x_bcastoff, y_bcastoff);
CopyBCastOff(bcast_info, &x_bcastoff, &y_bcastoff);
}
int64_t out_len = bcast_info.out_len;
......
......@@ -16,9 +16,9 @@
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_helper.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
......@@ -56,7 +56,7 @@ __global__ void SparseSGDFunctorKernel(const T* selected_rows,
for (int64_t index = threadIdx.x; index < row_numel; index += blockDim.x) {
// Since index in rows of SelectedRows can be duplicate, we have to use
// Atomic Operation to avoid concurrent write error.
paddle::platform::CudaAtomicAdd(
phi::CudaAtomicAdd(
tensor_out_ptr + index,
-static_cast<T>(1.0) * learning_rate[0] * selected_rows_ptr[index]);
}
......
......@@ -14,13 +14,13 @@
#include "paddle/phi/kernels/shard_index_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T>
__global__ void ShardIndexInner(const T* in_data,
......
......@@ -18,9 +18,9 @@
#include "paddle/fluid/framework/gpu_utils.h"
#include "paddle/fluid/operators/transpose_op.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/transpose_grad_kernel_impl.h"
......
......@@ -14,14 +14,14 @@
#include "paddle/phi/kernels/trunc_grad_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T>
__global__ void TruncGrad(T* dx, int64_t N) {
......
......@@ -14,14 +14,14 @@
#include "paddle/phi/kernels/trunc_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T>
class TruncFunctor {
......
......@@ -18,9 +18,9 @@
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/kernel_registry.h"
......
......@@ -18,9 +18,9 @@
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/kernel_registry.h"
......
......@@ -18,7 +18,7 @@
#include "paddle/phi/kernels/cpu/index_select_impl.h"
#include "paddle/phi/kernels/repeat_interleave_grad_kernel.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
......@@ -33,7 +33,7 @@ namespace cub = hipcub;
namespace phi {
#if defined(__NVCC__) || defined(__HIPCC__)
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename IndexT>
__global__ void index_select_grad_cuda_kernel(const T* output_grad,
......@@ -53,7 +53,7 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad,
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
phi::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
}
template <typename T>
......
......@@ -18,9 +18,9 @@
#include "paddle/phi/kernels/cpu/index_select_impl.h"
#include "paddle/phi/kernels/repeat_interleave_kernel.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_resources.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#endif
......@@ -30,7 +30,7 @@
namespace phi {
#if defined(__NVCC__) || defined(__HIPCC__)
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename IndexT>
__global__ void index_select_cuda_kernel(const T* input,
T* output,
......@@ -81,9 +81,8 @@ void RepeatInterleaveKernel(const Context& ctx,
output_dim[dim] = index_size;
out->Resize(phi::make_ddim(output_dim));
phi::IndexSelectInner<Context, T, int>(ctx, &x_copy, index, out, dim);
}
#if defined(__NVCC__) || defined(__HIPCC__)
else {
} else {
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
paddle::framework::TensorFromVector<int>(index_vec, ctx, &index);
......@@ -105,6 +104,8 @@ void RepeatInterleaveKernel(const Context& ctx,
stream>>>(
x.data<T>(), out_data, index_data, numel, stride, size, delta);
}
#else
}
#endif
}
......@@ -163,9 +164,8 @@ void RepeatInterleaveWithTensorIndexKernel(const Context& ctx,
out->Resize(phi::make_ddim(output_dim));
IndexSelectInner<Context, T, int64_t>(ctx, &x_copy, index, out, dim);
}
}
#if defined(__NVCC__) || defined(__HIPCC__)
else {
} else {
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
auto stream = ctx.stream();
......@@ -209,6 +209,8 @@ void RepeatInterleaveWithTensorIndexKernel(const Context& ctx,
in_data, out_data, index_data, numel, stride, size, delta);
}
}
#else
}
#endif
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册