提交 51536f7f 编写于 作者: D dengkaipeng

StrideASum. test=develop

上级 93701dba
......@@ -56,7 +56,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kMatMul);
ONE_CASE(kHMax);
ONE_CASE(kHSum);
ONE_CASE(kStrideSum);
ONE_CASE(kStrideASum);
ONE_CASE(kSoftmax);
ONE_CASE(kEmbSeqPool);
ONE_CASE(kSgd);
......
......@@ -53,7 +53,7 @@ typedef enum {
kVSquare,
kVSub,
kVTanh,
kStrideSum,
kStrideASum,
kStrideScal,
} KernelType;
......@@ -132,7 +132,7 @@ DECLARE_KERNELTUPLE(XYNTuple, VCopy);
DECLARE_KERNELTUPLE(XRNTuple, HMax);
DECLARE_KERNELTUPLE(XRNTuple, HSum);
DECLARE_KERNELTUPLE(XRNSTuple, StrideSum);
DECLARE_KERNELTUPLE(XRNSTuple, StrideASum);
typedef struct {
void* gates; // gates: x_ch, x_ih, x_fh, x_oh
......
......@@ -54,7 +54,7 @@ void Softmax(const T* x, T* y, int n, int bs, int m) {
auto compute_hmax = KernelFuncs<HMaxTuple<T>, CPUPlace>::Cache().At(n);
auto compute_hsum = KernelFuncs<HSumTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vscal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
auto compute_stridesum = KernelFuncs<StrideSumTuple<T>, CPUPlace>::Cache().At(n);
auto compute_stridesum = KernelFuncs<StrideASumTuple<T>, CPUPlace>::Cache().At(n);
auto compute_stridescal = KernelFuncs<StrideScalTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vaddbias =
KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
......
......@@ -147,12 +147,12 @@ void ASum<double>(const double* x, double* res, int n) {
}
template <>
void StrideSum<float>(const float* x, float* res, int n, int stride) {
void StrideASum<float>(const float* x, float* res, int n, int stride) {
res[0] = platform::dynload::cblas_sasum(n, x, stride);
}
template <>
void StrideSum<double>(const double* x, double* res, int n, int stride) {
void StrideASum<double>(const double* x, double* res, int n, int stride) {
res[0] = platform::dynload::cblas_dasum(n, x, stride);
}
......@@ -174,7 +174,7 @@ bool VScalKernel<float>::CanBeUsed(const int& d) const {
template <>
bool StrideScalKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
return true;
}
template <>
......
......@@ -129,7 +129,7 @@ template <typename T>
void ASum(const T* x, T* res, int n);
template <typename T>
void StrideSum(const T* x, T* res, int n, int stride);
void StrideASum(const T* x, T* res, int n, int stride);
template <typename T>
void StrideScal(const T* a, const T* x, T* y, int n, int stride);
......@@ -155,7 +155,7 @@ void Softmax(const T* x, T* y, int n, int bs, int m=1) {
VScal(&sum, &y[i * n], &y[i * n], n);
} else {
for (int j = 0; j < m; ++j) {
StrideSum(&y[i * n + j], &sum, n/m, m);
StrideASum(&y[i * n + j], &sum, n/m, m);
sum = static_cast<T>(1) / sum;
StrideScal(&sum, &y[i * n + j], &y[i * n + j], n/m, m);
}
......
......@@ -33,7 +33,7 @@ USE_JITKERNEL_REFER(kMatMul)
USE_JITKERNEL_REFER(kVSquare)
USE_JITKERNEL_REFER(kHSum)
USE_JITKERNEL_REFER(kHMax)
USE_JITKERNEL_REFER(kStrideSum)
USE_JITKERNEL_REFER(kStrideASum)
USE_JITKERNEL_REFER(kSoftmax)
USE_JITKERNEL_REFER(kEmbSeqPool)
USE_JITKERNEL_REFER(kSgd)
......
......@@ -52,7 +52,7 @@ REGISTER_REFER_KERNEL(SeqPool);
REGISTER_REFER_KERNEL(MatMul);
REGISTER_REFER_KERNEL(HMax);
REGISTER_REFER_KERNEL(HSum);
REGISTER_REFER_KERNEL(StrideSum);
REGISTER_REFER_KERNEL(StrideASum);
REGISTER_REFER_KERNEL(Softmax);
REGISTER_REFER_KERNEL(EmbSeqPool);
REGISTER_REFER_KERNEL(Sgd);
......
......@@ -412,10 +412,10 @@ void HSum(const T* x, T* res, int n) {
}
template <typename T>
void StrideSum(const T* x, T* res, int n, int stride) {
void StrideASum(const T* x, T* res, int n, int stride) {
res[0] = x[0];
for (int i = stride; i < n; i+=stride) {
res[0] += x[i];
res[0] += std::abs(x[i]);
}
}
......@@ -442,7 +442,7 @@ void Softmax(const T* x, T* y, int n, int bs = 1, int m = 1) {
VScal(&scalar, y, y, n);
} else {
for (int j = 0; j < m; j++) {
StrideSum(&y[j], &scalar, n, m);
StrideASum(&y[j], &scalar, n, m);
scalar = static_cast<T>(1) / scalar;
StrideScal(&scalar, &y[j], &y[j], n, m);
}
......@@ -554,7 +554,7 @@ DECLARE_REFER_KERNEL(GRUHtPart2);
DECLARE_REFER_KERNEL(HMax);
DECLARE_REFER_KERNEL(HSum);
DECLARE_REFER_KERNEL(StrideSum);
DECLARE_REFER_KERNEL(StrideASum);
// others
DECLARE_REFER_KERNEL(CRFDecoding);
......
......@@ -727,6 +727,7 @@ void TestKernelSoftmax() {
if (m > n || n % m != 0) {
continue;
}
VLOG(10) << "Softmax: " << bs << ", " << n << ", " << m;
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(bs * n), y(bs * n);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册