未验证 提交 4ea1d041 编写于 作者: T thunder95 提交者: GitHub

【Hackathon 4th No.26】为 Paddle 新增 paddle.sparse.nn.Softmax 稀疏 API 的 coo 格式计算逻辑 (#53613)

上级 3143d8bf
...@@ -322,7 +322,8 @@ ...@@ -322,7 +322,8 @@
func : UnchangedInferMeta func : UnchangedInferMeta
param : [out] param : [out]
kernel : kernel :
func : softmax_csr_grad{sparse_csr, sparse_csr -> sparse_csr} func : softmax_coo_grad{sparse_coo, sparse_coo -> sparse_coo},
softmax_csr_grad{sparse_csr, sparse_csr -> sparse_csr}
- backward_op : sparse_coo_tensor_grad - backward_op : sparse_coo_tensor_grad
forward : sparse_coo_tensor(Tensor values, Tensor indices, int64_t[] shape) -> Tensor(out) forward : sparse_coo_tensor(Tensor values, Tensor indices, int64_t[] shape) -> Tensor(out)
......
...@@ -286,7 +286,8 @@ ...@@ -286,7 +286,8 @@
func : UnchangedInferMeta func : UnchangedInferMeta
param : [x] param : [x]
kernel : kernel :
func : softmax_csr{sparse_csr -> sparse_csr} func : softmax_coo{sparse_coo -> sparse_coo},
softmax_csr{sparse_csr -> sparse_csr}
layout : x layout : x
backward : softmax_grad backward : softmax_grad
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
namespace phi {
namespace funcs {
namespace sparse {
/* Given the indices of a sparse tensor, return a vector of offsets
for the entries in the equivalent dense tensor. */
template <typename IntT, typename Context>
inline DenseTensor GetOffsets(const Context& dev_ctx,
const DenseTensor& indices,
const std::vector<IntT>& sizes,
const IntT dim) {
#ifdef __HIPCC__
const auto& policy = thrust::hip::par.on(dev_ctx.stream());
#else
const auto& policy = thrust::cuda::par.on(dev_ctx.stream());
#endif
auto ndim = indices.dims()[0];
auto nnz = indices.dims()[1];
std::vector<IntT> host_strides(ndim, 1);
if (ndim > 1) {
for (IntT i = ndim - 2; i >= 0; i--) {
host_strides[i] = host_strides[i + 1] * (i + 1 == dim ? 1 : sizes[i + 1]);
}
}
const IntArray strides_shape(phi::vectorize<IntT>(indices.dims()));
DenseTensor strides = phi::Empty<IntT>(dev_ctx, strides_shape);
auto strides_ptr = strides.data<IntT>();
memory_utils::Copy(dev_ctx.GetPlace(),
strides_ptr,
phi::CPUPlace(),
host_strides.data(),
sizeof(IntT) * host_strides.size(),
dev_ctx.stream());
DenseTensor offsets = phi::Empty<IntT>(dev_ctx, {nnz});
auto indices_ptr = indices.data<IntT>();
thrust::transform(
policy,
thrust::make_counting_iterator(IntT(0)),
thrust::make_counting_iterator(IntT(nnz)),
thrust::device_ptr<IntT>(offsets.data<IntT>()),
[strides_ptr, indices_ptr, nnz, dim, ndim] __device__(IntT x) {
IntT pool_index = 0;
for (IntT j = 0; j < ndim; j++) {
if (j != dim) {
auto indice_cur_ptr = indices_ptr + j * nnz + x;
auto stride = strides_ptr[j];
pool_index += stride * (*indice_cur_ptr);
}
}
return pool_index;
});
return offsets;
}
/* Return pools of indices that align with the given dimension and the
corresponding max values for each pool. */
template <typename T,
typename IntT,
typename Context,
bool requireMxRows = true>
std::tuple<DenseTensor, DenseTensor, DenseTensor, DenseTensor> ComputePoolMax(
const Context& dev_ctx,
const DenseTensor& indices,
const DenseTensor& values,
const std::vector<IntT>& sizes,
IntT nvalues,
const IntT dim) {
#ifdef __HIPCC__
const auto& policy = thrust::hip::par.on(dev_ctx.stream());
#else
const auto& policy = thrust::cuda::par.on(dev_ctx.stream());
#endif
using thrust_ptr = thrust::device_ptr<IntT>;
auto nnz = indices.dims()[1];
DenseTensor offsets = phi::funcs::sparse::GetOffsets<IntT, Context>(
dev_ctx, indices, sizes, dim);
auto offsets_ptr = offsets.data<IntT>();
phi::DenseTensor sorted_indices = phi::Empty<IntT>(dev_ctx, {nnz});
thrust_ptr sorted_indices_thrust_ptr(sorted_indices.data<IntT>());
thrust::sequence(
policy, sorted_indices_thrust_ptr, sorted_indices_thrust_ptr + nnz, 0);
/* sort indices corresponding to offsets */
thrust::sort(policy,
sorted_indices_thrust_ptr,
sorted_indices_thrust_ptr + nnz,
[offsets_ptr] __device__(IntT x, IntT y) {
return offsets_ptr[x] < offsets_ptr[y];
});
DenseTensor pool_sizes = phi::Empty<IntT>(dev_ctx, {nnz});
/* reduce the elements which are groupped by pool index,
returns all the pool indexes with unique offset value for each. */
auto new_end =
thrust::reduce_by_key(policy,
sorted_indices_thrust_ptr,
sorted_indices_thrust_ptr + nnz,
thrust::make_constant_iterator(IntT(1)),
thrust::make_discard_iterator(),
thrust_ptr(pool_sizes.data<IntT>()),
[offsets_ptr] __device__(IntT x, IntT y) {
return offsets_ptr[x] == offsets_ptr[y];
});
auto new_sz =
thrust::distance(thrust_ptr(pool_sizes.data<IntT>()), new_end.second);
pool_sizes.Resize(phi::make_ddim({new_sz}));
DenseTensor pool_offsets;
pool_offsets.Resize(phi::make_ddim({new_sz}));
dev_ctx.template Alloc<T>(&pool_offsets);
phi::Copy(dev_ctx, pool_sizes, dev_ctx.GetPlace(), false, &pool_offsets);
/* accumulate value for each pool index */
thrust_ptr pool_offsets_thrust_ptr(pool_offsets.data<IntT>());
thrust::exclusive_scan(policy,
pool_offsets_thrust_ptr,
pool_offsets_thrust_ptr + new_sz,
pool_offsets_thrust_ptr);
DenseTensor mx_buffer;
if (requireMxRows) {
mx_buffer = phi::Full<T>(
dev_ctx, {new_sz * nvalues}, -std::numeric_limits<T>::infinity());
auto mx_buffer_ptr = mx_buffer.data<T>();
auto pool_sizes_ptr = pool_sizes.data<IntT>();
auto sorted_indices_ptr = sorted_indices.data<IntT>();
auto pool_offsets_ptr = pool_offsets.data<IntT>();
auto values_ptr = values.data<T>();
/* calculate max value in each pool. */
thrust::for_each(policy,
thrust::make_counting_iterator(IntT(0)),
thrust::make_counting_iterator(IntT(new_sz)),
[sorted_indices_ptr,
pool_sizes_ptr,
pool_offsets_ptr,
mx_buffer_ptr,
values_ptr,
nvalues] __device__(IntT index) {
IntT curr_pool_size = pool_sizes_ptr[index];
auto mx_row = mx_buffer_ptr + index * nvalues;
IntT offset = pool_offsets_ptr[index];
for (IntT p = 0; p < curr_pool_size; p++) {
IntT i = *(sorted_indices_ptr + offset + p);
for (IntT j = 0; j < nvalues; j++) {
auto value_tmp = *(values_ptr);
mx_row[j] = std::max(mx_row[j], value_tmp);
}
}
});
}
return std::make_tuple(sorted_indices, pool_offsets, pool_sizes, mx_buffer);
}
inline int GetNumThreads(int nElem) {
#if defined(PADLDE_WITH_ROCM)
int threadSizes[5] = {16, 32, 64, 128, 256};
#else
int threadSizes[5] = {32, 64, 128, 256, 512};
#endif
for (int i = 0; i != 5; ++i) {
if (nElem <= threadSizes[i]) {
return threadSizes[i];
}
}
return threadSizes[4];
}
} // namespace sparse
} // namespace funcs
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/tensor_utils.h"
namespace phi {
namespace funcs {
namespace sparse {
template <typename IntT>
inline void GetPoolsSoftmax(const DenseTensor& indices,
const std::vector<IntT>& sizes,
const int dim,
std::map<IntT, std::vector<IntT>>* pools) {
auto ndim = indices.dims()[0];
auto nnz = indices.dims()[1];
std::vector<IntT> strides(ndim, 1);
if (ndim > 1) {
for (IntT i = ndim - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * (i + 1 == dim ? 1 : sizes[i + 1]);
}
}
auto* indices_data = indices.data<IntT>();
for (IntT i = 0; i < nnz; i++) {
IntT pool_index = 0;
for (IntT j = 0; j < ndim; j++) {
if (j == dim) continue;
pool_index += strides[j] * indices_data[j * nnz + i];
}
if (pools->find(pool_index) == pools->end()) {
std::vector<IntT> vec;
(*pools)[pool_index] = vec;
}
(*pools)[pool_index].push_back(i);
}
}
template <typename IntT>
inline std::vector<IntT> GetOffsets(const DenseTensor& indices,
const std::vector<IntT>& sizes,
const int dim) {
auto ndim = indices.dims()[0];
auto nnz = indices.dims()[1];
std::vector<IntT> offsets(nnz);
std::vector<IntT> strides(ndim, 1);
auto indices_ptr = indices.data<IntT>();
if (ndim > 1) {
for (IntT i = ndim - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * sizes[i + 1];
}
}
for (int i = 0; i < nnz; i++) {
IntT acc = 0;
for (int j = 0; j < ndim; j++) {
auto indices_cur = indices_ptr + j * nnz + i;
auto stride = strides[j];
if (j != dim) {
acc += stride * (*indices_cur);
}
}
offsets[i] = acc;
}
return offsets;
}
} // namespace sparse
} // namespace funcs
} // namespace phi
...@@ -16,9 +16,14 @@ limitations under the License. */ ...@@ -16,9 +16,14 @@ limitations under the License. */
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h" #include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/cpu_vec.h" #include "paddle/phi/kernels/funcs/cpu_vec.h"
#include "paddle/phi/kernels/funcs/sparse/softmax.h"
#include "paddle/phi/kernels/softmax_grad_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h" #include "paddle/phi/kernels/sparse/empty_kernel.h"
namespace phi { namespace phi {
...@@ -85,6 +90,119 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx, ...@@ -85,6 +90,119 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx,
})); }));
} }
template <typename T, typename IntT, typename Context>
void SoftmaxCooGradCPUKernel(const Context& dev_ctx,
const SparseCooTensor& out,
const SparseCooTensor& dout,
int axis,
SparseCooTensor* dx) {
auto out_indices = out.indices();
auto out_values = out.values();
const auto out_dims = out.dims();
auto sparse_dim = out.sparse_dim();
auto sizes = phi::vectorize<IntT>(out_dims);
auto grad_indices = dout.indices();
auto grad_values = dout.values();
auto grad_nnz = dout.nnz();
*(dx->mutable_indices()) = out_indices;
DenseTensor* values = dx->mutable_values();
values->Resize(out_dims);
values->set_meta(out_values.meta());
dev_ctx.template Alloc<T>(values);
auto out_offsets = phi::funcs::sparse::GetOffsets(out_indices, sizes, -1);
auto grad_offsets = phi::funcs::sparse::GetOffsets(grad_indices, sizes, -1);
int dim = axis < 0 ? out_dims.size() + axis : axis;
if (dim >= sparse_dim) {
bool is_same_offset = out_offsets == grad_offsets;
PADDLE_ENFORCE_EQ(
is_same_offset,
true,
phi::errors::Unimplemented(
"SparseCooTensor only support same offsets for softmax."));
SoftmaxGradKernel<T, Context>(
dev_ctx, out_values, grad_values, dim - sparse_dim + 1, values);
return;
}
auto nnz = out.nnz();
IntT nvalues = std::accumulate(sizes.begin() + sparse_dim,
sizes.end(),
static_cast<IntT>(1),
std::multiplies<>());
DenseTensor values_2(*values);
values_2.Resize(phi::make_ddim({nnz, nvalues}));
DenseTensor out_values_2(out_values);
out_values_2.Resize(phi::make_ddim({nnz, nvalues}));
DenseTensor grad_values_2(grad_values);
grad_values_2.Resize(phi::make_ddim({nnz, nvalues}));
std::map<IntT, std::vector<IntT>> pools;
phi::funcs::sparse::GetPoolsSoftmax(out_indices, sizes, dim, &pools);
for (size_t p = 0; p < pools.size(); p++) {
auto pool_indices = pools[p];
if (pool_indices.empty()) continue;
std::vector<T> tmp_row(nvalues, 0);
/* Compute tmp = - sum_j output_j * grad_j */
for (IntT i : pool_indices) {
auto out_values_row = out_values_2.data<T>() + i * nvalues;
auto low = std::lower_bound(
grad_offsets.begin(), grad_offsets.end(), out_offsets[i]);
auto j = low - grad_offsets.begin();
if (j < grad_nnz && (out_offsets[i] == grad_offsets[j])) {
auto grad_values_row = grad_values_2.data<T>() + j * nvalues;
for (IntT k = 0; k < nvalues; k++) {
tmp_row[k] -= (*(out_values_row + k)) * (*(grad_values_row + k));
}
}
}
/* Compute grad_input = output * (grad + tmp)*/
for (IntT i : pool_indices) {
auto out_values_row = out_values_2.data<T>() + i * nvalues;
auto values_row = values_2.data<T>() + i * nvalues;
auto low = std::lower_bound(
grad_offsets.begin(), grad_offsets.end(), out_offsets[i]);
auto j = low - grad_offsets.begin();
if (j < grad_nnz && (out_offsets[i] == grad_offsets[j])) {
auto grad_values_row = grad_values_2.data<T>() + j * nvalues;
for (IntT k = 0; k < nvalues; k++) {
*(values_row + k) =
(*(out_values_row + k)) * ((*(grad_values_row + k)) + tmp_row[k]);
}
} else {
for (IntT k = 0; k < nvalues; k++) {
*(values_row + k) = (*out_values_row + k) * (tmp_row[k]);
}
}
}
}
}
template <typename T, typename Context>
void SoftmaxCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& out,
const SparseCooTensor& dout,
int axis,
SparseCooTensor* dx) {
PD_VISIT_BASE_INTEGRAL_TYPES(
out.indices().dtype(), "SoftmaxCooGradCPUKernel", ([&] {
SoftmaxCooGradCPUKernel<T, data_t, Context>(
dev_ctx, out, dout, axis, dx);
}));
}
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
...@@ -96,3 +214,12 @@ PD_REGISTER_KERNEL(softmax_csr_grad, ...@@ -96,3 +214,12 @@ PD_REGISTER_KERNEL(softmax_csr_grad,
double) { double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
} }
PD_REGISTER_KERNEL(softmax_coo_grad,
CPU,
ALL_LAYOUT,
phi::sparse::SoftmaxCooGradKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
...@@ -18,7 +18,10 @@ limitations under the License. */ ...@@ -18,7 +18,10 @@ limitations under the License. */
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h" #include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/cpu_vec.h" #include "paddle/phi/kernels/funcs/cpu_vec.h"
#include "paddle/phi/kernels/funcs/sparse/softmax.h"
#include "paddle/phi/kernels/softmax_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h" #include "paddle/phi/kernels/sparse/empty_kernel.h"
namespace phi { namespace phi {
...@@ -85,6 +88,90 @@ void SoftmaxCsrKernel(const Context& dev_ctx, ...@@ -85,6 +88,90 @@ void SoftmaxCsrKernel(const Context& dev_ctx,
})); }));
} }
template <typename T, typename IntT, typename Context>
void SoftmaxCooCPUKernel(const Context& dev_ctx,
const SparseCooTensor& x,
int axis,
SparseCooTensor* out) {
auto indices = x.indices();
auto values = x.values();
const auto x_dims = x.dims();
const auto sparse_dim = x.sparse_dim();
DenseTensor out_indices(indices);
DenseTensor out_values = EmptyLike<T, Context>(dev_ctx, values);
out->SetMember(out_indices, out_values, x.dims(), x.coalesced());
int dim = axis < 0 ? x_dims.size() + axis : axis;
/* If dim is greater than or equal to sparse_dim, the dense softmax is used.
*/
if (dim >= sparse_dim) {
SoftmaxKernel<T, Context>(
dev_ctx, values, dim - sparse_dim + 1, &out_values);
return;
}
const std::vector<IntT> sizes = phi::vectorize<IntT>(x_dims);
std::map<IntT, std::vector<IntT>> pools;
IntT nvalues = std::accumulate(sizes.begin() + sparse_dim,
sizes.end(),
static_cast<IntT>(1),
std::multiplies<>());
phi::funcs::sparse::GetPoolsSoftmax(out_indices, sizes, dim, &pools);
auto values_ptr = values.data<T>();
auto out_values_ptr = out_values.data<T>();
for (size_t p = 0; p < pools.size(); p++) {
auto pool_indices = pools[p];
if (pool_indices.empty()) {
continue;
}
std::vector<T> mx_row(nvalues, -std::numeric_limits<T>::infinity());
std::vector<T> exp_sums_row(nvalues, 0);
IntT pool_size = static_cast<IntT>(pool_indices.size());
// Compute max for each pool
for (IntT i = 0; i < pool_size; i++) {
auto values_row = values_ptr + pool_indices[i] * nvalues;
for (IntT j = 0; j < nvalues; j++) {
mx_row[j] = std::max(mx_row[j], *(values_row + j));
}
}
// exp to (v - mx) and sum the results
for (IntT i = 0; i < pool_size; i++) {
auto values_row = values_ptr + pool_indices[i] * nvalues;
auto out_values_row = out_values_ptr + pool_indices[i] * nvalues;
for (IntT j = 0; j < nvalues; j++) {
auto v = std::exp(*(values_row + j) - mx_row[j]);
out_values_row[j] = v;
exp_sums_row[j] += v;
}
}
/* Normalize with the sum of exponents */
for (IntT i = 0; i < pool_size; i++) {
auto out_values_row = out_values_ptr + pool_indices[i] * nvalues;
for (IntT j = 0; j < nvalues; j++) {
out_values_row[j] *= 1.0 / exp_sums_row[j];
}
}
}
}
// cpu kerenel
template <typename T, typename Context>
void SoftmaxCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
int axis,
SparseCooTensor* out) {
PD_VISIT_BASE_INTEGRAL_TYPES(
x.indices().dtype(), "SoftmaxCooCPUKernel", ([&] {
SoftmaxCooCPUKernel<T, data_t, Context>(dev_ctx, x, axis, out);
}));
}
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
...@@ -96,3 +183,12 @@ PD_REGISTER_KERNEL(softmax_csr, ...@@ -96,3 +183,12 @@ PD_REGISTER_KERNEL(softmax_csr,
double) { double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
} }
PD_REGISTER_KERNEL(softmax_coo,
CPU,
ALL_LAYOUT,
phi::sparse::SoftmaxCooKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
...@@ -14,10 +14,24 @@ limitations under the License. */ ...@@ -14,10 +14,24 @@ limitations under the License. */
#include "paddle/phi/kernels/sparse/softmax_grad_kernel.h" #include "paddle/phi/kernels/sparse/softmax_grad_kernel.h"
#include <thrust/binary_search.h>
#include <thrust/device_ptr.h>
#include <thrust/equal.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <thrust/transform.h>
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h" #include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h" #include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/softmax.cu.h"
#include "paddle/phi/kernels/softmax_grad_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h" #include "paddle/phi/kernels/sparse/empty_kernel.h"
namespace phi { namespace phi {
...@@ -104,6 +118,184 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx, ...@@ -104,6 +118,184 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx,
})); }));
} }
template <typename T, typename IntT>
__global__ void SoftmaxCooGradGPURawKernel(IntT* sorted_pool_indices,
IntT size,
IntT* pool_sizes,
IntT* pool_offsets,
IntT nvalues,
IntT grad_nnz,
IntT* grad_offsets,
IntT* out_offsets,
IntT* lower_bound_values,
T* values,
T* out_values,
T* grad_values,
int total_rows) {
int row = blockIdx.x * blockDim.y + threadIdx.y;
if (row >= total_rows) return;
int tid = threadIdx.x;
int index = row / nvalues;
int nval = row % nvalues;
IntT offset = pool_offsets[index];
IntT* pool_indices = sorted_pool_indices + offset;
IntT pool_indices_size = pool_sizes[index];
int kIteration = (pool_indices_size + warpSize - 1) / warpSize;
T mul_result = 0;
for (int k = 0; k < kIteration; ++k) {
int idx = tid + k * warpSize;
if (idx >= pool_indices_size) break;
auto i = pool_indices[idx];
auto cur_out_value = out_values + i * nvalues;
auto j = lower_bound_values[i];
if (j < grad_nnz && (out_offsets[i] == grad_offsets[j])) {
auto cur_grad_value = grad_values + j * nvalues;
mul_result += (*(cur_out_value + nval)) * (*(cur_grad_value + nval));
}
}
T sum = phi::funcs::WarpReduceSum<T>(mul_result, 0xFFFFFFFF);
for (int k = 0; k < kIteration; ++k) {
int idx = tid + k * warpSize;
if (idx >= pool_indices_size) break;
auto i = pool_indices[idx];
auto j = lower_bound_values[i];
auto cur_out_value = out_values + i * nvalues;
auto cur_value = values + i * nvalues;
auto cur_grad_value = grad_values + j * nvalues;
if (j < grad_nnz && (out_offsets[i] == grad_offsets[j])) {
cur_value[nval] =
(*(cur_out_value + nval)) * (*(cur_grad_value + nval) - sum);
} else {
cur_value[nval] = -(*(cur_out_value + nval)) * sum;
}
}
}
template <typename T, typename IntT, typename Context>
void SoftmaxCooGradGPUKernel(const Context& dev_ctx,
const SparseCooTensor& out,
const SparseCooTensor& dout,
int axis,
SparseCooTensor* dx) {
using thrust_ptr = thrust::device_ptr<IntT>;
auto out_indices = out.indices();
auto out_values = out.values();
auto out_values_ptr = out_values.data<T>();
const auto output_indices_dims = out.indices().dims();
const auto out_dims = out.dims();
auto sparse_dim = out.sparse_dim();
auto sizes = phi::vectorize<IntT>(out_dims);
auto grad_indices = dout.indices();
auto grad_values = dout.values();
auto grad_values_ptr = grad_values.data<T>();
auto out_nnz = out.nnz();
auto grad_nnz = dout.nnz();
auto place = dev_ctx.GetPlace();
auto stream = dev_ctx.stream();
*(dx->mutable_indices()) = out_indices;
DenseTensor* values = dx->mutable_values();
values->Resize(out_dims);
values->set_meta(out_values.meta());
dev_ctx.template Alloc<T>(values);
phi::funcs::SetConstant<GPUContext, T> set_zero;
set_zero(dev_ctx, values, static_cast<T>(0.0f));
DenseTensor out_offsets = phi::funcs::sparse::GetOffsets<IntT, Context>(
dev_ctx, out_indices, sizes, static_cast<IntT>(-1));
auto out_offsets_ptr = out_offsets.data<IntT>();
DenseTensor grad_offsets = phi::funcs::sparse::GetOffsets<IntT, Context>(
dev_ctx, grad_indices, sizes, static_cast<IntT>(-1));
auto grad_offsets_ptr = grad_offsets.data<IntT>();
#ifdef PADDLE_WITH_HIP
const auto& policy = thrust::hip::par.on(dev_ctx.stream());
bool is_same_offset = thrust::equal(thrust::hip::par.on(dev_ctx.stream()),
#else
const auto& policy = thrust::cuda::par.on(dev_ctx.stream());
bool is_same_offset = thrust::equal(thrust::cuda::par.on(dev_ctx.stream()),
#endif
out_offsets_ptr,
out_offsets_ptr + out_offsets.numel(),
grad_offsets_ptr);
int dim = axis < 0 ? out_dims.size() + axis : axis;
if (dim >= sparse_dim) {
PADDLE_ENFORCE_EQ(
is_same_offset,
true,
phi::errors::Unimplemented(
"SparseCooTensor only support same offsets for softmax."));
SoftmaxGradKernel<T, Context>(
dev_ctx, out_values, grad_values, dim - sparse_dim + 1, values);
return;
}
auto nnz = out.nnz();
IntT nvalues = std::accumulate(sizes.begin() + sparse_dim,
sizes.end(),
static_cast<IntT>(1),
std::multiplies<>());
DenseTensor values_2(*values);
values_2.Resize(phi::make_ddim({nnz, nvalues}));
DenseTensor sorted_indices;
DenseTensor pool_offsets;
DenseTensor pool_sizes;
std::tie(sorted_indices, pool_offsets, pool_sizes, std::ignore) =
phi::funcs::sparse::ComputePoolMax<T, IntT, Context, false>(
dev_ctx, out_indices, values_2, sizes, nvalues, dim);
DenseTensor bound =
phi::Empty<IntT>(dev_ctx, {static_cast<IntT>(out_offsets.dims()[0])});
IntT* bound_ptr = bound.data<IntT>();
thrust::lower_bound(policy,
thrust_ptr(grad_offsets_ptr),
thrust_ptr(grad_offsets_ptr + grad_offsets.dims()[0]),
thrust_ptr(out_offsets_ptr),
thrust_ptr(out_offsets_ptr) + out_offsets.dims()[0],
thrust_ptr(bound.data<IntT>()));
auto pool_size = pool_offsets.dims()[0];
int total_rows = pool_size * nvalues;
dim3 grid((total_rows + 15) / 16);
dim3 block(32, 16);
SoftmaxCooGradGPURawKernel<T, IntT>
<<<grid, block, 0, stream>>>(sorted_indices.data<IntT>(),
pool_size,
pool_sizes.data<IntT>(),
pool_offsets.data<IntT>(),
nvalues,
grad_nnz,
grad_offsets.data<IntT>(),
out_offsets.data<IntT>(),
bound_ptr,
values_2.data<T>(),
out_values.data<T>(),
grad_values.data<T>(),
total_rows);
}
template <typename T, typename Context>
void SoftmaxCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& out,
const SparseCooTensor& dout,
int axis,
SparseCooTensor* dx) {
PD_VISIT_BASE_INTEGRAL_TYPES(
out.indices().dtype(), "SoftmaxCooGradGPUKernel", ([&] {
SoftmaxCooGradGPUKernel<T, data_t, Context>(
dev_ctx, out, dout, axis, dx);
}));
}
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
...@@ -115,3 +307,12 @@ PD_REGISTER_KERNEL(softmax_csr_grad, ...@@ -115,3 +307,12 @@ PD_REGISTER_KERNEL(softmax_csr_grad,
double) { double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
} }
PD_REGISTER_KERNEL(softmax_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::SoftmaxCooGradKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
...@@ -14,12 +14,25 @@ limitations under the License. */ ...@@ -14,12 +14,25 @@ limitations under the License. */
#include "paddle/phi/kernels/sparse/softmax_kernel.h" #include "paddle/phi/kernels/sparse/softmax_kernel.h"
#include <thrust/device_ptr.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <thrust/transform.h>
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h" #include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/activation_functor.h" #include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h" #include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
#include "paddle/phi/kernels/funcs/sparse/softmax.cu.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/softmax_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h" #include "paddle/phi/kernels/sparse/empty_kernel.h"
namespace phi { namespace phi {
...@@ -31,7 +44,6 @@ __global__ void SoftmaxGpuKernel(const IntT* x_crows, ...@@ -31,7 +44,6 @@ __global__ void SoftmaxGpuKernel(const IntT* x_crows,
T* out_values, T* out_values,
int row_number, int row_number,
int total_row_number) { int total_row_number) {
// out = exp(x-x_max) / sum(exp(x-x_max))
int row = blockIdx.x * blockDim.y + threadIdx.y; int row = blockIdx.x * blockDim.y + threadIdx.y;
int non_zero_idx = threadIdx.x; int non_zero_idx = threadIdx.x;
if (row >= total_row_number) return; if (row >= total_row_number) return;
...@@ -116,6 +128,132 @@ void SoftmaxCsrKernel(const Context& dev_ctx, ...@@ -116,6 +128,132 @@ void SoftmaxCsrKernel(const Context& dev_ctx,
})); }));
} }
template <typename T, typename IntT>
__global__ void SoftmaxCooGPURawKernel(IntT* sorted_pool_indices,
IntT* pool_sizes,
IntT* pool_offsets,
IntT nvalues,
T* input_values,
T* output_values,
int total_rows) {
int row = blockIdx.x * blockDim.y + threadIdx.y;
if (row >= total_rows) return;
int tid = threadIdx.x;
int index = row / nvalues;
int j = row % nvalues;
IntT offset = pool_offsets[index];
IntT* pool_indices = sorted_pool_indices + offset;
IntT pool_indices_size = pool_sizes[index];
int kIteration = (pool_indices_size + warpSize - 1) / warpSize;
T max_val = -std::numeric_limits<T>::infinity();
for (int k = 0; k < kIteration; ++k) {
int idx = tid + k * warpSize;
if (idx >= pool_indices_size) break;
auto i = pool_indices[idx];
auto cur_value = input_values + j + nvalues * i;
if (*cur_value > max_val) {
max_val = *cur_value;
}
}
T row_max_val = phi::funcs::WarpReduceMax<T>(max_val, 0xFFFFFFFF);
T exp_sum = 0;
for (int k = 0; k < kIteration; ++k) {
int idx = tid + k * warpSize;
if (idx >= pool_indices_size) break;
auto i = pool_indices[idx];
auto cur_value = input_values + j + nvalues * i;
auto cur_out_value = output_values + i * nvalues + j;
auto functor = phi::funcs::CudaExpFunctor<T>();
T exp = functor(*cur_value - row_max_val);
exp_sum += exp;
*cur_out_value = exp;
}
T row_exp_sum = phi::funcs::WarpReduceSum<T>(exp_sum, 0xFFFFFFFF);
row_exp_sum = 1.0 / row_exp_sum;
for (int k = 0; k < kIteration; ++k) {
int idx = tid + k * warpSize;
if (idx >= pool_indices_size) break;
auto i = pool_indices[idx];
auto cur_out_value = output_values + i * nvalues + j;
*cur_out_value *= row_exp_sum;
}
}
template <typename T, typename IntT, typename Context>
void SoftmaxCooGPUKernel(const Context& dev_ctx,
const SparseCooTensor& x,
int axis,
SparseCooTensor* out) {
auto indices = x.indices();
auto values = x.values();
const auto x_dims = x.dims();
const std::vector<IntT> sizes = phi::vectorize<IntT>(x_dims);
const auto sparse_dim = x.sparse_dim();
const IntT x_nnz = x.nnz();
DenseTensor out_indices(indices);
DenseTensor out_values = EmptyLike<T, Context>(dev_ctx, values);
out->SetMember(out_indices, out_values, x.dims(), x.coalesced());
int dim = axis < 0 ? x_dims.size() + axis : axis;
/* If dim is greater than or equal to sparse_dim, the dense softmax is used.
*/
if (dim >= sparse_dim) {
SoftmaxKernel<T, Context>(
dev_ctx, values, dim - sparse_dim + 1, &out_values);
return;
}
auto stream = dev_ctx.stream();
IntT nvalues = std::accumulate(sizes.begin() + sparse_dim,
sizes.end(),
static_cast<IntT>(1),
std::multiplies<>());
auto values_2 = values.Resize({x_nnz, nvalues});
/* Compute independent pools of indices */
DenseTensor sorted_indices;
DenseTensor pool_offsets;
DenseTensor pool_sizes;
std::tie(sorted_indices, pool_offsets, pool_sizes, std::ignore) =
phi::funcs::sparse::ComputePoolMax<T, IntT, Context, false>(
dev_ctx, indices, values_2, sizes, nvalues, static_cast<IntT>(dim));
auto pool_size = pool_offsets.dims()[0];
auto out_values_ptr = out_values.data<T>();
auto values_ptr = values.data<T>();
int total_rows = pool_size * nvalues;
dim3 grid((total_rows + 15) / 16);
dim3 block(32, 16);
SoftmaxCooGPURawKernel<T, IntT>
<<<grid, block, 0, stream>>>(sorted_indices.data<IntT>(),
pool_sizes.data<IntT>(),
pool_offsets.data<IntT>(),
nvalues,
values_ptr,
out_values_ptr,
total_rows);
}
template <typename T, typename Context>
void SoftmaxCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
int axis,
SparseCooTensor* out) {
PD_VISIT_BASE_INTEGRAL_TYPES(
x.indices().dtype(), "SoftmaxCooGPUKernel", ([&] {
SoftmaxCooGPUKernel<T, data_t, Context>(dev_ctx, x, axis, out);
}));
}
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
...@@ -127,3 +265,12 @@ PD_REGISTER_KERNEL(softmax_csr, ...@@ -127,3 +265,12 @@ PD_REGISTER_KERNEL(softmax_csr,
double) { double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
} }
PD_REGISTER_KERNEL(softmax_coo,
GPU,
ALL_LAYOUT,
phi::sparse::SoftmaxCooKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h"
namespace phi { namespace phi {
...@@ -26,5 +27,12 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx, ...@@ -26,5 +27,12 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx,
int axis, int axis,
SparseCsrTensor* dx); SparseCsrTensor* dx);
template <typename T, typename Context>
void SoftmaxCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& out,
const SparseCooTensor& dout,
int axis,
SparseCooTensor* dx);
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h"
namespace phi { namespace phi {
...@@ -25,5 +26,11 @@ void SoftmaxCsrKernel(const Context& dev_ctx, ...@@ -25,5 +26,11 @@ void SoftmaxCsrKernel(const Context& dev_ctx,
int axis, int axis,
SparseCsrTensor* out); SparseCsrTensor* out);
template <typename T, typename Context>
void SoftmaxCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
int axis,
SparseCooTensor* out);
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
...@@ -18,8 +18,12 @@ import numpy as np ...@@ -18,8 +18,12 @@ import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
import paddle import paddle
import paddle.nn.functional as F
np.random.seed(2022) np.random.seed(2022)
devices = ['cpu']
if paddle.device.get_device() != "cpu":
devices.append(paddle.device.get_device())
class TestCsrSoftmax(unittest.TestCase): class TestCsrSoftmax(unittest.TestCase):
...@@ -124,5 +128,168 @@ class TestCsrSoftmax(unittest.TestCase): ...@@ -124,5 +128,168 @@ class TestCsrSoftmax(unittest.TestCase):
np.testing.assert_allclose(csr.grad.values().numpy(), dx, rtol=1e-05) np.testing.assert_allclose(csr.grad.values().numpy(), dx, rtol=1e-05)
class TestCooSoftmax(unittest.TestCase):
def sparse_softmax(self, sparse, dense_shape, sparse_dim, dim):
"""
sparse softmax algorithm in Python.
"""
inf = float('inf')
indices = sparse.indices()
values = sparse.values()
size = sparse.shape
dense_size = tuple(size[sparse_dim:])
dense_dim = len(dense_size)
if dim < sparse_dim:
nnz = sparse.nnz()
# compute pool indices
strides = np.ones((sparse_dim, 1))
for i in reversed(range(sparse_dim - 1)):
strides[i, 0] = strides[i + 1, 0] * size[i + 1]
strides[dim, 0] = 0
strides = paddle.to_tensor(strides, dtype=indices.dtype)
pool = paddle.sum((indices * strides), axis=0).numpy()
i2p = {}
for i in range(nnz):
c = int(pool[i])
if c not in i2p:
i2p[c] = len(i2p)
pool[i] = i2p[c]
mx = paddle.empty((pool.max() + 1,) + dense_size).numpy()
mx[:] = -inf
np_values = values.numpy()
for n in range(nnz):
p = pool[n]
mx[p] = np.where(mx[p] > np_values[n], mx[p], np_values[n])
# apply exp to (v - mx) and sum the results
exp_values = paddle.empty_like(values).numpy()
exp_sums = np.zeros_like(mx)
for n in range(nnz):
p = pool[n]
v = exp_values[n] = np.exp(np_values[n] - mx[p])
exp_sums[p] = exp_sums[p] + v
# normalize with the sum of exponents
for n in range(nnz):
p = pool[n]
exp_values[n] = exp_values[n] / exp_sums[p]
return paddle.sparse.sparse_coo_tensor(
indices, exp_values, dense_shape
)
elif dim < sparse_dim + dense_dim:
return paddle.sparse.sparse_coo_tensor(
indices, F.softmax(values, dim - sparse_dim + 1), size
)
else:
print(
"`dim(=%s)` must be smaller than `sparse_dim(=%s) + dense_dim(=%s)`"
% (dim, sparse_dim, dense_dim)
)
def check_run(self, dense_shape):
mask = np.random.rand(*dense_shape) < 0.5
np_x = np.random.rand(*dense_shape) * mask
for device in devices:
paddle.device.set_device(device)
for sparse_dim in range(1, len(dense_shape)):
coo = (
paddle.to_tensor(np_x, stop_gradient=False)
.detach()
.to_sparse_coo(sparse_dim)
)
size = coo.shape
dense_size = tuple(size[sparse_dim:])
dense_dim = len(dense_size)
for axis in range(sparse_dim + dense_dim):
coo = (
paddle.to_tensor(np_x, stop_gradient=False)
.detach()
.to_sparse_coo(sparse_dim)
)
coo.stop_gradient = False
py_out = self.sparse_softmax(
coo, dense_shape, sparse_dim, axis
)
m = paddle.sparse.nn.Softmax(axis=axis)
out = m(coo)
np.testing.assert_allclose(
py_out.indices().numpy(),
out.indices().numpy(),
rtol=1e-05,
)
np.testing.assert_allclose(
py_out.values().numpy(),
out.values().numpy(),
rtol=1e-05,
)
out.backward(coo.detach())
dense_tensor = paddle.to_tensor(np_x, stop_gradient=False)
model_dense = paddle.nn.Softmax(axis=axis)
dense_out = model_dense(dense_tensor)
dense_out.backward(dense_tensor.detach())
dg_npy = dense_tensor.grad.numpy()
np.testing.assert_allclose(
coo.grad.to_dense().numpy(), dg_npy, rtol=1e-05
)
def test_softmax2d(self):
self.check_run((16, 128))
def test_softmax3d(self):
self.check_run((16, 16, 128))
def test_softmax2d_static(self):
for device in devices:
paddle.device.set_device(device)
np_x = np.array([[11, 0, 0, 14, 15], [0, 22, 0, 24, 0]]).astype(
'float32'
)
coo = (
paddle.to_tensor(np_x, stop_gradient=False)
.detach()
.to_sparse_coo(2)
)
m = paddle.sparse.nn.Softmax()
dy_out = m(coo)
dy_out_dense = dy_out.to_dense().numpy()
paddle.enable_static()
indices = paddle.static.data(
name='indices', shape=[2, 5], dtype='int32'
)
values = paddle.static.data(
name='values', shape=[5, 1], dtype='float32'
)
dense_shape = [2, 5]
sp_x = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
sparse_softmax = paddle.sparse.nn.Softmax()
sp_y = sparse_softmax(sp_x)
out = sp_y.to_dense()
exe = paddle.static.Executor()
indices_data = [[0, 0, 0, 1, 1], [0, 3, 4, 1, 3]]
values_data = np.array([11, 14, 15, 22, 24]).astype('float32')
fetch = exe.run(
feed={'indices': indices_data, 'values': values_data},
fetch_list=[out],
return_numpy=True,
)
np.testing.assert_allclose(dy_out_dense, fetch[0], rtol=1e-5)
paddle.disable_static()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -57,7 +57,6 @@ def relu(x, name=None): ...@@ -57,7 +57,6 @@ def relu(x, name=None):
return out return out
@dygraph_only
def softmax(x, axis=-1, name=None): def softmax(x, axis=-1, name=None):
r""" r"""
sparse softmax activation, requiring x to be a SparseCooTensor or SparseCsrTensor. sparse softmax activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
...@@ -112,8 +111,35 @@ def softmax(x, axis=-1, name=None): ...@@ -112,8 +111,35 @@ def softmax(x, axis=-1, name=None):
# values=[0.34132850, 0.29843223, 0.36023921, 0.20176248, 0.41964680, # values=[0.34132850, 0.29843223, 0.36023921, 0.20176248, 0.41964680,
# 0.37859070, 0.30015594, 0.26316854, 0.16354506, 0.27313042]) # 0.37859070, 0.30015594, 0.26316854, 0.16354506, 0.27313042])
coo = x.to_sparse_coo(sparse_dim=2)
print(coo)
# Tensor(shape=[3, 4], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True,
# indices=[[0, 0, 0, 1, 1, 1, 2, 2, 2, 2],
# [0, 1, 3, 0, 2, 3, 0, 1, 2, 3]],
# values=[0.83438963, 0.70008713, 0.88831252, 0.02200012, 0.75432241,
# 0.65136462, 0.96088767, 0.82938021, 0.35367414, 0.86653489])
out = paddle.sparse.nn.functional.softmax(coo)
print(out)
# Tensor(shape=[3, 4], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True,
# indices=[[0, 0, 0, 1, 1, 1, 2, 2, 2, 2],
# [0, 1, 3, 0, 2, 3, 0, 1, 2, 3]],
# values=[0.34132853, 0.29843226, 0.36023924, 0.20176250, 0.41964683,
# 0.37859073, 0.30015597, 0.26316857, 0.16354507, 0.27313042])
""" """
if in_dynamic_mode():
return _C_ops.sparse_softmax(x, axis) return _C_ops.sparse_softmax(x, axis)
else:
op_type = 'sparse_softmax'
helper = LayerHelper(op_type)
out = helper.create_sparse_variable_for_type_inference(x.dtype)
helper.append_op(
type=op_type,
inputs={'x': x},
outputs={'out': out},
attrs={'axis': axis},
)
return out
@dygraph_only @dygraph_only
......
...@@ -116,6 +116,22 @@ class Softmax(Layer): ...@@ -116,6 +116,22 @@ class Softmax(Layer):
# cols=[0, 1, 3, 1, 2, 0, 1], # cols=[0, 1, 3, 1, 2, 0, 1],
# values=[0.23070428, 0.27815846, 0.49113727, 0.67227983, 0.32772022, # values=[0.23070428, 0.27815846, 0.49113727, 0.67227983, 0.32772022,
# 0.49353254, 0.50646752]) # 0.49353254, 0.50646752])
coo = x.to_sparse_coo(sparse_dim=2)
print(coo)
# Tensor(shape=[3, 4], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True,
# indices=[[0, 0, 0, 1, 1, 1, 2, 2, 2, 2],
# [0, 1, 3, 0, 2, 3, 0, 1, 2, 3]],
# values=[0.83438963, 0.70008713, 0.88831252, 0.02200012, 0.75432241,
# 0.65136462, 0.96088767, 0.82938021, 0.35367414, 0.86653489])
out = softmax(coo)
print(out)
# Tensor(shape=[3, 4], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True,
# indices=[[0, 0, 0, 1, 1, 1, 2, 2, 2, 2],
# [0, 1, 3, 0, 2, 3, 0, 1, 2, 3]],
# values=[0.34132853, 0.29843226, 0.36023924, 0.20176250, 0.41964683,
# 0.37859073, 0.30015597, 0.26316857, 0.16354507, 0.27313042])
""" """
def __init__(self, axis=-1, name=None): def __init__(self, axis=-1, name=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册