提交 51536f7f 编写于 作者: D dengkaipeng

StrideASum. test=develop

上级 93701dba
...@@ -56,7 +56,7 @@ const char* to_string(KernelType kt) { ...@@ -56,7 +56,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kMatMul); ONE_CASE(kMatMul);
ONE_CASE(kHMax); ONE_CASE(kHMax);
ONE_CASE(kHSum); ONE_CASE(kHSum);
ONE_CASE(kStrideSum); ONE_CASE(kStrideASum);
ONE_CASE(kSoftmax); ONE_CASE(kSoftmax);
ONE_CASE(kEmbSeqPool); ONE_CASE(kEmbSeqPool);
ONE_CASE(kSgd); ONE_CASE(kSgd);
......
...@@ -53,7 +53,7 @@ typedef enum { ...@@ -53,7 +53,7 @@ typedef enum {
kVSquare, kVSquare,
kVSub, kVSub,
kVTanh, kVTanh,
kStrideSum, kStrideASum,
kStrideScal, kStrideScal,
} KernelType; } KernelType;
...@@ -132,7 +132,7 @@ DECLARE_KERNELTUPLE(XYNTuple, VCopy); ...@@ -132,7 +132,7 @@ DECLARE_KERNELTUPLE(XYNTuple, VCopy);
DECLARE_KERNELTUPLE(XRNTuple, HMax); DECLARE_KERNELTUPLE(XRNTuple, HMax);
DECLARE_KERNELTUPLE(XRNTuple, HSum); DECLARE_KERNELTUPLE(XRNTuple, HSum);
DECLARE_KERNELTUPLE(XRNSTuple, StrideSum); DECLARE_KERNELTUPLE(XRNSTuple, StrideASum);
typedef struct { typedef struct {
void* gates; // gates: x_ch, x_ih, x_fh, x_oh void* gates; // gates: x_ch, x_ih, x_fh, x_oh
......
...@@ -54,7 +54,7 @@ void Softmax(const T* x, T* y, int n, int bs, int m) { ...@@ -54,7 +54,7 @@ void Softmax(const T* x, T* y, int n, int bs, int m) {
auto compute_hmax = KernelFuncs<HMaxTuple<T>, CPUPlace>::Cache().At(n); auto compute_hmax = KernelFuncs<HMaxTuple<T>, CPUPlace>::Cache().At(n);
auto compute_hsum = KernelFuncs<HSumTuple<T>, CPUPlace>::Cache().At(n); auto compute_hsum = KernelFuncs<HSumTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vscal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n); auto compute_vscal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
auto compute_stridesum = KernelFuncs<StrideSumTuple<T>, CPUPlace>::Cache().At(n); auto compute_stridesum = KernelFuncs<StrideASumTuple<T>, CPUPlace>::Cache().At(n);
auto compute_stridescal = KernelFuncs<StrideScalTuple<T>, CPUPlace>::Cache().At(n); auto compute_stridescal = KernelFuncs<StrideScalTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vaddbias = auto compute_vaddbias =
KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n); KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
......
...@@ -147,12 +147,12 @@ void ASum<double>(const double* x, double* res, int n) { ...@@ -147,12 +147,12 @@ void ASum<double>(const double* x, double* res, int n) {
} }
template <> template <>
void StrideSum<float>(const float* x, float* res, int n, int stride) { void StrideASum<float>(const float* x, float* res, int n, int stride) {
res[0] = platform::dynload::cblas_sasum(n, x, stride); res[0] = platform::dynload::cblas_sasum(n, x, stride);
} }
template <> template <>
void StrideSum<double>(const double* x, double* res, int n, int stride) { void StrideASum<double>(const double* x, double* res, int n, int stride) {
res[0] = platform::dynload::cblas_dasum(n, x, stride); res[0] = platform::dynload::cblas_dasum(n, x, stride);
} }
...@@ -174,7 +174,7 @@ bool VScalKernel<float>::CanBeUsed(const int& d) const { ...@@ -174,7 +174,7 @@ bool VScalKernel<float>::CanBeUsed(const int& d) const {
template <> template <>
bool StrideScalKernel<float>::CanBeUsed(const int& d) const { bool StrideScalKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512; return true;
} }
template <> template <>
......
...@@ -129,7 +129,7 @@ template <typename T> ...@@ -129,7 +129,7 @@ template <typename T>
void ASum(const T* x, T* res, int n); void ASum(const T* x, T* res, int n);
template <typename T> template <typename T>
void StrideSum(const T* x, T* res, int n, int stride); void StrideASum(const T* x, T* res, int n, int stride);
template <typename T> template <typename T>
void StrideScal(const T* a, const T* x, T* y, int n, int stride); void StrideScal(const T* a, const T* x, T* y, int n, int stride);
...@@ -155,7 +155,7 @@ void Softmax(const T* x, T* y, int n, int bs, int m=1) { ...@@ -155,7 +155,7 @@ void Softmax(const T* x, T* y, int n, int bs, int m=1) {
VScal(&sum, &y[i * n], &y[i * n], n); VScal(&sum, &y[i * n], &y[i * n], n);
} else { } else {
for (int j = 0; j < m; ++j) { for (int j = 0; j < m; ++j) {
StrideSum(&y[i * n + j], &sum, n/m, m); StrideASum(&y[i * n + j], &sum, n/m, m);
sum = static_cast<T>(1) / sum; sum = static_cast<T>(1) / sum;
StrideScal(&sum, &y[i * n + j], &y[i * n + j], n/m, m); StrideScal(&sum, &y[i * n + j], &y[i * n + j], n/m, m);
} }
......
...@@ -33,7 +33,7 @@ USE_JITKERNEL_REFER(kMatMul) ...@@ -33,7 +33,7 @@ USE_JITKERNEL_REFER(kMatMul)
USE_JITKERNEL_REFER(kVSquare) USE_JITKERNEL_REFER(kVSquare)
USE_JITKERNEL_REFER(kHSum) USE_JITKERNEL_REFER(kHSum)
USE_JITKERNEL_REFER(kHMax) USE_JITKERNEL_REFER(kHMax)
USE_JITKERNEL_REFER(kStrideSum) USE_JITKERNEL_REFER(kStrideASum)
USE_JITKERNEL_REFER(kSoftmax) USE_JITKERNEL_REFER(kSoftmax)
USE_JITKERNEL_REFER(kEmbSeqPool) USE_JITKERNEL_REFER(kEmbSeqPool)
USE_JITKERNEL_REFER(kSgd) USE_JITKERNEL_REFER(kSgd)
......
...@@ -52,7 +52,7 @@ REGISTER_REFER_KERNEL(SeqPool); ...@@ -52,7 +52,7 @@ REGISTER_REFER_KERNEL(SeqPool);
REGISTER_REFER_KERNEL(MatMul); REGISTER_REFER_KERNEL(MatMul);
REGISTER_REFER_KERNEL(HMax); REGISTER_REFER_KERNEL(HMax);
REGISTER_REFER_KERNEL(HSum); REGISTER_REFER_KERNEL(HSum);
REGISTER_REFER_KERNEL(StrideSum); REGISTER_REFER_KERNEL(StrideASum);
REGISTER_REFER_KERNEL(Softmax); REGISTER_REFER_KERNEL(Softmax);
REGISTER_REFER_KERNEL(EmbSeqPool); REGISTER_REFER_KERNEL(EmbSeqPool);
REGISTER_REFER_KERNEL(Sgd); REGISTER_REFER_KERNEL(Sgd);
......
...@@ -412,10 +412,10 @@ void HSum(const T* x, T* res, int n) { ...@@ -412,10 +412,10 @@ void HSum(const T* x, T* res, int n) {
} }
template <typename T> template <typename T>
void StrideSum(const T* x, T* res, int n, int stride) { void StrideASum(const T* x, T* res, int n, int stride) {
res[0] = x[0]; res[0] = x[0];
for (int i = stride; i < n; i+=stride) { for (int i = stride; i < n; i+=stride) {
res[0] += x[i]; res[0] += std::abs(x[i]);
} }
} }
...@@ -442,7 +442,7 @@ void Softmax(const T* x, T* y, int n, int bs = 1, int m = 1) { ...@@ -442,7 +442,7 @@ void Softmax(const T* x, T* y, int n, int bs = 1, int m = 1) {
VScal(&scalar, y, y, n); VScal(&scalar, y, y, n);
} else { } else {
for (int j = 0; j < m; j++) { for (int j = 0; j < m; j++) {
StrideSum(&y[j], &scalar, n, m); StrideASum(&y[j], &scalar, n, m);
scalar = static_cast<T>(1) / scalar; scalar = static_cast<T>(1) / scalar;
StrideScal(&scalar, &y[j], &y[j], n, m); StrideScal(&scalar, &y[j], &y[j], n, m);
} }
...@@ -554,7 +554,7 @@ DECLARE_REFER_KERNEL(GRUHtPart2); ...@@ -554,7 +554,7 @@ DECLARE_REFER_KERNEL(GRUHtPart2);
DECLARE_REFER_KERNEL(HMax); DECLARE_REFER_KERNEL(HMax);
DECLARE_REFER_KERNEL(HSum); DECLARE_REFER_KERNEL(HSum);
DECLARE_REFER_KERNEL(StrideSum); DECLARE_REFER_KERNEL(StrideASum);
// others // others
DECLARE_REFER_KERNEL(CRFDecoding); DECLARE_REFER_KERNEL(CRFDecoding);
......
...@@ -727,6 +727,7 @@ void TestKernelSoftmax() { ...@@ -727,6 +727,7 @@ void TestKernelSoftmax() {
if (m > n || n % m != 0) { if (m > n || n % m != 0) {
continue; continue;
} }
VLOG(10) << "Softmax: " << bs << ", " << n << ", " << m;
auto ref = jit::GetReferFunc<KernelTuple>(); auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> x(bs * n), y(bs * n); std::vector<T> x(bs * n), y(bs * n);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册