提交 93701dba 编写于 作者: D dengkaipeng

add jit kernel for softmax axis. test=develop

上级 6c641827
...@@ -386,7 +386,7 @@ void BenchKernelSoftmax() { ...@@ -386,7 +386,7 @@ void BenchKernelSoftmax() {
RandomVec<T>(bs * n, x.mutable_data<T>(PlaceType()), -2.f, 2.f); RandomVec<T>(bs * n, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* x_data = x.data<T>(); const T* x_data = x.data<T>();
T* y_data = y.mutable_data<T>(PlaceType()); T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KernelTuple, PlaceType>(n, x_data, y_data, n, bs); BenchAllImpls<KernelTuple, PlaceType>(n, x_data, y_data, n, bs, 1);
} }
} }
} }
......
...@@ -34,6 +34,7 @@ const char* to_string(KernelType kt) { ...@@ -34,6 +34,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kVAddRelu); ONE_CASE(kVAddRelu);
ONE_CASE(kVSub); ONE_CASE(kVSub);
ONE_CASE(kVScal); ONE_CASE(kVScal);
ONE_CASE(kStrideScal);
ONE_CASE(kVAddBias); ONE_CASE(kVAddBias);
ONE_CASE(kVRelu); ONE_CASE(kVRelu);
ONE_CASE(kVBroadcast); ONE_CASE(kVBroadcast);
...@@ -55,6 +56,7 @@ const char* to_string(KernelType kt) { ...@@ -55,6 +56,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kMatMul); ONE_CASE(kMatMul);
ONE_CASE(kHMax); ONE_CASE(kHMax);
ONE_CASE(kHSum); ONE_CASE(kHSum);
ONE_CASE(kStrideSum);
ONE_CASE(kSoftmax); ONE_CASE(kSoftmax);
ONE_CASE(kEmbSeqPool); ONE_CASE(kEmbSeqPool);
ONE_CASE(kSgd); ONE_CASE(kSgd);
......
...@@ -53,6 +53,8 @@ typedef enum { ...@@ -53,6 +53,8 @@ typedef enum {
kVSquare, kVSquare,
kVSub, kVSub,
kVTanh, kVTanh,
kStrideSum,
kStrideScal,
} KernelType; } KernelType;
typedef enum { typedef enum {
...@@ -74,6 +76,14 @@ struct XYZNTuple { ...@@ -74,6 +76,14 @@ struct XYZNTuple {
template <typename T> template <typename T>
struct AXYNTuple : public XYZNTuple<T> {}; struct AXYNTuple : public XYZNTuple<T> {};
// a, x, y, n, stride
template <typename T>
struct AXYNSTuple {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int, int);
};
// x, y, n // x, y, n
template <typename T> template <typename T>
struct XYNTuple { struct XYNTuple {
...@@ -86,6 +96,14 @@ struct XYNTuple { ...@@ -86,6 +96,14 @@ struct XYNTuple {
template <typename T> template <typename T>
struct XRNTuple : public XYNTuple<T> {}; struct XRNTuple : public XYNTuple<T> {};
// x, returned value, n, stride
template <typename T>
struct XRNSTuple {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, T*, int, int);
};
#define DECLARE_KERNELTUPLE(kernel_tuple, type) \ #define DECLARE_KERNELTUPLE(kernel_tuple, type) \
template <typename T> \ template <typename T> \
struct type##Tuple : public kernel_tuple<T> { \ struct type##Tuple : public kernel_tuple<T> { \
...@@ -101,6 +119,8 @@ DECLARE_KERNELTUPLE(XYZNTuple, VSub); ...@@ -101,6 +119,8 @@ DECLARE_KERNELTUPLE(XYZNTuple, VSub);
DECLARE_KERNELTUPLE(AXYNTuple, VScal); DECLARE_KERNELTUPLE(AXYNTuple, VScal);
DECLARE_KERNELTUPLE(AXYNTuple, VAddBias); DECLARE_KERNELTUPLE(AXYNTuple, VAddBias);
DECLARE_KERNELTUPLE(AXYNSTuple, StrideScal);
DECLARE_KERNELTUPLE(XYNTuple, VRelu); DECLARE_KERNELTUPLE(XYNTuple, VRelu);
DECLARE_KERNELTUPLE(XYNTuple, VIdentity); DECLARE_KERNELTUPLE(XYNTuple, VIdentity);
DECLARE_KERNELTUPLE(XYNTuple, VSquare); DECLARE_KERNELTUPLE(XYNTuple, VSquare);
...@@ -112,6 +132,8 @@ DECLARE_KERNELTUPLE(XYNTuple, VCopy); ...@@ -112,6 +132,8 @@ DECLARE_KERNELTUPLE(XYNTuple, VCopy);
DECLARE_KERNELTUPLE(XRNTuple, HMax); DECLARE_KERNELTUPLE(XRNTuple, HMax);
DECLARE_KERNELTUPLE(XRNTuple, HSum); DECLARE_KERNELTUPLE(XRNTuple, HSum);
DECLARE_KERNELTUPLE(XRNSTuple, StrideSum);
typedef struct { typedef struct {
void* gates; // gates: x_ch, x_ih, x_fh, x_oh void* gates; // gates: x_ch, x_ih, x_fh, x_oh
const void* ct_1; const void* ct_1;
...@@ -285,7 +307,7 @@ struct SoftmaxTuple { ...@@ -285,7 +307,7 @@ struct SoftmaxTuple {
static constexpr KernelType kernel_type = kSoftmax; static constexpr KernelType kernel_type = kSoftmax;
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const T*, T*, int, int); typedef void (*func_type)(const T*, T*, int, int, int);
}; };
// nChw16c = nChw16c .* NC // nChw16c = nChw16c .* NC
......
...@@ -50,10 +50,12 @@ void VTanh(const T* x, T* y, int n) { ...@@ -50,10 +50,12 @@ void VTanh(const T* x, T* y, int n) {
compute_addbias(&b, y, y, n); compute_addbias(&b, y, y, n);
} }
void Softmax(const T* x, T* y, int n, int bs) { void Softmax(const T* x, T* y, int n, int bs, int m) {
auto compute_hmax = KernelFuncs<HMaxTuple<T>, CPUPlace>::Cache().At(n); auto compute_hmax = KernelFuncs<HMaxTuple<T>, CPUPlace>::Cache().At(n);
auto compute_hsum = KernelFuncs<HSumTuple<T>, CPUPlace>::Cache().At(n); auto compute_hsum = KernelFuncs<HSumTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vscal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n); auto compute_vscal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
auto compute_stridesum = KernelFuncs<StrideSumTuple<T>, CPUPlace>::Cache().At(n);
auto compute_stridescal = KernelFuncs<StrideScalTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vaddbias = auto compute_vaddbias =
KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n); KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vexp = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n); auto compute_vexp = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
...@@ -64,9 +66,17 @@ void Softmax(const T* x, T* y, int n, int bs) { ...@@ -64,9 +66,17 @@ void Softmax(const T* x, T* y, int n, int bs) {
scalar = static_cast<T>(0) - scalar; scalar = static_cast<T>(0) - scalar;
compute_vaddbias(&scalar, x, y, n); // x - max compute_vaddbias(&scalar, x, y, n); // x - max
compute_vexp(y, y, n); compute_vexp(y, y, n);
if (m == 1) {
compute_hsum(y, &scalar, n); compute_hsum(y, &scalar, n);
scalar = static_cast<T>(1) / scalar; scalar = static_cast<T>(1) / scalar;
compute_vscal(&scalar, y, y, n); compute_vscal(&scalar, y, y, n);
} else {
for (int j = 0; j < m; ++j) {
compute_stridesum(&y[j], &scalar, n, m);
scalar = static_cast<T>(1) / scalar;
compute_stridescal(&scalar, &y[j], &y[j], n, m);
}
}
x += n; x += n;
y += n; y += n;
} }
......
...@@ -26,7 +26,7 @@ using T = float; ...@@ -26,7 +26,7 @@ using T = float;
void VSigmoid(const T* x, T* y, int n); void VSigmoid(const T* x, T* y, int n);
void VTanh(const T* x, T* y, int n); void VTanh(const T* x, T* y, int n);
void Softmax(const T* x, T* y, int n, int bs); void Softmax(const T* x, T* y, int n, int bs, int m);
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr); void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr);
void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr); void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr);
......
...@@ -7,6 +7,7 @@ USE_JITKERNEL_MORE(kMatMul, mkl) ...@@ -7,6 +7,7 @@ USE_JITKERNEL_MORE(kMatMul, mkl)
USE_JITKERNEL_MORE(kVMul, mkl) USE_JITKERNEL_MORE(kVMul, mkl)
USE_JITKERNEL_MORE(kVAdd, mkl) USE_JITKERNEL_MORE(kVAdd, mkl)
USE_JITKERNEL_MORE(kVScal, mkl) USE_JITKERNEL_MORE(kVScal, mkl)
USE_JITKERNEL_MORE(kStrideScal, mkl)
USE_JITKERNEL_MORE(kVExp, mkl) USE_JITKERNEL_MORE(kVExp, mkl)
USE_JITKERNEL_MORE(kVSquare, mkl) USE_JITKERNEL_MORE(kVSquare, mkl)
USE_JITKERNEL_MORE(kVCopy, mkl) USE_JITKERNEL_MORE(kVCopy, mkl)
......
...@@ -78,6 +78,24 @@ void VScal<double>(const double* a, const double* x, double* y, int n) { ...@@ -78,6 +78,24 @@ void VScal<double>(const double* a, const double* x, double* y, int n) {
} }
} }
template <>
void StrideScal<float>(const float* a, const float* x, float* y, int n, int stride) {
if (x == y) {
platform::dynload::cblas_sscal(n, *a, y, stride);
} else {
refer::StrideScal<float>(a, x, y, n, stride);
}
}
template <>
void StrideScal<double>(const double* a, const double* x, double* y, int n, int stride) {
if (x == y) {
platform::dynload::cblas_dscal(n, *a, y, stride);
} else {
refer::StrideScal<double>(a, x, y, n, stride);
}
}
template <> template <>
void VExp<float>(const float* x, float* y, int n) { void VExp<float>(const float* x, float* y, int n) {
platform::dynload::vsExp(n, x, y); platform::dynload::vsExp(n, x, y);
...@@ -128,6 +146,16 @@ void ASum<double>(const double* x, double* res, int n) { ...@@ -128,6 +146,16 @@ void ASum<double>(const double* x, double* res, int n) {
res[0] = platform::dynload::cblas_dasum(n, x, 1); res[0] = platform::dynload::cblas_dasum(n, x, 1);
} }
template <>
void StrideSum<float>(const float* x, float* res, int n, int stride) {
res[0] = platform::dynload::cblas_sasum(n, x, stride);
}
template <>
void StrideSum<double>(const double* x, double* res, int n, int stride) {
res[0] = platform::dynload::cblas_dasum(n, x, stride);
}
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <> template <>
bool VMulKernel<float>::CanBeUsed(const int& d) const { bool VMulKernel<float>::CanBeUsed(const int& d) const {
...@@ -144,6 +172,11 @@ bool VScalKernel<float>::CanBeUsed(const int& d) const { ...@@ -144,6 +172,11 @@ bool VScalKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512; return platform::MayIUse(platform::avx512f) && d > 512;
} }
template <>
bool StrideScalKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <> template <>
bool VExpKernel<float>::CanBeUsed(const int& d) const { bool VExpKernel<float>::CanBeUsed(const int& d) const {
return d > 7; return d > 7;
...@@ -235,6 +268,7 @@ bool SoftmaxKernel<float>::CanBeUsed(const int& d) const { ...@@ -235,6 +268,7 @@ bool SoftmaxKernel<float>::CanBeUsed(const int& d) const {
AWALYS_USE_ME_WITH_DOUBLE(VMul); AWALYS_USE_ME_WITH_DOUBLE(VMul);
AWALYS_USE_ME_WITH_DOUBLE(VAdd); AWALYS_USE_ME_WITH_DOUBLE(VAdd);
AWALYS_USE_ME_WITH_DOUBLE(VScal); AWALYS_USE_ME_WITH_DOUBLE(VScal);
AWALYS_USE_ME_WITH_DOUBLE(StrideScal);
AWALYS_USE_ME_WITH_DOUBLE(VExp); AWALYS_USE_ME_WITH_DOUBLE(VExp);
AWALYS_USE_ME_WITH_DOUBLE(VSigmoid); AWALYS_USE_ME_WITH_DOUBLE(VSigmoid);
AWALYS_USE_ME_WITH_DOUBLE(VTanh); AWALYS_USE_ME_WITH_DOUBLE(VTanh);
...@@ -259,6 +293,7 @@ REGISTER_MKL_KERNEL(MatMul); ...@@ -259,6 +293,7 @@ REGISTER_MKL_KERNEL(MatMul);
REGISTER_MKL_KERNEL(VMul); REGISTER_MKL_KERNEL(VMul);
REGISTER_MKL_KERNEL(VAdd); REGISTER_MKL_KERNEL(VAdd);
REGISTER_MKL_KERNEL(VScal); REGISTER_MKL_KERNEL(VScal);
REGISTER_MKL_KERNEL(StrideScal);
REGISTER_MKL_KERNEL(VExp); REGISTER_MKL_KERNEL(VExp);
REGISTER_MKL_KERNEL(VSquare); REGISTER_MKL_KERNEL(VSquare);
REGISTER_MKL_KERNEL(VCopy); REGISTER_MKL_KERNEL(VCopy);
......
...@@ -129,7 +129,13 @@ template <typename T> ...@@ -129,7 +129,13 @@ template <typename T>
void ASum(const T* x, T* res, int n); void ASum(const T* x, T* res, int n);
template <typename T> template <typename T>
void Softmax(const T* x, T* y, int n, int bs) { void StrideSum(const T* x, T* res, int n, int stride);
template <typename T>
void StrideScal(const T* a, const T* x, T* y, int n, int stride);
template <typename T>
void Softmax(const T* x, T* y, int n, int bs, int m=1) {
std::vector<T> entities(bs); std::vector<T> entities(bs);
for (int i = 0; i < bs; ++i) { for (int i = 0; i < bs; ++i) {
entities[i] = x[i * n]; entities[i] = x[i * n];
...@@ -143,9 +149,17 @@ void Softmax(const T* x, T* y, int n, int bs) { ...@@ -143,9 +149,17 @@ void Softmax(const T* x, T* y, int n, int bs) {
VExp(y, y, n * bs); VExp(y, y, n * bs);
for (int i = 0; i < bs; ++i) { for (int i = 0; i < bs; ++i) {
T sum; T sum;
if (m == 1) {
ASum(&y[i * n], &sum, n); ASum(&y[i * n], &sum, n);
sum = static_cast<T>(1) / sum; sum = static_cast<T>(1) / sum;
VScal(&sum, &y[i * n], &y[i * n], n); VScal(&sum, &y[i * n], &y[i * n], n);
} else {
for (int j = 0; j < m; ++j) {
StrideSum(&y[i * n + j], &sum, n/m, m);
sum = static_cast<T>(1) / sum;
StrideScal(&sum, &y[i * n + j], &y[i * n + j], n/m, m);
}
}
} }
} }
...@@ -193,6 +207,7 @@ DECLARE_MKL_KERNEL(VAdd); ...@@ -193,6 +207,7 @@ DECLARE_MKL_KERNEL(VAdd);
// AXYN // AXYN
DECLARE_MKL_KERNEL(VScal); DECLARE_MKL_KERNEL(VScal);
DECLARE_MKL_KERNEL(StrideScal);
// XYN // XYN
DECLARE_MKL_KERNEL(VExp); DECLARE_MKL_KERNEL(VExp);
......
...@@ -12,6 +12,7 @@ USE_JITKERNEL_REFER(kVAdd) ...@@ -12,6 +12,7 @@ USE_JITKERNEL_REFER(kVAdd)
USE_JITKERNEL_REFER(kVAddRelu) USE_JITKERNEL_REFER(kVAddRelu)
USE_JITKERNEL_REFER(kVSub) USE_JITKERNEL_REFER(kVSub)
USE_JITKERNEL_REFER(kVScal) USE_JITKERNEL_REFER(kVScal)
USE_JITKERNEL_REFER(kStrideScal)
USE_JITKERNEL_REFER(kVAddBias) USE_JITKERNEL_REFER(kVAddBias)
USE_JITKERNEL_REFER(kVCopy) USE_JITKERNEL_REFER(kVCopy)
USE_JITKERNEL_REFER(kVRelu) USE_JITKERNEL_REFER(kVRelu)
...@@ -32,6 +33,7 @@ USE_JITKERNEL_REFER(kMatMul) ...@@ -32,6 +33,7 @@ USE_JITKERNEL_REFER(kMatMul)
USE_JITKERNEL_REFER(kVSquare) USE_JITKERNEL_REFER(kVSquare)
USE_JITKERNEL_REFER(kHSum) USE_JITKERNEL_REFER(kHSum)
USE_JITKERNEL_REFER(kHMax) USE_JITKERNEL_REFER(kHMax)
USE_JITKERNEL_REFER(kStrideSum)
USE_JITKERNEL_REFER(kSoftmax) USE_JITKERNEL_REFER(kSoftmax)
USE_JITKERNEL_REFER(kEmbSeqPool) USE_JITKERNEL_REFER(kEmbSeqPool)
USE_JITKERNEL_REFER(kSgd) USE_JITKERNEL_REFER(kSgd)
......
...@@ -27,6 +27,7 @@ REGISTER_REFER_KERNEL(VAddRelu); ...@@ -27,6 +27,7 @@ REGISTER_REFER_KERNEL(VAddRelu);
REGISTER_REFER_KERNEL(VSub); REGISTER_REFER_KERNEL(VSub);
REGISTER_REFER_KERNEL(VScal); REGISTER_REFER_KERNEL(VScal);
REGISTER_REFER_KERNEL(StrideScal);
REGISTER_REFER_KERNEL(VAddBias); REGISTER_REFER_KERNEL(VAddBias);
REGISTER_REFER_KERNEL(VRelu); REGISTER_REFER_KERNEL(VRelu);
...@@ -51,6 +52,7 @@ REGISTER_REFER_KERNEL(SeqPool); ...@@ -51,6 +52,7 @@ REGISTER_REFER_KERNEL(SeqPool);
REGISTER_REFER_KERNEL(MatMul); REGISTER_REFER_KERNEL(MatMul);
REGISTER_REFER_KERNEL(HMax); REGISTER_REFER_KERNEL(HMax);
REGISTER_REFER_KERNEL(HSum); REGISTER_REFER_KERNEL(HSum);
REGISTER_REFER_KERNEL(StrideSum);
REGISTER_REFER_KERNEL(Softmax); REGISTER_REFER_KERNEL(Softmax);
REGISTER_REFER_KERNEL(EmbSeqPool); REGISTER_REFER_KERNEL(EmbSeqPool);
REGISTER_REFER_KERNEL(Sgd); REGISTER_REFER_KERNEL(Sgd);
......
...@@ -411,19 +411,42 @@ void HSum(const T* x, T* res, int n) { ...@@ -411,19 +411,42 @@ void HSum(const T* x, T* res, int n) {
} }
} }
template <typename T>
void StrideSum(const T* x, T* res, int n, int stride) {
res[0] = x[0];
for (int i = stride; i < n; i+=stride) {
res[0] += x[i];
}
}
template <typename T>
void StrideScal(const T* a, const T* x, T* y, int n , int stride) {
for (int i = 0; i < n; i+=stride) {
y[i] = x[i] * a[0];
}
}
// y = e^(x - max(x)) // y = e^(x - max(x))
// y = y / sum(y) // y = y / sum(y)
template <typename T> template <typename T>
void Softmax(const T* x, T* y, int n, int bs = 1) { void Softmax(const T* x, T* y, int n, int bs = 1, int m = 1) {
for (int i = 0; i < bs; ++i) { for (int i = 0; i < bs; ++i) {
T scalar; T scalar;
HMax(x, &scalar, n); HMax(x, &scalar, n);
scalar = static_cast<T>(0) - scalar; scalar = static_cast<T>(0) - scalar;
VAddBias(&scalar, x, y, n); // x - max VAddBias(&scalar, x, y, n); // x - max
VExp(y, y, n); VExp(y, y, n);
if (m == 1) {
HSum(y, &scalar, n); HSum(y, &scalar, n);
scalar = static_cast<T>(1) / scalar; scalar = static_cast<T>(1) / scalar;
VScal(&scalar, y, y, n); VScal(&scalar, y, y, n);
} else {
for (int j = 0; j < m; j++) {
StrideSum(&y[j], &scalar, n, m);
scalar = static_cast<T>(1) / scalar;
StrideScal(&scalar, &y[j], &y[j], n, m);
}
}
x += n; x += n;
y += n; y += n;
} }
...@@ -507,6 +530,9 @@ DECLARE_REFER_KERNEL(VSub); ...@@ -507,6 +530,9 @@ DECLARE_REFER_KERNEL(VSub);
DECLARE_REFER_KERNEL(VScal); DECLARE_REFER_KERNEL(VScal);
DECLARE_REFER_KERNEL(VAddBias); DECLARE_REFER_KERNEL(VAddBias);
// const T* a, const T* x, T* y, int n, int stride
DECLARE_REFER_KERNEL(StrideScal);
// const T* x, T* y, int n // const T* x, T* y, int n
DECLARE_REFER_KERNEL(VRelu); DECLARE_REFER_KERNEL(VRelu);
DECLARE_REFER_KERNEL(VIdentity); DECLARE_REFER_KERNEL(VIdentity);
...@@ -528,6 +554,8 @@ DECLARE_REFER_KERNEL(GRUHtPart2); ...@@ -528,6 +554,8 @@ DECLARE_REFER_KERNEL(GRUHtPart2);
DECLARE_REFER_KERNEL(HMax); DECLARE_REFER_KERNEL(HMax);
DECLARE_REFER_KERNEL(HSum); DECLARE_REFER_KERNEL(HSum);
DECLARE_REFER_KERNEL(StrideSum);
// others // others
DECLARE_REFER_KERNEL(CRFDecoding); DECLARE_REFER_KERNEL(CRFDecoding);
DECLARE_REFER_KERNEL(LayerNorm); DECLARE_REFER_KERNEL(LayerNorm);
......
...@@ -723,6 +723,10 @@ void TestKernelSoftmax() { ...@@ -723,6 +723,10 @@ void TestKernelSoftmax() {
VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
for (int bs : {1, 2, 10}) { for (int bs : {1, 2, 10}) {
for (int n : TestSizes()) { for (int n : TestSizes()) {
for (int m : {1, 2}) {
if (m > n || n % m != 0) {
continue;
}
auto ref = jit::GetReferFunc<KernelTuple>(); auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> x(bs * n), y(bs * n); std::vector<T> x(bs * n), y(bs * n);
...@@ -732,14 +736,14 @@ void TestKernelSoftmax() { ...@@ -732,14 +736,14 @@ void TestKernelSoftmax() {
std::vector<T> xinp(x.size()); // inplace test std::vector<T> xinp(x.size()); // inplace test
std::copy(x.begin(), x.end(), xinp.begin()); std::copy(x.begin(), x.end(), xinp.begin());
ref(x_data, y_data, n, bs); ref(x_data, y_data, n, bs, m);
T* xinp_data = xinp.data(); T* xinp_data = xinp.data();
ref(xinp_data, xinp_data, n, bs); ref(xinp_data, xinp_data, n, bs, m);
ExpectEQ<T>(xinp_data, y_data, n * bs); ExpectEQ<T>(xinp_data, y_data, n * bs);
auto verifier = [](const typename KernelTuple::func_type tgt, auto verifier = [](const typename KernelTuple::func_type tgt,
const std::vector<T>& x, const std::vector<T>& yref, const std::vector<T>& x, const std::vector<T>& yref,
int n, int bs) { int n, int bs, int m) {
EXPECT_TRUE(tgt != nullptr); EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(yref.size(), x.size()); EXPECT_EQ(yref.size(), x.size());
EXPECT_EQ(x.size(), static_cast<size_t>(n * bs)); EXPECT_EQ(x.size(), static_cast<size_t>(n * bs));
...@@ -748,14 +752,15 @@ void TestKernelSoftmax() { ...@@ -748,14 +752,15 @@ void TestKernelSoftmax() {
std::vector<T> ytgt(n * bs); std::vector<T> ytgt(n * bs);
T* ytgt_data = ytgt.data(); T* ytgt_data = ytgt.data();
// test normal // test normal
tgt(x_data, ytgt_data, n, bs); tgt(x_data, ytgt_data, n, bs, m);
ExpectEQ<T>(ytgt_data, yref_data, n * bs); ExpectEQ<T>(ytgt_data, yref_data, n * bs);
// test inplace x // test inplace x
std::copy(x.begin(), x.end(), ytgt.begin()); std::copy(x.begin(), x.end(), ytgt.begin());
tgt(ytgt_data, ytgt_data, n, bs); tgt(ytgt_data, ytgt_data, n, bs, m);
ExpectEQ<T>(ytgt_data, yref_data, n * bs); ExpectEQ<T>(ytgt_data, yref_data, n * bs);
}; };
TestAllImpls<KernelTuple, PlaceType>(n, verifier, x, y, n, bs); TestAllImpls<KernelTuple, PlaceType>(n, verifier, x, y, n, bs, m);
}
} }
} }
} }
......
...@@ -76,8 +76,8 @@ using enable_if_CPU = typename std::enable_if< ...@@ -76,8 +76,8 @@ using enable_if_CPU = typename std::enable_if<
template <typename DeviceContext> template <typename DeviceContext>
class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> { class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
void operator()(const DeviceContext& context, const framework::Tensor* X, void operator()(const DeviceContext& context, const int axis_dim,
framework::Tensor* Y) { const framework::Tensor* X, framework::Tensor* Y) {
auto in_dims = X->dims(); auto in_dims = X->dims();
const float* in_data = X->data<float>(); const float* in_data = X->data<float>();
float* out_data = Y->data<float>(); float* out_data = Y->data<float>();
...@@ -87,7 +87,8 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> { ...@@ -87,7 +87,8 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
auto compute_softmax = auto compute_softmax =
jit::KernelFuncs<jit::SoftmaxTuple<float>, platform::CPUPlace>::Cache() jit::KernelFuncs<jit::SoftmaxTuple<float>, platform::CPUPlace>::Cache()
.At(in_dims[kClassDim]); .At(in_dims[kClassDim]);
compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]); compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim],
in_dims[kClassDim] / axis_dim);
} }
}; };
......
...@@ -42,9 +42,18 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -42,9 +42,18 @@ class SoftmaxOp : public framework::OperatorWithKernel {
auto dim_x = ctx->GetInputDim("X"); auto dim_x = ctx->GetInputDim("X");
auto rank_x = dim_x.size(); auto rank_x = dim_x.size();
auto axis = ctx->Attrs().Get<int>("axis"); auto axis = ctx->Attrs().Get<int>("axis");
PADDLE_ENFORCE(axis >= -1 && axis < rank_x, PADDLE_ENFORCE(axis >= -rank_x && axis < rank_x,
"Attr(axis) value should larger equal then -1" "Attr(axis) value should be in range [-R, R-1], "
"and less then the rank of Input(X)"); "R is the rank of Input(X).");
auto use_cudnn = ctx->Attrs().Get<bool>("use_cudnn");
auto use_mkldnn = ctx->Attrs().Get<bool>("use_mkldnn");
if (axis != rank_x - 1 && axis != -1) {
PADDLE_ENFORCE(!use_cudnn,
"CUDNN kernel only support axis as -1.");
PADDLE_ENFORCE(!use_mkldnn,
"MKLDNN kernel only support axis as -1.");
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
......
...@@ -63,8 +63,6 @@ class SoftmaxKernel : public framework::OpKernel<T> { ...@@ -63,8 +63,6 @@ class SoftmaxKernel : public framework::OpKernel<T> {
Tensor X_2d, Out_2d; Tensor X_2d, Out_2d;
X_2d.ShareDataWith(*X).Resize({n, d}); X_2d.ShareDataWith(*X).Resize({n, d});
Out_2d.ShareDataWith(*Out).Resize({n, d}); Out_2d.ShareDataWith(*Out).Resize({n, d});
// Tensor X_2d = framework::ReshapeToMatrix(*X, axis - 1);
// Tensor Out_2d = framework::ReshapeToMatrix(*Out, axis - 1);
#ifdef PADDLE_ON_INFERENCE #ifdef PADDLE_ON_INFERENCE
math::SoftmaxFunctor<DeviceContext, T, true>()( math::SoftmaxFunctor<DeviceContext, T, true>()(
...@@ -96,9 +94,6 @@ class SoftmaxGradKernel : public framework::OpKernel<T> { ...@@ -96,9 +94,6 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
dX_2d.ShareDataWith(*dX).Resize({n, d}); dX_2d.ShareDataWith(*dX).Resize({n, d});
Out_2d.ShareDataWith(*Out).Resize({n, d}); Out_2d.ShareDataWith(*Out).Resize({n, d});
dOut_2d.ShareDataWith(*dOut).Resize({n, d}); dOut_2d.ShareDataWith(*dOut).Resize({n, d});
// Tensor Out_2d = framework::ReshapeToMatrix(*Out, axis - 1);
// Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, axis - 1);
// Tensor dX_2d = framework::ReshapeToMatrix(*dX, axis - 1);
math::SoftmaxGradFunctor<DeviceContext, T>()( math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), axis_dim, &Out_2d, &dOut_2d, context.template device_context<DeviceContext>(), axis_dim, &Out_2d, &dOut_2d,
......
...@@ -125,26 +125,6 @@ class TestSoftmaxCUDNNOp2(TestSoftmaxCUDNNOp): ...@@ -125,26 +125,6 @@ class TestSoftmaxCUDNNOp2(TestSoftmaxCUDNNOp):
return [2, 3, 4, 5] return [2, 3, 4, 5]
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxCUDNNOp3(TestSoftmaxCUDNNOp):
def get_x_shape(self):
return [2, 3, 4, 5]
def get_axis(self):
return 0
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxCUDNNOp4(TestSoftmaxCUDNNOp):
def get_x_shape(self):
return [2, 3, 4, 5]
def get_axis(self):
return 1
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestSoftmaxCUDNNOp5(TestSoftmaxCUDNNOp): class TestSoftmaxCUDNNOp5(TestSoftmaxCUDNNOp):
...@@ -152,7 +132,7 @@ class TestSoftmaxCUDNNOp5(TestSoftmaxCUDNNOp): ...@@ -152,7 +132,7 @@ class TestSoftmaxCUDNNOp5(TestSoftmaxCUDNNOp):
return [2, 3, 4, 5] return [2, 3, 4, 5]
def get_axis(self): def get_axis(self):
return 2 return 3
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册