diff --git a/paddle/phi/api/yaml/sparse_api.yaml b/paddle/phi/api/yaml/sparse_api.yaml index e816824b82f722a2827715c18041a2c653ebe9bf..fcde699e71dfd54caeadaa33f4830ba275916f5e 100644 --- a/paddle/phi/api/yaml/sparse_api.yaml +++ b/paddle/phi/api/yaml/sparse_api.yaml @@ -316,10 +316,10 @@ args : (Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) output : Tensor(out), Tensor(rulebook) kernel : - func : sparse_maxpool{sparse_coo -> sparse_coo, dense} + func : maxpool_coo{sparse_coo -> sparse_coo, dense} layout : x intermediate : rulebook - backward : sparse_maxpool_grad + backward : maxpool_grad - api: mv args : (Tensor x, Tensor vec) diff --git a/paddle/phi/api/yaml/sparse_bw_api.yaml b/paddle/phi/api/yaml/sparse_bw_api.yaml index 68e6020ac3626f7d7c70cd7d4a7d2760d9dce4c9..ab0070840f7fdbf3ed9c1f16528fcfe823d50396 100644 --- a/paddle/phi/api/yaml/sparse_bw_api.yaml +++ b/paddle/phi/api/yaml/sparse_bw_api.yaml @@ -137,6 +137,13 @@ matmul_coo_dense_grad {sparse_coo, dense, dense -> sparse_coo, dense}, matmul_coo_coo_grad {sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo} +- backward_api : maxpool_grad + forward : maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook) + args : (Tensor x, Tensor rulebook, Tensor out, Tensor out_grad, int[] kernel_sizes) + output : Tensor(x_grad) + kernel : + func : maxpool_coo_grad {sparse_coo, dense, sparse_coo, sparse_coo -> sparse_coo} + - backward_api : multiply_grad forward : multiply(Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad) @@ -198,13 +205,6 @@ kernel : func : softmax_csr_grad{sparse_csr, sparse_csr -> sparse_csr} -- backward_api : sparse_maxpool_grad - forward : sparse_maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook) - args : (Tensor x, Tensor rulebook, Tensor out, Tensor out_grad, int[] kernel_sizes) - output : Tensor(x_grad) - kernel : - func : sparse_maxpool_grad {sparse_coo, dense, sparse_coo, sparse_coo -> sparse_coo} - - backward_api : sqrt_grad forward : sqrt(Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) @@ -255,7 +255,7 @@ - backward_api: fused_attention_grad forward : fused_attention_csr(Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) -> Tensor(out), Tensor(softmax) args: (Tensor query, Tensor key, Tensor value, Tensor softmax, Tensor out_grad) - output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad) + output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad) kernel : func : fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense} layout : softmax diff --git a/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/pool_grad_kernel.cc similarity index 82% rename from paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc rename to paddle/phi/kernels/sparse/cpu/pool_grad_kernel.cc index 64c843c07a6ef4a6198fe6c38af66faa2bbb48bf..dfdd00433680a2425d4381545a92b7daeaf7f8dc 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/pool_grad_kernel.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h" +#include "paddle/phi/kernels/sparse/pool_grad_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" @@ -25,13 +25,13 @@ namespace phi { namespace sparse { template -void MaxPoolGradCPUKernel(const CPUContext& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& rulebook, - const SparseCooTensor& out, - const SparseCooTensor& out_grad, - const std::vector& kernel_sizes, - SparseCooTensor* x_grad) { +void MaxPoolCooGradCPUKernel(const CPUContext& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const SparseCooTensor& out_grad, + const std::vector& kernel_sizes, + SparseCooTensor* x_grad) { int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; const int channels = x.dims()[4]; int rulebook_len = rulebook.dims()[1]; @@ -75,16 +75,16 @@ void MaxPoolGradCPUKernel(const CPUContext& dev_ctx, } template -void MaxPoolGradKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& rulebook, - const SparseCooTensor& out, - const SparseCooTensor& out_grad, - const std::vector& kernel_sizes, - SparseCooTensor* x_grad) { +void MaxPoolCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const SparseCooTensor& out_grad, + const std::vector& kernel_sizes, + SparseCooTensor* x_grad) { PD_VISIT_INTEGRAL_TYPES( - x.non_zero_indices().dtype(), "MaxPoolGradCPUKernel", ([&] { - MaxPoolGradCPUKernel( + x.non_zero_indices().dtype(), "MaxPoolCooGradCPUKernel", ([&] { + MaxPoolCooGradCPUKernel( dev_ctx, x, rulebook, out, out_grad, kernel_sizes, x_grad); })); } @@ -92,10 +92,10 @@ void MaxPoolGradKernel(const Context& dev_ctx, } // namespace sparse } // namespace phi -PD_REGISTER_KERNEL(sparse_maxpool_grad, +PD_REGISTER_KERNEL(maxpool_coo_grad, CPU, ALL_LAYOUT, - phi::sparse::MaxPoolGradKernel, + phi::sparse::MaxPoolCooGradKernel, float, double) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); diff --git a/paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc b/paddle/phi/kernels/sparse/cpu/pool_kernel.cc similarity index 79% rename from paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc rename to paddle/phi/kernels/sparse/cpu/pool_kernel.cc index 7655913374dbd74a598a9e54a0a2e8da37293af1..ae32b6cc1d695a78db190f3b9ac0df5a50f8c586 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/pool_kernel.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/phi/kernels/sparse/sparse_pool_kernel.h" +#include "paddle/phi/kernels/sparse/pool_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" @@ -30,14 +30,14 @@ namespace sparse { * out: (N, D, H, W, OC) **/ template -void MaxPoolCPUKernel(const CPUContext& dev_ctx, - const SparseCooTensor& x, - const std::vector& kernel_sizes, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - SparseCooTensor* out, - DenseTensor* rulebook) { +void MaxPoolCooCPUKernel(const CPUContext& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + SparseCooTensor* out, + DenseTensor* rulebook) { const auto& x_dims = x.dims(); int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; const std::vector& real_kernel_sizes = @@ -98,34 +98,34 @@ void MaxPoolCPUKernel(const CPUContext& dev_ctx, } template -void MaxPoolKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const std::vector& kernel_sizes, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - SparseCooTensor* out, - DenseTensor* rulebook) { +void MaxPoolCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + SparseCooTensor* out, + DenseTensor* rulebook) { PD_VISIT_INTEGRAL_TYPES( - x.non_zero_indices().dtype(), "MaxPoolCPUKernel", ([&] { - MaxPoolCPUKernel(dev_ctx, - x, - kernel_sizes, - paddings, - dilations, - strides, - out, - rulebook); + x.non_zero_indices().dtype(), "MaxPoolCooCPUKernel", ([&] { + MaxPoolCooCPUKernel(dev_ctx, + x, + kernel_sizes, + paddings, + dilations, + strides, + out, + rulebook); })); } } // namespace sparse } // namespace phi -PD_REGISTER_KERNEL(sparse_maxpool, +PD_REGISTER_KERNEL(maxpool_coo, CPU, ALL_LAYOUT, - phi::sparse::MaxPoolKernel, + phi::sparse::MaxPoolCooKernel, float, double) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); diff --git a/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/pool_grad_kernel.cu similarity index 88% rename from paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu rename to paddle/phi/kernels/sparse/gpu/pool_grad_kernel.cu index 5fe6e68c1e83f978ada43a6db697006cdd5bc6b9..724072443a9ed5aa8633444622dc828c4ce30b3d 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/pool_grad_kernel.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h" +#include "paddle/phi/kernels/sparse/pool_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_info.h" @@ -52,13 +52,13 @@ __global__ void MaxPoolGradCudaKernel(const T* in_features_ptr, } template -void MaxPoolGradGPUKernel(const GPUContext& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& rulebook, - const SparseCooTensor& out, - const SparseCooTensor& out_grad, - const std::vector& kernel_sizes, - SparseCooTensor* x_grad) { +void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const SparseCooTensor& out_grad, + const std::vector& kernel_sizes, + SparseCooTensor* x_grad) { int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; const int in_channels = x.dims()[4]; int rulebook_len = rulebook.dims()[1]; @@ -121,16 +121,16 @@ void MaxPoolGradGPUKernel(const GPUContext& dev_ctx, } template -void MaxPoolGradKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& rulebook, - const SparseCooTensor& out, - const SparseCooTensor& out_grad, - const std::vector& kernel_sizes, - SparseCooTensor* x_grad) { +void MaxPoolCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const SparseCooTensor& out_grad, + const std::vector& kernel_sizes, + SparseCooTensor* x_grad) { PD_VISIT_INTEGRAL_TYPES( - x.non_zero_indices().dtype(), "MaxPoolGradGPUKernel", ([&] { - MaxPoolGradGPUKernel( + x.non_zero_indices().dtype(), "MaxPoolCooGradGPUKernel", ([&] { + MaxPoolCooGradGPUKernel( dev_ctx, x, rulebook, out, out_grad, kernel_sizes, x_grad); })); } @@ -138,10 +138,10 @@ void MaxPoolGradKernel(const Context& dev_ctx, } // namespace sparse } // namespace phi -PD_REGISTER_KERNEL(sparse_maxpool_grad, +PD_REGISTER_KERNEL(maxpool_coo_grad, GPU, ALL_LAYOUT, - phi::sparse::MaxPoolGradKernel, + phi::sparse::MaxPoolCooGradKernel, float, double) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); diff --git a/paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu b/paddle/phi/kernels/sparse/gpu/pool_kernel.cu similarity index 84% rename from paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu rename to paddle/phi/kernels/sparse/gpu/pool_kernel.cu index bc6723d26b7a6067ab1ef2bfa4e5e1d3a61a862f..0d24594f0a85f26b352597b1bfab9ea715050e82 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/pool_kernel.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/phi/kernels/sparse/sparse_pool_kernel.h" +#include "paddle/phi/kernels/sparse/pool_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" @@ -48,14 +48,14 @@ __global__ void MaxPoolCudaKernel(const T* in_features_ptr, * out: (N, D, H, W, OC) **/ template -void MaxPoolGPUKernel(const GPUContext& dev_ctx, - const SparseCooTensor& x, - const std::vector& kernel_sizes, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - SparseCooTensor* out, - DenseTensor* rulebook) { +void MaxPoolCooGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + SparseCooTensor* out, + DenseTensor* rulebook) { const auto& x_dims = x.dims(); int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; const std::vector& real_kernel_sizes = @@ -127,34 +127,34 @@ void MaxPoolGPUKernel(const GPUContext& dev_ctx, } template -void MaxPoolKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const std::vector& kernel_sizes, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - SparseCooTensor* out, - DenseTensor* rulebook) { +void MaxPoolCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + SparseCooTensor* out, + DenseTensor* rulebook) { PD_VISIT_INTEGRAL_TYPES( - x.non_zero_indices().dtype(), "MaxPoolGPUKernel", ([&] { - MaxPoolGPUKernel(dev_ctx, - x, - kernel_sizes, - paddings, - dilations, - strides, - out, - rulebook); + x.non_zero_indices().dtype(), "MaxPoolCooGPUKernel", ([&] { + MaxPoolCooGPUKernel(dev_ctx, + x, + kernel_sizes, + paddings, + dilations, + strides, + out, + rulebook); })); } } // namespace sparse } // namespace phi -PD_REGISTER_KERNEL(sparse_maxpool, +PD_REGISTER_KERNEL(maxpool_coo, GPU, ALL_LAYOUT, - phi::sparse::MaxPoolKernel, + phi::sparse::MaxPoolCooKernel, float, double, phi::dtype::float16) { diff --git a/paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h b/paddle/phi/kernels/sparse/pool_grad_kernel.h similarity index 56% rename from paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h rename to paddle/phi/kernels/sparse/pool_grad_kernel.h index 2f7366a010aaa4772737b5f131d6ea0eeb434467..6afcbfea6ca2685d52ef9fe05cce46fcf16cad84 100644 --- a/paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h +++ b/paddle/phi/kernels/sparse/pool_grad_kernel.h @@ -22,23 +22,23 @@ namespace phi { namespace sparse { template -void MaxPoolGradKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& rulebook, - const SparseCooTensor& out, - const SparseCooTensor& out_grad, - const std::vector& kernel_sizes, - SparseCooTensor* x_grad); +void MaxPoolCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const SparseCooTensor& out_grad, + const std::vector& kernel_sizes, + SparseCooTensor* x_grad); template -SparseCooTensor MaxPoolGrad(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& rulebook, - const SparseCooTensor& out, - const SparseCooTensor& out_grad, - const std::vector& kernel_sizes) { +SparseCooTensor MaxPoolCooGrad(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const SparseCooTensor& out_grad, + const std::vector& kernel_sizes) { SparseCooTensor x_grad; - MaxPoolGradKernel( + MaxPoolCooGradKernel( dev_ctx, x, rulebook, out, out_grad, kernel_sizes, &x_grad); return x_grad; } diff --git a/paddle/phi/kernels/sparse/sparse_pool_kernel.h b/paddle/phi/kernels/sparse/pool_kernel.h similarity index 54% rename from paddle/phi/kernels/sparse/sparse_pool_kernel.h rename to paddle/phi/kernels/sparse/pool_kernel.h index d5248a1ad250ed4141245c2ff401b049afa82542..95349291efb691ed3c31aee798475161b4abf7c7 100644 --- a/paddle/phi/kernels/sparse/sparse_pool_kernel.h +++ b/paddle/phi/kernels/sparse/pool_kernel.h @@ -22,25 +22,25 @@ namespace phi { namespace sparse { template -void MaxPoolKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const std::vector& kernel_sizes, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - SparseCooTensor* out, - DenseTensor* rulebook); +void MaxPoolCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + SparseCooTensor* out, + DenseTensor* rulebook); template -SparseCooTensor MaxPool(const Context& dev_ctx, - const SparseCooTensor& x, - const std::vector& kernel_sizes, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - DenseTensor* rulebook) { +SparseCooTensor MaxPoolCoo(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + DenseTensor* rulebook) { SparseCooTensor coo; - MaxPoolKernel( + MaxPoolCooKernel( dev_ctx, x, kernel_sizes, paddings, dilations, strides, &coo, rulebook); return coo; } diff --git a/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc index eeba9cdc131d828e09f3a4c0da893d70f86ea29c..08f8cd8a732736bfc9f16457c21dc40b00b33bd2 100644 --- a/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc @@ -23,8 +23,8 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/sparse/coalesce_kernel.h" -#include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h" -#include "paddle/phi/kernels/sparse/sparse_pool_kernel.h" +#include "paddle/phi/kernels/sparse/pool_grad_kernel.h" +#include "paddle/phi/kernels/sparse/pool_kernel.h" namespace phi { namespace tests { @@ -91,13 +91,13 @@ void TestMaxPoolBase(const std::vector& indices, if (!std::is_same::value) { DenseTensor rulebook; - SparseCooTensor out = sparse::MaxPool(dev_ctx_cpu, - x_tensor, - kernel_sizes, - paddings, - dilations, - strides, - &rulebook); + SparseCooTensor out = sparse::MaxPoolCoo(dev_ctx_cpu, + x_tensor, + kernel_sizes, + paddings, + dilations, + strides, + &rulebook); ASSERT_EQ(correct_out_dims.size(), out.dims().size()); for (int i = 0; i < correct_out_dims.size(); i++) { @@ -113,7 +113,7 @@ void TestMaxPoolBase(const std::vector& indices, f_verify(out.non_zero_elements().data(), correct_out_features); if (backward) { - SparseCooTensor x_grad = sparse::MaxPoolGrad( + SparseCooTensor x_grad = sparse::MaxPoolCooGrad( dev_ctx_cpu, x_tensor, rulebook, out, out, kernel_sizes); f_verify(x_grad.non_zero_elements().data(), features_grad); } @@ -151,13 +151,13 @@ void TestMaxPoolBase(const std::vector& indices, SparseCooTensor d_x_tensor(d_indices_tensor, d_features_tensor, x_dims); DenseTensor d_rulebook; - SparseCooTensor d_out = sparse::MaxPool(dev_ctx_gpu, - d_x_tensor, - kernel_sizes, - paddings, - dilations, - strides, - &d_rulebook); + SparseCooTensor d_out = sparse::MaxPoolCoo(dev_ctx_gpu, + d_x_tensor, + kernel_sizes, + paddings, + dilations, + strides, + &d_rulebook); SparseCooTensor tmp_d_out = sparse::Coalesce(dev_ctx_gpu, d_out); ASSERT_EQ(correct_out_dims.size(), d_out.dims().size()); @@ -191,7 +191,7 @@ void TestMaxPoolBase(const std::vector& indices, f_verify(h_features_tensor.data(), correct_out_features); if (backward) { - SparseCooTensor x_grad = sparse::MaxPoolGrad( + SparseCooTensor x_grad = sparse::MaxPoolCooGrad( dev_ctx_gpu, d_x_tensor, d_rulebook, d_out, d_out, kernel_sizes); DenseTensor h_features_grad = phi::EmptyLike(dev_ctx_cpu, x_grad.non_zero_elements());