diff --git a/paddle/fluid/operators/jit/more/mix/mix.cc b/paddle/fluid/operators/jit/more/mix/mix.cc index 463e45f6ce1bc8f9085e8efe05ae53486242b181..4f309501b6ddd6d9ac5e2a25e223241e318d41ca 100644 --- a/paddle/fluid/operators/jit/more/mix/mix.cc +++ b/paddle/fluid/operators/jit/more/mix/mix.cc @@ -50,7 +50,7 @@ void VTanh(const T* x, T* y, int n) { compute_addbias(&b, y, y, n); } -void Softmax(const T* x, T* y, int n, int bs, int m) { +void Softmax(const T* x, T* y, int n, int bs, int remain) { auto compute_hmax = KernelFuncs, CPUPlace>::Cache().At(n); auto compute_hsum = KernelFuncs, CPUPlace>::Cache().At(n); auto compute_vscal = KernelFuncs, CPUPlace>::Cache().At(n); @@ -66,15 +66,15 @@ void Softmax(const T* x, T* y, int n, int bs, int m) { scalar = static_cast(0) - scalar; compute_vaddbias(&scalar, x, y, n); // x - max compute_vexp(y, y, n); - if (m == 1) { + if (remain == 1) { compute_hsum(y, &scalar, n); scalar = static_cast(1) / scalar; compute_vscal(&scalar, y, y, n); } else { - for (int j = 0; j < m; ++j) { - compute_stridesum(&y[j], &scalar, n, m); + for (int j = 0; j < remain; ++j) { + compute_stridesum(&y[j], &scalar, n, remain); scalar = static_cast(1) / scalar; - compute_stridescal(&scalar, &y[j], &y[j], n, m); + compute_stridescal(&scalar, &y[j], &y[j], n, remain); } } x += n; diff --git a/paddle/fluid/operators/jit/more/mix/mix.h b/paddle/fluid/operators/jit/more/mix/mix.h index a0079506f8d71216f8e89f0d1ec386f2161cdb93..035425317edca95bc574807fa029ff373a7e10b8 100644 --- a/paddle/fluid/operators/jit/more/mix/mix.h +++ b/paddle/fluid/operators/jit/more/mix/mix.h @@ -26,7 +26,7 @@ using T = float; void VSigmoid(const T* x, T* y, int n); void VTanh(const T* x, T* y, int n); -void Softmax(const T* x, T* y, int n, int bs, int m); +void Softmax(const T* x, T* y, int n, int bs, int remain); void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr); void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr); diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc index 9e21e2b8d39c4827e78c63f9dcf6a08f9be67cf9..fc8800ec7241167d32aff51d94d61cdd888944f6 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.cc +++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc @@ -81,7 +81,7 @@ void VScal(const double* a, const double* x, double* y, int n) { template <> void StrideScal(const float* a, const float* x, float* y, int n, int stride) { if (x == y) { - platform::dynload::cblas_sscal(n, *a, y, stride); + platform::dynload::cblas_sscal(n/stride, *a, y, stride); } else { refer::StrideScal(a, x, y, n, stride); } @@ -90,7 +90,7 @@ void StrideScal(const float* a, const float* x, float* y, int n, int stri template <> void StrideScal(const double* a, const double* x, double* y, int n, int stride) { if (x == y) { - platform::dynload::cblas_dscal(n, *a, y, stride); + platform::dynload::cblas_dscal(n/stride, *a, y, stride); } else { refer::StrideScal(a, x, y, n, stride); } @@ -148,12 +148,12 @@ void ASum(const double* x, double* res, int n) { template <> void StrideASum(const float* x, float* res, int n, int stride) { - res[0] = platform::dynload::cblas_sasum(n, x, stride); + res[0] = platform::dynload::cblas_sasum(n/stride, x, stride); } template <> void StrideASum(const double* x, double* res, int n, int stride) { - res[0] = platform::dynload::cblas_dasum(n, x, stride); + res[0] = platform::dynload::cblas_dasum(n/stride, x, stride); } // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h index 2f135f9e7a4a1fae903ff239d6328fd5ee717b12..1fbb87b0cf9ef1030d21a4c98da1e45503a397ec 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.h +++ b/paddle/fluid/operators/jit/more/mkl/mkl.h @@ -135,7 +135,7 @@ template void StrideScal(const T* a, const T* x, T* y, int n, int stride); template -void Softmax(const T* x, T* y, int n, int bs, int m=1) { +void Softmax(const T* x, T* y, int n, int bs, int remain=1) { std::vector entities(bs); for (int i = 0; i < bs; ++i) { entities[i] = x[i * n]; @@ -149,15 +149,15 @@ void Softmax(const T* x, T* y, int n, int bs, int m=1) { VExp(y, y, n * bs); for (int i = 0; i < bs; ++i) { T sum; - if (m == 1) { + if (remain == 1) { ASum(&y[i * n], &sum, n); sum = static_cast(1) / sum; VScal(&sum, &y[i * n], &y[i * n], n); } else { - for (int j = 0; j < m; ++j) { - StrideASum(&y[i * n + j], &sum, n/m, m); + for (int j = 0; j < remain; ++j) { + StrideASum(&y[i * n + j], &sum, n, remain); sum = static_cast(1) / sum; - StrideScal(&sum, &y[i * n + j], &y[i * n + j], n/m, m); + StrideScal(&sum, &y[i * n + j], &y[i * n + j], n, remain); } } } diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index e3387f60a6ec798130a894751cfbec0526243b64..c62925232ba0b42aa02f37e0f6d6dbac2bebd13c 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -421,30 +421,34 @@ void StrideASum(const T* x, T* res, int n, int stride) { template void StrideScal(const T* a, const T* x, T* y, int n , int stride) { - for (int i = 0; i < n; i+=stride) { - y[i] = x[i] * a[0]; + for (int i = 0; i < n; ++i) { + if (i % stride == 0) { + y[i] = x[i] * a[0]; + } else { + y[i] = x[i]; + } } } // y = e^(x - max(x)) // y = y / sum(y) template -void Softmax(const T* x, T* y, int n, int bs = 1, int m = 1) { +void Softmax(const T* x, T* y, int n, int bs = 1, int remain = 1) { for (int i = 0; i < bs; ++i) { T scalar; HMax(x, &scalar, n); scalar = static_cast(0) - scalar; VAddBias(&scalar, x, y, n); // x - max VExp(y, y, n); - if (m == 1) { + if (remain == 1) { HSum(y, &scalar, n); scalar = static_cast(1) / scalar; VScal(&scalar, y, y, n); } else { - for (int j = 0; j < m; j++) { - StrideASum(&y[j], &scalar, n, m); + for (int j = 0; j < remain; j++) { + StrideASum(&y[j], &scalar, n, remain); scalar = static_cast(1) / scalar; - StrideScal(&scalar, &y[j], &y[j], n, m); + StrideScal(&scalar, &y[j], &y[j], n, remain); } } x += n; diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index c47ec01d3e01d2a348a226798d0985b3c193a99a..1397e5be181e3dfa7ea72b361f1b28fc6f5f83e6 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -723,11 +723,10 @@ void TestKernelSoftmax() { VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); for (int bs : {1, 2, 10}) { for (int n : TestSizes()) { - for (int m : {1, 2}) { + for (int m : {1, 2, 3}) { // remain if (m > n || n % m != 0) { continue; } - VLOG(10) << "Softmax: " << bs << ", " << n << ", " << m; auto ref = jit::GetReferFunc(); EXPECT_TRUE(ref != nullptr); std::vector x(bs * n), y(bs * n); @@ -766,6 +765,86 @@ void TestKernelSoftmax() { } } +template +void TestKernelStrideASum() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); + for (int d : TestSizes()) { + for (int m : {1, 2, 3}) { // stride + if (m > d || d % m != 0) { + continue; + } + auto ref = jit::GetReferFunc(); + EXPECT_TRUE(ref != nullptr); + std::vector x(d); + RandomVec(d, x.data()); + T ref_res; + ref(x.data(), &ref_res, d, m); + + auto verifier = [](const typename KernelTuple::func_type tgt, + const std::vector& x, const T ref_res, + const int m) { + EXPECT_TRUE(tgt != nullptr); + T tgt_res; + tgt(x.data(), &tgt_res, x.size(), m); + ExpectEQ(&tgt_res, &ref_res, 1); + }; + TestAllImpls(d, verifier, x, ref_res, m); + } + } +} + +template +void TestKernelStrideScal() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); + // for (int d : TestSizes()) { + // for (int m : {1, 2, 3}) { // stride + for (int d : {4}) { + for (int m : {2}) { // stride + if (m > d || d % m != 0) { + continue; + } + auto ref = jit::GetReferFunc(); + EXPECT_TRUE(ref != nullptr); + + const T a = static_cast(3); + std::vector x(d), yref(d); + std::vector xinp(d); // inplace test + RandomVec(d, x.data()); + std::copy(x.begin(), x.end(), xinp.begin()); + + const T* x_data = x.data(); + T* yref_data = yref.data(); + T* xinp_data = xinp.data(); + // test refer code inplace + ref(&a, x_data, yref_data, d, m); + ref(&a, xinp_data, xinp_data, d, m); + ExpectEQ(xinp_data, yref_data, d); + + auto verifier = [](const typename KernelTuple::func_type tgt, const T a, + const std::vector& x, const std::vector& yref, + const int m) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(yref.size(), x.size()); + const T* x_data = x.data(); + const T* yref_data = yref.data(); + const int d = yref.size(); + std::vector ytgt(d); + T* ytgt_data = ytgt.data(); + // test normal + tgt(&a, x_data, ytgt_data, d, m); + ExpectEQ(ytgt_data, yref_data, d); + // test inplace x + std::copy(x.begin(), x.end(), ytgt.begin()); + tgt(&a, ytgt_data, ytgt_data, d, m); + ExpectEQ(ytgt_data, yref_data, d); + }; + TestAllImpls(d, verifier, a, x, yref, m); + } + } +} + template void TestKernelSgd() { using T = typename KernelTuple::data_type; @@ -918,7 +997,7 @@ TEST(JITKernel_pool, more) { EXPECT_EQ(kers.size(), 10UL); #else #ifdef PADDLE_WITH_MKLML - EXPECT_EQ(kers.size(), 21UL); + EXPECT_EQ(kers.size(), 22UL); #else EXPECT_EQ(kers.size(), 8UL); #endif @@ -927,7 +1006,7 @@ TEST(JITKernel_pool, more) { TEST(JITKernel_pool, refer) { const auto& kers = jit::ReferKernelPool::Instance().AllKernels(); - EXPECT_EQ(kers.size(), 29UL); + EXPECT_EQ(kers.size(), 31UL); } // test helper @@ -1298,3 +1377,6 @@ TEST_CPU_KERNEL(MatMul); TEST_CPU_KERNEL(Softmax); TEST_CPU_KERNEL(Sgd); TEST_CPU_KERNEL(VBroadcast); + +TEST_CPU_KERNEL(StrideASum); +TEST_CPU_KERNEL(StrideScal);