未验证 提交 cf95db58 编写于 作者: Z zhangkaihuo 提交者: GitHub

[sparse]Fix mask_kernel name (#50713)

上级 500a8bc2
......@@ -28,10 +28,10 @@ namespace phi {
namespace sparse {
template <typename T, typename IntT>
void SparseMaskCPUKernel(const CPUContext& dev_ctx,
const DenseTensor& x,
const SparseCooTensor& mask,
SparseCooTensor* out) {
void MaskCooCPUKernel(const CPUContext& dev_ctx,
const DenseTensor& x,
const SparseCooTensor& mask,
SparseCooTensor* out) {
const DDim& dims = x.dims();
PADDLE_ENFORCE_EQ(
x.dims(),
......@@ -75,21 +75,21 @@ void SparseMaskCPUKernel(const CPUContext& dev_ctx,
* x and mask must have the same shape.
**/
template <typename T, typename Context>
void SparseMaskKernel(const Context& dev_ctx,
const DenseTensor& x,
const SparseCooTensor& mask,
SparseCooTensor* out) {
void MaskCooKernel(const Context& dev_ctx,
const DenseTensor& x,
const SparseCooTensor& mask,
SparseCooTensor* out) {
PD_VISIT_BASE_INTEGRAL_TYPES(
mask.indices().dtype(), "SparseMaskCPUKernel", ([&] {
SparseMaskCPUKernel<T, data_t>(dev_ctx, x, mask, out);
mask.indices().dtype(), "MaskCooCPUKernel", ([&] {
MaskCooCPUKernel<T, data_t>(dev_ctx, x, mask, out);
}));
}
template <typename T, typename IntT>
void SparseMaskHelperCPUKernel(const CPUContext& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& mask_indices,
DenseTensor* out) {
void MaskHelperCooCPUKernel(const CPUContext& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& mask_indices,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(
mask_indices.dims().size(),
2,
......@@ -142,23 +142,23 @@ void SparseMaskHelperCPUKernel(const CPUContext& dev_ctx,
* @brief filter values from x.values() using mask_indices
*/
template <typename T, typename Context>
void SparseMaskHelperKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& mask_indices,
DenseTensor* out) {
void MaskHelperCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& mask_indices,
DenseTensor* out) {
PD_VISIT_BASE_INTEGRAL_TYPES(
x.indices().dtype(), "SparseMaskHelperCPUKernel", ([&] {
SparseMaskHelperCPUKernel<T, data_t>(dev_ctx, x, mask_indices, out);
x.indices().dtype(), "MaskHelperCooCPUKernel", ([&] {
MaskHelperCooCPUKernel<T, data_t>(dev_ctx, x, mask_indices, out);
}));
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(sparse_mask,
PD_REGISTER_KERNEL(mask_coo,
CPU,
ALL_LAYOUT,
phi::sparse::SparseMaskKernel,
phi::sparse::MaskCooKernel,
float,
double,
uint8_t,
......@@ -169,10 +169,10 @@ PD_REGISTER_KERNEL(sparse_mask,
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
PD_REGISTER_KERNEL(sparse_mask_helper,
PD_REGISTER_KERNEL(mask_helper_coo,
CPU,
ALL_LAYOUT,
phi::sparse::SparseMaskHelperKernel,
phi::sparse::MaskHelperCooKernel,
float,
double,
uint8_t,
......
......@@ -50,10 +50,10 @@ __global__ void MaskKernel(const T* x_ptr,
}
template <typename T, typename IntT>
void SparseMaskGPUKernel(const GPUContext& dev_ctx,
const DenseTensor& x,
const SparseCooTensor& mask,
SparseCooTensor* out) {
void MaskCooGPUKernel(const GPUContext& dev_ctx,
const DenseTensor& x,
const SparseCooTensor& mask,
SparseCooTensor* out) {
const DDim& dims = x.dims();
PADDLE_ENFORCE_EQ(
x.dims(),
......@@ -108,13 +108,13 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx,
* x and mask must have the same shape.
**/
template <typename T, typename Context>
void SparseMaskKernel(const Context& dev_ctx,
const DenseTensor& x,
const SparseCooTensor& mask,
SparseCooTensor* out) {
void MaskCooKernel(const Context& dev_ctx,
const DenseTensor& x,
const SparseCooTensor& mask,
SparseCooTensor* out) {
PD_VISIT_BASE_INTEGRAL_TYPES(
mask.indices().dtype(), "SparseMaskGPUKernel", ([&] {
SparseMaskGPUKernel<T, data_t>(dev_ctx, x, mask, out);
mask.indices().dtype(), "MaskCooGPUKernel", ([&] {
MaskCooGPUKernel<T, data_t>(dev_ctx, x, mask, out);
}));
}
......@@ -155,10 +155,10 @@ __global__ void MaskCopy(const IntT* mask_indexs,
}
template <typename T, typename IntT>
void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& mask_indices,
DenseTensor* out) {
void MaskHelperCooGPUKernel(const GPUContext& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& mask_indices,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(
mask_indices.dims().size(),
2,
......@@ -279,23 +279,23 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
}
template <typename T, typename Context>
void SparseMaskHelperKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& mask_indices,
DenseTensor* out) {
void MaskHelperCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& mask_indices,
DenseTensor* out) {
PD_VISIT_BASE_INTEGRAL_TYPES(
x.indices().dtype(), "SparseMaskHelperGPUKernel", ([&] {
SparseMaskHelperGPUKernel<T, data_t>(dev_ctx, x, mask_indices, out);
x.indices().dtype(), "MaskHelperCooGPUKernel", ([&] {
MaskHelperCooGPUKernel<T, data_t>(dev_ctx, x, mask_indices, out);
}));
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(mask,
PD_REGISTER_KERNEL(mask_coo,
GPU,
ALL_LAYOUT,
phi::sparse::SparseMaskKernel,
phi::sparse::MaskCooKernel,
float,
double,
phi::dtype::float16,
......@@ -307,10 +307,10 @@ PD_REGISTER_KERNEL(mask,
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
PD_REGISTER_KERNEL(mask_helper,
PD_REGISTER_KERNEL(mask_helper_coo,
GPU,
ALL_LAYOUT,
phi::sparse::SparseMaskHelperKernel,
phi::sparse::MaskHelperCooKernel,
float,
double,
phi::dtype::float16,
......
......@@ -21,16 +21,16 @@ namespace phi {
namespace sparse {
template <typename T, typename Context>
void SparseMaskKernel(const Context& dev_ctx,
const DenseTensor& x,
const SparseCooTensor& mask,
SparseCooTensor* out);
void MaskCooKernel(const Context& dev_ctx,
const DenseTensor& x,
const SparseCooTensor& mask,
SparseCooTensor* out);
template <typename T, typename Context>
void SparseMaskHelperKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& mask_indices,
DenseTensor* out);
void MaskHelperCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& mask_indices,
DenseTensor* out);
} // namespace sparse
} // namespace phi
......@@ -32,7 +32,7 @@ void CooToDenseGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& out_grad,
SparseCooTensor* x_grad) {
SparseMaskKernel<T, Context>(dev_ctx, out_grad, x, x_grad);
MaskCooKernel<T, Context>(dev_ctx, out_grad, x, x_grad);
}
} // namespace sparse
......
......@@ -38,7 +38,7 @@ void SparseCooTensorGradKernel(const Context& dev_ctx,
const DenseTensor& indices,
const SparseCooTensor& out_grad,
DenseTensor* values_grad) {
SparseMaskHelperKernel<T, Context>(dev_ctx, out_grad, indices, values_grad);
MaskHelperCooKernel<T, Context>(dev_ctx, out_grad, indices, values_grad);
}
} // namespace sparse
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册