未验证 提交 ad0c106c 编写于 作者: Z zhangkaihuo 提交者: GitHub

Fix sparse conv and verify sparse conv backward (#40961)

上级 9e764d82
......@@ -25,37 +25,37 @@ namespace sparse {
template <typename T, typename Context>
void Conv3dGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& rulebook,
const DenseTensor& kernel,
const DenseTensor& out_grad,
const DenseTensor& rulebook,
const SparseCooTensor& out_grad,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm,
DenseTensor* x_grad,
SparseCooTensor* x_grad,
DenseTensor* kernel_grad);
template <typename T, typename Context>
std::vector<DenseTensor> Conv3dGrad(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& rulebook,
const DenseTensor& kernel,
const DenseTensor& out_grad,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm) {
DenseTensor x_grad =
phi::Empty<Context>(dev_ctx, DenseTensorMeta(x.dtype(), {1}, x.layout()));
std::tuple<SparseCooTensor, DenseTensor> Conv3dGrad(
const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& kernel,
const DenseTensor& rulebook,
const SparseCooTensor& out_grad,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm) {
SparseCooTensor x_grad;
DenseTensor kernel_grad = phi::Empty<Context>(
dev_ctx, DenseTensorMeta(kernel.dtype(), {1}, kernel.layout()));
// TODO(zhangkaihuo): call InferMeta func here
Conv3dGradKernel<T, Context>(dev_ctx,
x,
rulebook,
kernel,
rulebook,
out_grad,
paddings,
dilations,
......@@ -64,10 +64,7 @@ std::vector<DenseTensor> Conv3dGrad(const Context& dev_ctx,
subm,
&x_grad,
&kernel_grad);
std::vector<DenseTensor> out(2);
out[0] = x_grad;
out[1] = kernel_grad;
return out;
return std::make_tuple(x_grad, kernel_grad);
}
} // namespace sparse
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/convolution_grad_kernel.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/sparse/cpu/convolution.h"
......@@ -31,15 +32,15 @@ namespace sparse {
template <typename T, typename Context>
void Conv3dGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& rulebook,
const DenseTensor& kernel,
const DenseTensor& out_grad,
const DenseTensor& rulebook,
const SparseCooTensor& out_grad,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm,
DenseTensor* x_grad,
SparseCooTensor* x_grad,
DenseTensor* kernel_grad) {
const auto& kernel_dims = kernel.dims();
const int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2];
......@@ -73,11 +74,18 @@ void Conv3dGradKernel(const Context& dev_ctx,
int half_kernel_size = kernel_size / 2;
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
x_grad->Resize(x.non_zero_elements().dims());
dev_ctx.Alloc(x_grad, x_grad->dtype(), sizeof(T) * x_grad->numel());
T* x_grad_values_ptr = x_grad->data<T>();
memset(x_grad_values_ptr, 0, sizeof(T) * x_grad->numel());
DenseTensor x_grad_indices =
phi::EmptyLike<int>(dev_ctx, x.non_zero_indices());
DenseTensor x_grad_values = phi::EmptyLike<T>(dev_ctx, x.non_zero_elements());
T* x_grad_values_ptr = x_grad_values.data<T>();
memset(x_grad_values_ptr, 0, sizeof(T) * x_grad_values.numel());
memset(d_x_features_ptr, 0, sizeof(T) * d_x_features.numel());
phi::Copy<Context>(dev_ctx,
x.non_zero_indices(),
dev_ctx.GetPlace(),
false,
&x_grad_indices);
x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true);
std::vector<int> offsets(kernel_size + 1), counter(kernel_size, 0);
for (int i = 0; i < rulebook_len; i++) {
......@@ -97,12 +105,12 @@ void Conv3dGradKernel(const Context& dev_ctx,
phi::funcs::sparse::SubmPreProcess<T, Context>(dev_ctx,
x,
kernel,
out_grad,
out_grad.non_zero_elements(),
in_channels,
out_channels,
half_kernel_size,
kernel_grad,
x_grad);
&x_grad_values);
if (max_count == 0) {
return;
}
......@@ -113,7 +121,7 @@ void Conv3dGradKernel(const Context& dev_ctx,
rulebook_len,
in_channels,
in_features_ptr);
Gather<T>(out_grad.data<T>(),
Gather<T>(out_grad.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len * 2,
rulebook_len,
out_channels,
......
......@@ -12,11 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "glog/logging.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/sparse/convolution_grad_kernel.h"
......@@ -36,15 +38,15 @@ namespace sparse {
template <typename T, typename Context>
void Conv3dGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& rulebook,
const DenseTensor& kernel,
const DenseTensor& out_grad,
const DenseTensor& rulebook,
const SparseCooTensor& out_grad,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm,
DenseTensor* x_grad,
SparseCooTensor* x_grad,
DenseTensor* kernel_grad) {
const auto& kernel_dims = kernel.dims();
const int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2];
......@@ -70,17 +72,25 @@ void Conv3dGradKernel(const Context& dev_ctx,
T* in_features_ptr = in_features.data<T>();
T* d_x_features_ptr = d_x_features.data<T>();
T* out_grad_features_ptr = out_grad_features.data<T>();
kernel_grad->ResizeAndAllocate(kernel_dims);
*kernel_grad = phi::EmptyLike<T>(dev_ctx, kernel);
T* d_kernel_ptr = kernel_grad->data<T>();
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, kernel_grad, static_cast<T>(0.0f));
int half_kernel_size = kernel_size / 2;
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
x_grad->ResizeAndAllocate(x.non_zero_elements().dims());
T* x_grad_values_ptr = x_grad->data<T>();
set_zero(dev_ctx, x_grad, static_cast<T>(0.0f));
DenseTensor x_grad_indices =
phi::EmptyLike<int>(dev_ctx, x.non_zero_indices());
DenseTensor x_grad_values = phi::EmptyLike<T>(dev_ctx, x.non_zero_elements());
T* x_grad_values_ptr = x_grad_values.data<T>();
set_zero(dev_ctx, &x_grad_values, static_cast<T>(0.0f));
set_zero(dev_ctx, &d_x_features, static_cast<T>(0.0f));
phi::Copy<Context>(dev_ctx,
x.non_zero_indices(),
dev_ctx.GetPlace(),
false,
&x_grad_indices);
x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true);
std::vector<int> offsets(kernel_size + 1), counter(kernel_size, 0),
h_counter(rulebook_len, 0);
......@@ -113,12 +123,12 @@ void Conv3dGradKernel(const Context& dev_ctx,
phi::funcs::sparse::SubmPreProcess<T, Context>(dev_ctx,
x,
kernel,
out_grad,
out_grad.non_zero_elements(),
in_channels,
out_channels,
half_kernel_size,
kernel_grad,
x_grad);
&x_grad_values);
if (max_count == 0) {
return;
}
......@@ -140,11 +150,12 @@ void Conv3dGradKernel(const Context& dev_ctx,
GatherKernel<T, int><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_grad.data<T>(),
rulebook_ptr + rulebook_len * 2,
out_grad_features_ptr,
rulebook_len,
out_channels);
dev_ctx.stream()>>>(
out_grad.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len * 2,
out_grad_features_ptr,
rulebook_len,
out_channels);
const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) {
......@@ -189,7 +200,7 @@ void Conv3dGradKernel(const Context& dev_ctx,
}
// 4. scatter
x_grad->ResizeAndAllocate(x.non_zero_elements().dims());
// x_grad->ResizeAndAllocate(x.non_zero_elements().dims());
DenseTensorMeta index_meta(DataType::INT32, {rulebook_len}, DataLayout::NCHW);
DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta));
DenseTensor unique_key = phi::Empty(dev_ctx, std::move(index_meta));
......
......@@ -71,6 +71,10 @@ void TestConv3dBase(const std::vector<int>& indices,
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx_cpu.SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx_cpu.Init();
const int in_channels = kernel_dims[3];
......@@ -132,19 +136,19 @@ void TestConv3dBase(const std::vector<int>& indices,
f_verify(out.non_zero_elements().data<T>(), correct_out_features);
if (backward) {
std::vector<DenseTensor> grads =
std::tuple<SparseCooTensor, DenseTensor> grads =
sparse::Conv3dGrad<T>(dev_ctx_cpu,
x_tensor,
rulebook,
kernel_tensor,
out.non_zero_elements(),
rulebook,
out,
paddings,
dilations,
strides,
1,
subm);
f_verify(grads[0].data<T>(), features_grad);
f_verify(grads[1].data<T>(), kernel_grad);
f_verify(std::get<0>(grads).non_zero_elements().data<T>(), features_grad);
f_verify(std::get<1>(grads).data<T>(), kernel_grad);
}
}
......@@ -233,23 +237,28 @@ void TestConv3dBase(const std::vector<int>& indices,
f_verify(h_features_tensor.data<T>(), correct_out_features);
if (backward) {
std::vector<DenseTensor> grads =
std::tuple<SparseCooTensor, DenseTensor> grads =
sparse::Conv3dGrad<T>(dev_ctx_gpu,
d_x_tensor,
d_rulebook,
d_kernel_tensor,
d_out.non_zero_elements(),
d_rulebook,
d_out,
paddings,
dilations,
strides,
1,
subm);
DenseTensor h_features_grad = phi::EmptyLike<T>(dev_ctx_cpu, grads[0]);
phi::Copy(dev_ctx_gpu, grads[0], phi::CPUPlace(), true, &h_features_grad);
DenseTensor d_features_grad = std::get<0>(grads).non_zero_elements();
DenseTensor d_kernel_grad = std::get<1>(grads);
DenseTensor h_features_grad =
phi::EmptyLike<T>(dev_ctx_cpu, d_features_grad);
phi::Copy(
dev_ctx_gpu, d_features_grad, phi::CPUPlace(), true, &h_features_grad);
f_verify(h_features_grad.data<T>(), features_grad);
DenseTensor h_kernel_grad = phi::EmptyLike<T>(dev_ctx_cpu, grads[1]);
phi::Copy(dev_ctx_gpu, grads[1], phi::CPUPlace(), true, &h_kernel_grad);
DenseTensor h_kernel_grad = phi::EmptyLike<T>(dev_ctx_cpu, d_kernel_grad);
phi::Copy(
dev_ctx_gpu, std::get<1>(grads), phi::CPUPlace(), true, &h_kernel_grad);
f_verify(h_kernel_grad.data<T>(), kernel_grad);
}
#endif
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from paddle import _C_ops
from paddle.fluid import core
from paddle.fluid.framework import _test_eager_guard
class TestSparseConv(unittest.TestCase):
def test_conv3d(self):
with _test_eager_guard():
kernel = [[[[[1], [1], [1]], [[1], [1], [1]], [[1], [1], [1]]]]]
dense_kernel = paddle.to_tensor(
kernel, dtype='float32', stop_gradient=False)
dense_kernel = paddle.reshape(dense_kernel, [1, 3, 3, 1, 1])
paddings = [0, 0, 0]
strides = [1, 1, 1]
dilations = [1, 1, 1]
indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
values = [1, 2, 3, 4]
indices = paddle.to_tensor(indices, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
dense_shape = [1, 1, 3, 4, 1]
correct_out_values = [[4], [10]]
sparse_input = core.eager.sparse_coo_tensor(indices, values,
dense_shape, False)
out = _C_ops.final_state_sparse_conv3d(sparse_input, dense_kernel,
paddings, dilations, strides,
1, False)
out.backward(out)
#At present, only backward can be verified to work normally
#TODO(zhangkaihuo): compare the result with dense conv
print(sparse_input.grad.non_zero_elements())
assert np.array_equal(correct_out_values,
out.non_zero_elements().numpy())
#TODO: Add more test case
- backward_api : conv3d_grad
forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm)
output : Tensor(x_grad@DenseTensor), Tensor(kernel_grad@DenseTensor)
output : Tensor(x_grad@SparseCooTensor), Tensor(kernel_grad@DenseTensor)
kernel :
func : sparse_conv3d_grad
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册