提交 7ba14d74 编写于 作者: P phlrain

fix lookup speed error

上级 755ad257
...@@ -35,7 +35,7 @@ __global__ void InputTypeConvert(const InT* in_ids, ...@@ -35,7 +35,7 @@ __global__ void InputTypeConvert(const InT* in_ids,
} }
} }
template <typename T, typename IdT, int BlockDimX, int BlockDimY, int GridDimX> template <typename T, typename IdT>
__global__ void LookupTableV2Grad(T* table, __global__ void LookupTableV2Grad(T* table,
const T* output, const T* output,
const IdT* ids, const IdT* ids,
...@@ -43,16 +43,20 @@ __global__ void LookupTableV2Grad(T* table, ...@@ -43,16 +43,20 @@ __global__ void LookupTableV2Grad(T* table,
const int64_t K, const int64_t K,
const int64_t D) { const int64_t D) {
int idx = threadIdx.x; int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX; int idy = blockIdx.x + threadIdx.y * gridDim.x;
while (idy < K) { while (idy < K) {
auto id = static_cast<int64_t>(ids[idy]); auto id = static_cast<int64_t>(ids[idy]);
const T* out = output + idy * D; const T* out = output + idy * D;
T* tab = table + id * D; T* tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) { #ifdef PADDLE_WITH_CUDA
paddle::platform::VectorizedAtomicAddPerBlock(D, idx, blockDim.x, out, tab);
#else
for (int i = idx; i < D; i += blockDim.x) {
paddle::platform::CudaAtomicAdd(&tab[i], out[i]); paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
} }
idy += BlockDimY * GridDimX; #endif
idy += blockDim.y * gridDim.x;
} }
} }
...@@ -83,20 +87,22 @@ struct LookupTableV2GradCUDAFunctor { ...@@ -83,20 +87,22 @@ struct LookupTableV2GradCUDAFunctor {
int D = weight_grad_->dims()[1]; int D = weight_grad_->dims()[1];
int K = input_.numel(); int K = input_.numel();
dim3 threads(128, 8);
dim3 grids(8, 1);
const T* d_output = d_output_t.template data<T>(); const T* d_output = d_output_t.template data<T>();
const auto* ids = input_.template data<IdT>(); const auto* ids = input_.template data<IdT>();
T* d_table = d_table_t->mutable_data<T>(dev_ctx_.GetPlace()); T* d_table = d_table_t->mutable_data<T>(dev_ctx_.GetPlace());
auto t = EigenVector<T>::Flatten(*d_table_t); #ifdef PADDLE_WITH_HIP
t.device(*dev_ctx_.eigen_device()) = t.constant(static_cast<T>(0)); PADDLE_ENFORCE_GPU_SUCCESS(
hipMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx_.stream()));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx_.stream()));
#endif
LookupTableV2Grad<T, const int gridx = 2 * dev_ctx_.GetSMCount();
IdT, dim3 threads(128, 8);
128, dim3 grids(gridx, 1);
8, LookupTableV2Grad<T, IdT><<<grids, threads, 0, dev_ctx_.stream()>>>(
8><<<grids, threads, 0, dev_ctx_.stream()>>>(
d_table, d_output, ids, N, K, D); d_table, d_output, ids, N, K, D);
} }
} }
......
...@@ -23,12 +23,7 @@ ...@@ -23,12 +23,7 @@
namespace phi { namespace phi {
template <typename T, template <typename T, typename IdT, bool PaddingFlag>
typename IdT,
int BlockDimX,
int BlockDimY,
int GridDimX,
bool PaddingFlag>
__global__ void LookupTableV2(T *output, __global__ void LookupTableV2(T *output,
const T *table, const T *table,
const IdT *ids, const IdT *ids,
...@@ -37,13 +32,13 @@ __global__ void LookupTableV2(T *output, ...@@ -37,13 +32,13 @@ __global__ void LookupTableV2(T *output,
const int64_t D, const int64_t D,
const int64_t padding_idx) { const int64_t padding_idx) {
int idx = threadIdx.x; int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX; int idy = blockIdx.x + threadIdx.y * gridDim.x;
while (idy < K) { while (idy < K) {
auto id = static_cast<int64_t>(ids[idy]); auto id = static_cast<int64_t>(ids[idy]);
T *out = output + idy * D; T *out = output + idy * D;
const T *tab = table + id * D; const T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) { for (int i = idx; i < D; i += blockDim.x) {
if (PaddingFlag) { if (PaddingFlag) {
if (id == padding_idx) if (id == padding_idx)
out[i] = static_cast<T>(0); out[i] = static_cast<T>(0);
...@@ -53,7 +48,7 @@ __global__ void LookupTableV2(T *output, ...@@ -53,7 +48,7 @@ __global__ void LookupTableV2(T *output,
out[i] = tab[i]; out[i] = tab[i];
} }
} }
idy += BlockDimY * GridDimX; idy += blockDim.y * gridDim.x;
} }
} }
...@@ -76,19 +71,20 @@ struct LookupTableV2CUDAFunctor { ...@@ -76,19 +71,20 @@ struct LookupTableV2CUDAFunctor {
size_t D = weight_.dims()[1]; size_t D = weight_.dims()[1];
size_t K = input_.numel(); size_t K = input_.numel();
const int gridx = 2 * dev_ctx_.GetSMCount();
dim3 threads(256, 4); dim3 threads(256, 4);
dim3 grids(80, 1); dim3 grids(gridx, 1);
const auto *table = weight_.template data<T>(); const T *table = weight_.template data<T>();
const auto *ids = input_.template data<IdT>(); const IdT *ids = input_.template data<IdT>();
auto *output = out_->template mutable_data<T>(dev_ctx_.GetPlace()); auto *output = out_->template mutable_data<T>(dev_ctx_.GetPlace());
auto stream = dev_ctx_.stream(); auto stream = dev_ctx_.stream();
if (padding_idx_ == -1) { if (padding_idx_ == -1) {
LookupTableV2<T, IdT, 256, 4, 80, false><<<grids, threads, 0, stream>>>( LookupTableV2<T, IdT, false><<<grids, threads, 0, stream>>>(
output, table, ids, N, K, D, padding_idx_); output, table, ids, N, K, D, padding_idx_);
} else { } else {
LookupTableV2<T, IdT, 256, 4, 80, true><<<grids, threads, 0, stream>>>( LookupTableV2<T, IdT, true><<<grids, threads, 0, stream>>>(
output, table, ids, N, K, D, padding_idx_); output, table, ids, N, K, D, padding_idx_);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册