提交 7ba14d74 编写于 作者: P phlrain

fix lookup speed error

上级 755ad257
......@@ -35,7 +35,7 @@ __global__ void InputTypeConvert(const InT* in_ids,
}
}
template <typename T, typename IdT, int BlockDimX, int BlockDimY, int GridDimX>
template <typename T, typename IdT>
__global__ void LookupTableV2Grad(T* table,
const T* output,
const IdT* ids,
......@@ -43,16 +43,20 @@ __global__ void LookupTableV2Grad(T* table,
const int64_t K,
const int64_t D) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX;
int idy = blockIdx.x + threadIdx.y * gridDim.x;
while (idy < K) {
auto id = static_cast<int64_t>(ids[idy]);
const T* out = output + idy * D;
T* tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
#ifdef PADDLE_WITH_CUDA
paddle::platform::VectorizedAtomicAddPerBlock(D, idx, blockDim.x, out, tab);
#else
for (int i = idx; i < D; i += blockDim.x) {
paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
}
idy += BlockDimY * GridDimX;
#endif
idy += blockDim.y * gridDim.x;
}
}
......@@ -83,20 +87,22 @@ struct LookupTableV2GradCUDAFunctor {
int D = weight_grad_->dims()[1];
int K = input_.numel();
dim3 threads(128, 8);
dim3 grids(8, 1);
const T* d_output = d_output_t.template data<T>();
const auto* ids = input_.template data<IdT>();
T* d_table = d_table_t->mutable_data<T>(dev_ctx_.GetPlace());
auto t = EigenVector<T>::Flatten(*d_table_t);
t.device(*dev_ctx_.eigen_device()) = t.constant(static_cast<T>(0));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
hipMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx_.stream()));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx_.stream()));
#endif
LookupTableV2Grad<T,
IdT,
128,
8,
8><<<grids, threads, 0, dev_ctx_.stream()>>>(
const int gridx = 2 * dev_ctx_.GetSMCount();
dim3 threads(128, 8);
dim3 grids(gridx, 1);
LookupTableV2Grad<T, IdT><<<grids, threads, 0, dev_ctx_.stream()>>>(
d_table, d_output, ids, N, K, D);
}
}
......
......@@ -23,12 +23,7 @@
namespace phi {
template <typename T,
typename IdT,
int BlockDimX,
int BlockDimY,
int GridDimX,
bool PaddingFlag>
template <typename T, typename IdT, bool PaddingFlag>
__global__ void LookupTableV2(T *output,
const T *table,
const IdT *ids,
......@@ -37,13 +32,13 @@ __global__ void LookupTableV2(T *output,
const int64_t D,
const int64_t padding_idx) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX;
int idy = blockIdx.x + threadIdx.y * gridDim.x;
while (idy < K) {
auto id = static_cast<int64_t>(ids[idy]);
T *out = output + idy * D;
const T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
for (int i = idx; i < D; i += blockDim.x) {
if (PaddingFlag) {
if (id == padding_idx)
out[i] = static_cast<T>(0);
......@@ -53,7 +48,7 @@ __global__ void LookupTableV2(T *output,
out[i] = tab[i];
}
}
idy += BlockDimY * GridDimX;
idy += blockDim.y * gridDim.x;
}
}
......@@ -76,19 +71,20 @@ struct LookupTableV2CUDAFunctor {
size_t D = weight_.dims()[1];
size_t K = input_.numel();
const int gridx = 2 * dev_ctx_.GetSMCount();
dim3 threads(256, 4);
dim3 grids(80, 1);
dim3 grids(gridx, 1);
const auto *table = weight_.template data<T>();
const auto *ids = input_.template data<IdT>();
const T *table = weight_.template data<T>();
const IdT *ids = input_.template data<IdT>();
auto *output = out_->template mutable_data<T>(dev_ctx_.GetPlace());
auto stream = dev_ctx_.stream();
if (padding_idx_ == -1) {
LookupTableV2<T, IdT, 256, 4, 80, false><<<grids, threads, 0, stream>>>(
LookupTableV2<T, IdT, false><<<grids, threads, 0, stream>>>(
output, table, ids, N, K, D, padding_idx_);
} else {
LookupTableV2<T, IdT, 256, 4, 80, true><<<grids, threads, 0, stream>>>(
LookupTableV2<T, IdT, true><<<grids, threads, 0, stream>>>(
output, table, ids, N, K, D, padding_idx_);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册