未验证 提交 e8e3b997 编写于 作者: Z zhangkaihuo 提交者: GitHub

fix sparse mask (#42305)

上级 e51fad5f
......@@ -115,4 +115,12 @@ void SparseCooTensor::SetMember(const DenseTensor& non_zero_indices,
this->coalesced_ = coalesced;
}
int32_t SparseCooTensor::sparse_dim() const {
return non_zero_indices_.dims()[0];
}
int32_t SparseCooTensor::dense_dim() const {
return dims_.size() - sparse_dim();
}
} // namespace phi
......@@ -150,6 +150,12 @@ class SparseCooTensor : public TensorBase,
/// \brief set the dims of original dense tensor
void set_dims(const DDim& dims) { this->dims_ = dims; }
/// \brief get the sparse dim
int32_t sparse_dim() const;
/// \brief get the dnese dim
int32_t dense_dim() const;
private:
// save the indices of non zero elements in original dense tensor
DenseTensor non_zero_indices_;
......
......@@ -39,7 +39,7 @@ void SparseMaskCPUKernel(const CPUContext& dev_ctx,
phi::errors::InvalidArgument("the input x and mask must have the shape"));
const DenseTensor& indices = mask.non_zero_indices();
const DenseTensor& values = mask.non_zero_elements();
int sparse_dim = indices.dims().size();
const int sparse_dim = mask.sparse_dim();
DenseTensor out_indices = phi::EmptyLike<T>(dev_ctx, indices);
DenseTensor out_values = phi::EmptyLike<T>(dev_ctx, values);
......@@ -95,7 +95,7 @@ void SparseMaskHelperCPUKernel(const CPUContext& dev_ctx,
2,
phi::errors::InvalidArgument("the mask_indices must be 2-D tensor"));
const int64_t sparse_dim = x.non_zero_indices().dims()[0];
const int32_t sparse_dim = x.sparse_dim();
std::vector<IntT> sparse_offsets(sparse_dim), x_indexs(x.nnz()),
mask_indexs(mask_indices.dims()[1]);
......
......@@ -50,7 +50,7 @@ void MaxPoolGradCPUKernel(const CPUContext& dev_ctx,
DenseTensor x_grad_values = phi::EmptyLike<T>(dev_ctx, x.non_zero_elements());
x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true);
T* x_grad_ptr = x_grad_values.data<T>();
memset(x_grad_ptr, 0, sizeof(T) * x_grad->numel());
memset(x_grad_ptr, 0, sizeof(T) * x_grad_values.numel());
phi::Copy<CPUContext>(dev_ctx,
x.non_zero_indices(),
dev_ctx.GetPlace(),
......
......@@ -254,7 +254,7 @@ void SparseCooToDenseKernel(const Context& dev_ctx,
if (indices_dims.size() == 1) {
sparse_dim = 1;
}
const int64_t dense_dim = values.dims().size() - 1;
const int64_t dense_dim = x.dense_dim();
const T* x_data = values.data<T>();
*out = phi::Empty(
......
......@@ -42,7 +42,7 @@ __global__ void MaskKernel(const T* x_ptr,
int64_t col_i = i - out_i * cols;
int64_t index = 0;
for (int j = 0; j < sparse_dim; j++) {
index += indices_ptr[j * non_zero_num + i] * sparse_offsets[j];
index += indices_ptr[j * non_zero_num + out_i] * sparse_offsets[j];
}
out_values_ptr[out_i * cols + col_i] = x_ptr[index * cols + col_i];
}
......@@ -60,16 +60,13 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx,
phi::errors::InvalidArgument("the input x and mask must have the shape"));
const DenseTensor& indices = mask.non_zero_indices();
const DenseTensor& values = mask.non_zero_elements();
int sparse_dim = indices.dims().size();
const int sparse_dim = mask.sparse_dim();
DenseTensor sparse_offsets = phi::Empty<GPUContext>(
dev_ctx,
DenseTensorMeta(DataType::INT64, {sparse_dim}, DataLayout::NCHW));
std::vector<int64_t> h_sparse_offsets(sparse_dim);
int64_t offset = 1;
for (int i = sparse_dim - 1; i >= 0; i--) {
h_sparse_offsets[i] = offset;
offset *= dims[i];
}
phi::funcs::sparse::CalcOffsetsPerDim(
dims, sparse_dim, h_sparse_offsets.data());
phi::backends::gpu::GpuMemcpyAsync(sparse_offsets.data<int64_t>(),
&h_sparse_offsets[0],
......@@ -151,7 +148,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
2,
phi::errors::InvalidArgument("the mask_indices must be 2-D tensor"));
const int64_t sparse_dim = x.non_zero_indices().dims()[0];
const int32_t sparse_dim = x.sparse_dim();
auto indices_dtype = paddle::experimental::CppTypeToDataType<IntT>::Type();
std::vector<IntT> sparse_offsets(sparse_dim);
......
......@@ -64,7 +64,7 @@ void MaxPoolGradGPUKernel(const GPUContext& dev_ctx,
int rulebook_len = rulebook.dims()[1];
const IntT* rulebook_ptr = rulebook.data<IntT>();
std::vector<IntT> offsets(kernel_size + 1), counter(kernel_size, 0),
h_counter(kernel_size);
h_counter(rulebook_len, 0);
phi::backends::gpu::GpuMemcpyAsync(&h_counter[0],
rulebook_ptr,
rulebook_len * sizeof(IntT),
......
......@@ -19,6 +19,7 @@ import paddle
import paddle.fluid.core as core
from paddle import _C_ops
from paddle.fluid.framework import _test_eager_guard
import copy
class TestMaxPool3DFunc(unittest.TestCase):
......@@ -44,23 +45,28 @@ class TestMaxPool3DFunc(unittest.TestCase):
def test(self):
with _test_eager_guard():
self.setUp()
self.dense_x.stop_gradient = False
sparse_x = self.dense_x.to_sparse_coo(4)
out = paddle.sparse.functional.max_pool3d(
sparse_out = paddle.sparse.functional.max_pool3d(
sparse_x,
self.kernel_sizes,
stride=self.strides,
padding=self.paddings)
out = out.to_dense()
out = sparse_out.to_dense()
out.backward(out)
dense_x = copy.deepcopy(self.dense_x)
dense_out = paddle.nn.functional.max_pool3d(
self.dense_x,
dense_x,
self.kernel_sizes,
stride=self.strides,
padding=self.paddings,
data_format='NDHWC')
dense_out.backward(dense_out)
#compare with dense
assert np.allclose(dense_out.flatten().numpy(),
out.flatten().numpy())
assert np.allclose(dense_out.numpy(), out.numpy())
assert np.allclose(dense_x.grad.numpy(), self.dense_x.grad.numpy())
class TestStride(TestMaxPool3DFunc):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册