Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
55d31980
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
55d31980
编写于
9月 21, 2022
作者:
Z
zhangkaihuo
提交者:
GitHub
9月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Sparse] add_coo_dense (#46322)
* for add_bias
上级
0f9dde43
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
155 addition
and
16 deletion
+155
-16
paddle/phi/api/lib/api_gen_utils.cc
paddle/phi/api/lib/api_gen_utils.cc
+3
-0
paddle/phi/api/yaml/sparse_backward.yaml
paddle/phi/api/yaml/sparse_backward.yaml
+8
-7
paddle/phi/api/yaml/sparse_ops.yaml
paddle/phi/api/yaml/sparse_ops.yaml
+3
-2
paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc
paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc
+11
-0
paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc
paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc
+26
-0
paddle/phi/kernels/sparse/elementwise_grad_kernel.h
paddle/phi/kernels/sparse/elementwise_grad_kernel.h
+25
-0
paddle/phi/kernels/sparse/elementwise_kernel.h
paddle/phi/kernels/sparse/elementwise_kernel.h
+20
-0
paddle/phi/kernels/sparse/gpu/elementwise_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/elementwise_grad_kernel.cu
+15
-0
paddle/phi/kernels/sparse/gpu/elementwise_kernel.cu
paddle/phi/kernels/sparse/gpu/elementwise_kernel.cu
+14
-0
python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py
...addle/fluid/tests/unittests/test_sparse_elementwise_op.py
+26
-0
python/paddle/incubate/sparse/binary.py
python/paddle/incubate/sparse/binary.py
+1
-1
python/paddle/incubate/sparse/nn/functional/conv.py
python/paddle/incubate/sparse/nn/functional/conv.py
+3
-6
未找到文件。
paddle/phi/api/lib/api_gen_utils.cc
浏览文件 @
55d31980
...
...
@@ -250,6 +250,9 @@ phi::SelectedRows* SetSelectedRowsKernelOutput(Tensor* out) {
}
phi
::
TensorBase
*
SetSparseKernelOutput
(
Tensor
*
out
,
TensorType
type
)
{
if
(
!
out
)
{
return
nullptr
;
}
if
(
!
out
->
initialized
())
{
if
(
type
==
TensorType
::
SPARSE_COO
)
{
auto
sparse_tensor
=
std
::
make_shared
<
phi
::
SparseCooTensor
>
(
...
...
paddle/phi/api/yaml/sparse_backward.yaml
浏览文件 @
55d31980
...
...
@@ -36,11 +36,12 @@
args
:
(Tensor x, Tensor y, Tensor out_grad)
output
:
Tensor(x_grad), Tensor(y_grad)
infer_meta
:
func
:
GeneralBinaryGradInferMeta
func
:
GeneralBinaryGradInferMeta
param
:
[
x
,
y
]
kernel
:
func
:
add_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
add_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}
add_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr},
add_coo_dense_grad{sparse_coo, dense, sparse_coo -> sparse_coo, dense}
-
backward_op
:
addmm_grad
forward
:
addmm(Tensor input, Tensor x, Tensor y, float alpha=1.0, float beta=1.0) -> Tensor(out)
...
...
@@ -104,7 +105,7 @@
args
:
(Tensor x, Tensor out_grad, DataType value_dtype)
output
:
Tensor(x_grad)
infer_meta
:
func
:
UnchangedInferMeta
func
:
UnchangedInferMeta
param
:
[
x
]
kernel
:
func
:
cast_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
...
...
@@ -126,7 +127,7 @@
args
:
(Tensor x, Tensor y, Tensor out, Tensor out_grad)
output
:
Tensor(x_grad), Tensor(y_grad)
infer_meta
:
func
:
GeneralBinaryGradInferMeta
func
:
GeneralBinaryGradInferMeta
param
:
[
x
,
y
]
kernel
:
func
:
divide_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
...
...
@@ -209,7 +210,7 @@
args
:
(Tensor x, Tensor y, Tensor out_grad)
output
:
Tensor(x_grad), Tensor(y_grad)
infer_meta
:
func
:
GeneralBinaryGradInferMeta
func
:
GeneralBinaryGradInferMeta
param
:
[
x
,
y
]
kernel
:
func
:
multiply_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
...
...
@@ -337,7 +338,7 @@
args
:
(Tensor x, Tensor y, Tensor out_grad)
output
:
Tensor(x_grad), Tensor(y_grad)
infer_meta
:
func
:
GeneralBinaryGradInferMeta
func
:
GeneralBinaryGradInferMeta
param
:
[
x
,
y
]
kernel
:
func
:
subtract_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
...
...
@@ -399,7 +400,7 @@
args
:
(Tensor query, Tensor key, Tensor value, Tensor softmax, Tensor out_grad)
output
:
Tensor(query_grad), Tensor(key_grad), Tensor(value_grad)
infer_meta
:
func
:
sparse::FusedAttentionGradInferMeta
func
:
sparse::FusedAttentionGradInferMeta
kernel
:
func
:
fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense}
layout
:
softmax
...
...
paddle/phi/api/yaml/sparse_ops.yaml
浏览文件 @
55d31980
...
...
@@ -35,10 +35,11 @@
args
:
(Tensor x, Tensor y)
output
:
Tensor(out)
infer_meta
:
func
:
ElementwiseInferMeta
func
:
ElementwiseInferMeta
kernel
:
func
:
add_coo_coo{sparse_coo, sparse_coo -> sparse_coo},
add_csr_csr{sparse_csr, sparse_csr -> sparse_csr}
add_coo_dense{sparse_coo, dense -> sparse_coo},
layout
:
x
backward
:
add_grad
...
...
@@ -114,7 +115,7 @@
args
:
(Tensor x, Tensor y)
output
:
Tensor(out)
infer_meta
:
func
:
ElementwiseInferMeta
func
:
ElementwiseInferMeta
kernel
:
func
:
divide_coo_coo{sparse_coo, sparse_coo -> sparse_coo},
divide_csr_csr{sparse_csr, sparse_csr -> sparse_csr}
...
...
paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc
浏览文件 @
55d31980
...
...
@@ -415,3 +415,14 @@ PD_REGISTER_KERNEL(divide_coo_coo_grad,
kernel
->
InputAt
(
2
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
3
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
add_coo_dense_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseAddDenseGradKernel
,
float
,
double
,
int
,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc
浏览文件 @
55d31980
...
...
@@ -156,6 +156,21 @@ void ElementWiseCooKernelImpl(const Context& dev_ctx,
"shape = [%s], Y's shape = [%s]."
,
x
.
dims
(),
y
.
dims
()));
// temporary policy: for broadcast add
// TODO(zhangkaihuo): implement a correct function
const
bool
is_add
=
std
::
is_same
<
Functor
,
funcs
::
AddFunctor
<
T
>>::
value
;
if
(
is_add
&&
x
.
indices
().
numel
()
==
y
.
indices
().
numel
())
{
int
compare_indices
=
memcmp
(
x
.
indices
().
data
<
IntT
>
(),
y
.
indices
().
data
<
IntT
>
(),
sizeof
(
IntT
)
*
x
.
indices
().
numel
());
if
(
compare_indices
==
0
)
{
EmptyLikeCooKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
out
);
phi
::
AddKernel
<
T
,
Context
>
(
dev_ctx
,
x
.
values
(),
y
.
values
(),
out
->
mutable_values
());
return
;
}
}
int64_t
element_size
=
1
;
for
(
auto
j
=
1
;
j
<
x
.
values
().
dims
().
size
();
++
j
)
{
element_size
*=
x
.
values
().
dims
()[
j
];
...
...
@@ -435,3 +450,14 @@ PD_REGISTER_KERNEL(divide_coo_coo,
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
add_coo_dense
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseAddDenseKernel
,
float
,
double
,
int
,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
paddle/phi/kernels/sparse/elementwise_grad_kernel.h
浏览文件 @
55d31980
...
...
@@ -14,6 +14,9 @@ limitations under the License. */
#pragma once
#include "paddle/phi/kernels/elementwise_add_grad_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
...
...
@@ -119,5 +122,27 @@ std::vector<SparseCooTensor> ElementWiseDivideCooGrad(
return
std
::
vector
<
SparseCooTensor
>
{
dx
,
dy
};
}
template
<
typename
T
,
typename
Context
>
void
ElementWiseAddDenseGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
y
,
const
SparseCooTensor
&
dout
,
SparseCooTensor
*
dx
,
DenseTensor
*
dy
)
{
DenseTensor
*
x_values_grad
=
nullptr
;
DenseTensor
*
y_grad
=
nullptr
;
if
(
dx
)
{
EmptyLikeCooKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
dx
);
x_values_grad
=
dx
->
mutable_values
();
}
if
(
dy
)
{
*
dy
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
y
);
y_grad
=
dy
;
}
phi
::
AddGradKernel
<
T
,
Context
>
(
dev_ctx
,
x
.
values
(),
y
,
dout
.
values
(),
-
1
,
x_values_grad
,
y_grad
);
}
}
// namespace sparse
}
// namespace phi
paddle/phi/kernels/sparse/elementwise_kernel.h
浏览文件 @
55d31980
...
...
@@ -14,6 +14,10 @@ limitations under the License. */
#pragma once
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/sparse/elementwise_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
...
...
@@ -78,5 +82,21 @@ DEFINE_ELEMENTWISE_KERNEL_FUNC(Subtract)
DEFINE_ELEMENTWISE_KERNEL_FUNC
(
Multiply
)
DEFINE_ELEMENTWISE_KERNEL_FUNC
(
Divide
)
template
<
typename
T
,
typename
Context
>
void
ElementWiseAddDenseKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
y
,
SparseCooTensor
*
out
)
{
// TODO(zhangkaiuo): to support universal sparse + dense
if
(
y
.
dims
().
size
()
==
1
&&
y
.
dims
()[
0
]
==
x
.
dims
()[
x
.
dims
().
size
()
-
1
])
{
EmptyLikeCooKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
out
);
phi
::
AddKernel
<
T
,
Context
>
(
dev_ctx
,
x
.
values
(),
y
,
out
->
mutable_values
());
out
->
SetIndicesDict
(
x
.
GetIndicesDict
());
}
else
{
PADDLE_THROW
(
errors
::
Unimplemented
(
"Not support Sparse + Dense in GPU mode"
));
}
}
}
// namespace sparse
}
// namespace phi
paddle/phi/kernels/sparse/gpu/elementwise_grad_kernel.cu
浏览文件 @
55d31980
...
...
@@ -15,6 +15,9 @@ limitations under the License. */
#include "paddle/phi/kernels/sparse/elementwise_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_grad_base.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
namespace
phi
{
...
...
@@ -54,3 +57,15 @@ PD_REGISTER_KERNEL(add_coo_coo_grad,
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
add_coo_dense_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseAddDenseGradKernel
,
float
,
double
,
int
,
int64_t
,
phi
::
dtype
::
float16
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
paddle/phi/kernels/sparse/gpu/elementwise_kernel.cu
浏览文件 @
55d31980
...
...
@@ -31,6 +31,7 @@ void ElementWiseAddCooGPUKernel(const GPUContext& dev_ctx,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
y
,
SparseCooTensor
*
out
)
{
// TODO(zhangkaiuo): to support universal sparse + sparse
const
auto
&
x_indices
=
x
.
indices
();
const
auto
&
y_indices
=
y
.
indices
();
PADDLE_ENFORCE_EQ
(
...
...
@@ -57,6 +58,7 @@ void ElementWiseAddCooGPUKernel(const GPUContext& dev_ctx,
EmptyLikeCooKernel
<
T
,
GPUContext
>
(
dev_ctx
,
x
,
out
);
phi
::
AddKernel
<
T
,
GPUContext
>
(
dev_ctx
,
x
.
values
(),
y
.
values
(),
out
->
mutable_values
());
out
->
SetIndicesDict
(
x
.
GetIndicesDict
());
}
template
<
typename
T
,
typename
Context
>
...
...
@@ -86,3 +88,15 @@ PD_REGISTER_KERNEL(add_coo_coo,
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
add_coo_dense
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseAddDenseKernel
,
float
,
double
,
int
,
int64_t
,
phi
::
dtype
::
float16
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py
浏览文件 @
55d31980
...
...
@@ -163,6 +163,32 @@ class TestSparseElementWiseAPI(unittest.TestCase):
np
.
testing
.
assert_allclose
(
sp_b
.
grad
.
values
().
numpy
(),
values2
.
grad
.
numpy
())
def
test_add_bias
(
self
):
indices_data
=
[[
0
,
1
],
[
0
,
3
]]
values_data
=
[[
1.0
,
1.0
],
[
2.0
,
2.0
]]
shape
=
[
2
,
4
,
2
]
sp_a
=
sparse
.
sparse_coo_tensor
(
indices_data
,
values_data
,
shape
,
stop_gradient
=
False
)
bias_values
=
[
1.0
,
2.0
]
values1
=
paddle
.
to_tensor
(
values_data
,
stop_gradient
=
False
)
values2
=
paddle
.
to_tensor
(
bias_values
,
stop_gradient
=
False
)
values3
=
paddle
.
to_tensor
(
bias_values
,
stop_gradient
=
False
)
#c.values() = a.values() + b
sp_c
=
sparse
.
add
(
sp_a
,
values2
)
sp_c
.
backward
()
ref_c
=
values1
+
values3
ref_c
.
backward
()
np
.
testing
.
assert_allclose
(
sp_c
.
values
().
numpy
(),
ref_c
.
numpy
())
np
.
testing
.
assert_allclose
(
sp_a
.
grad
.
values
().
numpy
(),
values1
.
grad
.
numpy
())
np
.
testing
.
assert_allclose
(
values2
.
grad
.
numpy
(),
values3
.
grad
.
numpy
())
if
__name__
==
"__main__"
:
paddle
.
device
.
set_device
(
'cpu'
)
...
...
python/paddle/incubate/sparse/binary.py
浏览文件 @
55d31980
...
...
@@ -253,7 +253,7 @@ def add(x, y, name=None):
"""
if
y
.
dtype
!=
x
.
dtype
:
y
=
_C_ops
.
sparse_
cast
(
y
,
None
,
x
.
dtype
)
y
=
cast
(
y
,
None
,
x
.
dtype
)
return
_C_ops
.
sparse_add
(
x
,
y
)
...
...
python/paddle/incubate/sparse/nn/functional/conv.py
浏览文件 @
55d31980
...
...
@@ -18,6 +18,8 @@ from paddle import _C_ops, _legacy_C_ops, in_dynamic_mode
from
paddle.fluid.layers.utils
import
convert_to_list
from
paddle.fluid.layers.nn
import
elementwise_add
from
...creation
import
sparse_coo_tensor
from
...binary
import
add
from
paddle.tensor
import
arange
from
paddle.nn.functional.conv
import
_update_padding_nd
...
...
@@ -67,12 +69,7 @@ def _conv3d(x,
groups
,
subm
,
key
if
key
is
not
None
else
""
)
if
bias
is
not
None
:
values
=
pre_bias
.
values
()
add_bias
=
elementwise_add
(
values
,
bias
,
axis
=
1
)
return
sparse_coo_tensor
(
pre_bias
.
indices
(),
add_bias
,
shape
=
pre_bias
.
shape
,
stop_gradient
=
pre_bias
.
stop_gradient
)
return
add
(
pre_bias
,
bias
)
else
:
return
pre_bias
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录