未验证 提交 0600b370 编写于 作者: T tensor-tang 提交者: GitHub

[CPU] refine softmax op fwd on CPU (#17522)

* refine softmax fwd

test=develop

* fix compile issue wih gpu

test=develop

* add value clip to avoid exp
上级 3ee3611a
......@@ -54,7 +54,14 @@ inline void vec_scal(const int n, const T a, T* x) {
#ifdef PADDLE_WITH_MKLML
template <>
inline void vec_exp<float>(const int n, const float* x, float* y) {
platform::dynload::vsExp(n, x, y);
constexpr int small_enough = 128;
if (n < small_enough) {
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
} else {
platform::dynload::vsExp(n, x, y);
}
}
template <>
......@@ -128,6 +135,47 @@ inline void vec_scal<float, platform::avx512f>(const int n, const float a,
vec_scal<float, platform::avx2>(n, a, x, y);
}
template <typename T, platform::cpu_isa_t isa = platform::isa_any>
inline void vec_sum(const size_t n, const T* x, T* s) {
s[0] = x[0];
for (size_t i = 1; i < n; ++i) {
s[0] += x[i];
}
}
template <>
inline void vec_sum<float, platform::avx>(const size_t n, const float* x,
float* s) {
#ifdef __AVX__
constexpr unsigned int block = YMM_FLOAT_BLOCK;
if (n < block) {
vec_sum<float, platform::isa_any>(n, x, s);
return;
}
unsigned int i, end;
i = end = 0;
s[0] = 0.f;
end = n & ~(block - 1);
__m256 tmp = _mm256_setzero_ps();
for (i = 0; i < end; i += block) {
tmp = _mm256_add_ps(tmp, _mm256_load_ps(x + i));
}
__m256 hsum = _mm256_hadd_ps(tmp, tmp);
hsum = _mm256_add_ps(hsum, _mm256_permute2f128_ps(hsum, hsum, 0x1));
_mm_store_ss(s, _mm_hadd_ps(_mm256_castps256_ps128(hsum),
_mm256_castps256_ps128(hsum)));
for (; i < n; i++) {
s[0] += x[i];
}
#else
vec_sum<float, platform::isa_any>(n, x, s);
#endif
}
template <typename T, platform::cpu_isa_t isa = platform::isa_any>
inline void vec_bias_sub(const int n, const T a, const T* x, T* y) {
for (int i = 0; i < n; ++i) {
......@@ -242,6 +290,39 @@ inline void vec_cross<float, platform::avx512f>(const int n, const float* x,
vec_cross<float, platform::avx>(n, x, y, z, out);
}
template <typename T, platform::cpu_isa_t isa = platform::isa_any>
inline void vec_clip(const size_t n, const T a, const T* x, T* y) {
for (size_t i = 0; i < n; ++i) {
y[i] = x[i] < a ? a : x[i];
}
}
template <>
inline void vec_clip<float, platform::avx>(const size_t n, const float a,
const float* x, float* y) {
#ifdef __AVX__
constexpr unsigned int block = YMM_FLOAT_BLOCK;
if (n < block) {
vec_clip<float, platform::isa_any>(n, a, x, y);
return;
}
unsigned int i = 0, end = 0;
end = n & ~(block - 1);
__m256 threshold = _mm256_set1_ps(a);
for (i = 0; i < end; i += block) {
_mm256_storeu_ps(y + i, _mm256_max_ps(_mm256_loadu_ps(x + i), threshold));
}
for (; i < n; i++) {
y[i] = x[i] < a ? a : x[i];
}
#else
vec_clip<float, platform::isa_any>(n, a, x, y);
#endif
}
template <typename T, platform::cpu_isa_t isa = platform::isa_any>
inline void vec_add_bias(const int n, const T a, const T* x, T* y) {
for (int i = 0; i < n; ++i) {
......
......@@ -65,12 +65,11 @@ void ref_relu(const int n, const T* x, T* y) {
}
template <typename T>
void RandomVec(const int n, T* a) {
void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
const T upper = static_cast<T>(20.f)) {
static unsigned int seed = 100;
std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1);
const T lower = static_cast<T>(-20.f);
const T upper = static_cast<T>(20.f);
for (int i = 0; i < n; ++i) {
a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower);
}
......@@ -144,6 +143,62 @@ TEST(CpuVecTest, relu) {
TestAndBench<double>(30, vec_relu<double>, ref_relu<double>);
}
template <typename T>
void compare_sum(size_t n, std::function<void(const size_t, const T*, T*)> tgt,
std::function<void(const size_t, const T*, T*)> ref) {
std::vector<T> x(n);
T ytgt_data, yref_data;
RandomVec<T>(n, x.data(), static_cast<T>(-2), static_cast<T>(2));
const T* x_data = x.data();
tgt(n, x_data, &ytgt_data);
ref(n, x_data, &yref_data);
EXPECT_NEAR(ytgt_data, yref_data, 1e-3);
}
TEST(CpuVecTest, vec_sum) {
namespace platform = paddle::platform;
using namespace paddle::operators::math; // NOLINT
for (size_t sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
compare_sum<float>(sz, vec_sum<float>, vec_sum<float, platform::isa_any>);
compare_sum<float>(sz, vec_sum<float, platform::avx>,
vec_sum<float, platform::isa_any>);
}
compare_sum<double>(30U, vec_sum<double>, vec_sum<double, platform::isa_any>);
}
template <typename T>
void compare_clip(
size_t n, T threshold,
std::function<void(const size_t, const T, const T*, T*)> tgt,
std::function<void(const size_t, const T, const T*, T*)> ref) {
std::vector<T> x(n);
std::vector<T> ytgt(n), yref(n);
RandomVec<T>(n, x.data(), static_cast<T>(-2), static_cast<T>(2));
const T* x_data = x.data();
T* yref_data = yref.data();
T* ytgt_data = ytgt.data();
tgt(n, threshold, x_data, ytgt_data);
ref(n, threshold, x_data, yref_data);
for (int i = 0; i < n; ++i) {
EXPECT_NEAR(ytgt_data[i], yref_data[i], 1e-3);
}
}
TEST(CpuVecTest, vec_clip) {
namespace platform = paddle::platform;
using namespace paddle::operators::math; // NOLINT
for (size_t sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
compare_clip<float>(sz, -4.f, vec_clip<float>,
vec_clip<float, platform::isa_any>);
compare_clip<float>(sz, -1.1f, vec_clip<float, platform::avx>,
vec_clip<float, platform::isa_any>);
}
compare_clip<double>(30U, 1.0, vec_clip<double>,
vec_clip<double, platform::isa_any>);
}
template <typename T>
void TestInplace(const int n, std::function<void(const int, const T*, T*)> tgt,
std::function<void(const int, const T*, T*)> ref) {
......
......@@ -17,6 +17,8 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
......@@ -34,16 +36,15 @@ struct ValueClip {
}
};
template <typename DeviceContext, typename T, bool is_test, typename Enable>
void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
const DeviceContext& context, const int axis_dim,
const framework::Tensor* X, framework::Tensor* Y) {
template <typename DeviceContext, typename T, bool is_test>
void SoftmaxEigen(const DeviceContext& context, const int axis_dim,
const framework::Tensor* X, framework::Tensor* Y) {
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y);
const int kBatchDim = 0;
const int kClassDim = 1;
const int batch_size = logits.dimension(kBatchDim);
const int num_classes = logits.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
......@@ -70,12 +71,58 @@ void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
.broadcast(one_axis));
}
template <typename DeviceContext, typename T, bool is_test, typename Enable>
void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
const DeviceContext& context, const int axis_dim,
const framework::Tensor* X, framework::Tensor* Y) {
SoftmaxEigen<DeviceContext, T, is_test>(context, axis_dim, X, Y);
}
template <class DeviceContext>
using enable_if_CPU = typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type;
template <typename DeviceContext, typename T, bool is_test>
class SoftmaxFunctor<DeviceContext, T, is_test, enable_if_CPU<DeviceContext>> {
public:
void operator()(const DeviceContext& context, const int axis_dim,
const framework::Tensor* X, framework::Tensor* Y) {
auto in_dims = X->dims();
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
const int num_classes = in_dims[kClassDim];
const int batch_size = in_dims[kBatchDim];
const int num_remain = num_classes / axis_dim;
if (num_remain == 1 && platform::MayIUse(platform::avx)) {
const T* in_data = X->data<T>();
T* out_data = Y->data<T>();
for (int bs = 0; bs < batch_size; ++bs) {
T max_val = *std::max_element(in_data, in_data + num_classes);
max_val *= static_cast<T>(-1);
vec_add_bias<T, platform::avx>(num_classes, max_val, in_data, out_data);
vec_clip<T, platform::avx>(num_classes, static_cast<T>(-64), out_data,
out_data);
vec_exp<T>(num_classes, out_data, out_data);
T sum = 0;
vec_sum<T, platform::avx>(num_classes, out_data, &sum);
sum = static_cast<T>(1) / sum;
vec_scal<T, platform::avx>(num_classes, sum, out_data, out_data);
in_data += num_classes;
out_data += num_classes;
}
} else {
SoftmaxEigen<DeviceContext, T, is_test>(context, axis_dim, X, Y);
}
}
};
template <typename DeviceContext>
class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
public:
void operator()(const DeviceContext& context, const int axis_dim,
const framework::Tensor* X, framework::Tensor* Y) {
auto in_dims = X->dims();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册