未验证 提交 503569a0 编写于 作者: Z zhangkaihuo 提交者: GitHub

fix conv3d backward (#42502)

上级 d73eb38c
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h" #include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/sparse/convolution_grad_kernel.h" #include "paddle/phi/kernels/sparse/convolution_grad_kernel.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" #include "paddle/phi/kernels/sparse/gpu/convolution.cu.h"
...@@ -203,38 +203,19 @@ void Conv3dGradGPUKernel(const GPUContext& dev_ctx, ...@@ -203,38 +203,19 @@ void Conv3dGradGPUKernel(const GPUContext& dev_ctx,
} }
// 4. scatter // 4. scatter
// x_grad->ResizeAndAllocate(x.non_zero_elements().dims());
DenseTensorMeta index_meta(DataType::INT32, {rulebook_len}, DataLayout::NCHW);
DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta));
DenseTensor unique_key = phi::Empty(
dev_ctx,
DenseTensorMeta(paddle::experimental::CppTypeToDataType<IntT>::Type(),
{rulebook_len},
DataLayout::NCHW));
DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta));
SortedAndUniqueIndex<GPUContext, IntT>(dev_ctx,
rulebook_ptr + rulebook_len,
rulebook_len,
&out_index,
&unique_key,
&unique_value);
config = phi::backends::gpu::GetGpuLaunchConfig1D( config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * in_channels, 1); dev_ctx, rulebook_len * in_channels, 1);
phi::funcs::sparse::ScatterKernel<T><<<config.block_per_grid.x, phi::funcs::ScatterCUDAKernel<<<config.block_per_grid,
config.thread_per_block.x, config.thread_per_block,
0, 0,
dev_ctx.stream()>>>( dev_ctx.stream()>>>(
d_x_features_ptr, d_x_features_ptr,
unique_value.data<int>(), rulebook_ptr + rulebook_len,
out_index.data<int>(), x_grad_values_ptr,
x.nnz(),
rulebook_len, rulebook_len,
in_channels, in_channels,
x_grad_values_ptr, false);
subm);
} }
template <typename T, typename Context> template <typename T, typename Context>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册