Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
60b86b2f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
60b86b2f
编写于
3月 09, 2022
作者:
Z
zhangkaihuo
提交者:
GitHub
3月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Sparse Conv3d gpu backward (#40143)
Sparse conv3d backward(gpu)
上级
3e9601ba
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
430 addition
and
215 deletion
+430
-215
paddle/phi/kernels/sparse/convolution_grad_kernel.h
paddle/phi/kernels/sparse/convolution_grad_kernel.h
+4
-2
paddle/phi/kernels/sparse/convolution_kernel.h
paddle/phi/kernels/sparse/convolution_kernel.h
+4
-14
paddle/phi/kernels/sparse/cpu/convolution.h
paddle/phi/kernels/sparse/cpu/convolution.h
+0
-5
paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc
paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc
+2
-9
paddle/phi/kernels/sparse/cpu/convolution_kernel.cc
paddle/phi/kernels/sparse/cpu/convolution_kernel.cc
+0
-5
paddle/phi/kernels/sparse/gpu/convolution.cu.h
paddle/phi/kernels/sparse/gpu/convolution.cu.h
+139
-0
paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu
+217
-0
paddle/phi/kernels/sparse/gpu/convolution_kernel.cu
paddle/phi/kernels/sparse/gpu/convolution_kernel.cu
+30
-142
paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc
paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc
+34
-38
未找到文件。
paddle/phi/kernels/sparse/convolution_grad_kernel.h
浏览文件 @
60b86b2f
...
...
@@ -45,8 +45,10 @@ std::vector<DenseTensor> Conv3dGrad(const Context& dev_ctx,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
)
{
DenseTensor
x_grad
=
phi
::
Empty
<
T
,
Context
>
(
dev_ctx
);
DenseTensor
kernel_grad
=
phi
::
Empty
<
T
,
Context
>
(
dev_ctx
);
DenseTensor
x_grad
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
x
.
dtype
(),
{
1
},
x
.
layout
()));
DenseTensor
kernel_grad
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
kernel
.
dtype
(),
{
1
},
kernel
.
layout
()));
// TODO(zhangkaihuo): call InferMeta func here
Conv3dGradKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
...
...
paddle/phi/kernels/sparse/convolution_kernel.h
浏览文件 @
60b86b2f
...
...
@@ -20,18 +20,6 @@ limitations under the License. */
#include "paddle/phi/kernels/empty_kernel.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
DenseTensor
Empty
(
const
Context
&
dev_ctx
)
{
phi
::
DenseTensor
dense_out
(
phi
::
make_intrusive
<
paddle
::
experimental
::
SharedStorage
>
(
dev_ctx
.
GetPlace
()),
{
paddle
::
experimental
::
CppTypeToDataType
<
T
>::
Type
(),
{
-
1
},
DataLayout
::
NCHW
});
return
dense_out
;
}
namespace
sparse
{
struct
Dims4D
{
...
...
@@ -149,8 +137,10 @@ SparseCooTensor Conv3d(const Context& dev_ctx,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
DenseTensor
*
rulebook
)
{
DenseTensor
indices
=
phi
::
Empty
<
T
,
Context
>
(
dev_ctx
);
DenseTensor
values
=
phi
::
Empty
<
T
,
Context
>
(
dev_ctx
);
DenseTensor
indices
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
1
},
DataLayout
::
NCHW
));
DenseTensor
values
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
x
.
dtype
(),
{
1
},
x
.
layout
()));
SparseCooTensor
coo
(
indices
,
values
,
x
.
dims
());
Conv3dKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
kernel
,
paddings
,
dilations
,
strides
,
groups
,
&
coo
,
rulebook
);
...
...
paddle/phi/kernels/sparse/cpu/convolution.h
浏览文件 @
60b86b2f
...
...
@@ -45,9 +45,6 @@ void ProductRuleBook(const Context& dev_ctx,
const
int64_t
non_zero_num
=
x
.
nnz
();
const
auto
&
non_zero_indices
=
x
.
non_zero_indices
();
const
int
*
indices_ptr
=
non_zero_indices
.
data
<
int
>
();
dev_ctx
.
Alloc
(
counter_per_kernel
,
counter_per_kernel
->
dtype
(),
sizeof
(
int
)
*
counter_per_kernel
->
numel
());
int
*
counter_ptr
=
counter_per_kernel
->
data
<
int
>
();
int
kernel_size
=
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
memset
(
counter_ptr
,
0
,
kernel_size
*
sizeof
(
int
));
...
...
@@ -138,8 +135,6 @@ void UpdateRulebookAndOutIndex(const Context& dev_ctx,
x
.
dtype
(),
{
out_non_zero_num
,
out_channels
},
x
.
layout
());
phi
::
DenseTensor
out_indices
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
indices_meta
));
phi
::
DenseTensor
out_values
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
values_meta
));
dev_ctx
.
Alloc
(
&
out_indices
,
out_indices
.
dtype
(),
out_indices
.
numel
()
*
sizeof
(
int
));
int
*
out_indices_ptr
=
out_indices
.
data
<
int
>
();
int
i
=
0
;
for
(
auto
it
=
out_indexs
.
begin
();
it
!=
out_indexs
.
end
();
it
++
,
i
++
)
{
...
...
paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc
浏览文件 @
60b86b2f
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/phi/kernels/sparse/convolution_grad_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/sparse/cpu/convolution.h"
namespace
phi
{
...
...
@@ -60,15 +61,8 @@ void Conv3dGradKernel(const Context& dev_ctx,
phi
::
DenseTensor
out_grad_features
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
out_grad_features_meta
));
dev_ctx
.
Alloc
(
&
in_features
,
in_features
.
dtype
(),
sizeof
(
T
)
*
in_features
.
numel
());
T
*
in_features_ptr
=
in_features
.
data
<
T
>
();
dev_ctx
.
Alloc
(
&
d_x_features
,
d_x_features
.
dtype
(),
sizeof
(
T
)
*
d_x_features
.
numel
());
T
*
d_x_features_ptr
=
d_x_features
.
data
<
T
>
();
dev_ctx
.
Alloc
(
&
out_grad_features
,
out_grad_features
.
dtype
(),
sizeof
(
T
)
*
out_grad_features
.
numel
());
T
*
out_grad_features_ptr
=
out_grad_features
.
data
<
T
>
();
kernel_grad
->
Resize
(
kernel_dims
);
dev_ctx
.
Alloc
(
...
...
@@ -156,12 +150,11 @@ void Conv3dGradKernel(const Context& dev_ctx,
}
// namespace sparse
}
// namespace phi
PD_REGISTER_KERNEL
(
sparse_conv_grad
,
PD_REGISTER_KERNEL
(
sparse_conv
3d
_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
Conv3dGradKernel
,
float
,
double
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
3
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
paddle/phi/kernels/sparse/cpu/convolution_kernel.cc
浏览文件 @
60b86b2f
...
...
@@ -81,8 +81,6 @@ void Conv3dKernel(const Context& dev_ctx,
phi
::
Empty
(
dev_ctx
,
std
::
move
(
in_features_meta
));
phi
::
DenseTensor
out_features
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
out_features_meta
));
dev_ctx
.
Alloc
(
&
in_features
,
x
.
dtype
(),
sizeof
(
T
)
*
in_features
.
numel
());
dev_ctx
.
Alloc
(
&
out_features
,
x
.
dtype
(),
sizeof
(
T
)
*
out_features
.
numel
());
T
*
in_features_ptr
=
in_features
.
data
<
T
>
();
T
*
out_features_ptr
=
out_features
.
data
<
T
>
();
...
...
@@ -128,9 +126,6 @@ void Conv3dKernel(const Context& dev_ctx,
}
// 4. scatter
dev_ctx
.
Alloc
(
out
->
mutable_non_zero_elements
(),
out
->
mutable_non_zero_elements
()
->
dtype
(),
sizeof
(
T
)
*
in_features
.
numel
());
T
*
out_values_ptr
=
out
->
mutable_non_zero_elements
()
->
data
<
T
>
();
memset
(
out_values_ptr
,
0
,
sizeof
(
T
)
*
out
->
nnz
()
*
out_channels
);
Scatter
<
T
>
(
out_features_ptr
,
...
...
paddle/phi/kernels/sparse/gpu/convolution.cu.h
0 → 100644
浏览文件 @
60b86b2f
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <thrust/execution_policy.h>
#include <thrust/remove.h>
#include <thrust/sort.h>
#include <thrust/unique.h>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
#include "paddle/phi/kernels/sparse/convolution_kernel.h"
namespace
phi
{
namespace
sparse
{
// TODO(zhangkaihuo): After the GatherCUDAKernel is migrated to phi, replace
// this kernel with phi::GatherCUDAKernel;
// Vectorization can be used to improve read and write bandwidth
/**
* brief: gather data from params according to indices
* params: the inputs
* indices: the indices you want to gather
* output: the outputs
* index_size: the size of indices
* slice_size: slice size corresponding to each index, here is the channel size
**/
template
<
typename
T
,
typename
IndexT
=
int
>
__global__
void
GatherKernel
(
const
T
*
params
,
const
IndexT
*
indices
,
T
*
output
,
size_t
index_size
,
size_t
slice_size
)
{
CUDA_KERNEL_LOOP_TYPE
(
i
,
index_size
*
slice_size
,
int64_t
)
{
int64_t
indices_i
=
i
/
slice_size
;
int64_t
slice_i
=
i
-
indices_i
*
slice_size
;
// offset inside the slice
IndexT
gather_i
=
indices
[
indices_i
];
int64_t
params_i
=
gather_i
*
slice_size
+
slice_i
;
*
(
output
+
i
)
=
*
(
params
+
params_i
);
}
}
/**
* brief: scatter add
* input: the inputs
* unique_value: refer to UpdateIndexKernel notes
* out_index: the output feature index
* non_zero_num: the number of output features
* rulebook_len: the length of rulebook
* channels: the output channel size
* out: the outputs
**/
template
<
typename
T
>
__global__
void
ScatterKernel
(
const
T
*
input
,
const
int
*
unique_value
,
const
int
*
out_index
,
const
int
non_zero_num
,
const
int
rulebook_len
,
const
int
channels
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
for
(
int
i
=
tid
;
i
<
non_zero_num
*
channels
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
indices_i
=
i
/
channels
;
int
channels_i
=
i
-
indices_i
*
channels
;
int
start
=
unique_value
[
indices_i
];
int
end
=
indices_i
==
non_zero_num
-
1
?
rulebook_len
:
unique_value
[
indices_i
+
1
];
// max(end-start) = kernel_size
T
sum
=
static_cast
<
T
>
(
0
);
for
(
int
j
=
start
;
j
<
end
;
j
++
)
{
const
int
out_feature_i
=
out_index
[
j
];
sum
+=
input
[
out_feature_i
*
channels
+
channels_i
];
}
out
[
indices_i
*
channels
+
channels_i
]
=
sum
;
}
}
template
<
typename
Context
>
inline
int
*
SortedAndUniqueIndex
(
const
Context
&
dev_ctx
,
const
int
*
rulebook_ptr
,
const
int
len
,
DenseTensor
*
out_index
,
DenseTensor
*
unique_key
,
DenseTensor
*
unique_value
)
{
phi
::
IndexKernel
<
int
,
kps
::
IdentityFunctor
<
int
>>
(
dev_ctx
,
out_index
,
kps
::
IdentityFunctor
<
int
>
());
phi
::
IndexKernel
<
int
,
kps
::
IdentityFunctor
<
int
>>
(
dev_ctx
,
unique_value
,
kps
::
IdentityFunctor
<
int
>
());
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
unique_key
->
data
<
int
>
(),
rulebook_ptr
,
sizeof
(
int
)
*
len
,
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToDevice
,
#else
cudaMemcpyDeviceToDevice
,
#endif
dev_ctx
.
stream
());
// compared with thrust::sort_by_key, thrust::merge_by_key may achieved higher
// performance, but thrust::merge_by_key limited by data size
#ifdef PADDLE_WITH_HIP
thrust
::
sort_by_key
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
thrust
::
sort_by_key
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
unique_key
->
data
<
int
>
(),
unique_key
->
data
<
int
>
()
+
len
,
out_index
->
data
<
int
>
());
// 4. unique
thrust
::
pair
<
int
*
,
int
*>
new_end
=
#ifdef PADDLE_WITH_HIP
thrust
::
unique_by_key
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
thrust
::
unique_by_key
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
unique_key
->
data
<
int
>
(),
unique_key
->
data
<
int
>
()
+
len
,
unique_value
->
data
<
int
>
());
return
new_end
.
first
;
}
}
// namespace sparse
}
// namespace phi
paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu
0 → 100644
浏览文件 @
60b86b2f
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/sparse/convolution_grad_kernel.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h"
namespace
phi
{
namespace
sparse
{
// rulebook[3, rulebook_len]:
//[
// [kernel_index],
// [in_i],
// [out_i],
//]
// x_grad = out_grad * transpose(kenrel)
// kernel_grad = transpose(x) * out_grad
template
<
typename
T
,
typename
Context
>
void
Conv3dGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
kernel
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
DenseTensor
*
x_grad
,
DenseTensor
*
kernel_grad
)
{
const
auto
&
kernel_dims
=
kernel
.
dims
();
const
int
kernel_size
=
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
const
int
in_channels
=
kernel_dims
[
3
];
const
int
out_channels
=
kernel_dims
[
4
];
const
int
*
rulebook_ptr
=
rulebook
.
data
<
int
>
();
const
int
rulebook_len
=
rulebook
.
dims
()[
1
];
DenseTensorMeta
in_features_meta
(
x
.
dtype
(),
{
rulebook_len
,
in_channels
},
DataLayout
::
NCHW
);
DenseTensorMeta
d_x_features_meta
(
x
.
dtype
(),
{
rulebook_len
,
in_channels
},
DataLayout
::
NCHW
);
DenseTensorMeta
out_grad_features_meta
(
x
.
dtype
(),
{
rulebook_len
,
out_channels
},
DataLayout
::
NCHW
);
phi
::
DenseTensor
in_features
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
in_features_meta
));
phi
::
DenseTensor
d_x_features
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
d_x_features_meta
));
phi
::
DenseTensor
out_grad_features
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
out_grad_features_meta
));
T
*
in_features_ptr
=
in_features
.
data
<
T
>
();
T
*
d_x_features_ptr
=
d_x_features
.
data
<
T
>
();
T
*
out_grad_features_ptr
=
out_grad_features
.
data
<
T
>
();
kernel_grad
->
Resize
(
kernel_dims
);
dev_ctx
.
Alloc
(
kernel_grad
,
kernel_grad
->
dtype
(),
kernel_grad
->
numel
()
*
sizeof
(
T
));
T
*
d_kernel_ptr
=
kernel_grad
->
data
<
T
>
();
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
kernel_grad
,
static_cast
<
T
>
(
0.0
f
));
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
*
in_channels
,
1
);
GatherKernel
<
T
,
int
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
x
.
non_zero_elements
().
data
<
T
>
(),
rulebook_ptr
+
rulebook_len
,
in_features_ptr
,
rulebook_len
,
in_channels
);
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
*
out_channels
,
1
);
GatherKernel
<
T
,
int
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
out_grad
.
non_zero_elements
().
data
<
T
>
(),
rulebook_ptr
+
rulebook_len
*
2
,
out_grad_features_ptr
,
rulebook_len
,
out_channels
);
auto
blas
=
phi
::
funcs
::
GetBlas
<
Context
,
T
>
(
dev_ctx
);
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
),
counter
(
kernel_size
,
0
),
h_counter
(
rulebook_len
,
0
);
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
h_counter
[
0
],
rulebook_ptr
,
rulebook_len
*
sizeof
(
int
),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost
,
#else
cudaMemcpyDeviceToHost
,
#endif
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
for
(
int
i
=
0
;
i
<
rulebook_len
;
i
++
)
{
counter
[
h_counter
[
i
]]
+=
1
;
}
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
offsets
[
i
]
=
offset
;
offset
+=
counter
[
i
];
}
offsets
[
kernel_size
]
=
offset
;
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
counter
[
i
]
<=
0
)
{
continue
;
}
const
int
M
=
counter
[
i
];
const
int
K
=
in_channels
;
const
int
N
=
out_channels
;
T
*
tmp_in_ptr
=
in_features_ptr
+
offsets
[
i
]
*
in_channels
;
T
*
tmp_out_grad_ptr
=
out_grad_features_ptr
+
offsets
[
i
]
*
out_channels
;
const
T
*
tmp_kernel_ptr
=
kernel_ptr
+
i
*
in_channels
*
out_channels
;
T
*
tmp_d_x_ptr
=
d_x_features_ptr
+
offsets
[
i
]
*
out_channels
;
T
*
tmp_d_kernel_ptr
=
d_kernel_ptr
+
i
*
in_channels
*
out_channels
;
// call gemm: d_kernel = transpose(x) * out_grad
// (in_channels, n) * (n, out_channels)
blas
.
GEMM
(
CblasTrans
,
CblasNoTrans
,
M
,
N
,
K
,
static_cast
<
T
>
(
1
),
tmp_in_ptr
,
tmp_out_grad_ptr
,
static_cast
<
T
>
(
0
),
tmp_d_kernel_ptr
);
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
blas
.
GEMM
(
CblasNoTrans
,
CblasTrans
,
M
,
K
,
N
,
static_cast
<
T
>
(
1
),
tmp_out_grad_ptr
,
tmp_kernel_ptr
,
static_cast
<
T
>
(
0
),
tmp_d_x_ptr
);
}
// 4. scatter
x_grad
->
Resize
(
x
.
non_zero_elements
().
dims
());
dev_ctx
.
Alloc
(
x_grad
,
x_grad
->
dtype
(),
sizeof
(
T
)
*
x_grad
->
numel
());
T
*
x_grad_values_ptr
=
x_grad
->
data
<
T
>
();
DenseTensor
out_index
=
phi
::
Empty
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
rulebook_len
},
DataLayout
::
NCHW
));
DenseTensor
unique_key
=
phi
::
Empty
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
rulebook_len
},
DataLayout
::
NCHW
));
DenseTensor
unique_value
=
phi
::
Empty
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
rulebook_len
},
DataLayout
::
NCHW
));
SortedAndUniqueIndex
(
dev_ctx
,
rulebook_ptr
+
rulebook_len
,
rulebook_len
,
&
out_index
,
&
unique_key
,
&
unique_value
);
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
*
in_channels
,
1
);
ScatterKernel
<
T
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
d_x_features_ptr
,
unique_value
.
data
<
int
>
(),
out_index
.
data
<
int
>
(),
x
.
nnz
(),
rulebook_len
,
in_channels
,
x_grad_values_ptr
);
}
}
// namespace sparse
}
// namespace phi
PD_REGISTER_KERNEL
(
sparse_conv3d_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
Conv3dGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
paddle/phi/kernels/sparse/gpu/convolution_kernel.cu
浏览文件 @
60b86b2f
...
...
@@ -17,7 +17,6 @@ limitations under the License. */
#include <thrust/sort.h>
#include <thrust/unique.h>
#include "glog/logging.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
...
...
@@ -28,19 +27,11 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/primitive/compute_primitives.h"
#include "paddle/phi/kernels/sparse/convolution_kernel.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h"
namespace
phi
{
namespace
sparse
{
// TODO(zhangkaihuo) replace this kernel with KP::InitWithDataIndex
__global__
void
InitByIndexKernel
(
const
int
n
,
int
*
out1
,
int
*
out2
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
for
(
int
i
=
tid
;
i
<
n
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
out1
[
i
]
=
i
;
out2
[
i
]
=
i
;
}
}
/**
* @brief: update the out index and indices
* unique_keys: save the index of the output feature list
...
...
@@ -124,7 +115,7 @@ __global__ void ProductRuleBookKernel(const int* x_indices,
int
in_z
=
x_indices
[
i
+
non_zero_num
];
int
in_y
=
x_indices
[
i
+
2
*
non_zero_num
];
int
in_x
=
x_indices
[
i
+
3
*
non_zero_num
];
int
in_i
=
-
1
,
out_index
=
-
1
;
int
in_i
=
-
1
,
out_index
=
-
1
,
kernel_i
=
-
1
;
if
(
Check
(
x_dims
,
kernel_dims
,
paddings
,
...
...
@@ -143,9 +134,11 @@ __global__ void ProductRuleBookKernel(const int* x_indices,
out_index
=
PointToIndex
<
Dims4D
>
(
batch
,
out_x
,
out_y
,
out_z
,
out_dims
);
atomicAdd
(
&
counter_buf
[
kernel_index
],
1
);
kernel_i
=
kernel_index
;
}
rulebook
[
kernel_index
*
non_zero_num
+
i
]
=
in_i
;
rulebook
[
kernel_index
*
non_zero_num
+
offset
+
i
]
=
out_index
;
rulebook
[
kernel_index
*
non_zero_num
+
i
]
=
kernel_i
;
rulebook
[
kernel_index
*
non_zero_num
+
offset
+
i
]
=
in_i
;
rulebook
[
kernel_index
*
non_zero_num
+
offset
*
2
+
i
]
=
out_index
;
++
kernel_index
;
}
}
...
...
@@ -157,68 +150,6 @@ __global__ void ProductRuleBookKernel(const int* x_indices,
}
}
// TODO(zhangkaihuo): After the GatherCUDAKernel is migrated to phi, replace
// this kernel with phi::GatherCUDAKernel;
// Vectorization can be used to improve read and write bandwidth
/**
* brief: gather data from params according to indices
* params: the inputs
* indices: the indices you want to gather
* output: the outputs
* index_size: the size of indices
* slice_size: slice size corresponding to each index, here is the channel size
**/
template
<
typename
T
,
typename
IndexT
=
int
>
__global__
void
GatherKernel
(
const
T
*
params
,
const
IndexT
*
indices
,
T
*
output
,
size_t
index_size
,
size_t
slice_size
)
{
CUDA_KERNEL_LOOP_TYPE
(
i
,
index_size
*
slice_size
,
int64_t
)
{
int64_t
indices_i
=
i
/
slice_size
;
int64_t
slice_i
=
i
-
indices_i
*
slice_size
;
// offset inside the slice
IndexT
gather_i
=
indices
[
indices_i
];
int64_t
params_i
=
gather_i
*
slice_size
+
slice_i
;
*
(
output
+
i
)
=
*
(
params
+
params_i
);
}
}
/**
* brief: scatter add
* input: the inputs
* unique_value: refer to UpdateIndexKernel notes
* out_index: the output feature index
* non_zero_num: the number of output features
* rulebook_len: the length of rulebook
* channels: the output channel size
* out: the outputs
**/
template
<
typename
T
>
__global__
void
ScatterKernel
(
const
T
*
input
,
const
int
*
unique_value
,
const
int
*
out_index
,
const
int
non_zero_num
,
const
int
rulebook_len
,
const
int
channels
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
for
(
int
i
=
tid
;
i
<
non_zero_num
*
channels
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
indices_i
=
i
/
channels
;
int
channels_i
=
i
-
indices_i
*
channels
;
int
start
=
unique_value
[
indices_i
];
int
end
=
indices_i
==
non_zero_num
-
1
?
rulebook_len
:
unique_value
[
indices_i
+
1
];
// max(end-start) = kernel_size
T
sum
=
static_cast
<
T
>
(
0
);
for
(
int
j
=
start
;
j
<
end
;
j
++
)
{
const
int
out_feature_i
=
out_index
[
j
];
sum
+=
input
[
out_feature_i
*
channels
+
channels_i
];
}
out
[
indices_i
*
channels
+
channels_i
]
=
sum
;
}
}
// brief: calculation the distance between start and end
__global__
void
DistanceKernel
(
const
int
*
start
,
const
int
*
end
,
...
...
@@ -264,16 +195,12 @@ int ProductRuleBook(const Context& dev_ctx,
const
int64_t
non_zero_num
=
x
.
nnz
();
const
auto
&
non_zero_indices
=
x
.
non_zero_indices
();
const
int
*
indices_ptr
=
non_zero_indices
.
data
<
int
>
();
dev_ctx
.
Alloc
(
counter_per_kernel
,
counter_per_kernel
->
dtype
(),
sizeof
(
int
)
*
counter_per_kernel
->
numel
());
int
*
counter_ptr
=
counter_per_kernel
->
data
<
int
>
();
dev_ctx
.
Alloc
(
offsets_per_kernel
,
offsets_per_kernel
->
dtype
(),
sizeof
(
int
)
*
offsets_per_kernel
->
numel
());
int
*
offsets_ptr
=
offsets_per_kernel
->
data
<
int
>
();
int
kernel_size
=
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
rulebook
->
ResizeAndAllocate
({
2
,
kernel_size
*
non_zero_num
});
const
int
rulebook_rows
=
3
;
const
int
rulebook_cols
=
kernel_size
*
non_zero_num
;
rulebook
->
ResizeAndAllocate
({
rulebook_rows
,
rulebook_cols
});
dev_ctx
.
Alloc
(
rulebook
,
rulebook
->
dtype
(),
sizeof
(
int
)
*
rulebook
->
numel
());
int
*
rulebook_ptr
=
rulebook
->
data
<
int
>
();
...
...
@@ -312,7 +239,7 @@ int ProductRuleBook(const Context& dev_ctx,
int
*
last
=
thrust
::
remove
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
rulebook_ptr
,
rulebook_ptr
+
2
*
kernel_size
*
non_zero_num
,
rulebook_ptr
+
rulebook_rows
*
rulebook_cols
,
-
1
);
#ifdef PADDLE_WITH_HIP
...
...
@@ -350,6 +277,7 @@ int ProductRuleBook(const Context& dev_ctx,
dev_ctx
.
Wait
();
int
rulebook_len
=
(
*
h_counter
)[
kernel_size
-
1
]
+
(
*
h_offsets
)[
kernel_size
-
1
];
rulebook
->
Resize
({
rulebook_rows
,
rulebook_len
});
// 3. sorted or merge the out index
out_index
->
ResizeAndAllocate
({
rulebook_len
});
...
...
@@ -365,66 +293,30 @@ int ProductRuleBook(const Context& dev_ctx,
unique_key
,
unique_key
->
dtype
(),
sizeof
(
int
)
*
unique_key
->
numel
());
int
*
unique_key_ptr
=
unique_key
->
data
<
int
>
();
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
,
1
);
InitByIndexKernel
<<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
rulebook_len
,
out_index_ptr
,
unique_value_ptr
);
#ifdef PADDLE_WITH_HIP
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
unique_key_ptr
,
rulebook_ptr
+
rulebook_len
,
rulebook_len
*
sizeof
(
int
),
hipMemcpyDeviceToDevice
,
dev_ctx
.
stream
());
#else
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
unique_key_ptr
,
rulebook_ptr
+
rulebook_len
,
rulebook_len
*
sizeof
(
int
),
cudaMemcpyDeviceToDevice
,
dev_ctx
.
stream
());
#endif
// compared with thrust::sort_by_key, thrust::merge_by_key may achieved higher
// performance, but thrust::merge_by_key limited by data size
#ifdef PADDLE_WITH_HIP
thrust
::
sort_by_key
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
thrust
::
sort_by_key
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
unique_key_ptr
,
unique_key_ptr
+
rulebook_len
,
out_index_ptr
);
// 4. unique
thrust
::
pair
<
int
*
,
int
*>
new_end
=
#ifdef PADDLE_WITH_HIP
thrust
::
unique_by_key
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
thrust
::
unique_by_key
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
unique_key_ptr
,
unique_key_ptr
+
rulebook_len
,
unique_value_ptr
);
int
*
new_end
=
SortedAndUniqueIndex
(
dev_ctx
,
rulebook_ptr
+
2
*
rulebook_len
,
rulebook_len
,
out_index
,
unique_key
,
unique_value
);
// thrust::distance doesn't support stream parameters
// const int out_non_zero_num = thrust::distance(unique_key_ptr,
// new_end.first);
DistanceKernel
<<<
1
,
1
>>>
(
unique_key_ptr
,
new_end
.
first
,
rulebook_ptr
+
2
*
kernel_size
*
non_zero_num
-
1
);
new_end
,
rulebook_ptr
+
rulebook_rows
*
rulebook_cols
-
1
);
int
out_non_zero_num
=
0
;
#ifdef PADDLE_WITH_HIP
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
out_non_zero_num
,
rulebook_ptr
+
2
*
kernel_size
*
non_zero_num
-
1
,
rulebook_ptr
+
rulebook_rows
*
rulebook_cols
-
1
,
sizeof
(
int
),
hipMemcpyDeviceToHost
,
dev_ctx
.
stream
());
#else
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
out_non_zero_num
,
rulebook_ptr
+
2
*
kernel_size
*
non_zero_num
-
1
,
rulebook_ptr
+
rulebook_rows
*
rulebook_cols
-
1
,
sizeof
(
int
),
cudaMemcpyDeviceToHost
,
dev_ctx
.
stream
());
...
...
@@ -440,8 +332,6 @@ int ProductRuleBook(const Context& dev_ctx,
phi
::
DenseTensor
out_indices
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
indices_meta
));
phi
::
DenseTensor
out_values
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
values_meta
));
dev_ctx
.
Alloc
(
&
out_indices
,
out_indices
.
dtype
(),
sizeof
(
int
)
*
out_indices
.
numel
());
int
*
out_indices_ptr
=
out_indices
.
data
<
int
>
();
config
=
...
...
@@ -456,7 +346,7 @@ int ProductRuleBook(const Context& dev_ctx,
rulebook_len
,
d_out_dims
,
out_indices_ptr
,
rulebook_ptr
+
rulebook_len
);
rulebook_ptr
+
2
*
rulebook_len
);
out
->
SetMember
(
out_indices
,
out_values
,
out_dims
,
true
);
return
rulebook_len
;
}
...
...
@@ -499,9 +389,12 @@ void Conv3dKernel(const Context& dev_ctx,
DataType
::
INT32
,
{
kernel_size
},
DataLayout
::
NCHW
);
DenseTensor
counter_per_kernel
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
counter_meta
));
DenseTensor
offsets_per_kernel
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
offsets_meta
));
DenseTensor
out_index
=
phi
::
Empty
<
int
,
Context
>
(
dev_ctx
);
DenseTensor
unique_key
=
phi
::
Empty
<
int
,
Context
>
(
dev_ctx
);
DenseTensor
unique_value
=
phi
::
Empty
<
int
,
Context
>
(
dev_ctx
);
DenseTensor
out_index
=
phi
::
Empty
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
1
},
DataLayout
::
NCHW
));
DenseTensor
unique_key
=
phi
::
Empty
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
1
},
DataLayout
::
NCHW
));
DenseTensor
unique_value
=
phi
::
Empty
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
1
},
DataLayout
::
NCHW
));
int
n
=
ProductRuleBook
<
T
,
Context
>
(
dev_ctx
,
x
,
...
...
@@ -522,6 +415,7 @@ void Conv3dKernel(const Context& dev_ctx,
const
int
*
counter_ptr
=
counter_per_kernel
.
data
<
int
>
();
const
int
*
offsets_ptr
=
counter_per_kernel
.
data
<
int
>
();
const
int
*
rulebook_ptr
=
rulebook
->
data
<
int
>
();
// 2. gather
DenseTensorMeta
in_features_meta
(
...
...
@@ -532,11 +426,7 @@ void Conv3dKernel(const Context& dev_ctx,
phi
::
Empty
(
dev_ctx
,
std
::
move
(
in_features_meta
));
phi
::
DenseTensor
out_features
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
out_features_meta
));
dev_ctx
.
Alloc
(
&
in_features
,
in_features
.
dtype
(),
sizeof
(
T
)
*
in_features
.
numel
());
T
*
in_features_ptr
=
in_features
.
data
<
T
>
();
dev_ctx
.
Alloc
(
&
out_features
,
out_features
.
dtype
(),
sizeof
(
T
)
*
out_features
.
numel
());
T
*
out_features_ptr
=
out_features
.
data
<
T
>
();
auto
config
=
...
...
@@ -545,7 +435,7 @@ void Conv3dKernel(const Context& dev_ctx,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
x
.
non_zero_elements
().
data
<
T
>
(),
rulebook
->
data
<
int
>
()
,
rulebook
_ptr
+
n
,
in_features_ptr
,
n
,
in_channels
);
...
...
@@ -553,8 +443,6 @@ void Conv3dKernel(const Context& dev_ctx,
// 3. call gemm for every werght
auto
blas
=
phi
::
funcs
::
GetBlas
<
Context
,
T
>
(
dev_ctx
);
auto
*
out_values
=
out
->
mutable_non_zero_elements
();
dev_ctx
.
Alloc
(
out_values
,
out_values
->
dtype
(),
sizeof
(
T
)
*
out_values
->
numel
());
T
*
out_values_ptr
=
out_values
->
data
<
T
>
();
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
...
...
paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc
浏览文件 @
60b86b2f
...
...
@@ -78,9 +78,6 @@ void TestConv3dBase(const std::vector<int>& indices,
DenseTensor
indices_tensor
=
phi
::
Empty
(
dev_ctx_cpu
,
DenseTensorMeta
(
DataType
::
INT32
,
{
4
,
non_zero_num
},
DataLayout
::
NCHW
));
dev_ctx_cpu
.
Alloc
(
&
indices_tensor
,
indices_tensor
.
dtype
(),
sizeof
(
int
)
*
indices_tensor
.
numel
());
memcpy
(
indices_tensor
.
data
<
int
>
(),
indices
.
data
(),
indices
.
size
()
*
sizeof
(
int
));
DenseTensor
features_tensor
=
phi
::
Empty
(
...
...
@@ -88,9 +85,6 @@ void TestConv3dBase(const std::vector<int>& indices,
DenseTensorMeta
(
paddle
::
experimental
::
CppTypeToDataType
<
T
>::
Type
(),
{
non_zero_num
,
in_channels
},
DataLayout
::
NHWC
));
dev_ctx_cpu
.
Alloc
(
&
features_tensor
,
features_tensor
.
dtype
(),
features_tensor
.
numel
()
*
sizeof
(
T
));
memcpy
(
features_tensor
.
data
<
T
>
(),
features
.
data
(),
features
.
size
()
*
sizeof
(
T
));
...
...
@@ -101,12 +95,18 @@ void TestConv3dBase(const std::vector<int>& indices,
DenseTensorMeta
(
paddle
::
experimental
::
CppTypeToDataType
<
T
>::
Type
(),
kernel_dims
,
DataLayout
::
NHWC
));
dev_ctx_cpu
.
Alloc
(
&
kernel_tensor
,
kernel_tensor
.
dtype
(),
kernel_tensor
.
numel
()
*
sizeof
(
T
));
memcpy
(
kernel_tensor
.
data
<
T
>
(),
kernel
.
data
(),
kernel
.
size
()
*
sizeof
(
T
));
auto
f_verify
=
[
&
](
const
T
*
real_data
,
const
std
::
vector
<
T
>&
correct_data
)
{
for
(
uint64_t
i
=
0
;
i
<
correct_data
.
size
();
i
++
)
{
float
tmp
=
std
::
fabs
(
static_cast
<
float
>
(
correct_data
[
i
]
-
real_data
[
i
]));
ASSERT_LT
(
tmp
,
diff
);
}
};
if
(
!
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
)
{
DenseTensor
rulebook
=
phi
::
Empty
<
int
,
phi
::
CPUContext
>
(
dev_ctx_cpu
);
DenseTensor
rulebook
=
phi
::
Empty
(
dev_ctx_cpu
,
DenseTensorMeta
(
DataType
::
INT32
,
{
1
},
DataLayout
::
NCHW
));
SparseCooTensor
out
=
sparse
::
Conv3d
<
T
>
(
dev_ctx_cpu
,
x_tensor
,
kernel_tensor
,
...
...
@@ -127,15 +127,6 @@ void TestConv3dBase(const std::vector<int>& indices,
correct_out_indices
.
size
()
*
sizeof
(
int
));
ASSERT_EQ
(
cmp_indices
,
0
);
auto
f_verify
=
[
&
](
const
T
*
real_data
,
const
std
::
vector
<
T
>&
correct_data
)
{
for
(
uint64_t
i
=
0
;
i
<
correct_data
.
size
();
i
++
)
{
float
tmp
=
std
::
fabs
(
static_cast
<
float
>
(
correct_data
[
i
]
-
real_data
[
i
]));
ASSERT_LT
(
tmp
,
diff
);
}
};
f_verify
(
out
.
non_zero_elements
().
data
<
T
>
(),
correct_out_features
);
if
(
backward
)
{
...
...
@@ -170,9 +161,6 @@ void TestConv3dBase(const std::vector<int>& indices,
DenseTensor
d_indices_tensor
=
phi
::
Empty
(
dev_ctx_gpu
,
DenseTensorMeta
(
DataType
::
INT32
,
{
4
,
non_zero_num
},
DataLayout
::
NCHW
));
dev_ctx_gpu
.
Alloc
(
&
d_indices_tensor
,
d_indices_tensor
.
dtype
(),
sizeof
(
int
)
*
d_indices_tensor
.
numel
());
phi
::
Copy
(
dev_ctx_gpu
,
indices_tensor
,
phi
::
GPUPlace
(),
true
,
&
d_indices_tensor
);
...
...
@@ -181,9 +169,6 @@ void TestConv3dBase(const std::vector<int>& indices,
DenseTensorMeta
(
paddle
::
experimental
::
CppTypeToDataType
<
T
>::
Type
(),
{
non_zero_num
,
in_channels
},
DataLayout
::
NHWC
));
dev_ctx_gpu
.
Alloc
(
&
d_features_tensor
,
d_features_tensor
.
dtype
(),
sizeof
(
T
)
*
d_features_tensor
.
numel
());
phi
::
Copy
(
dev_ctx_gpu
,
features_tensor
,
phi
::
GPUPlace
(),
true
,
&
d_features_tensor
);
...
...
@@ -194,13 +179,11 @@ void TestConv3dBase(const std::vector<int>& indices,
DenseTensorMeta
(
paddle
::
experimental
::
CppTypeToDataType
<
T
>::
Type
(),
kernel_dims
,
DataLayout
::
NHWC
));
dev_ctx_gpu
.
Alloc
(
&
d_kernel_tensor
,
d_kernel_tensor
.
dtype
(),
sizeof
(
T
)
*
d_kernel_tensor
.
numel
());
phi
::
Copy
(
dev_ctx_gpu
,
kernel_tensor
,
phi
::
GPUPlace
(),
true
,
&
d_kernel_tensor
);
DenseTensor
d_rulebook
=
phi
::
Empty
<
int
,
phi
::
GPUContext
>
(
dev_ctx_gpu
);
DenseTensor
d_rulebook
=
phi
::
Empty
(
dev_ctx_gpu
,
DenseTensorMeta
(
DataType
::
INT32
,
{
1
},
DataLayout
::
NCHW
));
SparseCooTensor
d_out
=
sparse
::
Conv3d
<
T
>
(
dev_ctx_gpu
,
d_x_tensor
,
d_kernel_tensor
,
...
...
@@ -219,9 +202,6 @@ void TestConv3dBase(const std::vector<int>& indices,
DenseTensor
h_indices_tensor
=
phi
::
Empty
(
dev_ctx_cpu
,
DenseTensorMeta
(
DataType
::
INT32
,
{
4
,
d_out
.
nnz
()},
DataLayout
::
NCHW
));
dev_ctx_cpu
.
Alloc
(
&
h_indices_tensor
,
h_indices_tensor
.
dtype
(),
sizeof
(
int
)
*
h_indices_tensor
.
numel
());
phi
::
Copy
(
dev_ctx_gpu
,
d_out
.
non_zero_indices
(),
phi
::
CPUPlace
(),
...
...
@@ -239,18 +219,34 @@ void TestConv3dBase(const std::vector<int>& indices,
{
d_out
.
nnz
()},
d_out
.
layout
()));
dev_ctx_cpu
.
Alloc
(
&
h_features_tensor
,
h_features_tensor
.
dtype
(),
sizeof
(
T
)
*
h_features_tensor
.
numel
());
phi
::
Copy
(
dev_ctx_gpu
,
d_out
.
non_zero_elements
(),
phi
::
CPUPlace
(),
true
,
&
h_features_tensor
);
for
(
uint64_t
i
=
0
;
i
<
correct_out_features
.
size
();
i
++
)
{
float
tmp
=
std
::
fabs
(
static_cast
<
float
>
(
correct_out_features
[
i
]
-
h_features_tensor
.
data
<
T
>
()[
i
]));
ASSERT_LT
(
tmp
,
diff
);
f_verify
(
h_features_tensor
.
data
<
T
>
(),
correct_out_features
);
if
(
backward
)
{
std
::
vector
<
DenseTensor
>
grads
=
sparse
::
Conv3dGrad
<
T
>
(
dev_ctx_gpu
,
d_x_tensor
,
d_rulebook
,
d_kernel_tensor
,
d_out
,
paddings
,
dilations
,
strides
,
1
);
DenseTensor
h_features_grad
=
phi
::
Empty
(
dev_ctx_cpu
,
DenseTensorMeta
(
grads
[
0
].
dtype
(),
grads
[
0
].
dims
(),
grads
[
0
].
layout
()));
phi
::
Copy
(
dev_ctx_gpu
,
grads
[
0
],
phi
::
CPUPlace
(),
true
,
&
h_features_grad
);
f_verify
(
h_features_grad
.
data
<
T
>
(),
features_grad
);
DenseTensor
h_kernel_grad
=
phi
::
Empty
(
dev_ctx_cpu
,
DenseTensorMeta
(
grads
[
1
].
dtype
(),
grads
[
1
].
dims
(),
grads
[
1
].
layout
()));
phi
::
Copy
(
dev_ctx_gpu
,
grads
[
1
],
phi
::
CPUPlace
(),
true
,
&
h_kernel_grad
);
f_verify
(
h_kernel_grad
.
data
<
T
>
(),
kernel_grad
);
}
#endif
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录