未验证 提交 0a5d625b 编写于 作者: Z zhangkaihuo 提交者: GitHub

Opt sparse mask_kernel (#44302)

* opt sparse_mask
上级 ae8ca764
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h" #include "paddle/phi/kernels/sparse/mask_kernel.h"
#include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
......
...@@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h" #include "paddle/phi/kernels/sparse/mask_kernel.h"
#include <thrust/binary_search.h>
#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
...@@ -24,6 +22,7 @@ limitations under the License. */ ...@@ -24,6 +22,7 @@ limitations under the License. */
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/visit_type.h" #include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h" #include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h"
...@@ -72,11 +71,7 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx, ...@@ -72,11 +71,7 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx,
phi::backends::gpu::GpuMemcpyAsync(sparse_offsets.data<int64_t>(), phi::backends::gpu::GpuMemcpyAsync(sparse_offsets.data<int64_t>(),
&h_sparse_offsets[0], &h_sparse_offsets[0],
sizeof(int64_t) * sparse_dim, sizeof(int64_t) * sparse_dim,
#ifdef PADDLE_WITH_HIP gpuMemcpyHostToDevice,
hipMemcpyHostToDevice,
#else
cudaMemcpyHostToDevice,
#endif
dev_ctx.stream()); dev_ctx.stream());
DenseTensor out_indices = phi::EmptyLike<T>(dev_ctx, indices); DenseTensor out_indices = phi::EmptyLike<T>(dev_ctx, indices);
...@@ -93,14 +88,15 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx, ...@@ -93,14 +88,15 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx,
auto config = auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num * cols, 1); phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num * cols, 1);
MaskKernel<T, IntT><<<config.block_per_grid, config.thread_per_block>>>( MaskKernel<T, IntT>
x_ptr, <<<config.block_per_grid, config.thread_per_block, 0, dev_ctx.stream()>>>(
indices_ptr, x_ptr,
sparse_offsets.data<int64_t>(), indices_ptr,
non_zero_num, sparse_offsets.data<int64_t>(),
cols, non_zero_num,
sparse_dim, cols,
out_values_ptr); sparse_dim,
out_values_ptr);
out->SetMember(out_indices, out_values, dims, true); out->SetMember(out_indices, out_values, dims, true);
} }
...@@ -121,19 +117,31 @@ void SparseMaskKernel(const Context& dev_ctx, ...@@ -121,19 +117,31 @@ void SparseMaskKernel(const Context& dev_ctx,
})); }));
} }
template <typename T, typename IntT> template <typename IntT>
__global__ void SparseMaskCopyKernel(const IntT* x_indexs, __global__ void MaskTable(const IntT* x_indexs, const int n, int* table) {
const IntT* mask_indexs, CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) {
const IntT* bound_out, int index = x_indexs[i];
const T* x_values, table[index] = i == 0 ? -1 : i;
const int64_t n, }
const int64_t stride, }
T* out_values) {
template <typename T, typename IntT, int VecSize>
__global__ void MaskCopy(const IntT* mask_indexs,
const int* table,
const int n,
const int stride,
const T* x_values,
T* out_values) {
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) {
const IntT j = bound_out[i]; int j = table[mask_indexs[i]];
if (j >= 0 && j < n && mask_indexs[i] == x_indexs[j]) { if (j != 0) {
for (int k = 0; k < stride; k++) { if (j == -1) j = 0;
out_values[i * stride + k] = x_values[j * stride + k]; for (int k = 0; k < stride; k += VecSize) {
LoadT vec_x;
phi::Load<T, VecSize>(x_values + j * stride + k, &vec_x);
phi::Store<T, VecSize>(vec_x, out_values + i * stride + k);
} }
} }
} }
...@@ -179,11 +187,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, ...@@ -179,11 +187,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data<IntT>(), phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data<IntT>(),
sparse_offsets.data(), sparse_offsets.data(),
sizeof(IntT) * sparse_dim, sizeof(IntT) * sparse_dim,
#ifdef PADDLE_WITH_HIP gpuMemcpyHostToDevice,
hipMemcpyHostToDevice,
#else
cudaMemcpyHostToDevice,
#endif
dev_ctx.stream()); dev_ctx.stream());
// 3. flatten x indices and mask indices // 3. flatten x indices and mask indices
...@@ -210,37 +214,54 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, ...@@ -210,37 +214,54 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
mask_indexs.numel(), mask_indexs.numel(),
sparse_dim, sparse_dim,
mask_indexs_ptr); mask_indexs_ptr);
// 4. call thrust::lower_bound
#ifdef PADDLE_WITH_HIP
thrust::lower_bound(thrust::hip::par.on(dev_ctx.stream()),
#else
thrust::lower_bound(thrust::cuda::par.on(dev_ctx.stream()),
#endif
x_indexs_ptr,
x_indexs_ptr + x_indexs.numel(),
mask_indexs_ptr,
mask_indexs_ptr + mask_indexs.numel(),
bound_out_ptr);
// 5. copy value to out int table_size = 1;
auto x_dims = x.dims();
for (int i = 0; i < x_dims.size() - 1; i++) {
table_size *= x_dims[i];
}
DenseTensor table = phi::Empty<int>(dev_ctx, {table_size});
phi::backends::gpu::GpuMemsetAsync(
table.data<int>(), 0, table_size * sizeof(int), dev_ctx.stream());
const int64_t stride =
x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1];
*out = phi::EmptyLike<T>(dev_ctx, x.non_zero_elements()); *out = phi::EmptyLike<T>(dev_ctx, x.non_zero_elements());
phi::funcs::SetConstant<GPUContext, T> set_zero; phi::funcs::SetConstant<GPUContext, T> set_zero;
set_zero(dev_ctx, out, static_cast<T>(0)); set_zero(dev_ctx, out, static_cast<T>(0));
T* out_ptr = out->data<T>(); T* out_ptr = out->data<T>();
config =
const int64_t stride = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_indexs.numel(), 1);
x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1]; MaskTable<<<config.block_per_grid,
config.thread_per_block,
SparseMaskCopyKernel<<<config.block_per_grid, 0,
config.thread_per_block, dev_ctx.stream()>>>(
0, x_indexs_ptr, x_indexs.numel(), table.data<int>());
dev_ctx.stream()>>>(x_indexs_ptr, config =
mask_indexs_ptr, phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1);
bound_out_ptr, const int VecBytes = 16;
x.non_zero_elements().data<T>(), const int VecSize = VecBytes / sizeof(T);
mask_indexs.numel(), if (stride % VecSize == 0) {
stride, MaskCopy<T, IntT, VecSize>
out_ptr); <<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(mask_indexs_ptr,
table.data<int>(),
mask_indexs.numel(),
stride,
x.non_zero_elements().data<T>(),
out_ptr);
} else {
MaskCopy<T, IntT, 1><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(mask_indexs_ptr,
table.data<int>(),
mask_indexs.numel(),
stride,
x.non_zero_elements().data<T>(),
out_ptr);
}
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -257,7 +278,7 @@ void SparseMaskHelperKernel(const Context& dev_ctx, ...@@ -257,7 +278,7 @@ void SparseMaskHelperKernel(const Context& dev_ctx,
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(sparse_mask, PD_REGISTER_KERNEL(mask,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::sparse::SparseMaskKernel, phi::sparse::SparseMaskKernel,
...@@ -272,7 +293,7 @@ PD_REGISTER_KERNEL(sparse_mask, ...@@ -272,7 +293,7 @@ PD_REGISTER_KERNEL(sparse_mask,
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
} }
PD_REGISTER_KERNEL(sparse_mask_helper, PD_REGISTER_KERNEL(mask_helper,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::sparse::SparseMaskHelperKernel, phi::sparse::SparseMaskHelperKernel,
......
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h" #include "paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h"
namespace phi { namespace phi {
namespace sparse { namespace sparse {
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h" #include "paddle/phi/kernels/sparse/mask_kernel.h"
namespace phi { namespace phi {
namespace sparse { namespace sparse {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册