未验证 提交 54474637 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #16057 from heavengate/softmax_axis

Add attr 'axis' for softmax
......@@ -95,7 +95,7 @@ paddle.fluid.layers.conv2d (ArgSpec(args=['input', 'num_filters', 'filter_size',
paddle.fluid.layers.conv3d (ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)), ('document', '37042620f9bd3a2da6e5d3138b2f724b'))
paddle.fluid.layers.sequence_pool (ArgSpec(args=['input', 'pool_type', 'is_test'], varargs=None, keywords=None, defaults=(False,)), ('document', 'a194fb80614023f543df3949fbd0d0b8'))
paddle.fluid.layers.sequence_softmax (ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None)), ('document', '19ef6f9cdd27feac8a1ae060f19c10b4'))
paddle.fluid.layers.softmax (ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None)), ('document', 'f19dd380864e61134ce3814e4be0de4b'))
paddle.fluid.layers.softmax (ArgSpec(args=['input', 'use_cudnn', 'name', 'axis'], varargs=None, keywords=None, defaults=(False, None, -1)), ('document', '59b1c6bf2f0fa9dc649c85fef3a3b2ea'))
paddle.fluid.layers.pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)), ('document', 'bbd84e855e660cd1084bb71a2fd0cdaa'))
paddle.fluid.layers.pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)), ('document', '043de7333b79ee0ac55053c14ed81625'))
paddle.fluid.layers.adaptive_pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)), ('document', '859b887174d06f361658f69cb7c06d95'))
......
......@@ -386,7 +386,7 @@ void BenchKernelSoftmax() {
RandomVec<T>(bs * n, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* x_data = x.data<T>();
T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KernelTuple, PlaceType>(n, x_data, y_data, n, bs);
BenchAllImpls<KernelTuple, PlaceType>(n, x_data, y_data, n, bs, 1);
}
}
}
......
......@@ -34,6 +34,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kVAddRelu);
ONE_CASE(kVSub);
ONE_CASE(kVScal);
ONE_CASE(kStrideScal);
ONE_CASE(kVAddBias);
ONE_CASE(kVRelu);
ONE_CASE(kVBroadcast);
......@@ -55,6 +56,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kMatMul);
ONE_CASE(kHMax);
ONE_CASE(kHSum);
ONE_CASE(kStrideASum);
ONE_CASE(kSoftmax);
ONE_CASE(kEmbSeqPool);
ONE_CASE(kSgd);
......
......@@ -38,6 +38,8 @@ typedef enum {
kNCHW16CMulNC,
kSeqPool,
kSoftmax,
kStrideASum,
kStrideScal,
kVAdd,
kVAddBias,
kVAddRelu,
......@@ -74,6 +76,14 @@ struct XYZNTuple {
template <typename T>
struct AXYNTuple : public XYZNTuple<T> {};
// a, x, y, n, stride
template <typename T>
struct AXYNSTuple {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int, int);
};
// x, y, n
template <typename T>
struct XYNTuple {
......@@ -86,6 +96,14 @@ struct XYNTuple {
template <typename T>
struct XRNTuple : public XYNTuple<T> {};
// x, returned value, n, stride
template <typename T>
struct XRNSTuple {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, T*, int, int);
};
#define DECLARE_KERNELTUPLE(kernel_tuple, type) \
template <typename T> \
struct type##Tuple : public kernel_tuple<T> { \
......@@ -101,6 +119,8 @@ DECLARE_KERNELTUPLE(XYZNTuple, VSub);
DECLARE_KERNELTUPLE(AXYNTuple, VScal);
DECLARE_KERNELTUPLE(AXYNTuple, VAddBias);
DECLARE_KERNELTUPLE(AXYNSTuple, StrideScal);
DECLARE_KERNELTUPLE(XYNTuple, VRelu);
DECLARE_KERNELTUPLE(XYNTuple, VIdentity);
DECLARE_KERNELTUPLE(XYNTuple, VSquare);
......@@ -112,6 +132,8 @@ DECLARE_KERNELTUPLE(XYNTuple, VCopy);
DECLARE_KERNELTUPLE(XRNTuple, HMax);
DECLARE_KERNELTUPLE(XRNTuple, HSum);
DECLARE_KERNELTUPLE(XRNSTuple, StrideASum);
typedef struct {
void* gates; // gates: x_ch, x_ih, x_fh, x_oh
const void* ct_1;
......@@ -285,7 +307,7 @@ struct SoftmaxTuple {
static constexpr KernelType kernel_type = kSoftmax;
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, T*, int, int);
typedef void (*func_type)(const T*, T*, int, int, int);
};
// nChw16c = nChw16c .* NC
......
......@@ -50,10 +50,15 @@ void VTanh(const T* x, T* y, int n) {
compute_addbias(&b, y, y, n);
}
void Softmax(const T* x, T* y, int n, int bs) {
// remain is the product of dimension shapes after the axis dimension
void Softmax(const T* x, T* y, int n, int bs, int remain) {
auto compute_hmax = KernelFuncs<HMaxTuple<T>, CPUPlace>::Cache().At(n);
auto compute_hsum = KernelFuncs<HSumTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vscal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
auto compute_strideasum =
KernelFuncs<StrideASumTuple<T>, CPUPlace>::Cache().At(n);
auto compute_stridescal =
KernelFuncs<StrideScalTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vaddbias =
KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vexp = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
......@@ -64,9 +69,17 @@ void Softmax(const T* x, T* y, int n, int bs) {
scalar = static_cast<T>(0) - scalar;
compute_vaddbias(&scalar, x, y, n); // x - max
compute_vexp(y, y, n);
compute_hsum(y, &scalar, n);
scalar = static_cast<T>(1) / scalar;
compute_vscal(&scalar, y, y, n);
if (remain == 1) {
compute_hsum(y, &scalar, n);
scalar = static_cast<T>(1) / scalar;
compute_vscal(&scalar, y, y, n);
} else {
for (int j = 0; j < remain; ++j) {
compute_strideasum(&y[j], &scalar, n, remain);
scalar = static_cast<T>(1) / scalar;
compute_stridescal(&scalar, &y[j], &y[j], n, remain);
}
}
x += n;
y += n;
}
......
......@@ -26,7 +26,7 @@ using T = float;
void VSigmoid(const T* x, T* y, int n);
void VTanh(const T* x, T* y, int n);
void Softmax(const T* x, T* y, int n, int bs);
void Softmax(const T* x, T* y, int n, int bs, int remain);
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr);
void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr);
......
......@@ -7,6 +7,7 @@ USE_JITKERNEL_MORE(kMatMul, mkl)
USE_JITKERNEL_MORE(kVMul, mkl)
USE_JITKERNEL_MORE(kVAdd, mkl)
USE_JITKERNEL_MORE(kVScal, mkl)
USE_JITKERNEL_MORE(kStrideScal, mkl)
USE_JITKERNEL_MORE(kVExp, mkl)
USE_JITKERNEL_MORE(kVSquare, mkl)
USE_JITKERNEL_MORE(kVCopy, mkl)
......
......@@ -78,6 +78,26 @@ void VScal<double>(const double* a, const double* x, double* y, int n) {
}
}
template <>
void StrideScal<float>(const float* a, const float* x, float* y, int n,
int stride) {
if (x == y) {
platform::dynload::cblas_sscal(n / stride, *a, y, stride);
} else {
refer::StrideScal<float>(a, x, y, n, stride);
}
}
template <>
void StrideScal<double>(const double* a, const double* x, double* y, int n,
int stride) {
if (x == y) {
platform::dynload::cblas_dscal(n / stride, *a, y, stride);
} else {
refer::StrideScal<double>(a, x, y, n, stride);
}
}
template <>
void VExp<float>(const float* x, float* y, int n) {
platform::dynload::vsExp(n, x, y);
......@@ -128,6 +148,16 @@ void ASum<double>(const double* x, double* res, int n) {
res[0] = platform::dynload::cblas_dasum(n, x, 1);
}
template <>
void StrideASum<float>(const float* x, float* res, int n, int stride) {
res[0] = platform::dynload::cblas_sasum(n / stride, x, stride);
}
template <>
void StrideASum<double>(const double* x, double* res, int n, int stride) {
res[0] = platform::dynload::cblas_dasum(n / stride, x, stride);
}
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <>
bool VMulKernel<float>::CanBeUsed(const int& d) const {
......@@ -144,6 +174,11 @@ bool VScalKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool StrideScalKernel<float>::CanBeUsed(const int& d) const {
return true;
}
template <>
bool VExpKernel<float>::CanBeUsed(const int& d) const {
return d > 7;
......@@ -235,6 +270,7 @@ bool SoftmaxKernel<float>::CanBeUsed(const int& d) const {
AWALYS_USE_ME_WITH_DOUBLE(VMul);
AWALYS_USE_ME_WITH_DOUBLE(VAdd);
AWALYS_USE_ME_WITH_DOUBLE(VScal);
AWALYS_USE_ME_WITH_DOUBLE(StrideScal);
AWALYS_USE_ME_WITH_DOUBLE(VExp);
AWALYS_USE_ME_WITH_DOUBLE(VSigmoid);
AWALYS_USE_ME_WITH_DOUBLE(VTanh);
......@@ -259,6 +295,7 @@ REGISTER_MKL_KERNEL(MatMul);
REGISTER_MKL_KERNEL(VMul);
REGISTER_MKL_KERNEL(VAdd);
REGISTER_MKL_KERNEL(VScal);
REGISTER_MKL_KERNEL(StrideScal);
REGISTER_MKL_KERNEL(VExp);
REGISTER_MKL_KERNEL(VSquare);
REGISTER_MKL_KERNEL(VCopy);
......
......@@ -129,7 +129,14 @@ template <typename T>
void ASum(const T* x, T* res, int n);
template <typename T>
void Softmax(const T* x, T* y, int n, int bs) {
void StrideASum(const T* x, T* res, int n, int stride);
template <typename T>
void StrideScal(const T* a, const T* x, T* y, int n, int stride);
// remain is the product of dimension shapes after the axis dimension
template <typename T>
void Softmax(const T* x, T* y, int n, int bs, int remain = 1) {
std::vector<T> entities(bs);
for (int i = 0; i < bs; ++i) {
entities[i] = x[i * n];
......@@ -143,9 +150,17 @@ void Softmax(const T* x, T* y, int n, int bs) {
VExp(y, y, n * bs);
for (int i = 0; i < bs; ++i) {
T sum;
ASum(&y[i * n], &sum, n);
sum = static_cast<T>(1) / sum;
VScal(&sum, &y[i * n], &y[i * n], n);
if (remain == 1) {
ASum(&y[i * n], &sum, n);
sum = static_cast<T>(1) / sum;
VScal(&sum, &y[i * n], &y[i * n], n);
} else {
for (int j = 0; j < remain; ++j) {
StrideASum(&y[i * n + j], &sum, n, remain);
sum = static_cast<T>(1) / sum;
StrideScal(&sum, &y[i * n + j], &y[i * n + j], n, remain);
}
}
}
}
......@@ -193,6 +208,7 @@ DECLARE_MKL_KERNEL(VAdd);
// AXYN
DECLARE_MKL_KERNEL(VScal);
DECLARE_MKL_KERNEL(StrideScal);
// XYN
DECLARE_MKL_KERNEL(VExp);
......
......@@ -12,6 +12,7 @@ USE_JITKERNEL_REFER(kVAdd)
USE_JITKERNEL_REFER(kVAddRelu)
USE_JITKERNEL_REFER(kVSub)
USE_JITKERNEL_REFER(kVScal)
USE_JITKERNEL_REFER(kStrideScal)
USE_JITKERNEL_REFER(kVAddBias)
USE_JITKERNEL_REFER(kVCopy)
USE_JITKERNEL_REFER(kVRelu)
......@@ -32,6 +33,7 @@ USE_JITKERNEL_REFER(kMatMul)
USE_JITKERNEL_REFER(kVSquare)
USE_JITKERNEL_REFER(kHSum)
USE_JITKERNEL_REFER(kHMax)
USE_JITKERNEL_REFER(kStrideASum)
USE_JITKERNEL_REFER(kSoftmax)
USE_JITKERNEL_REFER(kEmbSeqPool)
USE_JITKERNEL_REFER(kSgd)
......
......@@ -27,6 +27,7 @@ REGISTER_REFER_KERNEL(VAddRelu);
REGISTER_REFER_KERNEL(VSub);
REGISTER_REFER_KERNEL(VScal);
REGISTER_REFER_KERNEL(StrideScal);
REGISTER_REFER_KERNEL(VAddBias);
REGISTER_REFER_KERNEL(VRelu);
......@@ -51,6 +52,7 @@ REGISTER_REFER_KERNEL(SeqPool);
REGISTER_REFER_KERNEL(MatMul);
REGISTER_REFER_KERNEL(HMax);
REGISTER_REFER_KERNEL(HSum);
REGISTER_REFER_KERNEL(StrideASum);
REGISTER_REFER_KERNEL(Softmax);
REGISTER_REFER_KERNEL(EmbSeqPool);
REGISTER_REFER_KERNEL(Sgd);
......
......@@ -411,19 +411,47 @@ void HSum(const T* x, T* res, int n) {
}
}
template <typename T>
void StrideASum(const T* x, T* res, int n, int stride) {
res[0] = x[0];
for (int i = stride; i < n; i += stride) {
res[0] += std::abs(x[i]);
}
}
template <typename T>
void StrideScal(const T* a, const T* x, T* y, int n, int stride) {
for (int i = 0; i < n; ++i) {
if (i % stride == 0) {
y[i] = x[i] * a[0];
} else {
y[i] = x[i];
}
}
}
// y = e^(x - max(x))
// y = y / sum(y)
// remain is the product of dimension shapes after the axis dimension
template <typename T>
void Softmax(const T* x, T* y, int n, int bs = 1) {
void Softmax(const T* x, T* y, int n, int bs = 1, int remain = 1) {
for (int i = 0; i < bs; ++i) {
T scalar;
HMax(x, &scalar, n);
scalar = static_cast<T>(0) - scalar;
VAddBias(&scalar, x, y, n); // x - max
VExp(y, y, n);
HSum(y, &scalar, n);
scalar = static_cast<T>(1) / scalar;
VScal(&scalar, y, y, n);
if (remain == 1) {
HSum(y, &scalar, n);
scalar = static_cast<T>(1) / scalar;
VScal(&scalar, y, y, n);
} else {
for (int j = 0; j < remain; j++) {
StrideASum(&y[j], &scalar, n, remain);
scalar = static_cast<T>(1) / scalar;
StrideScal(&scalar, &y[j], &y[j], n, remain);
}
}
x += n;
y += n;
}
......@@ -507,6 +535,9 @@ DECLARE_REFER_KERNEL(VSub);
DECLARE_REFER_KERNEL(VScal);
DECLARE_REFER_KERNEL(VAddBias);
// const T* a, const T* x, T* y, int n, int stride
DECLARE_REFER_KERNEL(StrideScal);
// const T* x, T* y, int n
DECLARE_REFER_KERNEL(VRelu);
DECLARE_REFER_KERNEL(VIdentity);
......@@ -528,6 +559,8 @@ DECLARE_REFER_KERNEL(GRUHtPart2);
DECLARE_REFER_KERNEL(HMax);
DECLARE_REFER_KERNEL(HSum);
DECLARE_REFER_KERNEL(StrideASum);
// others
DECLARE_REFER_KERNEL(CRFDecoding);
DECLARE_REFER_KERNEL(LayerNorm);
......
......@@ -723,39 +723,122 @@ void TestKernelSoftmax() {
VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
for (int bs : {1, 2, 10}) {
for (int n : TestSizes()) {
for (int m : {1, 2, 3}) { // remain
if (m > n || n % m != 0) {
continue;
}
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(bs * n), y(bs * n);
RandomVec<T>(bs * n, x.data());
const T* x_data = x.data();
T* y_data = y.data();
std::vector<T> xinp(x.size()); // inplace test
std::copy(x.begin(), x.end(), xinp.begin());
ref(x_data, y_data, n, bs, m);
T* xinp_data = xinp.data();
ref(xinp_data, xinp_data, n, bs, m);
ExpectEQ<T>(xinp_data, y_data, n * bs);
auto verifier = [](const typename KernelTuple::func_type tgt,
const std::vector<T>& x, const std::vector<T>& yref,
int n, int bs, int m) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(yref.size(), x.size());
EXPECT_EQ(x.size(), static_cast<size_t>(n * bs));
const T* x_data = x.data();
const T* yref_data = yref.data();
std::vector<T> ytgt(n * bs);
T* ytgt_data = ytgt.data();
// test normal
tgt(x_data, ytgt_data, n, bs, m);
ExpectEQ<T>(ytgt_data, yref_data, n * bs);
// test inplace x
std::copy(x.begin(), x.end(), ytgt.begin());
tgt(ytgt_data, ytgt_data, n, bs, m);
ExpectEQ<T>(ytgt_data, yref_data, n * bs);
};
TestAllImpls<KernelTuple, PlaceType>(n, verifier, x, y, n, bs, m);
}
}
}
}
template <typename KernelTuple, typename PlaceType>
void TestKernelStrideASum() {
using T = typename KernelTuple::data_type;
VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
for (int d : TestSizes()) {
for (int m : {1, 2, 3}) { // stride
if (m > d || d % m != 0) {
continue;
}
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(d);
RandomVec<T>(d, x.data());
T ref_res;
ref(x.data(), &ref_res, d, m);
auto verifier = [](const typename KernelTuple::func_type tgt,
const std::vector<T>& x, const T ref_res,
const int m) {
EXPECT_TRUE(tgt != nullptr);
T tgt_res;
tgt(x.data(), &tgt_res, x.size(), m);
ExpectEQ<T>(&tgt_res, &ref_res, 1);
};
TestAllImpls<KernelTuple, PlaceType>(d, verifier, x, ref_res, m);
}
}
}
template <typename KernelTuple, typename PlaceType>
void TestKernelStrideScal() {
using T = typename KernelTuple::data_type;
VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
for (int d : TestSizes()) {
for (int m : {1, 2, 3}) { // stride
if (m > d || d % m != 0) {
continue;
}
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(bs * n), y(bs * n);
RandomVec<T>(bs * n, x.data());
const T* x_data = x.data();
T* y_data = y.data();
std::vector<T> xinp(x.size()); // inplace test
const T a = static_cast<T>(3);
std::vector<T> x(d), yref(d);
std::vector<T> xinp(d); // inplace test
RandomVec<T>(d, x.data());
std::copy(x.begin(), x.end(), xinp.begin());
ref(x_data, y_data, n, bs);
const T* x_data = x.data();
T* yref_data = yref.data();
T* xinp_data = xinp.data();
ref(xinp_data, xinp_data, n, bs);
ExpectEQ<T>(xinp_data, y_data, n * bs);
// test refer code inplace
ref(&a, x_data, yref_data, d, m);
ref(&a, xinp_data, xinp_data, d, m);
ExpectEQ<T>(xinp_data, yref_data, d);
auto verifier = [](const typename KernelTuple::func_type tgt,
auto verifier = [](const typename KernelTuple::func_type tgt, const T a,
const std::vector<T>& x, const std::vector<T>& yref,
int n, int bs) {
const int m) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(yref.size(), x.size());
EXPECT_EQ(x.size(), static_cast<size_t>(n * bs));
const T* x_data = x.data();
const T* yref_data = yref.data();
std::vector<T> ytgt(n * bs);
const int d = yref.size();
std::vector<T> ytgt(d);
T* ytgt_data = ytgt.data();
// test normal
tgt(x_data, ytgt_data, n, bs);
ExpectEQ<T>(ytgt_data, yref_data, n * bs);
tgt(&a, x_data, ytgt_data, d, m);
ExpectEQ<T>(ytgt_data, yref_data, d);
// test inplace x
std::copy(x.begin(), x.end(), ytgt.begin());
tgt(ytgt_data, ytgt_data, n, bs);
ExpectEQ<T>(ytgt_data, yref_data, n * bs);
tgt(&a, ytgt_data, ytgt_data, d, m);
ExpectEQ<T>(ytgt_data, yref_data, d);
};
TestAllImpls<KernelTuple, PlaceType>(n, verifier, x, y, n, bs);
TestAllImpls<KernelTuple, PlaceType>(d, verifier, a, x, yref, m);
}
}
}
......@@ -912,7 +995,7 @@ TEST(JITKernel_pool, more) {
EXPECT_EQ(kers.size(), 10UL);
#else
#ifdef PADDLE_WITH_MKLML
EXPECT_EQ(kers.size(), 21UL);
EXPECT_EQ(kers.size(), 22UL);
#else
EXPECT_EQ(kers.size(), 8UL);
#endif
......@@ -921,7 +1004,7 @@ TEST(JITKernel_pool, more) {
TEST(JITKernel_pool, refer) {
const auto& kers = jit::ReferKernelPool::Instance().AllKernels();
EXPECT_EQ(kers.size(), 29UL);
EXPECT_EQ(kers.size(), 31UL);
}
// test helper
......@@ -1292,3 +1375,6 @@ TEST_CPU_KERNEL(MatMul);
TEST_CPU_KERNEL(Softmax);
TEST_CPU_KERNEL(Sgd);
TEST_CPU_KERNEL(VBroadcast);
TEST_CPU_KERNEL(StrideASum);
TEST_CPU_KERNEL(StrideScal);
......@@ -23,15 +23,16 @@ template <typename DeviceContext, typename T, bool is_test,
typename Enable = void>
class SoftmaxFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y);
void operator()(const DeviceContext& context, const int axis_dim,
const framework::Tensor* X, framework::Tensor* Y);
};
template <typename DeviceContext, typename T>
class SoftmaxGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor* y,
const framework::Tensor* y_grad, framework::Tensor* x_grad);
void operator()(const DeviceContext& context, const int axis_dim,
const framework::Tensor* y, const framework::Tensor* y_grad,
framework::Tensor* x_grad);
};
#ifdef PADDLE_WITH_CUDA
......
......@@ -36,8 +36,8 @@ struct ValueClip {
template <typename DeviceContext, typename T, bool is_test, typename Enable>
void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y) {
const DeviceContext& context, const int axis_dim,
const framework::Tensor* X, framework::Tensor* Y) {
auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y);
......@@ -46,10 +46,13 @@ void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
const int batch_size = logits.dimension(kBatchDim);
const int num_classes = logits.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
auto shifted_logits = (logits -
logits.maximum(along_class)
......@@ -60,11 +63,11 @@ void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
softmax.device(*context.eigen_device()) = shifted_logits.exp();
softmax.device(*context.eigen_device()) = (softmax *
softmax.sum(along_class)
softmax.reshape(batch_axis_remain)
.sum(along_class)
.inverse()
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
.broadcast(one_axis));
}
template <class DeviceContext>
......@@ -73,8 +76,8 @@ using enable_if_CPU = typename std::enable_if<
template <typename DeviceContext>
class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
void operator()(const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y) {
void operator()(const DeviceContext& context, const int axis_dim,
const framework::Tensor* X, framework::Tensor* Y) {
auto in_dims = X->dims();
const float* in_data = X->data<float>();
float* out_data = Y->data<float>();
......@@ -84,14 +87,16 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
auto compute_softmax =
jit::KernelFuncs<jit::SoftmaxTuple<float>, platform::CPUPlace>::Cache()
.At(in_dims[kClassDim]);
compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]);
compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim],
in_dims[kClassDim] / axis_dim);
}
};
template <typename DeviceContext, typename T>
void SoftmaxGradFunctor<DeviceContext, T>::operator()(
const DeviceContext& context, const framework::Tensor* y,
const framework::Tensor* y_grad, framework::Tensor* x_grad) {
const DeviceContext& context, const int axis_dim,
const framework::Tensor* y, const framework::Tensor* y_grad,
framework::Tensor* x_grad) {
auto softmax = EigenMatrix<T>::From(*y);
auto softmax_grad = EigenMatrix<T>::From(*y_grad);
auto logits_grad = EigenMatrix<T>::From(*x_grad);
......@@ -101,16 +106,19 @@ void SoftmaxGradFunctor<DeviceContext, T>::operator()(
const int batch_size = softmax.dimension(kBatchDim);
const int num_classes = softmax.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
auto dot = (softmax * softmax_grad)
.reshape(batch_axis_remain)
.sum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class);
.broadcast(one_axis);
logits_grad.device(*context.eigen_device()) = (softmax_grad - dot) * softmax;
}
......
......@@ -39,6 +39,20 @@ class SoftmaxOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SoftmaxOp should not be null.");
auto dim_x = ctx->GetInputDim("X");
auto rank_x = dim_x.size();
auto axis = ctx->Attrs().Get<int>("axis");
PADDLE_ENFORCE(axis >= -rank_x && axis < rank_x,
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X).");
auto use_cudnn = ctx->Attrs().Get<bool>("use_cudnn");
auto use_mkldnn = ctx->Attrs().Get<bool>("use_mkldnn");
if (axis != rank_x - 1 && axis != -1) {
PADDLE_ENFORCE(!use_cudnn, "CUDNN kernel only support axis as -1.");
PADDLE_ENFORCE(!use_mkldnn, "MKLDNN kernel only support axis as -1.");
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
......@@ -80,8 +94,12 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X",
"The input tensor of softmax, "
"whose last dimension is the input_feature_dimensions.");
"whose dimension :attr:`axis` is the input_feature_dimensions.");
AddOutput("Out", "The normalized values with the same shape as X.");
AddAttr<int>("axis",
"The dimension index of Input(x) to perform softmax,"
"default -1 for last dimension")
.SetDefault(-1);
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
......@@ -106,12 +124,13 @@ Softmax Operator.
The input of the softmax operator is a tensor of any rank. The output tensor
has the same shape as the input.
The input tensor will first be logically flattened to a 2-D matrix. The matrix's
second dimension(row length) is as same as the last dimension of the input
The dimension :attr:`axis` of the input tensor will be permuted to the last.
Then the input tensor will be logically flattened to a 2-D matrix. The matrix's
second dimension(row length) is as same as the dimension :attr:`axis` of the input
tensor, and the first dimension(column length) is the product of all other
dimensions of the input tensor. For each row of the matrix, the softmax operator
squashes the K-dimensional(K is the width of the matrix, which is also the size
of the input tensor's last dimension) vector of arbitrary real values to a
of the input tensor's dimension :attr:`axis`) vector of arbitrary real values to a
K-dimensional vector of real values in the range [0, 1] that add up to 1.
It computes the exponential of the given dimension and the sum of exponential
values of all the other dimensions in the K-dimensional vector input.
......
......@@ -20,6 +20,30 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
static inline int CanonicalAxis(const int axis, const int rank) {
if (axis < 0) {
return axis + rank;
}
return axis;
}
static inline int SizeToAxis(const int axis, DDim dims) {
int size = 1;
for (int i = 0; i < axis; i++) {
size *= dims[i];
}
return size;
}
static inline int SizeFromAxis(const int axis, DDim dims) {
int size = 1;
for (int i = axis; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
template <typename DeviceContext, typename T>
class SoftmaxKernel : public framework::OpKernel<T> {
......@@ -27,20 +51,27 @@ class SoftmaxKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out");
const int rank = X->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = X->dims()[axis];
// allocate memory on device.
Out->mutable_data<T>(context.GetPlace());
int rank = X->dims().size();
Tensor X_2d = framework::ReshapeToMatrix(*X, rank - 1);
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
const int n = SizeToAxis(axis, X->dims());
const int d = SizeFromAxis(axis, X->dims());
Tensor X_2d, Out_2d;
X_2d.ShareDataWith(*X).Resize({n, d});
Out_2d.ShareDataWith(*Out).Resize({n, d});
#ifdef PADDLE_ON_INFERENCE
math::SoftmaxFunctor<DeviceContext, T, true>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
context.template device_context<DeviceContext>(), axis_dim, &X_2d,
&Out_2d);
#else
math::SoftmaxFunctor<DeviceContext, T, false>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
context.template device_context<DeviceContext>(), axis_dim, &X_2d,
&Out_2d);
#endif
}
};
......@@ -52,18 +83,23 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
auto* Out = context.Input<Tensor>("Out");
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
const int rank = dX->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = dX->dims()[axis];
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
int rank = Out->dims().size();
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
Tensor dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
const int n = SizeToAxis(axis, dX->dims());
const int d = SizeFromAxis(axis, dX->dims());
Tensor dX_2d, Out_2d, dOut_2d;
dX_2d.ShareDataWith(*dX).Resize({n, d});
Out_2d.ShareDataWith(*Out).Resize({n, d});
dOut_2d.ShareDataWith(*dOut).Resize({n, d});
math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), &Out_2d, &dOut_2d,
&dX_2d);
context.template device_context<DeviceContext>(), axis_dim, &Out_2d,
&dOut_2d, &dX_2d);
}
};
......
......@@ -40,10 +40,12 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
int axis_dim = logits->dims()[logits->dims().size() - 1];
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
dev_ctx, logits, softmax);
dev_ctx, axis_dim, logits, softmax);
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
dev_ctx, loss, softmax, labels, context.Attr<bool>("soft_label"),
context.Attr<int>("ignore_index"));
......
......@@ -67,9 +67,11 @@ class CudnnCTCKernel : public framework::OpKernel<T> {
softmax_logits.mutable_data<T>(logits->dims(), ctx.GetPlace());
softmax_logits.set_lod(logits_lod);
int rank = logits->dims().size();
int axis_dim = logits->dims()[rank - 1];
Tensor in_2d = framework::ReshapeToMatrix(*logits, rank - 1);
Tensor out_2d = framework::ReshapeToMatrix(softmax_logits, rank - 1);
math::SoftmaxFunctor<DeviceContext, T, false>()(dev_ctx, &in_2d, &out_2d);
math::SoftmaxFunctor<DeviceContext, T, false>()(dev_ctx, axis_dim, &in_2d,
&out_2d);
// ctc needs sequences data stored in transposed padding format
// logits and grad using padding data of layout 'TNC'
......
......@@ -1820,17 +1820,18 @@ def sequence_softmax(input, use_cudnn=False, name=None):
return softmax_out
def softmax(input, use_cudnn=False, name=None):
def softmax(input, use_cudnn=False, name=None, axis=-1):
"""
The input of the softmax operator is a tensor of any rank. The output tensor
has the same shape as the input.
The input tensor will first be logically flattened to a 2-D matrix. The matrix's
second dimension(row length) is as same as the last dimension of the input
The dimension :attr:`axis` of the input tensor will be permuted to the last.
Then the input tensor will be logically flattened to a 2-D matrix. The matrix's
second dimension(row length) is the same as the dimension :attr:`axis` of the input
tensor, and the first dimension(column length) is the product of all other
dimensions of the input tensor. For each row of the matrix, the softmax operator
squashes the K-dimensional(K is the width of the matrix, which is also the size
of the input tensor's last dimension) vector of arbitrary real values to a
of the input tensor's dimension :attr:`axis`) vector of arbitrary real values to a
K-dimensional vector of real values in the range [0, 1] that add up to 1.
It computes the exponential of the given dimension and the sum of exponential
......@@ -1852,6 +1853,9 @@ def softmax(input, use_cudnn=False, name=None):
False by default. Default: False
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None.
axis (int): The index of dimension to perform softmax calculations, it should
be in range :math:`[-1, rank - 1]`, while :math:`rank` is the rank of
input variable. Default: -1.
Returns:
Variable: output of softmax
......@@ -1861,7 +1865,10 @@ def softmax(input, use_cudnn=False, name=None):
.. code-block:: python
fc = fluid.layers.fc(input=x, size=10)
softmax = fluid.layers.softmax(input=fc)
# perform softmax in the second dimension
softmax = fluid.layers.softmax(input=fc, axis=1)
# perform softmax in the last dimension
softmax = fluid.layers.softmax(input=fc, axis=-1)
"""
helper = LayerHelper('softmax', **locals())
......@@ -1871,7 +1878,8 @@ def softmax(input, use_cudnn=False, name=None):
type="softmax",
inputs={"X": input},
outputs={"Out": softmax_out},
attrs={"use_cudnn": use_cudnn})
attrs={"axis": axis,
"use_cudnn": use_cudnn})
return softmax_out
......
......@@ -845,7 +845,7 @@ class TestBook(unittest.TestCase):
with program_guard(program):
data = layers.data(name='data', shape=[10], dtype='float32')
hid = layers.fc(input=data, size=20)
self.assertIsNotNone(layers.softmax(hid))
self.assertIsNotNone(layers.softmax(hid, axis=1))
print(str(program))
def test_space_to_depth(self):
......
......@@ -31,6 +31,9 @@ class TestSoftmaxOp(OpTest):
def get_x_shape(self):
return [10, 10]
def get_axis(self):
return -1
def setUp(self):
self.op_type = "softmax"
self.use_cudnn = False
......@@ -38,15 +41,15 @@ class TestSoftmaxOp(OpTest):
self.dtype = np.float32
self.init_kernel_type()
self.shape = self.get_x_shape()
self.axis = self.get_axis()
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
out = np.apply_along_axis(stable_softmax, 1,
x.reshape([-1, self.shape[-1]]))
out = out.reshape(self.shape)
out = np.apply_along_axis(stable_softmax, self.axis, x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {
'axis': self.axis,
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn
}
......@@ -76,6 +79,38 @@ class TestSoftmaxOp2(TestSoftmaxOp):
return [2, 3, 4, 5]
class TestSoftmaxOp3(TestSoftmaxOp):
def get_x_shape(self):
return [2, 3, 4, 5]
def get_axis(self):
return 0
class TestSoftmaxOp4(TestSoftmaxOp):
def get_x_shape(self):
return [2, 3, 4, 5]
def get_axis(self):
return 1
class TestSoftmaxOp5(TestSoftmaxOp):
def get_x_shape(self):
return [2, 3, 4, 5]
def get_axis(self):
return 2
class TestSoftmaxOp5(TestSoftmaxOp):
def get_x_shape(self):
return [2, 3, 4, 5]
def get_axis(self):
return 3
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxCUDNNOp(TestSoftmaxOp):
......@@ -90,6 +125,16 @@ class TestSoftmaxCUDNNOp2(TestSoftmaxCUDNNOp):
return [2, 3, 4, 5]
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxCUDNNOp5(TestSoftmaxCUDNNOp):
def get_x_shape(self):
return [2, 3, 4, 5]
def get_axis(self):
return 3
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxFP16Op(TestSoftmaxOp):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册