Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
f1eda7d0
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f1eda7d0
编写于
5月 12, 2022
作者:
T
tiancaishaonvjituizi
提交者:
GitHub
5月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【Hackathon No.60】refactor unary sparse ops and add sparse sqrt, tanh, sin (#41356)
上级
ddb3868e
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
867 addition
and
334 deletion
+867
-334
paddle/phi/kernels/activation_grad_kernel.h
paddle/phi/kernels/activation_grad_kernel.h
+1
-0
paddle/phi/kernels/sparse/activation_grad_kernel.cc
paddle/phi/kernels/sparse/activation_grad_kernel.cc
+0
-70
paddle/phi/kernels/sparse/activation_grad_kernel.h
paddle/phi/kernels/sparse/activation_grad_kernel.h
+0
-29
paddle/phi/kernels/sparse/activation_kernel.cc
paddle/phi/kernels/sparse/activation_kernel.cc
+0
-66
paddle/phi/kernels/sparse/activation_kernel.h
paddle/phi/kernels/sparse/activation_kernel.h
+0
-39
paddle/phi/kernels/sparse/unary_grad_kernel.cc
paddle/phi/kernels/sparse/unary_grad_kernel.cc
+183
-0
paddle/phi/kernels/sparse/unary_grad_kernel.h
paddle/phi/kernels/sparse/unary_grad_kernel.h
+41
-0
paddle/phi/kernels/sparse/unary_kernel.cc
paddle/phi/kernels/sparse/unary_kernel.cc
+177
-0
paddle/phi/kernels/sparse/unary_kernel.h
paddle/phi/kernels/sparse/unary_kernel.h
+48
-0
paddle/phi/tests/kernels/test_sparse_activation_dev_api.cc
paddle/phi/tests/kernels/test_sparse_activation_dev_api.cc
+3
-3
python/paddle/fluid/tests/unittests/test_sparse_activation_op.py
...paddle/fluid/tests/unittests/test_sparse_activation_op.py
+0
-50
python/paddle/fluid/tests/unittests/test_sparse_unary_op.py
python/paddle/fluid/tests/unittests/test_sparse_unary_op.py
+133
-0
python/paddle/sparse/__init__.py
python/paddle/sparse/__init__.py
+10
-6
python/paddle/sparse/functional/__init__.py
python/paddle/sparse/functional/__init__.py
+5
-2
python/paddle/sparse/functional/activation.py
python/paddle/sparse/functional/activation.py
+0
-53
python/paddle/sparse/functional/unary.py
python/paddle/sparse/functional/unary.py
+177
-0
python/paddle/sparse/layer/__init__.py
python/paddle/sparse/layer/__init__.py
+1
-1
python/paddle/sparse/layer/unary.py
python/paddle/sparse/layer/unary.py
+0
-0
python/paddle/utils/code_gen/sparse_api.yaml
python/paddle/utils/code_gen/sparse_api.yaml
+60
-8
python/paddle/utils/code_gen/sparse_bw_api.yaml
python/paddle/utils/code_gen/sparse_bw_api.yaml
+28
-7
未找到文件。
paddle/phi/kernels/activation_grad_kernel.h
浏览文件 @
f1eda7d0
...
...
@@ -187,6 +187,7 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Log1p);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Relu
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Tanh
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Sigmoid
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Sqrt
);
DECLARE_ACTIVATION_GRAD_KERNEL_NODEP
(
Round
);
DECLARE_ACTIVATION_GRAD_KERNEL_NODEP
(
Floor
);
...
...
paddle/phi/kernels/sparse/activation_grad_kernel.cc
已删除
100644 → 0
浏览文件 @
ddb3868e
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/activation_grad_kernel.h"
#include "paddle/phi/kernels/activation_grad_kernel.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
namespace
sparse
{
template
<
typename
T
,
typename
Context
>
void
SparseReluGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
out_grad
,
SparseCooTensor
*
x_grad
)
{
DenseTensor
non_zero_indices
=
phi
::
EmptyLike
<
T
,
Context
>
(
dev_ctx
,
x
.
non_zero_indices
());
DenseTensor
non_zero_elements
=
phi
::
EmptyLike
<
T
,
Context
>
(
dev_ctx
,
x
.
non_zero_elements
());
phi
::
Copy
(
dev_ctx
,
x
.
non_zero_indices
(),
dev_ctx
.
GetPlace
(),
false
,
&
non_zero_indices
);
phi
::
ReluGradKernel
<
T
,
Context
>
(
dev_ctx
,
x
.
non_zero_elements
(),
out_grad
.
non_zero_elements
(),
&
non_zero_elements
);
x_grad
->
SetMember
(
non_zero_indices
,
non_zero_elements
,
x
.
dims
(),
true
);
}
}
// namespace sparse
}
// namespace phi
PD_REGISTER_KERNEL
(
sparse_relu_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
SparseReluGradKernel
,
float
,
double
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL
(
sparse_relu_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
SparseReluGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
#endif
paddle/phi/kernels/sparse/activation_grad_kernel.h
已删除
100644 → 0
浏览文件 @
ddb3868e
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/sparse_coo_tensor.h"
namespace
phi
{
namespace
sparse
{
template
<
typename
T
,
typename
Context
>
void
SparseReluGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
out_grad
,
SparseCooTensor
*
x_grad
);
}
// namespace sparse
}
// namespace phi
paddle/phi/kernels/sparse/activation_kernel.cc
已删除
100644 → 0
浏览文件 @
ddb3868e
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/activation_kernel.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
namespace
sparse
{
template
<
typename
T
,
typename
Context
>
void
SparseReluKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
SparseCooTensor
*
out
)
{
DenseTensor
non_zero_indices
=
phi
::
EmptyLike
<
T
,
Context
>
(
dev_ctx
,
x
.
non_zero_indices
());
DenseTensor
non_zero_elements
=
phi
::
EmptyLike
<
T
,
Context
>
(
dev_ctx
,
x
.
non_zero_elements
());
phi
::
Copy
(
dev_ctx
,
x
.
non_zero_indices
(),
dev_ctx
.
GetPlace
(),
false
,
&
non_zero_indices
);
phi
::
ReluKernel
<
T
,
Context
>
(
dev_ctx
,
x
.
non_zero_elements
(),
&
non_zero_elements
);
out
->
SetMember
(
non_zero_indices
,
non_zero_elements
,
x
.
dims
(),
true
);
}
}
// namespace sparse
}
// namespace phi
PD_REGISTER_KERNEL
(
sparse_relu
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
SparseReluKernel
,
float
,
double
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL
(
sparse_relu
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
SparseReluKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
#endif
paddle/phi/kernels/sparse/activation_kernel.h
已删除
100644 → 0
浏览文件 @
ddb3868e
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
namespace
phi
{
namespace
sparse
{
template
<
typename
T
,
typename
Context
>
void
SparseReluKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
SparseCooTensor
*
out
);
template
<
typename
T
,
typename
Context
>
SparseCooTensor
SparseRelu
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
)
{
DenseTensor
indices
,
values
;
SparseCooTensor
coo
(
indices
,
values
,
x
.
dims
());
SparseReluKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
&
coo
);
return
coo
;
}
}
// namespace sparse
}
// namespace phi
paddle/phi/kernels/sparse/unary_grad_kernel.cc
0 → 100644
浏览文件 @
f1eda7d0
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/kernels/activation_grad_kernel.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#define DEFINE_SPARSE_UNARY_GRAD_KERNEL(DenseKernelFunc) \
namespace phi { \
namespace sparse { \
\
template <typename T, typename Context> \
void SparseCoo##DenseKernelFunc(const Context& dev_ctx, \
const SparseCooTensor& x_or_out, \
const SparseCooTensor& out_grad, \
SparseCooTensor* x_grad) { \
DenseTensor non_zero_indices = \
phi::EmptyLike<T, Context>(dev_ctx, x_or_out.non_zero_indices()); \
DenseTensor non_zero_elements = \
phi::EmptyLike<T, Context>(dev_ctx, x_or_out.non_zero_elements()); \
phi::Copy(dev_ctx, \
x_or_out.non_zero_indices(), \
dev_ctx.GetPlace(), \
false, \
&non_zero_indices); \
phi::DenseKernelFunc<T, Context>(dev_ctx, \
x_or_out.non_zero_elements(), \
out_grad.non_zero_elements(), \
&non_zero_elements); \
x_grad->SetMember( \
non_zero_indices, non_zero_elements, x_or_out.dims(), true); \
} \
\
template <typename T, typename Context> \
void SparseCsr##DenseKernelFunc(const Context& dev_ctx, \
const SparseCsrTensor& x_or_out, \
const SparseCsrTensor& out_grad, \
SparseCsrTensor* out) { \
DenseTensor non_zero_crows = \
phi::EmptyLike<T, Context>(dev_ctx, x_or_out.non_zero_crows()); \
DenseTensor non_zero_cols = \
phi::EmptyLike<T, Context>(dev_ctx, x_or_out.non_zero_cols()); \
DenseTensor non_zero_elements = \
phi::EmptyLike<T, Context>(dev_ctx, x_or_out.non_zero_elements()); \
phi::Copy(dev_ctx, \
x_or_out.non_zero_crows(), \
dev_ctx.GetPlace(), \
false, \
&non_zero_crows); \
phi::Copy(dev_ctx, \
x_or_out.non_zero_cols(), \
dev_ctx.GetPlace(), \
false, \
&non_zero_cols); \
phi::DenseKernelFunc<T, Context>(dev_ctx, \
x_or_out.non_zero_elements(), \
out_grad.non_zero_elements(), \
&non_zero_elements); \
out->SetMember( \
non_zero_crows, non_zero_cols, non_zero_elements, x_or_out.dims()); \
} \
} \
}
#define REGISTER_CPU_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc) \
PD_REGISTER_KERNEL(sparse_coo_##kernel_name, \
CPU, \
ALL_LAYOUT, \
phi::sparse::SparseCoo##DenseKernelFunc, \
float, \
double) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \
} \
PD_REGISTER_KERNEL(sparse_csr_##kernel_name, \
CPU, \
ALL_LAYOUT, \
phi::sparse::SparseCsr##DenseKernelFunc, \
float, \
double) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#define REGISTER_GPU_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc) \
PD_REGISTER_KERNEL(sparse_coo_##kernel_name, \
GPU, \
ALL_LAYOUT, \
phi::sparse::SparseCoo##DenseKernelFunc, \
float, \
double, \
phi::dtype::float16) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \
} \
\
PD_REGISTER_KERNEL(sparse_csr_##kernel_name, \
GPU, \
ALL_LAYOUT, \
phi::sparse::SparseCsr##DenseKernelFunc, \
float, \
double, \
phi::dtype::float16) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
}
#else
// This macro definition is empty when GPU is disabled
#define REGISTER_GPU_SPARSE_UNARY_KERNEL(sparse_kernel_name, DenseKernelFunc)
#endif
#define REGISTER_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc) \
REGISTER_CPU_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc) \
REGISTER_GPU_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc)
#define DEFINE_AND_REGISTER_SPARSE_UNARY_GRAD_KERNEL(kernel_name, \
DenseKernelFunc) \
DEFINE_SPARSE_UNARY_GRAD_KERNEL(DenseKernelFunc) \
REGISTER_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc)
// NOTE: the following code is to bypass the restriction of Paddle
// kernel registration mechanism. Do NOT refactor them unless you
// know what you are doing.
// If you want to implement any new kernel, please follow `sin_grad`,
// `tanh_grad` etc, do NOT follow the following `relu_grad`.
DEFINE_SPARSE_UNARY_GRAD_KERNEL
(
ReluGradKernel
)
PD_REGISTER_KERNEL
(
sparse_coo_relu_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
SparseCooReluGradKernel
,
float
,
double
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
sparse_csr_relu_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
SparseCsrReluGradKernel
,
float
,
double
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL
(
sparse_coo_relu_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
SparseCooReluGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
sparse_csr_relu_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
SparseCsrReluGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
}
#endif
DEFINE_AND_REGISTER_SPARSE_UNARY_GRAD_KERNEL
(
sin_grad
,
SinGradKernel
)
DEFINE_AND_REGISTER_SPARSE_UNARY_GRAD_KERNEL
(
sqrt_grad
,
SqrtGradKernel
)
DEFINE_AND_REGISTER_SPARSE_UNARY_GRAD_KERNEL
(
tanh_grad
,
TanhGradKernel
)
paddle/phi/kernels/sparse/unary_grad_kernel.h
0 → 100644
浏览文件 @
f1eda7d0
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#define DECLARE_SPARSE_UNARY_GRAD_KERNEL(name) \
template <typename T, typename Context> \
void SparseCoo##name##GradKernel(const Context& dev_ctx, \
const SparseCooTensor& x, \
const SparseCooTensor& out_grad, \
SparseCooTensor* x_grad); \
\
template <typename T, typename Context> \
void SparseCsr##name##GradKernel(const Context& dev_ctx, \
const SparseCsrTensor& x, \
const SparseCsrTensor& out_grad, \
SparseCsrTensor* x_grad);
namespace
phi
{
namespace
sparse
{
DECLARE_SPARSE_UNARY_GRAD_KERNEL
(
Relu
)
DECLARE_SPARSE_UNARY_GRAD_KERNEL
(
Sqrt
)
DECLARE_SPARSE_UNARY_GRAD_KERNEL
(
Sin
)
}
// namespace sparse
}
// namespace phi
paddle/phi/kernels/sparse/unary_kernel.cc
0 → 100644
浏览文件 @
f1eda7d0
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#define DEFINE_SPARSE_UNARY_KERNEL(DenseKernelFunc) \
namespace phi { \
namespace sparse { \
\
template <typename T, typename Context> \
void SparseCoo##DenseKernelFunc(const Context& dev_ctx, \
const SparseCooTensor& x, \
SparseCooTensor* out) { \
DenseTensor non_zero_indices = \
phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_indices()); \
DenseTensor non_zero_elements = \
phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_elements()); \
phi::Copy(dev_ctx, \
x.non_zero_indices(), \
dev_ctx.GetPlace(), \
false, \
&non_zero_indices); \
phi::DenseKernelFunc<T, Context>( \
dev_ctx, x.non_zero_elements(), &non_zero_elements); \
out->SetMember(non_zero_indices, non_zero_elements, x.dims(), true); \
} \
\
template <typename T, typename Context> \
void SparseCsr##DenseKernelFunc(const Context& dev_ctx, \
const SparseCsrTensor& x, \
SparseCsrTensor* out) { \
DenseTensor non_zero_crows = \
phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_crows()); \
DenseTensor non_zero_cols = \
phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_cols()); \
DenseTensor non_zero_elements = \
phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_elements()); \
phi::Copy(dev_ctx, \
x.non_zero_crows(), \
dev_ctx.GetPlace(), \
false, \
&non_zero_crows); \
phi::Copy(dev_ctx, \
x.non_zero_cols(), \
dev_ctx.GetPlace(), \
false, \
&non_zero_cols); \
phi::DenseKernelFunc<T, Context>( \
dev_ctx, x.non_zero_elements(), &non_zero_elements); \
out->SetMember( \
non_zero_crows, non_zero_cols, non_zero_elements, x.dims()); \
} \
} \
}
#define REGISTER_CPU_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc) \
PD_REGISTER_KERNEL(sparse_coo_##kernel_name, \
CPU, \
ALL_LAYOUT, \
phi::sparse::SparseCoo##DenseKernelFunc, \
float, \
double) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \
} \
PD_REGISTER_KERNEL(sparse_csr_##kernel_name, \
CPU, \
ALL_LAYOUT, \
phi::sparse::SparseCsr##DenseKernelFunc, \
float, \
double) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#define REGISTER_GPU_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc) \
PD_REGISTER_KERNEL(sparse_coo_##kernel_name, \
GPU, \
ALL_LAYOUT, \
phi::sparse::SparseCoo##DenseKernelFunc, \
float, \
double, \
phi::dtype::float16) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \
} \
\
PD_REGISTER_KERNEL(sparse_csr_##kernel_name, \
GPU, \
ALL_LAYOUT, \
phi::sparse::SparseCsr##DenseKernelFunc, \
float, \
double, \
phi::dtype::float16) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
}
#else
// This macro definition is empty when GPU is disabled
#define REGISTER_GPU_SPARSE_UNARY_KERNEL(sparse_kernel_name, DenseKernelFunc)
#endif
#define REGISTER_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc) \
REGISTER_CPU_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc) \
REGISTER_GPU_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc)
#define DEFINE_AND_REGISTER_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc) \
DEFINE_SPARSE_UNARY_KERNEL(DenseKernelFunc) \
REGISTER_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc)
// NOTE: the following code is to bypass the restriction of Paddle
// kernel registration mechanism. Do NOT refactor them unless you
// know what you are doing.
// If you want to implement any new kernel, please follow `sin`,
// `tanh` etc, do NOT follow `sqrt`.
DEFINE_SPARSE_UNARY_KERNEL
(
SqrtKernel
)
PD_REGISTER_KERNEL
(
sparse_coo_sqrt
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
SparseCooSqrtKernel
,
float
,
double
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
sparse_csr_sqrt
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
SparseCsrSqrtKernel
,
float
,
double
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL
(
sparse_coo_sqrt
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
SparseCooSqrtKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
sparse_csr_sqrt
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
SparseCsrSqrtKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
}
#endif
DEFINE_AND_REGISTER_SPARSE_UNARY_KERNEL
(
sin
,
SinKernel
)
DEFINE_AND_REGISTER_SPARSE_UNARY_KERNEL
(
tanh
,
TanhKernel
)
DEFINE_AND_REGISTER_SPARSE_UNARY_KERNEL
(
relu
,
ReluKernel
)
paddle/phi/kernels/sparse/unary_kernel.h
0 → 100644
浏览文件 @
f1eda7d0
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#define DECLARE_SPARSE_UNARY_KERNEL(name) \
template <typename T, typename Context> \
void SparseCoo##name##Kernel( \
const Context& dev_ctx, const SparseCooTensor& x, SparseCooTensor* out); \
\
template <typename T, typename Context> \
void SparseCsr##name##Kernel( \
const Context& dev_ctx, const SparseCsrTensor& x, SparseCsrTensor* out);
namespace
phi
{
namespace
sparse
{
DECLARE_SPARSE_UNARY_KERNEL
(
Relu
)
DECLARE_SPARSE_UNARY_KERNEL
(
Sqrt
)
DECLARE_SPARSE_UNARY_KERNEL
(
Sin
)
template
<
typename
T
,
typename
Context
>
SparseCooTensor
SparseRelu
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
)
{
DenseTensor
indices
,
values
;
SparseCooTensor
coo
(
indices
,
values
,
x
.
dims
());
SparseCooReluKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
&
coo
);
return
coo
;
}
}
// namespace sparse
}
// namespace phi
paddle/phi/tests/kernels/test_sparse_activation_dev_api.cc
浏览文件 @
f1eda7d0
...
...
@@ -24,9 +24,9 @@ limitations under the License. */
#include "paddle/phi/kernels/activation_grad_kernel.h"
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/sparse/activation_grad_kernel.h"
#include "paddle/phi/kernels/sparse/activation_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"
namespace
phi
{
namespace
tests
{
...
...
@@ -70,7 +70,7 @@ TEST(DEV_API, sparse_relu) {
SparseCooTensor
sparse_out_grad
(
sparse_coo
.
non_zero_indices
(),
dense_out
,
{
3
,
4
});
sparse
::
SparseReluGradKernel
<
float
>
(
sparse
::
Sparse
Coo
ReluGradKernel
<
float
>
(
dev_ctx_cpu
,
sparse_coo
,
sparse_out_grad
,
&
sparse_grad_x
);
cmp
=
memcmp
(
dense_grad_x
.
data
<
float
>
(),
...
...
python/paddle/fluid/tests/unittests/test_sparse_activation_op.py
已删除
100644 → 0
浏览文件 @
ddb3868e
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle
from
paddle.fluid.framework
import
_test_eager_guard
class
TestSparseActivation
(
unittest
.
TestCase
):
def
test_sparse_relu
(
self
):
with
_test_eager_guard
():
x
=
[[
0
,
-
1
,
0
,
2
],
[
0
,
0
,
-
3
,
0
],
[
4
,
5
,
0
,
0
]]
def
dense_relu
(
x
):
dense_x
=
paddle
.
to_tensor
(
x
,
dtype
=
'float32'
,
stop_gradient
=
False
)
dense_relu
=
paddle
.
nn
.
ReLU
()
dense_out
=
dense_relu
(
dense_x
)
dense_out
.
backward
(
dense_out
)
return
dense_out
,
dense_x
.
grad
dense_x
=
paddle
.
to_tensor
(
x
,
dtype
=
'float32'
,
stop_gradient
=
False
)
sparse_dim
=
2
sparse_x
=
dense_x
.
to_sparse_coo
(
sparse_dim
)
sparse_relu
=
paddle
.
sparse
.
ReLU
()
sparse_out
=
sparse_relu
(
sparse_x
)
sparse_out
.
backward
(
sparse_out
)
dense_out
,
dense_x_grad
=
dense_relu
(
x
)
assert
np
.
array_equal
(
dense_out
.
numpy
(),
sparse_out
.
to_dense
().
numpy
())
assert
np
.
array_equal
(
dense_x_grad
.
numpy
(),
sparse_x
.
grad
.
to_dense
().
numpy
())
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_sparse_unary_op.py
0 → 100644
浏览文件 @
f1eda7d0
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
from
typing
import
Union
,
Callable
import
numpy
as
np
import
paddle
from
paddle.fluid.framework
import
_test_eager_guard
from
paddle
import
_C_ops
class
TestSparseUnary
(
unittest
.
TestCase
):
def
assert_raises_on_dense_tensor
(
self
,
sparse_func
):
with
_test_eager_guard
():
dense_x
=
paddle
.
ones
((
2
,
3
))
with
self
.
assertRaises
(
ValueError
):
sparse_func
(
dense_x
)
def
compare_with_dense
(
self
,
x
,
to_sparse
:
Callable
[[
paddle
.
Tensor
],
paddle
.
Tensor
],
dense_func
:
Callable
[[
paddle
.
Tensor
],
paddle
.
Tensor
],
sparse_func
:
Callable
[[
paddle
.
Tensor
],
paddle
.
Tensor
],
test_gradient
:
bool
,
):
def
tensor_allclose
(
dense_tensor
:
paddle
.
Tensor
,
sparse_tensor
:
paddle
.
Tensor
):
dense_numpy
=
dense_tensor
.
numpy
()
mask
=
~
np
.
isnan
(
dense_numpy
)
return
np
.
allclose
(
dense_numpy
[
mask
],
sparse_tensor
.
to_dense
().
numpy
()[
mask
])
with
_test_eager_guard
():
dense_x
=
paddle
.
to_tensor
(
x
,
dtype
=
"float32"
,
stop_gradient
=
not
test_gradient
)
sparse_x
=
to_sparse
(
dense_x
)
sparse_out
=
sparse_func
(
sparse_x
)
dense_x
=
paddle
.
to_tensor
(
x
,
dtype
=
"float32"
,
stop_gradient
=
not
test_gradient
)
dense_out
=
dense_func
(
dense_x
)
assert
tensor_allclose
(
dense_out
,
sparse_out
)
if
test_gradient
:
dense_out
.
backward
(
dense_out
)
sparse_out
.
backward
(
sparse_out
)
assert
tensor_allclose
(
dense_x
.
grad
,
sparse_x
.
grad
)
def
test_sparse_relu
(
self
):
x
=
[[
0
,
-
1
,
0
,
2
],
[
0
,
0
,
-
3
,
0
],
[
4
,
5
,
0
,
0
]]
sparse_dim
=
2
self
.
compare_with_dense
(
x
,
lambda
x
:
x
.
to_sparse_coo
(
sparse_dim
),
paddle
.
nn
.
ReLU
(),
paddle
.
sparse
.
ReLU
(),
True
,
)
self
.
compare_with_dense
(
x
,
lambda
x
:
x
.
to_sparse_csr
(),
paddle
.
nn
.
ReLU
(),
paddle
.
sparse
.
ReLU
(),
False
,
)
self
.
assert_raises_on_dense_tensor
(
paddle
.
sparse
.
ReLU
())
def
test_sparse_sqrt
(
self
):
x
=
[[
0
,
16
,
0
,
0
],
[
0
,
0
,
0
,
0
],
[
0
,
4
,
2
,
0
]]
sparse_dim
=
2
self
.
compare_with_dense
(
x
,
lambda
x
:
x
.
to_sparse_coo
(
sparse_dim
),
paddle
.
sqrt
,
paddle
.
sparse
.
sqrt
,
True
,
)
self
.
compare_with_dense
(
x
,
lambda
x
:
x
.
to_sparse_csr
(),
paddle
.
sqrt
,
paddle
.
sparse
.
sqrt
,
False
,
)
self
.
assert_raises_on_dense_tensor
(
paddle
.
sparse
.
sqrt
)
def
test_sparse_sin
(
self
):
x
=
[[
0
,
16
,
0
,
0
],
[
0
,
0
,
0
,
0
],
[
0
,
4
,
2
,
0
]]
sparse_dim
=
2
self
.
compare_with_dense
(
x
,
lambda
x
:
x
.
to_sparse_coo
(
sparse_dim
),
paddle
.
sin
,
paddle
.
sparse
.
sin
,
True
,
)
self
.
compare_with_dense
(
x
,
lambda
x
:
x
.
to_sparse_csr
(),
paddle
.
sin
,
paddle
.
sparse
.
sin
,
False
,
)
self
.
assert_raises_on_dense_tensor
(
paddle
.
sparse
.
sin
)
def
test_sparse_tanh
(
self
):
x
=
[[
0
,
16
,
0
,
0
],
[
0
,
0
,
0
,
0
],
[
0
,
-
4
,
2
,
0
]]
sparse_dim
=
2
self
.
compare_with_dense
(
x
,
lambda
x
:
x
.
to_sparse_coo
(
sparse_dim
),
paddle
.
tanh
,
paddle
.
sparse
.
tanh
,
True
,
)
self
.
compare_with_dense
(
x
,
lambda
x
:
x
.
to_sparse_csr
(),
paddle
.
tanh
,
paddle
.
sparse
.
tanh
,
False
,
)
self
.
assert_raises_on_dense_tensor
(
paddle
.
sparse
.
tanh
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/sparse/__init__.py
浏览文件 @
f1eda7d0
...
...
@@ -14,15 +14,19 @@
from
.creation
import
sparse_coo_tensor
from
.creation
import
sparse_csr_tensor
from
.layer
.activation
import
ReLU
from
.layer
.norm
import
BatchNorm
from
.layer
import
ReLU
from
.layer
import
BatchNorm
from
.layer
.conv
import
Conv3D
from
.layer
.conv
import
SubmConv3D
from
.layer
import
Conv3D
from
.layer
import
SubmConv3D
from
.layer.pooling
import
MaxPool3D
from
.layer
import
MaxPool3D
from
.functional
import
sqrt
from
.functional
import
sin
from
.functional
import
tanh
__all__
=
[
'sparse_coo_tensor'
,
'sparse_csr_tensor'
,
'ReLU'
,
'Conv3D'
,
'SubmConv3D'
,
'BatchNorm'
,
'MaxPool3D'
'BatchNorm'
,
'MaxPool3D'
,
'sqrt'
,
'sin'
,
'tanh'
]
python/paddle/sparse/functional/__init__.py
浏览文件 @
f1eda7d0
...
...
@@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.activation
import
relu
# noqa: F401
from
.unary
import
relu
# noqa: F401
from
.unary
import
tanh
# noqa: F401
from
.unary
import
sqrt
# noqa: F401
from
.unary
import
sin
# noqa: F401
from
.conv
import
conv3d
# noqa: F401
from
.conv
import
subm_conv3d
# noqa: F401
from
.pooling
import
max_pool3d
# noqa: F401
__all__
=
[
'relu'
,
'
conv3d'
,
'subm_conv3d'
,
'max_pool3d
'
]
__all__
=
[
'relu'
,
'
tanh'
,
'conv3d'
,
'subm_conv3d'
,
'max_pool3d'
,
'sqrt'
,
'sin
'
]
python/paddle/sparse/functional/activation.py
已删除
100644 → 0
浏览文件 @
ddb3868e
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__
=
[]
from
paddle
import
_C_ops
,
in_dynamic_mode
def
relu
(
x
,
name
=
None
):
"""
sparse relu activation.
.. math::
out = max(x, 0)
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Sparse Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import numpy as np
from paddle.fluid.framework import _test_eager_guard
with _test_eager_guard():
dense_x = paddle.to_tensor(np.array([-2, 0, 1]).astype('float32'))
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.sparse.functional.relu(sparse_x)
"""
assert
in_dynamic_mode
(),
"Currently, Sparse API only support dynamic mode"
assert
x
.
is_sparse_coo
(
),
"Currently, sparse.relu only support the input of SparseCooTensor"
return
_C_ops
.
final_state_sparse_relu
(
x
)
python/paddle/sparse/functional/unary.py
0 → 100644
浏览文件 @
f1eda7d0
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__
=
[]
from
paddle
import
_C_ops
,
in_dynamic_mode
def
relu
(
x
,
name
=
None
):
"""
sparse relu activation, requiring x to be a sparse coo or sparse csr tensor.
.. math::
out = max(x, 0)
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Sparse Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
from paddle.fluid.framework import _test_eager_guard
with _test_eager_guard():
dense_x = paddle.to_tensor([-2, 0, 1], dtype='float32')
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.sparse.functional.relu(sparse_x)
"""
assert
in_dynamic_mode
(),
"Currently, Sparse API only support dynamic mode"
if
x
.
is_sparse_coo
():
return
_C_ops
.
final_state_sparse_coo_relu
(
x
)
elif
x
.
is_sparse_csr
():
return
_C_ops
.
final_state_sparse_csr_relu
(
x
)
else
:
raise
ValueError
(
"Currently, sparse.relu only support the input of SparseCooTensor or SparseCsrTensor"
)
def
tanh
(
x
,
name
=
None
):
"""
sparse tanh activation, requiring x to be a sparse coo or sparse csr tensor.
.. math::
out = tanh(x)
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Sparse Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
from paddle.fluid.framework import _test_eager_guard
with _test_eager_guard():
dense_x = paddle.to_tensor([-2, 0, 1], dtype='float32')
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.sparse.tanh(sparse_x)
"""
assert
in_dynamic_mode
(),
"Currently, Sparse API only support dynamic mode"
if
x
.
is_sparse_coo
():
return
_C_ops
.
final_state_sparse_coo_tanh
(
x
)
elif
x
.
is_sparse_csr
():
return
_C_ops
.
final_state_sparse_csr_tanh
(
x
)
else
:
raise
ValueError
(
"Currently, sparse.tanh only support the input of SparseCooTensor or SparseCsrTensor"
)
def
sqrt
(
x
,
name
=
None
):
"""
Calculate square root of x, requiring x to be a sparse coo or sparse csr tensor.
.. math::
out = sqrt(x)
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Sparse Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
from paddle.fluid.framework import _test_eager_guard
with _test_eager_guard():
dense_x = paddle.to_tensor([4, 0, 1], dtype='float32')
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.sparse.sqrt(sparse_x)
"""
assert
in_dynamic_mode
(),
"Currently, Sparse API only support dynamic mode"
if
x
.
is_sparse_coo
():
return
_C_ops
.
final_state_sparse_coo_sqrt
(
x
)
elif
x
.
is_sparse_csr
():
return
_C_ops
.
final_state_sparse_csr_sqrt
(
x
)
else
:
raise
ValueError
(
"Currently, sparse.sqrt only support the input of SparseCooTensor or SparseCsrTensor"
)
def
sin
(
x
,
name
=
None
):
"""
Calculate sin of x, requiring x to be a sparse coo or sparse csr tensor.
.. math::
out = sin(x)
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Sparse Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
from paddle.fluid.framework import _test_eager_guard
with _test_eager_guard():
dense_x = paddle.to_tensor([-2, 0, 3], dtype='float32')
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.sparse.sin(sparse_x)
"""
assert
in_dynamic_mode
(),
"Currently, Sparse API only support dynamic mode"
if
x
.
is_sparse_coo
():
return
_C_ops
.
final_state_sparse_coo_sin
(
x
)
elif
x
.
is_sparse_csr
():
return
_C_ops
.
final_state_sparse_csr_sin
(
x
)
else
:
raise
ValueError
(
"Currently, sparse.sin only support the input of SparseCooTensor or SparseCsrTensor"
)
python/paddle/sparse/layer/__init__.py
浏览文件 @
f1eda7d0
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
activation
import
ReLU
from
.
unary
import
ReLU
from
.norm
import
BatchNorm
from
.conv
import
Conv3D
from
.conv
import
SubmConv3D
...
...
python/paddle/sparse/layer/
activation
.py
→
python/paddle/sparse/layer/
unary
.py
浏览文件 @
f1eda7d0
文件已移动
python/paddle/utils/code_gen/sparse_api.yaml
浏览文件 @
f1eda7d0
...
...
@@ -7,6 +7,38 @@
intermediate
:
rulebook
backward
:
conv3d_grad
-
api
:
coo_relu
args
:
(Tensor x)
output
:
Tensor(out@SparseCooTensor)
kernel
:
func
:
sparse_coo_relu
layout
:
x
backward
:
sparse_coo_relu_grad
-
api
:
coo_sin
args
:
(Tensor x)
output
:
Tensor(out@SparseCooTensor)
kernel
:
func
:
sparse_coo_sin
layout
:
x
backward
:
sparse_coo_sin_grad
-
api
:
coo_sqrt
args
:
(Tensor x)
output
:
Tensor(out@SparseCooTensor)
kernel
:
func
:
sparse_coo_sqrt
layout
:
x
backward
:
sparse_coo_sqrt_grad
-
api
:
coo_tanh
args
:
(Tensor x)
output
:
Tensor(out@SparseCooTensor)
kernel
:
func
:
sparse_coo_tanh
layout
:
x
backward
:
sparse_coo_tanh_grad
-
api
:
coo_to_dense
args
:
(Tensor x)
output
:
Tensor(out@DenseTensor)
...
...
@@ -30,6 +62,34 @@
data_type
:
values
backward
:
create_sparse_coo_tensor_grad
-
api
:
csr_relu
args
:
(Tensor x)
output
:
Tensor(out@SparseCsrTensor)
kernel
:
func
:
sparse_csr_relu
layout
:
x
-
api
:
csr_sin
args
:
(Tensor x)
output
:
Tensor(out@SparseCsrTensor)
kernel
:
func
:
sparse_csr_sin
layout
:
x
-
api
:
csr_sqrt
args
:
(Tensor x)
output
:
Tensor(out@SparseCsrTensor)
kernel
:
func
:
sparse_csr_sqrt
layout
:
x
-
api
:
csr_tanh
args
:
(Tensor x)
output
:
Tensor(out@SparseCsrTensor)
kernel
:
func
:
sparse_csr_tanh
layout
:
x
-
api
:
csr_values
args
:
(Tensor x)
output
:
Tensor(out@DenseTensor)
...
...
@@ -43,14 +103,6 @@
invoke
:
to_sparse_coo_impl(x, sparse_dim)
backward
:
dense_to_coo_grad
-
api
:
relu
args
:
(Tensor x)
output
:
Tensor(out@SparseCooTensor)
kernel
:
func
:
sparse_relu
layout
:
x
backward
:
sparse_relu_grad
-
api
:
to_dense
args
:
(Tensor x)
output
:
Tensor(out@DenseTensor)
...
...
python/paddle/utils/code_gen/sparse_bw_api.yaml
浏览文件 @
f1eda7d0
...
...
@@ -32,16 +32,37 @@
output
:
Tensor(x_grad@DenseTensor)
invoke
:
to_dense_impl(out_grad)
-
backward_api
:
sparse_
maxpool
_grad
forward
:
sparse_
maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out@SparseCooTensor), Tensor(rulebook@Dense
Tensor)
args
:
(Tensor
x, Tensor rulebook, Tensor out, Tensor out_grad, int[] kernel_sizes
)
-
backward_api
:
sparse_
coo_relu
_grad
forward
:
sparse_
coo_relu(Tensor x) -> Tensor(out@SparseCoo
Tensor)
args
:
(Tensor
out, Tensor out_grad
)
output
:
Tensor(x_grad@SparseCooTensor)
kernel
:
func
:
sparse_
maxpool
_grad
func
:
sparse_
coo_relu
_grad
-
backward_api
:
sparse_
relu
_grad
forward
:
sparse_
relu
(Tensor x) -> Tensor(out@SparseCooTensor)
-
backward_api
:
sparse_
coo_sin
_grad
forward
:
sparse_
coo_sin
(Tensor x) -> Tensor(out@SparseCooTensor)
args
:
(Tensor x, Tensor out_grad)
output
:
Tensor(x_grad@SparseCooTensor)
kernel
:
func
:
sparse_relu_grad
func
:
sparse_coo_sin_grad
-
backward_api
:
sparse_coo_sqrt_grad
forward
:
sparse_coo_sqrt(Tensor x) -> Tensor(out@SparseCooTensor)
args
:
(Tensor out, Tensor out_grad)
output
:
Tensor(x_grad@SparseCooTensor)
kernel
:
func
:
sparse_coo_sqrt_grad
-
backward_api
:
sparse_coo_tanh_grad
forward
:
sparse_coo_tanh(Tensor x) -> Tensor(out@SparseCooTensor)
args
:
(Tensor out, Tensor out_grad)
output
:
Tensor(x_grad@SparseCooTensor)
kernel
:
func
:
sparse_coo_tanh_grad
-
backward_api
:
sparse_maxpool_grad
forward
:
sparse_maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
args
:
(Tensor x, Tensor rulebook, Tensor out, Tensor out_grad, int[] kernel_sizes)
output
:
Tensor(x_grad@SparseCooTensor)
kernel
:
func
:
sparse_maxpool_grad
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录