未验证 提交 cf95db58 编写于 作者: Z zhangkaihuo 提交者: GitHub

[sparse]Fix mask_kernel name (#50713)

上级 500a8bc2
...@@ -28,7 +28,7 @@ namespace phi { ...@@ -28,7 +28,7 @@ namespace phi {
namespace sparse { namespace sparse {
template <typename T, typename IntT> template <typename T, typename IntT>
void SparseMaskCPUKernel(const CPUContext& dev_ctx, void MaskCooCPUKernel(const CPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const SparseCooTensor& mask, const SparseCooTensor& mask,
SparseCooTensor* out) { SparseCooTensor* out) {
...@@ -75,18 +75,18 @@ void SparseMaskCPUKernel(const CPUContext& dev_ctx, ...@@ -75,18 +75,18 @@ void SparseMaskCPUKernel(const CPUContext& dev_ctx,
* x and mask must have the same shape. * x and mask must have the same shape.
**/ **/
template <typename T, typename Context> template <typename T, typename Context>
void SparseMaskKernel(const Context& dev_ctx, void MaskCooKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const SparseCooTensor& mask, const SparseCooTensor& mask,
SparseCooTensor* out) { SparseCooTensor* out) {
PD_VISIT_BASE_INTEGRAL_TYPES( PD_VISIT_BASE_INTEGRAL_TYPES(
mask.indices().dtype(), "SparseMaskCPUKernel", ([&] { mask.indices().dtype(), "MaskCooCPUKernel", ([&] {
SparseMaskCPUKernel<T, data_t>(dev_ctx, x, mask, out); MaskCooCPUKernel<T, data_t>(dev_ctx, x, mask, out);
})); }));
} }
template <typename T, typename IntT> template <typename T, typename IntT>
void SparseMaskHelperCPUKernel(const CPUContext& dev_ctx, void MaskHelperCooCPUKernel(const CPUContext& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& mask_indices, const DenseTensor& mask_indices,
DenseTensor* out) { DenseTensor* out) {
...@@ -142,23 +142,23 @@ void SparseMaskHelperCPUKernel(const CPUContext& dev_ctx, ...@@ -142,23 +142,23 @@ void SparseMaskHelperCPUKernel(const CPUContext& dev_ctx,
* @brief filter values from x.values() using mask_indices * @brief filter values from x.values() using mask_indices
*/ */
template <typename T, typename Context> template <typename T, typename Context>
void SparseMaskHelperKernel(const Context& dev_ctx, void MaskHelperCooKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& mask_indices, const DenseTensor& mask_indices,
DenseTensor* out) { DenseTensor* out) {
PD_VISIT_BASE_INTEGRAL_TYPES( PD_VISIT_BASE_INTEGRAL_TYPES(
x.indices().dtype(), "SparseMaskHelperCPUKernel", ([&] { x.indices().dtype(), "MaskHelperCooCPUKernel", ([&] {
SparseMaskHelperCPUKernel<T, data_t>(dev_ctx, x, mask_indices, out); MaskHelperCooCPUKernel<T, data_t>(dev_ctx, x, mask_indices, out);
})); }));
} }
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(sparse_mask, PD_REGISTER_KERNEL(mask_coo,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::sparse::SparseMaskKernel, phi::sparse::MaskCooKernel,
float, float,
double, double,
uint8_t, uint8_t,
...@@ -169,10 +169,10 @@ PD_REGISTER_KERNEL(sparse_mask, ...@@ -169,10 +169,10 @@ PD_REGISTER_KERNEL(sparse_mask,
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
} }
PD_REGISTER_KERNEL(sparse_mask_helper, PD_REGISTER_KERNEL(mask_helper_coo,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::sparse::SparseMaskHelperKernel, phi::sparse::MaskHelperCooKernel,
float, float,
double, double,
uint8_t, uint8_t,
......
...@@ -50,7 +50,7 @@ __global__ void MaskKernel(const T* x_ptr, ...@@ -50,7 +50,7 @@ __global__ void MaskKernel(const T* x_ptr,
} }
template <typename T, typename IntT> template <typename T, typename IntT>
void SparseMaskGPUKernel(const GPUContext& dev_ctx, void MaskCooGPUKernel(const GPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const SparseCooTensor& mask, const SparseCooTensor& mask,
SparseCooTensor* out) { SparseCooTensor* out) {
...@@ -108,13 +108,13 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx, ...@@ -108,13 +108,13 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx,
* x and mask must have the same shape. * x and mask must have the same shape.
**/ **/
template <typename T, typename Context> template <typename T, typename Context>
void SparseMaskKernel(const Context& dev_ctx, void MaskCooKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const SparseCooTensor& mask, const SparseCooTensor& mask,
SparseCooTensor* out) { SparseCooTensor* out) {
PD_VISIT_BASE_INTEGRAL_TYPES( PD_VISIT_BASE_INTEGRAL_TYPES(
mask.indices().dtype(), "SparseMaskGPUKernel", ([&] { mask.indices().dtype(), "MaskCooGPUKernel", ([&] {
SparseMaskGPUKernel<T, data_t>(dev_ctx, x, mask, out); MaskCooGPUKernel<T, data_t>(dev_ctx, x, mask, out);
})); }));
} }
...@@ -155,7 +155,7 @@ __global__ void MaskCopy(const IntT* mask_indexs, ...@@ -155,7 +155,7 @@ __global__ void MaskCopy(const IntT* mask_indexs,
} }
template <typename T, typename IntT> template <typename T, typename IntT>
void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, void MaskHelperCooGPUKernel(const GPUContext& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& mask_indices, const DenseTensor& mask_indices,
DenseTensor* out) { DenseTensor* out) {
...@@ -279,23 +279,23 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, ...@@ -279,23 +279,23 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void SparseMaskHelperKernel(const Context& dev_ctx, void MaskHelperCooKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& mask_indices, const DenseTensor& mask_indices,
DenseTensor* out) { DenseTensor* out) {
PD_VISIT_BASE_INTEGRAL_TYPES( PD_VISIT_BASE_INTEGRAL_TYPES(
x.indices().dtype(), "SparseMaskHelperGPUKernel", ([&] { x.indices().dtype(), "MaskHelperCooGPUKernel", ([&] {
SparseMaskHelperGPUKernel<T, data_t>(dev_ctx, x, mask_indices, out); MaskHelperCooGPUKernel<T, data_t>(dev_ctx, x, mask_indices, out);
})); }));
} }
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(mask, PD_REGISTER_KERNEL(mask_coo,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::sparse::SparseMaskKernel, phi::sparse::MaskCooKernel,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
...@@ -307,10 +307,10 @@ PD_REGISTER_KERNEL(mask, ...@@ -307,10 +307,10 @@ PD_REGISTER_KERNEL(mask,
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
} }
PD_REGISTER_KERNEL(mask_helper, PD_REGISTER_KERNEL(mask_helper_coo,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::sparse::SparseMaskHelperKernel, phi::sparse::MaskHelperCooKernel,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
......
...@@ -21,13 +21,13 @@ namespace phi { ...@@ -21,13 +21,13 @@ namespace phi {
namespace sparse { namespace sparse {
template <typename T, typename Context> template <typename T, typename Context>
void SparseMaskKernel(const Context& dev_ctx, void MaskCooKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const SparseCooTensor& mask, const SparseCooTensor& mask,
SparseCooTensor* out); SparseCooTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void SparseMaskHelperKernel(const Context& dev_ctx, void MaskHelperCooKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& mask_indices, const DenseTensor& mask_indices,
DenseTensor* out); DenseTensor* out);
......
...@@ -32,7 +32,7 @@ void CooToDenseGradKernel(const Context& dev_ctx, ...@@ -32,7 +32,7 @@ void CooToDenseGradKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& out_grad, const DenseTensor& out_grad,
SparseCooTensor* x_grad) { SparseCooTensor* x_grad) {
SparseMaskKernel<T, Context>(dev_ctx, out_grad, x, x_grad); MaskCooKernel<T, Context>(dev_ctx, out_grad, x, x_grad);
} }
} // namespace sparse } // namespace sparse
......
...@@ -38,7 +38,7 @@ void SparseCooTensorGradKernel(const Context& dev_ctx, ...@@ -38,7 +38,7 @@ void SparseCooTensorGradKernel(const Context& dev_ctx,
const DenseTensor& indices, const DenseTensor& indices,
const SparseCooTensor& out_grad, const SparseCooTensor& out_grad,
DenseTensor* values_grad) { DenseTensor* values_grad) {
SparseMaskHelperKernel<T, Context>(dev_ctx, out_grad, indices, values_grad); MaskHelperCooKernel<T, Context>(dev_ctx, out_grad, indices, values_grad);
} }
} // namespace sparse } // namespace sparse
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册