未验证 提交 2ddbc647 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

【Sparse】add new API/OP(csr->csr) of SparseTensor softmax (#43475)

* add new API/OP(csr->csr) of SparseTensor softmax

* fix comment
上级 a4cfa5ae
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/softmax_grad_kernel.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/cpu_vec.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
namespace plt = paddle::platform;
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SoftmaxCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& out,
const SparseCsrTensor& dout,
int axis,
SparseCsrTensor* dx) {
PADDLE_ENFORCE_EQ(axis,
-1,
phi::errors::Unimplemented(
"SparseCsrTensor only support axis=-1 for softmax, "
"which is faster when reading data by row (axis=-1)"));
EmptyLikeCsrKernel<T, Context>(dev_ctx, dout, dx);
auto out_dim = out.dims();
int rows = 1;
for (int i = 0; i < out_dim.size() - 1; ++i) {
rows *= out_dim[i];
}
const DenseTensor& out_crows = out.non_zero_crows();
const DenseTensor& out_values = out.non_zero_elements();
const DenseTensor& dout_values = dout.non_zero_elements();
DenseTensor* dx_values = dx->mutable_non_zero_elements();
int row_first = 0;
int row_nnz = 0;
const T* out_data = out_values.data<T>();
const T* dout_data = dout_values.data<T>();
T* dx_data = dx_values->data<T>();
// dx = (dout - sum(dout * out)) * out
PD_VISIT_INTEGRAL_TYPES(
out.non_zero_crows().dtype(), "SoftmaxCsrGradKernel", ([&] {
const data_t* out_crows_data = out_crows.data<data_t>();
for (int i = 0; i < rows; ++i) {
row_first = static_cast<int>(out_crows_data[i]);
row_nnz = static_cast<int>(out_crows_data[i + 1] - out_crows_data[i]);
out_data = out_data + row_first;
dout_data = dout_data + row_first;
dx_data = dx_data + row_first;
T sum = 0;
phi::funcs::vec_mul_reduce<T, plt::avx>(
row_nnz, dout_data, out_data, &sum);
phi::funcs::vec_add_bias<T, plt::avx>(
row_nnz, static_cast<T>(-1) * sum, dout_data, dx_data);
phi::funcs::vec_mul<T, plt::avx>(row_nnz, dx_data, out_data, dx_data);
}
}));
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(softmax_csr_grad,
CPU,
ALL_LAYOUT,
phi::sparse::SoftmaxCsrGradKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/softmax_kernel.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/cpu_vec.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
namespace plt = paddle::platform;
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SoftmaxCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
int axis,
SparseCsrTensor* out) {
PADDLE_ENFORCE_EQ(axis,
-1,
phi::errors::Unimplemented(
"SparseCsrTensor only support axis=-1 for softmax, "
"which is faster when reading data by row (axis=-1)"));
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, out);
auto x_dim = x.dims();
int row_number = 1;
for (int i = 0; i < x_dim.size() - 1; ++i) {
row_number *= x_dim[i];
}
const DenseTensor& x_crows = x.non_zero_crows();
const DenseTensor& x_values = x.non_zero_elements();
DenseTensor* out_values = out->mutable_non_zero_elements();
int row_first = 0;
int row_nnz = 0;
T row_max_val = 0;
const T* x_data = x_values.data<T>();
T* out_data = out_values->data<T>();
// out = exp(x-x_max) / sum( exp(x-x_max ))
PD_VISIT_INTEGRAL_TYPES(
x.non_zero_crows().dtype(), "CsrSoftmaxKernel", ([&] {
const data_t* x_crows_data = x_crows.data<data_t>();
for (int i = 0; i < row_number; ++i) {
row_first = static_cast<int>(x_crows_data[i]);
row_nnz = static_cast<int>(x_crows_data[i + 1] - x_crows_data[i]);
x_data = x_data + row_first;
out_data = out_data + row_first;
row_max_val = *std::max_element(x_data, x_data + row_nnz);
phi::funcs::vec_add_bias<T, plt::avx>(
row_nnz, static_cast<T>(-1) * row_max_val, x_data, out_data);
phi::funcs::vec_exp<T>(row_nnz, out_data, out_data);
T sum = 0;
phi::funcs::vec_sum<T, plt::avx>(row_nnz, out_data, &sum);
phi::funcs::vec_scal<T, plt::avx>(
row_nnz, static_cast<T>(1) / sum, out_data, out_data);
}
}));
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(softmax_csr,
CPU,
ALL_LAYOUT,
phi::sparse::SoftmaxCsrKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void EmptyLikeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
SparseCsrTensor* out) {
const DenseTensor& x_crows = x.non_zero_crows();
const DenseTensor& x_cols = x.non_zero_cols();
const DenseTensor& x_values = x.non_zero_elements();
DenseTensor* out_crows = out->mutable_non_zero_crows();
DenseTensor* out_cols = out->mutable_non_zero_cols();
DenseTensor* out_values = out->mutable_non_zero_elements();
out->set_dims(x.dims());
phi::Copy(dev_ctx, x_crows, dev_ctx.GetPlace(), false, out_crows);
phi::Copy(dev_ctx, x_cols, dev_ctx.GetPlace(), false, out_cols);
out_values->Resize(x_values.dims());
dev_ctx.template Alloc<T>(out_values);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(empty_like_csr,
CPU,
ALL_LAYOUT,
phi::sparse::EmptyLikeCsrKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(empty_like_csr,
GPU,
ALL_LAYOUT,
phi::sparse::EmptyLikeCsrKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/sparse_csr_tensor.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void EmptyLikeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
SparseCsrTensor* out);
} // namespace sparse
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/softmax_grad_kernel.h"
namespace phi {
namespace sparse {
template <typename T, typename IntT = int>
__global__ void SoftmaxGradGpuKernel(const IntT* out_crows,
const T* out_values,
const T* dout_values,
T* dx_values,
int row_number) {
// dx = (dout - sum(dout * out)) * out
int row = blockIdx.x * blockDim.y + threadIdx.y;
int non_zero_idx = threadIdx.x;
if (row >= row_number) return;
int row_first = static_cast<int>(out_crows[row]);
int row_nnz = static_cast<int>(out_crows[row + 1] - out_crows[row]);
if (row_nnz == 0) return;
int kIteration = (row_nnz + warpSize - 1) / warpSize;
T mul_result = 0;
for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * warpSize;
if (idx >= row_nnz) break;
mul_result += out_values[row_first + idx] * dout_values[row_first + idx];
}
T sum = phi::funcs::warpReduceSum<T>(mul_result, 0xFFFFFFFF);
for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * warpSize;
if (idx >= row_nnz) break;
dx_values[row_first + idx] =
(dout_values[row_first + idx] - sum) * out_values[row_first + idx];
}
}
template <typename T, typename Context>
void SoftmaxCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& out,
const SparseCsrTensor& dout,
int axis,
SparseCsrTensor* dx) {
PADDLE_ENFORCE_EQ(axis,
-1,
phi::errors::Unimplemented(
"SparseCsrTensor only support axis=-1 for softmax, "
"which is faster when reading data by row (axis=-1)"));
EmptyLikeCsrKernel<T, Context>(dev_ctx, dout, dx);
auto out_dim = out.dims();
int row_number = 1;
for (int i = 0; i < out_dim.size() - 1; ++i) {
row_number *= out_dim[i];
}
dim3 grid((row_number + 3) / 4);
dim3 block(32, 4);
PD_VISIT_INTEGRAL_TYPES(
out.non_zero_crows().dtype(), "SoftmaxCsrGradKernel", ([&] {
SoftmaxGradGpuKernel<T, data_t><<<grid, block, 0, dev_ctx.stream()>>>(
out.non_zero_crows().data<data_t>(),
out.non_zero_elements().data<T>(),
dout.non_zero_elements().data<T>(),
dx->mutable_non_zero_elements()->data<T>(),
row_number);
}));
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(softmax_csr_grad,
GPU,
ALL_LAYOUT,
phi::sparse::SoftmaxCsrGradKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/softmax_kernel.h"
namespace phi {
namespace sparse {
template <typename T, typename IntT = int>
__global__ void SoftmaxGpuKernel(const IntT* x_crows,
const T* x_values,
T* out_values,
int row_number) {
// out = exp(x-x_max) / sum(exp(x-x_max))
int row = blockIdx.x * blockDim.y + threadIdx.y;
int non_zero_idx = threadIdx.x;
if (row >= row_number) return;
int row_first = static_cast<int>(x_crows[row]);
int row_nnz = static_cast<int>(x_crows[row + 1] - x_crows[row]);
if (row_nnz == 0) return;
int kIteration = (row_nnz + warpSize - 1) / warpSize;
T max_val = -std::numeric_limits<T>::infinity();
for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * warpSize;
if (idx >= row_nnz) break;
if (max_val < x_values[row_first + idx]) {
max_val = x_values[row_first + idx];
}
}
T row_max_val = phi::funcs::warpReduceMax<T>(max_val, 0xFFFFFFFF);
T exp_sum = 0;
for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * warpSize;
if (idx >= row_nnz) break;
auto functor = phi::funcs::CudaExpFunctor<T>();
T exp = functor(x_values[row_first + idx] - row_max_val);
exp_sum += exp;
out_values[row_first + idx] = exp;
}
T row_exp_sum = phi::funcs::warpReduceSum<T>(exp_sum, 0xFFFFFFFF);
for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * warpSize;
if (idx >= row_nnz) break;
out_values[row_first + idx] = out_values[row_first + idx] / row_exp_sum;
}
}
template <typename T, typename Context>
void SoftmaxCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
int axis,
SparseCsrTensor* out) {
PADDLE_ENFORCE_EQ(axis,
-1,
phi::errors::Unimplemented(
"SparseCsrTensor only support axis=-1 for softmax, "
"which is faster when reading data by row (axis=-1)"));
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, out);
auto x_dim = x.dims();
int row_number = 1;
for (int i = 0; i < x_dim.size() - 1; ++i) {
row_number *= x_dim[i];
}
dim3 grid((row_number + 3) / 4);
dim3 block(32, 4);
DenseTensor tmp_tensor =
phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_elements());
PD_VISIT_INTEGRAL_TYPES(x.non_zero_crows().dtype(), "CsrSoftmaxKernel", ([&] {
SoftmaxGpuKernel<T, data_t>
<<<grid, block, 0, dev_ctx.stream()>>>(
x.non_zero_crows().data<data_t>(),
x.non_zero_elements().data<T>(),
out->mutable_non_zero_elements()->data<T>(),
row_number);
}));
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(softmax_csr,
GPU,
ALL_LAYOUT,
phi::sparse::SoftmaxCsrKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/sparse_csr_tensor.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SoftmaxCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& out,
const SparseCsrTensor& dout,
int axis,
SparseCsrTensor* dx);
} // namespace sparse
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/sparse_csr_tensor.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SoftmaxCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& X,
int axis,
SparseCsrTensor* out);
} // namespace sparse
} // namespace phi
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid.framework import _test_eager_guard
import numpy as np
import scipy
import scipy.sparse as sp
import unittest
import os
import re
import math
np.random.seed(2022)
class TestCsrSoftmax(unittest.TestCase):
def test_softmax(self):
with _test_eager_guard():
mask = np.random.rand(1, 5) < 0.5
np_x = np.random.rand(1, 5) * mask
np_csr = sp.csr_matrix(np_x)
row_number = np_csr.shape[0]
np_out = np.array([])
for i in range(row_number):
start = np_csr.indptr[i]
end = np_csr.indptr[i + 1]
if start == end:
continue
x = np_csr.data[start:end]
x_max = np.max(x, keepdims=True)
x_exp = np.exp(x - x_max)
x_exp_sum = np.sum(x_exp, keepdims=True)
np_out = np.concatenate([np_out, x_exp / x_exp_sum])
csr = paddle.to_tensor(np_x, stop_gradient=False).to_sparse_csr()
m = paddle.incubate.sparse.nn.Softmax()
out = m(csr)
self.assertTrue(np.allclose(out.crows().numpy(), np_csr.indptr))
self.assertTrue(np.allclose(out.cols().numpy(), np_csr.indices))
self.assertTrue(np.allclose(out.values().numpy(), np_out))
# dx = (dout - sum(dout * out)) * out, dout=rand_x
out.backward(csr.detach())
for i in range(row_number):
start = np_csr.indptr[i]
end = np_csr.indptr[i + 1]
if start == end:
continue
out = np_out[start:end]
dout = np_csr.data[start:end]
sum = np.sum(dout * out, keepdims=True)
dx = (dout - sum) * out
self.assertTrue(np.allclose(csr.grad.crows().numpy(),
np_csr.indptr))
self.assertTrue(np.allclose(csr.grad.cols().numpy(),
np_csr.indices))
self.assertTrue(np.allclose(csr.grad.values().numpy(), dx))
if __name__ == "__main__":
unittest.main()
......@@ -27,7 +27,7 @@ class TestSparseUnary(unittest.TestCase):
def assert_raises_on_dense_tensor(self, sparse_func):
with _test_eager_guard():
dense_x = paddle.ones((2, 3))
with self.assertRaises(ValueError):
with self.assertRaises(NotImplementedError):
sparse_func(dense_x)
def compare_with_dense(
......
......@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.common_ops_import import dygraph_only
from paddle import _C_ops
from paddle.fluid.framework import dygraph_only
__all__ = []
......
......@@ -15,6 +15,7 @@
from . import functional
from .layer.activation import ReLU
from .layer.activation import Softmax
from .layer.norm import BatchNorm
from .layer.conv import Conv3D
from .layer.conv import SubmConv3D
......@@ -22,6 +23,7 @@ from .layer.pooling import MaxPool3D
__all__ = [
'ReLU',
'Softmax',
'BatchNorm',
'Conv3D',
'SubmConv3D',
......
......@@ -16,10 +16,12 @@ from .conv import conv3d # noqa: F401
from .conv import subm_conv3d # noqa: F401
from .pooling import max_pool3d # noqa: F401
from .activation import relu # noqa: F401
from .activation import softmax # noqa: F401
__all__ = [
'conv3d',
'subm_conv3d',
'max_pool3d',
'relu',
'softmax',
]
......@@ -15,8 +15,10 @@
__all__ = []
from paddle import _C_ops, in_dynamic_mode
from paddle.fluid.framework import dygraph_only
@dygraph_only
def relu(x, name=None):
"""
sparse relu activation, requiring x to be a sparse coo or sparse csr tensor.
......@@ -44,12 +46,63 @@ def relu(x, name=None):
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.nn.functional.relu(sparse_x)
"""
return _C_ops.final_state_sparse_relu(x)
assert in_dynamic_mode(), "Currently, Sparse API only support dynamic mode"
if x.is_sparse_coo() or x.is_sparse_csr():
return _C_ops.final_state_sparse_relu(x)
else:
raise ValueError(
"Currently, sparse.relu only support the input of SparseCooTensor or SparseCsrTensor"
)
@dygraph_only
def softmax(x, axis=-1, name=None):
"""
sparse softmax activation, x must be SparseCsrTensor or SparseCooTensor.
Note:
Only supported axis=-1 for SparseCsrTensor, which is faster when read data
by row (axis=-1).
From the point of view of dense matrix, for each row :math:`i` and each column :math:`j`
in the matrix, we have:
.. math::
softmax_ij = \frac{\exp(x_ij - max_j(x_ij))}{\sum_j(exp(x_ij - max_j(x_ij))}
Parameters:
x (Tensor): The input tensor. It can be SparseCooTensor/SparseCsrTensor. The data type can be float32 or float64.
axis (int, optional): The axis along which to perform softmax calculations. Only support -1 for SparseCsrTensor.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: SparseCoo or SparseCsr, whose layout is the same with `x` .
Examples:
.. code-block:: python
import paddle
import numpy as np
from paddle.fluid.framework import _test_eager_guard
paddle.seed(100)
with _test_eager_guard():
mask = np.random.rand(3, 4) < 0.5
np_x = np.random.rand(3, 4) * mask
# [[0. 0. 0.96823406 0.19722934]
# [0.94373937 0. 0.02060066 0.71456372]
# [0. 0. 0. 0.98275049]]
csr = paddle.to_tensor(np_x).to_sparse_csr()
# Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True,
# crows=[0, 2, 5, 6],
# cols=[2, 3, 0, 2, 3, 3],
# values=[0.96823406, 0.19722934, 0.94373937, 0.02060066, 0.71456372,
# 0.98275049])
out = paddle.incubate.sparse.nn.functional.softmax(csr)
# Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True,
# crows=[0, 2, 5, 6],
# cols=[2, 3, 0, 2, 3, 3],
# values=[0.68373820, 0.31626180, 0.45610887, 0.18119845, 0.36269269,
# 1. ])
"""
return _C_ops.final_state_sparse_softmax(x, axis)
......@@ -59,3 +59,72 @@ class ReLU(Layer):
def extra_repr(self):
name_str = 'name={}'.format(self._name) if self._name else ''
return name_str
class Softmax(Layer):
"""
sparse softmax activation, x must be SparseCsrTensor or SparseCooTensor.
Note:
Only supported axis=-1 for SparseCsrTensor, which is faster when read data
by row (axis=-1).
From the point of view of dense matrix, for each row :math:`i` and each column :math:`j`
in the matrix, we have:
.. math::
softmax_ij = \frac{\exp(x_ij - max_j(x_ij))}{\sum_j(exp(x_ij - max_j(x_ij))}
Parameters:
axis (int, optional): The axis along which to perform softmax calculations. Only support -1 for SparseCsrTensor.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: SparseCooTensor / SparseCsrTensor with any shape.
- output: Sparse Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
import numpy as np
from paddle.fluid.framework import _test_eager_guard
paddle.seed(100)
with _test_eager_guard():
mask = np.random.rand(3, 4) < 0.5
np_x = np.random.rand(3, 4) * mask
# [[0. 0. 0.96823406 0.19722934]
# [0.94373937 0. 0.02060066 0.71456372]
# [0. 0. 0. 0.98275049]]
csr = paddle.to_tensor(np_x).to_sparse_csr()
# Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True,
# crows=[0, 2, 5, 6],
# cols=[2, 3, 0, 2, 3, 3],
# values=[0.96823406, 0.19722934, 0.94373937, 0.02060066, 0.71456372,
# 0.98275049])
m = paddle.incubate.sparse.nn.Softmax()
out = m(csr)
# Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True,
# crows=[0, 2, 5, 6],
# cols=[2, 3, 0, 2, 3, 3],
# values=[0.68373820, 0.31626180, 0.45610887, 0.18119845, 0.36269269,
# 1. ])
"""
def __init__(self, axis=-1, name=None):
super(Softmax, self).__init__()
self._axis = axis
self._name = name
def forward(self, x):
return F.softmax(x, self._axis, self._name)
def extra_repr(self):
name_str = 'name={}'.format(self._name) if self._name else ''
return name_str
......@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = []
from paddle import _C_ops
from paddle.fluid.framework import dygraph_only
from paddle import _C_ops, in_dynamic_mode
__all__ = []
@dygraph_only
def tanh(x, name=None):
"""
sparse tanh activation, requiring x to be a sparse coo or sparse csr tensor.
......@@ -44,17 +46,10 @@ def tanh(x, name=None):
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.tanh(sparse_x)
"""
assert in_dynamic_mode(), "Currently, Sparse API only support dynamic mode"
if x.is_sparse_coo() or x.is_sparse_csr():
return _C_ops.final_state_sparse_tanh(x)
else:
raise ValueError(
"Currently, sparse.tanh only support the input of SparseCooTensor or SparseCsrTensor"
)
return _C_ops.final_state_sparse_tanh(x)
@dygraph_only
def sqrt(x, name=None):
"""
Calculate square root of x, requiring x to be a sparse coo or sparse csr tensor.
......@@ -82,17 +77,10 @@ def sqrt(x, name=None):
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.sqrt(sparse_x)
"""
assert in_dynamic_mode(), "Currently, Sparse API only support dynamic mode"
if x.is_sparse_coo() or x.is_sparse_csr():
return _C_ops.final_state_sparse_sqrt(x)
else:
raise ValueError(
"Currently, sparse.sqrt only support the input of SparseCooTensor or SparseCsrTensor"
)
return _C_ops.final_state_sparse_sqrt(x)
@dygraph_only
def sin(x, name=None):
"""
Calculate sin of x, requiring x to be a sparse coo or sparse csr tensor.
......@@ -120,12 +108,4 @@ def sin(x, name=None):
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.sin(sparse_x)
"""
assert in_dynamic_mode(), "Currently, Sparse API only support dynamic mode"
if x.is_sparse_coo() or x.is_sparse_csr():
return _C_ops.final_state_sparse_sin(x)
else:
raise ValueError(
"Currently, sparse.sin only support the input of SparseCooTensor or SparseCsrTensor"
)
return _C_ops.final_state_sparse_sin(x)
......@@ -46,6 +46,14 @@
layout : x
backward : sin_grad
- api : softmax
args : (Tensor x, int axis=-1)
output : Tensor(out)
kernel :
func : softmax_csr{sparse_csr -> sparse_csr}
layout : x
backward : softmax_grad
- api : sqrt
args : (Tensor x)
output : Tensor(out)
......
......@@ -53,6 +53,13 @@
kernel :
func : sparse_coo_sin_grad {sparse_coo, sparse_coo -> sparse_coo}
- backward_api : softmax_grad
forward : softmax(Tensor x, int axis=-1) -> Tensor(out)
args : (Tensor out, Tensor out_grad, int axis)
output : Tensor(x_grad)
kernel :
func : softmax_csr_grad{sparse_csr, sparse_csr -> sparse_csr}
- backward_api : sparse_maxpool_grad
forward : sparse_maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook)
args : (Tensor x, Tensor rulebook, Tensor out, Tensor out_grad, int[] kernel_sizes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册