Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0d78e491
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0d78e491
编写于
3月 11, 2022
作者:
Z
zhangkaihuo
提交者:
GitHub
3月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Submanifold convolution (#40363)
submanifold convolution
上级
17d8a5e0
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
521 addition
and
101 deletion
+521
-101
paddle/phi/kernels/sparse/convolution_grad_kernel.h
paddle/phi/kernels/sparse/convolution_grad_kernel.h
+4
-1
paddle/phi/kernels/sparse/convolution_kernel.h
paddle/phi/kernels/sparse/convolution_kernel.h
+12
-2
paddle/phi/kernels/sparse/cpu/convolution.h
paddle/phi/kernels/sparse/cpu/convolution.h
+24
-3
paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc
paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc
+54
-17
paddle/phi/kernels/sparse/cpu/convolution_kernel.cc
paddle/phi/kernels/sparse/cpu/convolution_kernel.cc
+2
-0
paddle/phi/kernels/sparse/cpu/submanifold_convolution_kernel.cu
.../phi/kernels/sparse/cpu/submanifold_convolution_kernel.cu
+30
-0
paddle/phi/kernels/sparse/gpu/convolution.cu.h
paddle/phi/kernels/sparse/gpu/convolution.cu.h
+5
-1
paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu
+73
-42
paddle/phi/kernels/sparse/gpu/convolution_kernel.cu
paddle/phi/kernels/sparse/gpu/convolution_kernel.cu
+202
-24
paddle/phi/tests/api/test_sparse_conv_api.cc
paddle/phi/tests/api/test_sparse_conv_api.cc
+1
-1
paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc
paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc
+110
-6
python/paddle/utils/code_gen/sparse_api.yaml
python/paddle/utils/code_gen/sparse_api.yaml
+1
-1
python/paddle/utils/code_gen/sparse_bw_api.yaml
python/paddle/utils/code_gen/sparse_bw_api.yaml
+3
-3
未找到文件。
paddle/phi/kernels/sparse/convolution_grad_kernel.h
浏览文件 @
0d78e491
...
@@ -32,6 +32,7 @@ void Conv3dGradKernel(const Context& dev_ctx,
...
@@ -32,6 +32,7 @@ void Conv3dGradKernel(const Context& dev_ctx,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
,
DenseTensor
*
x_grad
,
DenseTensor
*
x_grad
,
DenseTensor
*
kernel_grad
);
DenseTensor
*
kernel_grad
);
...
@@ -44,7 +45,8 @@ std::vector<DenseTensor> Conv3dGrad(const Context& dev_ctx,
...
@@ -44,7 +45,8 @@ std::vector<DenseTensor> Conv3dGrad(const Context& dev_ctx,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
)
{
const
int
groups
,
const
bool
subm
)
{
DenseTensor
x_grad
=
DenseTensor
x_grad
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
x
.
dtype
(),
{
1
},
x
.
layout
()));
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
x
.
dtype
(),
{
1
},
x
.
layout
()));
DenseTensor
kernel_grad
=
phi
::
Empty
<
Context
>
(
DenseTensor
kernel_grad
=
phi
::
Empty
<
Context
>
(
...
@@ -59,6 +61,7 @@ std::vector<DenseTensor> Conv3dGrad(const Context& dev_ctx,
...
@@ -59,6 +61,7 @@ std::vector<DenseTensor> Conv3dGrad(const Context& dev_ctx,
dilations
,
dilations
,
strides
,
strides
,
groups
,
groups
,
subm
,
&
x_grad
,
&
x_grad
,
&
kernel_grad
);
&
kernel_grad
);
std
::
vector
<
DenseTensor
>
out
(
2
);
std
::
vector
<
DenseTensor
>
out
(
2
);
...
...
paddle/phi/kernels/sparse/convolution_kernel.h
浏览文件 @
0d78e491
...
@@ -125,6 +125,7 @@ void Conv3dKernel(const Context& dev_ctx,
...
@@ -125,6 +125,7 @@ void Conv3dKernel(const Context& dev_ctx,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
,
SparseCooTensor
*
out
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
);
DenseTensor
*
rulebook
);
...
@@ -136,14 +137,23 @@ SparseCooTensor Conv3d(const Context& dev_ctx,
...
@@ -136,14 +137,23 @@ SparseCooTensor Conv3d(const Context& dev_ctx,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
,
DenseTensor
*
rulebook
)
{
DenseTensor
*
rulebook
)
{
DenseTensor
indices
=
phi
::
Empty
<
Context
>
(
DenseTensor
indices
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
1
},
DataLayout
::
NCHW
));
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
1
},
DataLayout
::
NCHW
));
DenseTensor
values
=
DenseTensor
values
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
x
.
dtype
(),
{
1
},
x
.
layout
()));
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
x
.
dtype
(),
{
1
},
x
.
layout
()));
SparseCooTensor
coo
(
indices
,
values
,
x
.
dims
());
SparseCooTensor
coo
(
indices
,
values
,
x
.
dims
());
Conv3dKernel
<
T
,
Context
>
(
Conv3dKernel
<
T
,
Context
>
(
dev_ctx
,
dev_ctx
,
x
,
kernel
,
paddings
,
dilations
,
strides
,
groups
,
&
coo
,
rulebook
);
x
,
kernel
,
paddings
,
dilations
,
strides
,
groups
,
subm
,
&
coo
,
rulebook
);
return
coo
;
return
coo
;
}
}
...
...
paddle/phi/kernels/sparse/cpu/convolution.h
浏览文件 @
0d78e491
...
@@ -39,6 +39,7 @@ void ProductRuleBook(const Context& dev_ctx,
...
@@ -39,6 +39,7 @@ void ProductRuleBook(const Context& dev_ctx,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
DDim
&
out_dims
,
const
DDim
&
out_dims
,
const
bool
subm
,
DenseTensor
*
rulebook
,
DenseTensor
*
rulebook
,
DenseTensor
*
counter_per_kernel
)
{
DenseTensor
*
counter_per_kernel
)
{
const
auto
&
kernel_dims
=
kernel
.
dims
();
const
auto
&
kernel_dims
=
kernel
.
dims
();
...
@@ -59,11 +60,24 @@ void ProductRuleBook(const Context& dev_ctx,
...
@@ -59,11 +60,24 @@ void ProductRuleBook(const Context& dev_ctx,
const
Dims4D
c_strides
(
1
,
strides
[
2
],
strides
[
1
],
strides
[
0
]);
const
Dims4D
c_strides
(
1
,
strides
[
2
],
strides
[
1
],
strides
[
0
]);
const
Dims4D
c_dilations
(
1
,
dilations
[
2
],
dilations
[
1
],
dilations
[
0
]);
const
Dims4D
c_dilations
(
1
,
dilations
[
2
],
dilations
[
1
],
dilations
[
0
]);
std
::
set
<
int
>
hash_in
;
if
(
subm
)
{
for
(
int
i
=
0
;
i
<
non_zero_num
;
i
++
)
{
int
batch
=
indices_ptr
[
i
];
int
in_z
=
indices_ptr
[
i
+
non_zero_num
];
int
in_y
=
indices_ptr
[
i
+
2
*
non_zero_num
];
int
in_x
=
indices_ptr
[
i
+
3
*
non_zero_num
];
int
index
=
PointToIndex
<
DDim
>
(
batch
,
in_x
,
in_y
,
in_z
,
x_dims
);
hash_in
.
insert
(
index
);
}
}
auto
f_calc_rulebook
=
[
&
](
int
*
rulebook_ptr
)
{
auto
f_calc_rulebook
=
[
&
](
int
*
rulebook_ptr
)
{
int
kernel_index
=
0
,
rulebook_index
=
0
;
int
kernel_index
=
0
,
rulebook_index
=
0
;
for
(
int
kz
=
0
;
kz
<
kernel_dims
[
0
];
kz
++
)
{
for
(
int
kz
=
0
;
kz
<
kernel_dims
[
0
];
kz
++
)
{
for
(
int
ky
=
0
;
ky
<
kernel_dims
[
1
];
ky
++
)
{
for
(
int
ky
=
0
;
ky
<
kernel_dims
[
1
];
ky
++
)
{
for
(
int
kx
=
0
;
kx
<
kernel_dims
[
2
];
kx
++
)
{
for
(
int
kx
=
0
;
kx
<
kernel_dims
[
2
];
kx
++
)
{
++
kernel_index
;
for
(
int64_t
i
=
0
;
i
<
non_zero_num
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
non_zero_num
;
i
++
)
{
int
batch
=
indices_ptr
[
i
];
int
batch
=
indices_ptr
[
i
];
int
in_z
=
indices_ptr
[
i
+
non_zero_num
];
int
in_z
=
indices_ptr
[
i
+
non_zero_num
];
...
@@ -83,11 +97,19 @@ void ProductRuleBook(const Context& dev_ctx,
...
@@ -83,11 +97,19 @@ void ProductRuleBook(const Context& dev_ctx,
kx
,
kx
,
ky
,
ky
,
kz
))
{
kz
))
{
if
(
subm
)
{
int
out_index
=
PointToIndex
<
DDim
>
(
batch
,
out_x
,
out_y
,
out_z
,
out_dims
);
if
(
hash_in
.
find
(
out_index
)
==
hash_in
.
end
())
{
continue
;
}
}
if
(
rulebook_ptr
==
nullptr
)
{
if
(
rulebook_ptr
==
nullptr
)
{
counter_ptr
[
kernel_index
]
+=
1
;
counter_ptr
[
kernel_index
-
1
]
+=
1
;
++
rulebook_len
;
++
rulebook_len
;
}
else
{
}
else
{
rulebook_ptr
[
rulebook_index
]
=
kernel_index
;
rulebook_ptr
[
rulebook_index
]
=
kernel_index
-
1
;
rulebook_ptr
[
rulebook_index
+
rulebook_len
]
=
i
;
// in_i
rulebook_ptr
[
rulebook_index
+
rulebook_len
]
=
i
;
// in_i
rulebook_ptr
[
rulebook_index
+
rulebook_len
*
2
]
=
rulebook_ptr
[
rulebook_index
+
rulebook_len
*
2
]
=
PointToIndex
<
DDim
>
(
PointToIndex
<
DDim
>
(
...
@@ -96,7 +118,6 @@ void ProductRuleBook(const Context& dev_ctx,
...
@@ -96,7 +118,6 @@ void ProductRuleBook(const Context& dev_ctx,
}
}
}
}
}
}
++
kernel_index
;
}
}
}
}
}
}
...
...
paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc
浏览文件 @
0d78e491
...
@@ -38,6 +38,7 @@ void Conv3dGradKernel(const Context& dev_ctx,
...
@@ -38,6 +38,7 @@ void Conv3dGradKernel(const Context& dev_ctx,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
,
DenseTensor
*
x_grad
,
DenseTensor
*
x_grad
,
DenseTensor
*
kernel_grad
)
{
DenseTensor
*
kernel_grad
)
{
const
auto
&
kernel_dims
=
kernel
.
dims
();
const
auto
&
kernel_dims
=
kernel
.
dims
();
...
@@ -70,32 +71,72 @@ void Conv3dGradKernel(const Context& dev_ctx,
...
@@ -70,32 +71,72 @@ void Conv3dGradKernel(const Context& dev_ctx,
T
*
d_kernel_ptr
=
kernel_grad
->
data
<
T
>
();
T
*
d_kernel_ptr
=
kernel_grad
->
data
<
T
>
();
memset
(
d_kernel_ptr
,
0
,
sizeof
(
T
)
*
kernel_grad
->
numel
());
memset
(
d_kernel_ptr
,
0
,
sizeof
(
T
)
*
kernel_grad
->
numel
());
Gather
<
T
>
(
x
.
non_zero_elements
().
data
<
T
>
(),
int
half_kernel_size
=
kernel_size
/
2
;
rulebook_ptr
+
rulebook_len
,
rulebook_len
,
in_channels
,
in_features_ptr
);
Gather
<
T
>
(
out_grad
.
non_zero_elements
().
data
<
T
>
(),
rulebook_ptr
+
rulebook_len
*
2
,
rulebook_len
,
out_channels
,
out_grad_features_ptr
);
auto
blas
=
phi
::
funcs
::
GetBlas
<
Context
,
T
>
(
dev_ctx
);
auto
blas
=
phi
::
funcs
::
GetBlas
<
Context
,
T
>
(
dev_ctx
);
x_grad
->
Resize
(
x
.
non_zero_elements
().
dims
());
dev_ctx
.
Alloc
(
x_grad
,
x_grad
->
dtype
(),
sizeof
(
T
)
*
x_grad
->
numel
());
T
*
x_grad_values_ptr
=
x_grad
->
data
<
T
>
();
memset
(
x_grad_values_ptr
,
0
,
sizeof
(
T
)
*
x_grad
->
numel
());
memset
(
d_x_features_ptr
,
0
,
sizeof
(
T
)
*
d_x_features
.
numel
());
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
),
counter
(
kernel_size
,
0
);
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
),
counter
(
kernel_size
,
0
);
for
(
int
i
=
0
;
i
<
rulebook_len
;
i
++
)
{
for
(
int
i
=
0
;
i
<
rulebook_len
;
i
++
)
{
counter
[
rulebook_ptr
[
i
]]
+=
1
;
counter
[
rulebook_ptr
[
i
]]
+=
1
;
}
}
int
offset
=
0
;
int
offset
=
0
,
max_count
=
0
;
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
offsets
[
i
]
=
offset
;
offsets
[
i
]
=
offset
;
offset
+=
counter
[
i
];
offset
+=
counter
[
i
];
if
(
i
<
half_kernel_size
)
{
max_count
=
std
::
max
(
max_count
,
counter
[
i
]);
}
}
}
offsets
[
kernel_size
]
=
offset
;
offsets
[
kernel_size
]
=
offset
;
if
(
subm
)
{
blas
.
GEMM
(
CblasTrans
,
CblasNoTrans
,
x
.
non_zero_elements
().
dims
()[
1
],
out_grad
.
non_zero_elements
().
dims
()[
1
],
x
.
non_zero_elements
().
dims
()[
0
],
static_cast
<
T
>
(
1
),
x
.
non_zero_elements
().
data
<
T
>
(),
out_grad
.
non_zero_elements
().
data
<
T
>
(),
static_cast
<
T
>
(
0
),
d_kernel_ptr
+
half_kernel_size
*
in_channels
*
out_channels
);
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
T
*
x_grad_ptr
=
x_grad
->
data
<
T
>
();
blas
.
GEMM
(
CblasNoTrans
,
CblasTrans
,
out_grad
.
non_zero_elements
().
dims
()[
0
],
in_channels
,
out_grad
.
non_zero_elements
().
dims
()[
1
],
static_cast
<
T
>
(
1
),
out_grad
.
non_zero_elements
().
data
<
T
>
(),
kernel
.
data
<
T
>
()
+
half_kernel_size
*
in_channels
*
out_channels
,
static_cast
<
T
>
(
0
),
x_grad_ptr
);
if
(
max_count
==
0
)
{
return
;
}
}
Gather
<
T
>
(
x
.
non_zero_elements
().
data
<
T
>
(),
rulebook_ptr
+
rulebook_len
,
rulebook_len
,
in_channels
,
in_features_ptr
);
Gather
<
T
>
(
out_grad
.
non_zero_elements
().
data
<
T
>
(),
rulebook_ptr
+
rulebook_len
*
2
,
rulebook_len
,
out_channels
,
out_grad_features_ptr
);
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
counter
[
i
]
<=
0
)
{
if
(
counter
[
i
]
<=
0
||
(
subm
&&
i
==
half_kernel_size
)
)
{
continue
;
continue
;
}
}
...
@@ -136,10 +177,6 @@ void Conv3dGradKernel(const Context& dev_ctx,
...
@@ -136,10 +177,6 @@ void Conv3dGradKernel(const Context& dev_ctx,
}
}
// 4. scatter
// 4. scatter
x_grad
->
Resize
(
x
.
non_zero_elements
().
dims
());
dev_ctx
.
Alloc
(
x_grad
,
x_grad
->
dtype
(),
sizeof
(
T
)
*
x_grad
->
numel
());
T
*
x_grad_values_ptr
=
x_grad
->
data
<
T
>
();
memset
(
x_grad_values_ptr
,
0
,
sizeof
(
T
)
*
x_grad
->
numel
());
Scatter
<
T
>
(
d_x_features_ptr
,
Scatter
<
T
>
(
d_x_features_ptr
,
rulebook
.
data
<
int
>
()
+
rulebook_len
,
rulebook
.
data
<
int
>
()
+
rulebook_len
,
rulebook_len
,
rulebook_len
,
...
...
paddle/phi/kernels/sparse/cpu/convolution_kernel.cc
浏览文件 @
0d78e491
...
@@ -35,6 +35,7 @@ void Conv3dKernel(const Context& dev_ctx,
...
@@ -35,6 +35,7 @@ void Conv3dKernel(const Context& dev_ctx,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
,
SparseCooTensor
*
out
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
)
{
DenseTensor
*
rulebook
)
{
// update padding and dilation
// update padding and dilation
...
@@ -63,6 +64,7 @@ void Conv3dKernel(const Context& dev_ctx,
...
@@ -63,6 +64,7 @@ void Conv3dKernel(const Context& dev_ctx,
dilations
,
dilations
,
strides
,
strides
,
out_dims
,
out_dims
,
subm
,
rulebook
,
rulebook
,
&
counter_per_kernel
);
&
counter_per_kernel
);
...
...
paddle/phi/kernels/sparse/cpu/submanifold_convolution_kernel.cu
0 → 100644
浏览文件 @
0d78e491
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <set>
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/sparse/submanifold_convolution_kernel.h"
namespace
phi
{
namespace
sparse
{}
// namespace sparse
}
// namespace phi
paddle/phi/kernels/sparse/gpu/convolution.cu.h
浏览文件 @
0d78e491
...
@@ -71,7 +71,8 @@ __global__ void ScatterKernel(const T* input,
...
@@ -71,7 +71,8 @@ __global__ void ScatterKernel(const T* input,
const
int
non_zero_num
,
const
int
non_zero_num
,
const
int
rulebook_len
,
const
int
rulebook_len
,
const
int
channels
,
const
int
channels
,
T
*
out
)
{
T
*
out
,
const
bool
subm
=
false
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
for
(
int
i
=
tid
;
i
<
non_zero_num
*
channels
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
for
(
int
i
=
tid
;
i
<
non_zero_num
*
channels
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
indices_i
=
i
/
channels
;
int
indices_i
=
i
/
channels
;
...
@@ -82,6 +83,9 @@ __global__ void ScatterKernel(const T* input,
...
@@ -82,6 +83,9 @@ __global__ void ScatterKernel(const T* input,
:
unique_value
[
indices_i
+
1
];
:
unique_value
[
indices_i
+
1
];
// max(end-start) = kernel_size
// max(end-start) = kernel_size
T
sum
=
static_cast
<
T
>
(
0
);
T
sum
=
static_cast
<
T
>
(
0
);
if
(
subm
)
{
sum
=
out
[
indices_i
*
channels
+
channels_i
];
}
for
(
int
j
=
start
;
j
<
end
;
j
++
)
{
for
(
int
j
=
start
;
j
<
end
;
j
++
)
{
const
int
out_feature_i
=
out_index
[
j
];
const
int
out_feature_i
=
out_index
[
j
];
sum
+=
input
[
out_feature_i
*
channels
+
channels_i
];
sum
+=
input
[
out_feature_i
*
channels
+
channels_i
];
...
...
paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu
浏览文件 @
0d78e491
...
@@ -43,6 +43,7 @@ void Conv3dGradKernel(const Context& dev_ctx,
...
@@ -43,6 +43,7 @@ void Conv3dGradKernel(const Context& dev_ctx,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
,
DenseTensor
*
x_grad
,
DenseTensor
*
x_grad
,
DenseTensor
*
kernel_grad
)
{
DenseTensor
*
kernel_grad
)
{
const
auto
&
kernel_dims
=
kernel
.
dims
();
const
auto
&
kernel_dims
=
kernel
.
dims
();
...
@@ -69,37 +70,18 @@ void Conv3dGradKernel(const Context& dev_ctx,
...
@@ -69,37 +70,18 @@ void Conv3dGradKernel(const Context& dev_ctx,
T
*
in_features_ptr
=
in_features
.
data
<
T
>
();
T
*
in_features_ptr
=
in_features
.
data
<
T
>
();
T
*
d_x_features_ptr
=
d_x_features
.
data
<
T
>
();
T
*
d_x_features_ptr
=
d_x_features
.
data
<
T
>
();
T
*
out_grad_features_ptr
=
out_grad_features
.
data
<
T
>
();
T
*
out_grad_features_ptr
=
out_grad_features
.
data
<
T
>
();
kernel_grad
->
Resize
(
kernel_dims
);
kernel_grad
->
ResizeAndAllocate
(
kernel_dims
);
dev_ctx
.
Alloc
(
kernel_grad
,
kernel_grad
->
dtype
(),
kernel_grad
->
numel
()
*
sizeof
(
T
));
T
*
d_kernel_ptr
=
kernel_grad
->
data
<
T
>
();
T
*
d_kernel_ptr
=
kernel_grad
->
data
<
T
>
();
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_zero
;
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
kernel_grad
,
static_cast
<
T
>
(
0.0
f
));
set_zero
(
dev_ctx
,
kernel_grad
,
static_cast
<
T
>
(
0.0
f
));
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
int
half_kernel_size
=
kernel_size
/
2
;
dev_ctx
,
rulebook_len
*
in_channels
,
1
);
GatherKernel
<
T
,
int
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
x
.
non_zero_elements
().
data
<
T
>
(),
rulebook_ptr
+
rulebook_len
,
in_features_ptr
,
rulebook_len
,
in_channels
);
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
*
out_channels
,
1
);
GatherKernel
<
T
,
int
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
out_grad
.
non_zero_elements
().
data
<
T
>
(),
rulebook_ptr
+
rulebook_len
*
2
,
out_grad_features_ptr
,
rulebook_len
,
out_channels
);
auto
blas
=
phi
::
funcs
::
GetBlas
<
Context
,
T
>
(
dev_ctx
);
auto
blas
=
phi
::
funcs
::
GetBlas
<
Context
,
T
>
(
dev_ctx
);
x_grad
->
ResizeAndAllocate
(
x
.
non_zero_elements
().
dims
());
T
*
x_grad_values_ptr
=
x_grad
->
data
<
T
>
();
set_zero
(
dev_ctx
,
x_grad
,
static_cast
<
T
>
(
0.0
f
));
set_zero
(
dev_ctx
,
&
d_x_features
,
static_cast
<
T
>
(
0.0
f
));
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
),
counter
(
kernel_size
,
0
),
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
),
counter
(
kernel_size
,
0
),
h_counter
(
rulebook_len
,
0
);
h_counter
(
rulebook_len
,
0
);
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
h_counter
[
0
],
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
h_counter
[
0
],
...
@@ -117,16 +99,72 @@ void Conv3dGradKernel(const Context& dev_ctx,
...
@@ -117,16 +99,72 @@ void Conv3dGradKernel(const Context& dev_ctx,
for
(
int
i
=
0
;
i
<
rulebook_len
;
i
++
)
{
for
(
int
i
=
0
;
i
<
rulebook_len
;
i
++
)
{
counter
[
h_counter
[
i
]]
+=
1
;
counter
[
h_counter
[
i
]]
+=
1
;
}
}
int
offset
=
0
;
int
offset
=
0
,
max_count
=
0
;
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
offsets
[
i
]
=
offset
;
offsets
[
i
]
=
offset
;
offset
+=
counter
[
i
];
offset
+=
counter
[
i
];
if
(
i
<
half_kernel_size
)
{
max_count
=
std
::
max
(
max_count
,
counter
[
i
]);
}
}
}
offsets
[
kernel_size
]
=
offset
;
offsets
[
kernel_size
]
=
offset
;
if
(
subm
)
{
blas
.
GEMM
(
CblasTrans
,
CblasNoTrans
,
x
.
non_zero_elements
().
dims
()[
1
],
out_grad
.
non_zero_elements
().
dims
()[
1
],
x
.
non_zero_elements
().
dims
()[
0
],
static_cast
<
T
>
(
1
),
x
.
non_zero_elements
().
data
<
T
>
(),
out_grad
.
non_zero_elements
().
data
<
T
>
(),
static_cast
<
T
>
(
0
),
d_kernel_ptr
+
half_kernel_size
*
in_channels
*
out_channels
);
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
T
*
x_grad_ptr
=
x_grad
->
data
<
T
>
();
blas
.
GEMM
(
CblasNoTrans
,
CblasTrans
,
out_grad
.
non_zero_elements
().
dims
()[
0
],
in_channels
,
out_grad
.
non_zero_elements
().
dims
()[
1
],
static_cast
<
T
>
(
1
),
out_grad
.
non_zero_elements
().
data
<
T
>
(),
kernel
.
data
<
T
>
()
+
half_kernel_size
*
in_channels
*
out_channels
,
static_cast
<
T
>
(
0
),
x_grad_ptr
);
if
(
max_count
==
0
)
{
return
;
}
}
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
*
in_channels
,
1
);
GatherKernel
<
T
,
int
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
x
.
non_zero_elements
().
data
<
T
>
(),
rulebook_ptr
+
rulebook_len
,
in_features_ptr
,
rulebook_len
,
in_channels
);
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
*
out_channels
,
1
);
GatherKernel
<
T
,
int
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
out_grad
.
non_zero_elements
().
data
<
T
>
(),
rulebook_ptr
+
rulebook_len
*
2
,
out_grad_features_ptr
,
rulebook_len
,
out_channels
);
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
counter
[
i
]
<=
0
)
{
if
(
counter
[
i
]
<=
0
||
(
subm
&&
i
==
half_kernel_size
)
)
{
continue
;
continue
;
}
}
...
@@ -167,19 +205,11 @@ void Conv3dGradKernel(const Context& dev_ctx,
...
@@ -167,19 +205,11 @@ void Conv3dGradKernel(const Context& dev_ctx,
}
}
// 4. scatter
// 4. scatter
x_grad
->
Resize
(
x
.
non_zero_elements
().
dims
());
x_grad
->
ResizeAndAllocate
(
x
.
non_zero_elements
().
dims
());
dev_ctx
.
Alloc
(
x_grad
,
x_grad
->
dtype
(),
sizeof
(
T
)
*
x_grad
->
numel
());
DenseTensorMeta
index_meta
(
DataType
::
INT32
,
{
rulebook_len
},
DataLayout
::
NCHW
);
T
*
x_grad_values_ptr
=
x_grad
->
data
<
T
>
();
DenseTensor
out_index
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
index_meta
));
DenseTensor
unique_key
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
index_meta
));
DenseTensor
out_index
=
phi
::
Empty
(
DenseTensor
unique_value
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
index_meta
));
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
rulebook_len
},
DataLayout
::
NCHW
));
DenseTensor
unique_key
=
phi
::
Empty
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
rulebook_len
},
DataLayout
::
NCHW
));
DenseTensor
unique_value
=
phi
::
Empty
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
rulebook_len
},
DataLayout
::
NCHW
));
SortedAndUniqueIndex
(
dev_ctx
,
SortedAndUniqueIndex
(
dev_ctx
,
rulebook_ptr
+
rulebook_len
,
rulebook_ptr
+
rulebook_len
,
...
@@ -200,7 +230,8 @@ void Conv3dGradKernel(const Context& dev_ctx,
...
@@ -200,7 +230,8 @@ void Conv3dGradKernel(const Context& dev_ctx,
x
.
nnz
(),
x
.
nnz
(),
rulebook_len
,
rulebook_len
,
in_channels
,
in_channels
,
x_grad_values_ptr
);
x_grad_values_ptr
,
subm
);
}
}
}
// namespace sparse
}
// namespace sparse
...
...
paddle/phi/kernels/sparse/gpu/convolution_kernel.cu
浏览文件 @
0d78e491
...
@@ -24,6 +24,7 @@ limitations under the License. */
...
@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/primitive/compute_primitives.h"
#include "paddle/phi/kernels/primitive/compute_primitives.h"
#include "paddle/phi/kernels/sparse/convolution_kernel.h"
#include "paddle/phi/kernels/sparse/convolution_kernel.h"
...
@@ -32,6 +33,34 @@ limitations under the License. */
...
@@ -32,6 +33,34 @@ limitations under the License. */
namespace
phi
{
namespace
phi
{
namespace
sparse
{
namespace
sparse
{
__global__
void
SetFlagAndUpdateCounterKernel
(
const
int
*
indexs
,
const
int
n
,
const
int
rulebook_len
,
const
int
kernel_size
,
int
*
rulebook_ptr
,
int
*
counter_ptr
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
extern
__shared__
int
cache_count
[];
// kernel_size
for
(
int
i
=
threadIdx
.
x
;
i
<
kernel_size
;
i
+=
blockDim
.
x
)
{
cache_count
[
i
]
=
0
;
}
__syncthreads
();
for
(
int
i
=
tid
;
i
<
n
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
index
=
indexs
[
i
];
int
kernel_index
=
rulebook_ptr
[
index
];
rulebook_ptr
[
index
+
rulebook_len
]
=
-
1
;
rulebook_ptr
[
index
+
2
*
rulebook_len
]
=
-
1
;
rulebook_ptr
[
index
]
=
-
1
;
atomicAdd
(
&
cache_count
[
kernel_index
],
1
);
}
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
kernel_size
;
i
+=
blockDim
.
x
)
{
atomicSub
(
&
counter_ptr
[
i
],
cache_count
[
i
]);
}
}
/**
/**
* @brief: update the out index and indices
* @brief: update the out index and indices
* unique_keys: save the index of the output feature list
* unique_keys: save the index of the output feature list
...
@@ -95,8 +124,10 @@ __global__ void ProductRuleBookKernel(const int* x_indices,
...
@@ -95,8 +124,10 @@ __global__ void ProductRuleBookKernel(const int* x_indices,
const
Dims4D
paddings
,
const
Dims4D
paddings
,
const
Dims4D
dilations
,
const
Dims4D
dilations
,
const
Dims4D
strides
,
const
Dims4D
strides
,
const
bool
subm
,
int
*
rulebook
,
int
*
rulebook
,
int
*
counter
)
{
int
*
counter
,
int
*
in_indexs
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
extern
__shared__
int
counter_buf
[];
// kernel_size
extern
__shared__
int
counter_buf
[];
// kernel_size
const
int
kernel_size
=
kernel_dims
[
3
]
*
kernel_dims
[
2
]
*
kernel_dims
[
1
];
const
int
kernel_size
=
kernel_dims
[
3
]
*
kernel_dims
[
2
]
*
kernel_dims
[
1
];
...
@@ -108,13 +139,16 @@ __global__ void ProductRuleBookKernel(const int* x_indices,
...
@@ -108,13 +139,16 @@ __global__ void ProductRuleBookKernel(const int* x_indices,
for
(
int
i
=
tid
;
i
<
non_zero_num
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
for
(
int
i
=
tid
;
i
<
non_zero_num
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
kernel_index
=
0
;
int
kernel_index
=
0
;
for
(
int
kz
=
0
;
kz
<
kernel_dims
[
1
];
kz
++
)
{
for
(
int
ky
=
0
;
ky
<
kernel_dims
[
2
];
ky
++
)
{
for
(
int
kx
=
0
;
kx
<
kernel_dims
[
3
];
kx
++
)
{
int
batch
=
x_indices
[
i
];
int
batch
=
x_indices
[
i
];
int
in_z
=
x_indices
[
i
+
non_zero_num
];
int
in_z
=
x_indices
[
i
+
non_zero_num
];
int
in_y
=
x_indices
[
i
+
2
*
non_zero_num
];
int
in_y
=
x_indices
[
i
+
2
*
non_zero_num
];
int
in_x
=
x_indices
[
i
+
3
*
non_zero_num
];
int
in_x
=
x_indices
[
i
+
3
*
non_zero_num
];
if
(
subm
)
{
in_indexs
[
i
]
=
PointToIndex
(
batch
,
in_x
,
in_y
,
in_z
,
x_dims
);
}
for
(
int
kz
=
0
;
kz
<
kernel_dims
[
1
];
kz
++
)
{
for
(
int
ky
=
0
;
ky
<
kernel_dims
[
2
];
ky
++
)
{
for
(
int
kx
=
0
;
kx
<
kernel_dims
[
3
];
kx
++
)
{
int
in_i
=
-
1
,
out_index
=
-
1
,
kernel_i
=
-
1
;
int
in_i
=
-
1
,
out_index
=
-
1
,
kernel_i
=
-
1
;
if
(
Check
(
x_dims
,
if
(
Check
(
x_dims
,
kernel_dims
,
kernel_dims
,
...
@@ -182,6 +216,7 @@ int ProductRuleBook(const Context& dev_ctx,
...
@@ -182,6 +216,7 @@ int ProductRuleBook(const Context& dev_ctx,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
DDim
&
out_dims
,
const
DDim
&
out_dims
,
const
bool
subm
,
DenseTensor
*
rulebook
,
DenseTensor
*
rulebook
,
DenseTensor
*
counter_per_kernel
,
DenseTensor
*
counter_per_kernel
,
DenseTensor
*
offsets_per_kernel
,
DenseTensor
*
offsets_per_kernel
,
...
@@ -195,13 +230,14 @@ int ProductRuleBook(const Context& dev_ctx,
...
@@ -195,13 +230,14 @@ int ProductRuleBook(const Context& dev_ctx,
const
int64_t
non_zero_num
=
x
.
nnz
();
const
int64_t
non_zero_num
=
x
.
nnz
();
const
auto
&
non_zero_indices
=
x
.
non_zero_indices
();
const
auto
&
non_zero_indices
=
x
.
non_zero_indices
();
const
int
*
indices_ptr
=
non_zero_indices
.
data
<
int
>
();
const
int
*
indices_ptr
=
non_zero_indices
.
data
<
int
>
();
DenseTensor
in_indexs
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
x
.
nnz
()},
DataLayout
::
NCHW
));
int
*
counter_ptr
=
counter_per_kernel
->
data
<
int
>
();
int
*
counter_ptr
=
counter_per_kernel
->
data
<
int
>
();
int
*
offsets_ptr
=
offsets_per_kernel
->
data
<
int
>
();
int
*
offsets_ptr
=
offsets_per_kernel
->
data
<
int
>
();
int
kernel_size
=
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
int
kernel_size
=
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
const
int
rulebook_rows
=
3
;
const
int
rulebook_rows
=
3
;
const
int
rulebook_cols
=
kernel_size
*
non_zero_num
;
const
int
rulebook_cols
=
kernel_size
*
non_zero_num
;
rulebook
->
ResizeAndAllocate
({
rulebook_rows
,
rulebook_cols
});
rulebook
->
ResizeAndAllocate
({
rulebook_rows
,
rulebook_cols
});
dev_ctx
.
Alloc
(
rulebook
,
rulebook
->
dtype
(),
sizeof
(
int
)
*
rulebook
->
numel
());
int
*
rulebook_ptr
=
rulebook
->
data
<
int
>
();
int
*
rulebook_ptr
=
rulebook
->
data
<
int
>
();
const
auto
x_dims
=
x
.
dims
();
const
auto
x_dims
=
x
.
dims
();
...
@@ -229,8 +265,10 @@ int ProductRuleBook(const Context& dev_ctx,
...
@@ -229,8 +265,10 @@ int ProductRuleBook(const Context& dev_ctx,
d_paddings
,
d_paddings
,
d_dilations
,
d_dilations
,
d_strides
,
d_strides
,
subm
,
rulebook_ptr
,
rulebook_ptr
,
counter_ptr
);
counter_ptr
,
in_indexs
.
data
<
int
>
());
// 2. remove -1
// 2. remove -1
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
...
@@ -242,6 +280,144 @@ int ProductRuleBook(const Context& dev_ctx,
...
@@ -242,6 +280,144 @@ int ProductRuleBook(const Context& dev_ctx,
rulebook_ptr
+
rulebook_rows
*
rulebook_cols
,
rulebook_ptr
+
rulebook_rows
*
rulebook_cols
,
-
1
);
-
1
);
DistanceKernel
<<<
1
,
1
,
0
,
dev_ctx
.
stream
()
>>>
(
rulebook_ptr
,
last
,
rulebook_ptr
+
3
*
kernel_size
*
non_zero_num
-
1
);
int
rulebook_len
=
0
;
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
rulebook_len
,
rulebook_ptr
+
3
*
kernel_size
*
non_zero_num
-
1
,
sizeof
(
int
),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost
,
#else
cudaMemcpyDeviceToHost
,
#endif
dev_ctx
.
stream
());
rulebook_len
/=
3
;
dev_ctx
.
Wait
();
if
(
subm
)
{
// At present, hashtable is not used to map the input and output indexes.
// At present, the intermediate output index is generated by normal
// convolution,
// and then the intermediate output index is subtracted from the input index
// to obain the rulebook.
// get difference
int32_t
*
A_key_ptr
=
rulebook_ptr
+
2
*
rulebook_len
;
int32_t
*
B_key_ptr
=
in_indexs
.
data
<
int
>
();
DenseTensor
A_val
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
rulebook_len
},
DataLayout
::
NCHW
));
DenseTensor
B_val
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
x
.
nnz
()},
DataLayout
::
NCHW
));
phi
::
IndexKernel
<
int
,
kps
::
IdentityFunctor
<
int
>>
(
dev_ctx
,
&
A_val
,
kps
::
IdentityFunctor
<
int
>
());
phi
::
IndexKernel
<
int
,
kps
::
IdentityFunctor
<
int
>>
(
dev_ctx
,
&
B_val
,
kps
::
IdentityFunctor
<
int
>
());
DenseTensor
key_result
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
rulebook_len
+
1
},
DataLayout
::
NCHW
));
DenseTensor
val_result
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
rulebook_len
},
DataLayout
::
NCHW
));
#ifdef PADDLE_WITH_HIP
thrust
::
exclusive_scan
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
thrust
::
exclusive_scan
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
counter_ptr
,
counter_ptr
+
kernel_size
,
offsets_ptr
);
std
::
vector
<
int
>
offsets
(
kernel_size
,
0
);
// TODO(zhangkaihuo): used unified memcpy interface
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
offsets
.
data
(),
offsets_ptr
,
kernel_size
*
sizeof
(
int
),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost
,
#else
cudaMemcpyDeviceToHost
,
#endif
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
thrust
::
pair
<
int
*
,
int
*>
end
;
// Because set_diff does not support duplicate data, set_diff is performed
// separately for each segment of data.
// TODO(zhangkaihuo): Using hashtable here may get better performance,
// further tests ared needed.
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
int
start
=
offsets
[
i
];
int
stop
=
i
==
kernel_size
-
1
?
rulebook_len
:
offsets
[
i
+
1
];
int
*
key_result_start
=
(
i
==
0
?
key_result
.
data
<
int
>
()
:
end
.
first
);
int
*
val_result_start
=
i
==
0
?
val_result
.
data
<
int
>
()
:
end
.
second
;
end
=
#ifdef PADDLE_WITH_HIP
thrust
::
set_difference_by_key
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
thrust
::
set_difference_by_key
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
A_key_ptr
+
start
,
A_key_ptr
+
stop
,
B_key_ptr
,
B_key_ptr
+
x
.
nnz
(),
A_val
.
data
<
int
>
()
+
start
,
B_val
.
data
<
int
>
(),
key_result_start
,
val_result_start
);
}
DistanceKernel
<<<
1
,
1
,
0
,
dev_ctx
.
stream
()
>>>
(
key_result
.
data
<
int
>
(),
end
.
first
,
key_result
.
data
<
int
>
()
+
rulebook_len
);
int
len
=
0
;
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
len
,
key_result
.
data
<
int
>
()
+
rulebook_len
,
sizeof
(
int
),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost
,
#else
cudaMemcpyDeviceToHost
,
#endif
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
// set the diff value = -1, and update counter
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
len
,
1
);
SetFlagAndUpdateCounterKernel
<<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
,
kernel_size
*
sizeof
(
int
),
dev_ctx
.
stream
()
>>>
(
val_result
.
data
<
int
>
(),
len
,
rulebook_len
,
kernel_size
,
rulebook_ptr
,
counter_ptr
);
// remove -1
#ifdef PADDLE_WITH_HIP
int
*
last
=
thrust
::
remove
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
int
*
last
=
thrust
::
remove
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
rulebook_ptr
,
rulebook_ptr
+
3
*
rulebook_len
,
-
1
);
DistanceKernel
<<<
1
,
1
,
0
,
dev_ctx
.
stream
()
>>>
(
rulebook_ptr
,
last
,
key_result
.
data
<
int
>
()
+
rulebook_len
);
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
rulebook_len
,
key_result
.
data
<
int
>
()
+
rulebook_len
,
sizeof
(
int
),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost
,
#else
cudaMemcpyDeviceToHost
,
#endif
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
rulebook_len
/=
3
;
}
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
thrust
::
exclusive_scan
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
thrust
::
exclusive_scan
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
#else
...
@@ -274,23 +450,14 @@ int ProductRuleBook(const Context& dev_ctx,
...
@@ -274,23 +450,14 @@ int ProductRuleBook(const Context& dev_ctx,
cudaMemcpyDeviceToHost
,
cudaMemcpyDeviceToHost
,
dev_ctx
.
stream
());
dev_ctx
.
stream
());
#endif
#endif
dev_ctx
.
Wait
();
int
rulebook_len
=
(
*
h_counter
)[
kernel_size
-
1
]
+
(
*
h_offsets
)[
kernel_size
-
1
];
rulebook
->
Resize
({
rulebook_rows
,
rulebook_len
});
rulebook
->
Resize
({
rulebook_rows
,
rulebook_len
});
// 3. sorted or merge the out index
// 3. sorted or merge the out index
out_index
->
ResizeAndAllocate
({
rulebook_len
});
out_index
->
ResizeAndAllocate
({
rulebook_len
});
unique_value
->
ResizeAndAllocate
({
rulebook_len
});
unique_value
->
ResizeAndAllocate
({
rulebook_len
});
unique_key
->
ResizeAndAllocate
({
rulebook_len
});
unique_key
->
ResizeAndAllocate
({
rulebook_len
});
dev_ctx
.
Alloc
(
out_index
,
out_index
->
dtype
(),
sizeof
(
int
)
*
out_index
->
numel
());
int
*
out_index_ptr
=
out_index
->
data
<
int
>
();
int
*
out_index_ptr
=
out_index
->
data
<
int
>
();
dev_ctx
.
Alloc
(
unique_value
,
unique_value
->
dtype
(),
sizeof
(
int
)
*
unique_value
->
numel
());
int
*
unique_value_ptr
=
unique_value
->
data
<
int
>
();
int
*
unique_value_ptr
=
unique_value
->
data
<
int
>
();
dev_ctx
.
Alloc
(
unique_key
,
unique_key
->
dtype
(),
sizeof
(
int
)
*
unique_key
->
numel
());
int
*
unique_key_ptr
=
unique_key
->
data
<
int
>
();
int
*
unique_key_ptr
=
unique_key
->
data
<
int
>
();
int
*
new_end
=
SortedAndUniqueIndex
(
dev_ctx
,
int
*
new_end
=
SortedAndUniqueIndex
(
dev_ctx
,
...
@@ -364,6 +531,7 @@ void Conv3dKernel(const Context& dev_ctx,
...
@@ -364,6 +531,7 @@ void Conv3dKernel(const Context& dev_ctx,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
,
SparseCooTensor
*
out
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
)
{
DenseTensor
*
rulebook
)
{
// update padding and dilation
// update padding and dilation
...
@@ -389,20 +557,28 @@ void Conv3dKernel(const Context& dev_ctx,
...
@@ -389,20 +557,28 @@ void Conv3dKernel(const Context& dev_ctx,
DataType
::
INT32
,
{
kernel_size
},
DataLayout
::
NCHW
);
DataType
::
INT32
,
{
kernel_size
},
DataLayout
::
NCHW
);
DenseTensor
counter_per_kernel
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
counter_meta
));
DenseTensor
counter_per_kernel
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
counter_meta
));
DenseTensor
offsets_per_kernel
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
offsets_meta
));
DenseTensor
offsets_per_kernel
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
offsets_meta
));
DenseTensor
out_index
=
phi
::
Empty
(
DenseTensorMeta
index_meta
(
DataType
::
INT32
,
{
1
},
DataLayout
::
NCHW
);
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
1
},
DataLayout
::
NCHW
));
DenseTensor
out_index
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
index_meta
));
DenseTensor
unique_key
=
phi
::
Empty
(
DenseTensor
unique_key
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
index_meta
));
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
1
},
DataLayout
::
NCHW
));
DenseTensor
unique_value
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
index_meta
));
DenseTensor
unique_value
=
phi
::
Empty
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
1
},
DataLayout
::
NCHW
));
std
::
vector
<
int
>
subm_paddings
(
paddings
),
subm_strides
(
strides
);
if
(
subm
)
{
auto
kernel_dims
=
kernel
.
dims
();
for
(
int
i
=
0
;
i
<
paddings
.
size
();
i
++
)
{
subm_paddings
[
i
]
=
kernel_dims
[
i
]
/
2
;
subm_strides
[
i
]
=
1
;
}
}
int
n
=
ProductRuleBook
<
T
,
Context
>
(
dev_ctx
,
int
n
=
ProductRuleBook
<
T
,
Context
>
(
dev_ctx
,
x
,
x
,
kernel
,
kernel
,
paddings
,
subm_
paddings
,
dilations
,
dilations
,
strides
,
s
ubm_s
trides
,
out_dims
,
out_dims
,
subm
,
rulebook
,
rulebook
,
&
counter_per_kernel
,
&
counter_per_kernel
,
&
offsets_per_kernel
,
&
offsets_per_kernel
,
...
@@ -428,6 +604,8 @@ void Conv3dKernel(const Context& dev_ctx,
...
@@ -428,6 +604,8 @@ void Conv3dKernel(const Context& dev_ctx,
phi
::
Empty
(
dev_ctx
,
std
::
move
(
out_features_meta
));
phi
::
Empty
(
dev_ctx
,
std
::
move
(
out_features_meta
));
T
*
in_features_ptr
=
in_features
.
data
<
T
>
();
T
*
in_features_ptr
=
in_features
.
data
<
T
>
();
T
*
out_features_ptr
=
out_features
.
data
<
T
>
();
T
*
out_features_ptr
=
out_features
.
data
<
T
>
();
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
&
out_features
,
static_cast
<
T
>
(
0.0
f
));
auto
config
=
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
n
*
in_channels
,
1
);
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
n
*
in_channels
,
1
);
...
...
paddle/phi/tests/api/test_sparse_conv_api.cc
浏览文件 @
0d78e491
...
@@ -78,7 +78,7 @@ void TestConv3dBase(const std::vector<int>& indices,
...
@@ -78,7 +78,7 @@ void TestConv3dBase(const std::vector<int>& indices,
if
(
!
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
)
{
if
(
!
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
)
{
auto
outs
=
paddle
::
experimental
::
sparse
::
conv3d
(
auto
outs
=
paddle
::
experimental
::
sparse
::
conv3d
(
x
,
weight
,
paddings
,
dilations
,
strides
,
1
);
x
,
weight
,
paddings
,
dilations
,
strides
,
1
,
false
);
auto
out
=
std
::
dynamic_pointer_cast
<
phi
::
SparseCooTensor
>
(
auto
out
=
std
::
dynamic_pointer_cast
<
phi
::
SparseCooTensor
>
(
std
::
get
<
0
>
(
outs
).
impl
());
std
::
get
<
0
>
(
outs
).
impl
());
...
...
paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc
浏览文件 @
0d78e491
...
@@ -64,7 +64,8 @@ void TestConv3dBase(const std::vector<int>& indices,
...
@@ -64,7 +64,8 @@ void TestConv3dBase(const std::vector<int>& indices,
const
float
diff
=
1e-3
,
const
float
diff
=
1e-3
,
const
bool
backward
=
false
,
const
bool
backward
=
false
,
const
std
::
vector
<
T
>
features_grad
=
{},
const
std
::
vector
<
T
>
features_grad
=
{},
const
std
::
vector
<
T
>
kernel_grad
=
{})
{
const
std
::
vector
<
T
>
kernel_grad
=
{},
const
bool
subm
=
false
)
{
phi
::
CPUContext
dev_ctx_cpu
;
phi
::
CPUContext
dev_ctx_cpu
;
dev_ctx_cpu
.
SetAllocator
(
dev_ctx_cpu
.
SetAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
...
@@ -114,6 +115,7 @@ void TestConv3dBase(const std::vector<int>& indices,
...
@@ -114,6 +115,7 @@ void TestConv3dBase(const std::vector<int>& indices,
dilations
,
dilations
,
strides
,
strides
,
1
,
1
,
subm
,
&
rulebook
);
&
rulebook
);
ASSERT_EQ
(
correct_out_dims
.
size
(),
out
.
dims
().
size
());
ASSERT_EQ
(
correct_out_dims
.
size
(),
out
.
dims
().
size
());
...
@@ -138,7 +140,8 @@ void TestConv3dBase(const std::vector<int>& indices,
...
@@ -138,7 +140,8 @@ void TestConv3dBase(const std::vector<int>& indices,
paddings
,
paddings
,
dilations
,
dilations
,
strides
,
strides
,
1
);
1
,
subm
);
f_verify
(
grads
[
0
].
data
<
T
>
(),
features_grad
);
f_verify
(
grads
[
0
].
data
<
T
>
(),
features_grad
);
f_verify
(
grads
[
1
].
data
<
T
>
(),
kernel_grad
);
f_verify
(
grads
[
1
].
data
<
T
>
(),
kernel_grad
);
}
}
...
@@ -191,6 +194,7 @@ void TestConv3dBase(const std::vector<int>& indices,
...
@@ -191,6 +194,7 @@ void TestConv3dBase(const std::vector<int>& indices,
dilations
,
dilations
,
strides
,
strides
,
1
,
1
,
subm
,
&
d_rulebook
);
&
d_rulebook
);
ASSERT_EQ
(
correct_out_dims
.
size
(),
d_out
.
dims
().
size
());
ASSERT_EQ
(
correct_out_dims
.
size
(),
d_out
.
dims
().
size
());
...
@@ -235,7 +239,8 @@ void TestConv3dBase(const std::vector<int>& indices,
...
@@ -235,7 +239,8 @@ void TestConv3dBase(const std::vector<int>& indices,
paddings
,
paddings
,
dilations
,
dilations
,
strides
,
strides
,
1
);
1
,
subm
);
DenseTensor
h_features_grad
=
phi
::
Empty
(
DenseTensor
h_features_grad
=
phi
::
Empty
(
dev_ctx_cpu
,
dev_ctx_cpu
,
DenseTensorMeta
(
grads
[
0
].
dtype
(),
grads
[
0
].
dims
(),
grads
[
0
].
layout
()));
DenseTensorMeta
(
grads
[
0
].
dtype
(),
grads
[
0
].
dims
(),
grads
[
0
].
layout
()));
...
@@ -266,7 +271,8 @@ void TestConv3d(const std::vector<int>& indices,
...
@@ -266,7 +271,8 @@ void TestConv3d(const std::vector<int>& indices,
const
float
diff
=
1e-3
,
const
float
diff
=
1e-3
,
const
bool
backward
=
false
,
const
bool
backward
=
false
,
const
std
::
vector
<
float
>
features_grad
=
{},
const
std
::
vector
<
float
>
features_grad
=
{},
const
std
::
vector
<
float
>
kernel_grad
=
{})
{
const
std
::
vector
<
float
>
kernel_grad
=
{},
const
bool
subm
=
false
)
{
// test float
// test float
TestConv3dBase
<
float
>
(
indices
,
TestConv3dBase
<
float
>
(
indices
,
features
,
features
,
...
@@ -283,7 +289,8 @@ void TestConv3d(const std::vector<int>& indices,
...
@@ -283,7 +289,8 @@ void TestConv3d(const std::vector<int>& indices,
diff
,
diff
,
backward
,
backward
,
features_grad
,
features_grad
,
kernel_grad
);
kernel_grad
,
subm
);
// test double
// test double
TestConv3dBase
<
double
>
(
indices
,
TestConv3dBase
<
double
>
(
indices
,
cast
<
float
,
double
>
(
features
),
cast
<
float
,
double
>
(
features
),
...
@@ -300,7 +307,8 @@ void TestConv3d(const std::vector<int>& indices,
...
@@ -300,7 +307,8 @@ void TestConv3d(const std::vector<int>& indices,
diff
,
diff
,
backward
,
backward
,
cast
<
float
,
double
>
(
features_grad
),
cast
<
float
,
double
>
(
features_grad
),
cast
<
float
,
double
>
(
kernel_grad
));
cast
<
float
,
double
>
(
kernel_grad
),
subm
);
}
}
TEST
(
DEV_API
,
sparse_conv3d
)
{
TEST
(
DEV_API
,
sparse_conv3d
)
{
...
@@ -661,5 +669,101 @@ TEST(DEV_API, sparse_conv3d_backward) {
...
@@ -661,5 +669,101 @@ TEST(DEV_API, sparse_conv3d_backward) {
kernel_grad
);
kernel_grad
);
}
}
TEST
(
DEV_API
,
sparse_conv2d_subm
)
{
const
int
in_channels
=
1
;
const
int
out_channels
=
1
;
DDim
x_dims
=
{
1
,
1
,
4
,
5
,
in_channels
};
DDim
kernel_dims
=
{
1
,
3
,
3
,
in_channels
,
out_channels
};
DDim
out_dims
=
{
1
,
1
,
4
,
5
,
out_channels
};
std
::
vector
<
int
>
paddings
=
{
0
,
1
,
1
};
std
::
vector
<
int
>
strides
=
{
1
,
1
,
1
};
std
::
vector
<
int
>
dilations
=
{
1
,
1
,
1
};
const
int
non_zero_num
=
4
;
std
::
vector
<
int
>
indices_flatten
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
3
,
3
,
3
,
2
,
2
,
3
};
std
::
vector
<
float
>
features
=
{
0.8854
,
0.6505
,
-
0.1999
,
0.3583
};
// 3*3*3=27
std
::
vector
<
float
>
kernel
=
{
0.9364
,
0.9460
,
0.6564
,
0.7999
,
0.2013
,
0.3812
,
0.5474
,
0.1016
,
0.3368
};
std
::
vector
<
int
>
out_indices_flatten
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
3
,
3
,
3
,
2
,
2
,
3
};
std
::
vector
<
float
>
out_features
=
{
0.1782
,
0.2313
,
0.7117
,
0.5214
};
std
::
vector
<
float
>
features_grad
=
{
0.0359
,
1.2080
,
0.5838
,
0.4541
};
std
::
vector
<
float
>
kernel_grad
=
{
0.3391
,
0.4630
,
0.0000
,
-
0.1042
,
0.3528
,
0.2550
,
0.0000
,
-
0.0462
,
0.0829
};
TestConv3d
(
indices_flatten
,
features
,
x_dims
,
kernel
,
kernel_dims
,
out_indices_flatten
,
out_features
,
out_dims
,
non_zero_num
,
paddings
,
strides
,
dilations
,
1e-3
,
true
,
features_grad
,
kernel_grad
,
true
);
}
TEST
(
DEV_API
,
sparse_conv3d_subm
)
{
const
int
in_channels
=
1
;
const
int
out_channels
=
1
;
DDim
x_dims
=
{
1
,
4
,
4
,
5
,
in_channels
};
DDim
kernel_dims
=
{
3
,
3
,
3
,
in_channels
,
out_channels
};
DDim
out_dims
=
{
1
,
4
,
4
,
5
,
out_channels
};
std
::
vector
<
int
>
paddings
=
{
1
,
1
,
1
};
std
::
vector
<
int
>
strides
=
{
1
,
1
,
1
};
std
::
vector
<
int
>
dilations
=
{
1
,
1
,
1
};
const
int
non_zero_num
=
3
;
std
::
vector
<
int
>
indices_flatten
=
{
0
,
0
,
0
,
1
,
3
,
3
,
2
,
0
,
2
,
0
,
3
,
1
};
std
::
vector
<
float
>
features
=
{
-
0.9578
,
0.1572
,
0.1036
};
// 3*3*3=27
std
::
vector
<
float
>
kernel
=
{
0.1367
,
0.4534
,
0.2138
,
0.8264
,
0.7534
,
0.3270
,
0.2880
,
0.1562
,
0.7770
,
0.6902
,
0.1981
,
0.1369
,
0.6582
,
0.7582
,
0.5640
,
0.8894
,
0.7350
,
0.1845
,
0.6892
,
0.3654
,
0.6076
,
0.0326
,
0.8412
,
0.5289
,
0.9824
,
0.8235
,
0.9802
};
std
::
vector
<
int
>
out_indices_flatten
=
{
0
,
0
,
0
,
1
,
3
,
3
,
2
,
0
,
2
,
0
,
3
,
1
};
std
::
vector
<
float
>
out_features
=
{
-
0.7262
,
0.1192
,
0.0785
};
std
::
vector
<
float
>
features_grad
=
{
-
0.5506
,
0.0904
,
0.0595
};
std
::
vector
<
float
>
kernel_grad
=
{
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.7224
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
};
TestConv3d
(
indices_flatten
,
features
,
x_dims
,
kernel
,
kernel_dims
,
out_indices_flatten
,
out_features
,
out_dims
,
non_zero_num
,
paddings
,
strides
,
dilations
,
1e-3
,
true
,
features_grad
,
kernel_grad
,
true
);
}
}
// namespace tests
}
// namespace tests
}
// namespace phi
}
// namespace phi
python/paddle/utils/code_gen/sparse_api.yaml
浏览文件 @
0d78e491
-
api
:
conv3d
-
api
:
conv3d
args
:
(Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups)
args
:
(Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups
, bool subm
)
output
:
Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
output
:
Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
kernel
:
kernel
:
func
:
sparse_conv3d
func
:
sparse_conv3d
...
...
python/paddle/utils/code_gen/sparse_bw_api.yaml
浏览文件 @
0d78e491
-
backward_api
:
conv3d_grad
-
backward_api
:
conv3d_grad
forward
:
conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
forward
:
conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups
, bool subm
) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
args
:
(Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups)
args
:
(Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups
, bool subm
)
output
:
Tensor(x_grad@DenseTensor), Tensor(kernel_grad@DenseTensor)
output
:
Tensor(x_grad@DenseTensor), Tensor(kernel_grad@DenseTensor)
kernel
:
kernel
:
func
:
sparse_conv_grad
func
:
sparse_conv
3d
_grad
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录