提交 93701dba 编写于 作者: D dengkaipeng

add jit kernel for softmax axis. test=develop

上级 6c641827
......@@ -386,7 +386,7 @@ void BenchKernelSoftmax() {
RandomVec<T>(bs * n, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* x_data = x.data<T>();
T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KernelTuple, PlaceType>(n, x_data, y_data, n, bs);
BenchAllImpls<KernelTuple, PlaceType>(n, x_data, y_data, n, bs, 1);
}
}
}
......
......@@ -34,6 +34,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kVAddRelu);
ONE_CASE(kVSub);
ONE_CASE(kVScal);
ONE_CASE(kStrideScal);
ONE_CASE(kVAddBias);
ONE_CASE(kVRelu);
ONE_CASE(kVBroadcast);
......@@ -55,6 +56,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kMatMul);
ONE_CASE(kHMax);
ONE_CASE(kHSum);
ONE_CASE(kStrideSum);
ONE_CASE(kSoftmax);
ONE_CASE(kEmbSeqPool);
ONE_CASE(kSgd);
......
......@@ -53,6 +53,8 @@ typedef enum {
kVSquare,
kVSub,
kVTanh,
kStrideSum,
kStrideScal,
} KernelType;
typedef enum {
......@@ -74,6 +76,14 @@ struct XYZNTuple {
template <typename T>
struct AXYNTuple : public XYZNTuple<T> {};
// a, x, y, n, stride
template <typename T>
struct AXYNSTuple {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int, int);
};
// x, y, n
template <typename T>
struct XYNTuple {
......@@ -86,6 +96,14 @@ struct XYNTuple {
template <typename T>
struct XRNTuple : public XYNTuple<T> {};
// x, returned value, n, stride
template <typename T>
struct XRNSTuple {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, T*, int, int);
};
#define DECLARE_KERNELTUPLE(kernel_tuple, type) \
template <typename T> \
struct type##Tuple : public kernel_tuple<T> { \
......@@ -101,6 +119,8 @@ DECLARE_KERNELTUPLE(XYZNTuple, VSub);
DECLARE_KERNELTUPLE(AXYNTuple, VScal);
DECLARE_KERNELTUPLE(AXYNTuple, VAddBias);
DECLARE_KERNELTUPLE(AXYNSTuple, StrideScal);
DECLARE_KERNELTUPLE(XYNTuple, VRelu);
DECLARE_KERNELTUPLE(XYNTuple, VIdentity);
DECLARE_KERNELTUPLE(XYNTuple, VSquare);
......@@ -112,6 +132,8 @@ DECLARE_KERNELTUPLE(XYNTuple, VCopy);
DECLARE_KERNELTUPLE(XRNTuple, HMax);
DECLARE_KERNELTUPLE(XRNTuple, HSum);
DECLARE_KERNELTUPLE(XRNSTuple, StrideSum);
typedef struct {
void* gates; // gates: x_ch, x_ih, x_fh, x_oh
const void* ct_1;
......@@ -285,7 +307,7 @@ struct SoftmaxTuple {
static constexpr KernelType kernel_type = kSoftmax;
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, T*, int, int);
typedef void (*func_type)(const T*, T*, int, int, int);
};
// nChw16c = nChw16c .* NC
......
......@@ -50,10 +50,12 @@ void VTanh(const T* x, T* y, int n) {
compute_addbias(&b, y, y, n);
}
void Softmax(const T* x, T* y, int n, int bs) {
void Softmax(const T* x, T* y, int n, int bs, int m) {
auto compute_hmax = KernelFuncs<HMaxTuple<T>, CPUPlace>::Cache().At(n);
auto compute_hsum = KernelFuncs<HSumTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vscal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
auto compute_stridesum = KernelFuncs<StrideSumTuple<T>, CPUPlace>::Cache().At(n);
auto compute_stridescal = KernelFuncs<StrideScalTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vaddbias =
KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vexp = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
......@@ -64,9 +66,17 @@ void Softmax(const T* x, T* y, int n, int bs) {
scalar = static_cast<T>(0) - scalar;
compute_vaddbias(&scalar, x, y, n); // x - max
compute_vexp(y, y, n);
compute_hsum(y, &scalar, n);
scalar = static_cast<T>(1) / scalar;
compute_vscal(&scalar, y, y, n);
if (m == 1) {
compute_hsum(y, &scalar, n);
scalar = static_cast<T>(1) / scalar;
compute_vscal(&scalar, y, y, n);
} else {
for (int j = 0; j < m; ++j) {
compute_stridesum(&y[j], &scalar, n, m);
scalar = static_cast<T>(1) / scalar;
compute_stridescal(&scalar, &y[j], &y[j], n, m);
}
}
x += n;
y += n;
}
......
......@@ -26,7 +26,7 @@ using T = float;
void VSigmoid(const T* x, T* y, int n);
void VTanh(const T* x, T* y, int n);
void Softmax(const T* x, T* y, int n, int bs);
void Softmax(const T* x, T* y, int n, int bs, int m);
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr);
void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr);
......
......@@ -7,6 +7,7 @@ USE_JITKERNEL_MORE(kMatMul, mkl)
USE_JITKERNEL_MORE(kVMul, mkl)
USE_JITKERNEL_MORE(kVAdd, mkl)
USE_JITKERNEL_MORE(kVScal, mkl)
USE_JITKERNEL_MORE(kStrideScal, mkl)
USE_JITKERNEL_MORE(kVExp, mkl)
USE_JITKERNEL_MORE(kVSquare, mkl)
USE_JITKERNEL_MORE(kVCopy, mkl)
......
......@@ -78,6 +78,24 @@ void VScal<double>(const double* a, const double* x, double* y, int n) {
}
}
template <>
void StrideScal<float>(const float* a, const float* x, float* y, int n, int stride) {
if (x == y) {
platform::dynload::cblas_sscal(n, *a, y, stride);
} else {
refer::StrideScal<float>(a, x, y, n, stride);
}
}
template <>
void StrideScal<double>(const double* a, const double* x, double* y, int n, int stride) {
if (x == y) {
platform::dynload::cblas_dscal(n, *a, y, stride);
} else {
refer::StrideScal<double>(a, x, y, n, stride);
}
}
template <>
void VExp<float>(const float* x, float* y, int n) {
platform::dynload::vsExp(n, x, y);
......@@ -128,6 +146,16 @@ void ASum<double>(const double* x, double* res, int n) {
res[0] = platform::dynload::cblas_dasum(n, x, 1);
}
template <>
void StrideSum<float>(const float* x, float* res, int n, int stride) {
res[0] = platform::dynload::cblas_sasum(n, x, stride);
}
template <>
void StrideSum<double>(const double* x, double* res, int n, int stride) {
res[0] = platform::dynload::cblas_dasum(n, x, stride);
}
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <>
bool VMulKernel<float>::CanBeUsed(const int& d) const {
......@@ -144,6 +172,11 @@ bool VScalKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool StrideScalKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool VExpKernel<float>::CanBeUsed(const int& d) const {
return d > 7;
......@@ -235,6 +268,7 @@ bool SoftmaxKernel<float>::CanBeUsed(const int& d) const {
AWALYS_USE_ME_WITH_DOUBLE(VMul);
AWALYS_USE_ME_WITH_DOUBLE(VAdd);
AWALYS_USE_ME_WITH_DOUBLE(VScal);
AWALYS_USE_ME_WITH_DOUBLE(StrideScal);
AWALYS_USE_ME_WITH_DOUBLE(VExp);
AWALYS_USE_ME_WITH_DOUBLE(VSigmoid);
AWALYS_USE_ME_WITH_DOUBLE(VTanh);
......@@ -259,6 +293,7 @@ REGISTER_MKL_KERNEL(MatMul);
REGISTER_MKL_KERNEL(VMul);
REGISTER_MKL_KERNEL(VAdd);
REGISTER_MKL_KERNEL(VScal);
REGISTER_MKL_KERNEL(StrideScal);
REGISTER_MKL_KERNEL(VExp);
REGISTER_MKL_KERNEL(VSquare);
REGISTER_MKL_KERNEL(VCopy);
......
......@@ -129,7 +129,13 @@ template <typename T>
void ASum(const T* x, T* res, int n);
template <typename T>
void Softmax(const T* x, T* y, int n, int bs) {
void StrideSum(const T* x, T* res, int n, int stride);
template <typename T>
void StrideScal(const T* a, const T* x, T* y, int n, int stride);
template <typename T>
void Softmax(const T* x, T* y, int n, int bs, int m=1) {
std::vector<T> entities(bs);
for (int i = 0; i < bs; ++i) {
entities[i] = x[i * n];
......@@ -143,9 +149,17 @@ void Softmax(const T* x, T* y, int n, int bs) {
VExp(y, y, n * bs);
for (int i = 0; i < bs; ++i) {
T sum;
ASum(&y[i * n], &sum, n);
sum = static_cast<T>(1) / sum;
VScal(&sum, &y[i * n], &y[i * n], n);
if (m == 1) {
ASum(&y[i * n], &sum, n);
sum = static_cast<T>(1) / sum;
VScal(&sum, &y[i * n], &y[i * n], n);
} else {
for (int j = 0; j < m; ++j) {
StrideSum(&y[i * n + j], &sum, n/m, m);
sum = static_cast<T>(1) / sum;
StrideScal(&sum, &y[i * n + j], &y[i * n + j], n/m, m);
}
}
}
}
......@@ -193,6 +207,7 @@ DECLARE_MKL_KERNEL(VAdd);
// AXYN
DECLARE_MKL_KERNEL(VScal);
DECLARE_MKL_KERNEL(StrideScal);
// XYN
DECLARE_MKL_KERNEL(VExp);
......
......@@ -12,6 +12,7 @@ USE_JITKERNEL_REFER(kVAdd)
USE_JITKERNEL_REFER(kVAddRelu)
USE_JITKERNEL_REFER(kVSub)
USE_JITKERNEL_REFER(kVScal)
USE_JITKERNEL_REFER(kStrideScal)
USE_JITKERNEL_REFER(kVAddBias)
USE_JITKERNEL_REFER(kVCopy)
USE_JITKERNEL_REFER(kVRelu)
......@@ -32,6 +33,7 @@ USE_JITKERNEL_REFER(kMatMul)
USE_JITKERNEL_REFER(kVSquare)
USE_JITKERNEL_REFER(kHSum)
USE_JITKERNEL_REFER(kHMax)
USE_JITKERNEL_REFER(kStrideSum)
USE_JITKERNEL_REFER(kSoftmax)
USE_JITKERNEL_REFER(kEmbSeqPool)
USE_JITKERNEL_REFER(kSgd)
......
......@@ -27,6 +27,7 @@ REGISTER_REFER_KERNEL(VAddRelu);
REGISTER_REFER_KERNEL(VSub);
REGISTER_REFER_KERNEL(VScal);
REGISTER_REFER_KERNEL(StrideScal);
REGISTER_REFER_KERNEL(VAddBias);
REGISTER_REFER_KERNEL(VRelu);
......@@ -51,6 +52,7 @@ REGISTER_REFER_KERNEL(SeqPool);
REGISTER_REFER_KERNEL(MatMul);
REGISTER_REFER_KERNEL(HMax);
REGISTER_REFER_KERNEL(HSum);
REGISTER_REFER_KERNEL(StrideSum);
REGISTER_REFER_KERNEL(Softmax);
REGISTER_REFER_KERNEL(EmbSeqPool);
REGISTER_REFER_KERNEL(Sgd);
......
......@@ -411,19 +411,42 @@ void HSum(const T* x, T* res, int n) {
}
}
template <typename T>
void StrideSum(const T* x, T* res, int n, int stride) {
res[0] = x[0];
for (int i = stride; i < n; i+=stride) {
res[0] += x[i];
}
}
template <typename T>
void StrideScal(const T* a, const T* x, T* y, int n , int stride) {
for (int i = 0; i < n; i+=stride) {
y[i] = x[i] * a[0];
}
}
// y = e^(x - max(x))
// y = y / sum(y)
template <typename T>
void Softmax(const T* x, T* y, int n, int bs = 1) {
void Softmax(const T* x, T* y, int n, int bs = 1, int m = 1) {
for (int i = 0; i < bs; ++i) {
T scalar;
HMax(x, &scalar, n);
scalar = static_cast<T>(0) - scalar;
VAddBias(&scalar, x, y, n); // x - max
VExp(y, y, n);
HSum(y, &scalar, n);
scalar = static_cast<T>(1) / scalar;
VScal(&scalar, y, y, n);
if (m == 1) {
HSum(y, &scalar, n);
scalar = static_cast<T>(1) / scalar;
VScal(&scalar, y, y, n);
} else {
for (int j = 0; j < m; j++) {
StrideSum(&y[j], &scalar, n, m);
scalar = static_cast<T>(1) / scalar;
StrideScal(&scalar, &y[j], &y[j], n, m);
}
}
x += n;
y += n;
}
......@@ -507,6 +530,9 @@ DECLARE_REFER_KERNEL(VSub);
DECLARE_REFER_KERNEL(VScal);
DECLARE_REFER_KERNEL(VAddBias);
// const T* a, const T* x, T* y, int n, int stride
DECLARE_REFER_KERNEL(StrideScal);
// const T* x, T* y, int n
DECLARE_REFER_KERNEL(VRelu);
DECLARE_REFER_KERNEL(VIdentity);
......@@ -528,6 +554,8 @@ DECLARE_REFER_KERNEL(GRUHtPart2);
DECLARE_REFER_KERNEL(HMax);
DECLARE_REFER_KERNEL(HSum);
DECLARE_REFER_KERNEL(StrideSum);
// others
DECLARE_REFER_KERNEL(CRFDecoding);
DECLARE_REFER_KERNEL(LayerNorm);
......
......@@ -723,39 +723,44 @@ void TestKernelSoftmax() {
VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
for (int bs : {1, 2, 10}) {
for (int n : TestSizes()) {
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(bs * n), y(bs * n);
RandomVec<T>(bs * n, x.data());
const T* x_data = x.data();
T* y_data = y.data();
for (int m : {1, 2}) {
if (m > n || n % m != 0) {
continue;
}
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(bs * n), y(bs * n);
RandomVec<T>(bs * n, x.data());
const T* x_data = x.data();
T* y_data = y.data();
std::vector<T> xinp(x.size()); // inplace test
std::copy(x.begin(), x.end(), xinp.begin());
ref(x_data, y_data, n, bs);
T* xinp_data = xinp.data();
ref(xinp_data, xinp_data, n, bs);
ExpectEQ<T>(xinp_data, y_data, n * bs);
std::vector<T> xinp(x.size()); // inplace test
std::copy(x.begin(), x.end(), xinp.begin());
ref(x_data, y_data, n, bs, m);
T* xinp_data = xinp.data();
ref(xinp_data, xinp_data, n, bs, m);
ExpectEQ<T>(xinp_data, y_data, n * bs);
auto verifier = [](const typename KernelTuple::func_type tgt,
const std::vector<T>& x, const std::vector<T>& yref,
int n, int bs) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(yref.size(), x.size());
EXPECT_EQ(x.size(), static_cast<size_t>(n * bs));
const T* x_data = x.data();
const T* yref_data = yref.data();
std::vector<T> ytgt(n * bs);
T* ytgt_data = ytgt.data();
// test normal
tgt(x_data, ytgt_data, n, bs);
ExpectEQ<T>(ytgt_data, yref_data, n * bs);
// test inplace x
std::copy(x.begin(), x.end(), ytgt.begin());
tgt(ytgt_data, ytgt_data, n, bs);
ExpectEQ<T>(ytgt_data, yref_data, n * bs);
};
TestAllImpls<KernelTuple, PlaceType>(n, verifier, x, y, n, bs);
auto verifier = [](const typename KernelTuple::func_type tgt,
const std::vector<T>& x, const std::vector<T>& yref,
int n, int bs, int m) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(yref.size(), x.size());
EXPECT_EQ(x.size(), static_cast<size_t>(n * bs));
const T* x_data = x.data();
const T* yref_data = yref.data();
std::vector<T> ytgt(n * bs);
T* ytgt_data = ytgt.data();
// test normal
tgt(x_data, ytgt_data, n, bs, m);
ExpectEQ<T>(ytgt_data, yref_data, n * bs);
// test inplace x
std::copy(x.begin(), x.end(), ytgt.begin());
tgt(ytgt_data, ytgt_data, n, bs, m);
ExpectEQ<T>(ytgt_data, yref_data, n * bs);
};
TestAllImpls<KernelTuple, PlaceType>(n, verifier, x, y, n, bs, m);
}
}
}
}
......
......@@ -76,8 +76,8 @@ using enable_if_CPU = typename std::enable_if<
template <typename DeviceContext>
class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
void operator()(const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y) {
void operator()(const DeviceContext& context, const int axis_dim,
const framework::Tensor* X, framework::Tensor* Y) {
auto in_dims = X->dims();
const float* in_data = X->data<float>();
float* out_data = Y->data<float>();
......@@ -87,7 +87,8 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
auto compute_softmax =
jit::KernelFuncs<jit::SoftmaxTuple<float>, platform::CPUPlace>::Cache()
.At(in_dims[kClassDim]);
compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]);
compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim],
in_dims[kClassDim] / axis_dim);
}
};
......
......@@ -42,9 +42,18 @@ class SoftmaxOp : public framework::OperatorWithKernel {
auto dim_x = ctx->GetInputDim("X");
auto rank_x = dim_x.size();
auto axis = ctx->Attrs().Get<int>("axis");
PADDLE_ENFORCE(axis >= -1 && axis < rank_x,
"Attr(axis) value should larger equal then -1"
"and less then the rank of Input(X)");
PADDLE_ENFORCE(axis >= -rank_x && axis < rank_x,
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X).");
auto use_cudnn = ctx->Attrs().Get<bool>("use_cudnn");
auto use_mkldnn = ctx->Attrs().Get<bool>("use_mkldnn");
if (axis != rank_x - 1 && axis != -1) {
PADDLE_ENFORCE(!use_cudnn,
"CUDNN kernel only support axis as -1.");
PADDLE_ENFORCE(!use_mkldnn,
"MKLDNN kernel only support axis as -1.");
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
......
......@@ -63,8 +63,6 @@ class SoftmaxKernel : public framework::OpKernel<T> {
Tensor X_2d, Out_2d;
X_2d.ShareDataWith(*X).Resize({n, d});
Out_2d.ShareDataWith(*Out).Resize({n, d});
// Tensor X_2d = framework::ReshapeToMatrix(*X, axis - 1);
// Tensor Out_2d = framework::ReshapeToMatrix(*Out, axis - 1);
#ifdef PADDLE_ON_INFERENCE
math::SoftmaxFunctor<DeviceContext, T, true>()(
......@@ -96,9 +94,6 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
dX_2d.ShareDataWith(*dX).Resize({n, d});
Out_2d.ShareDataWith(*Out).Resize({n, d});
dOut_2d.ShareDataWith(*dOut).Resize({n, d});
// Tensor Out_2d = framework::ReshapeToMatrix(*Out, axis - 1);
// Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, axis - 1);
// Tensor dX_2d = framework::ReshapeToMatrix(*dX, axis - 1);
math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), axis_dim, &Out_2d, &dOut_2d,
......
......@@ -125,26 +125,6 @@ class TestSoftmaxCUDNNOp2(TestSoftmaxCUDNNOp):
return [2, 3, 4, 5]
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxCUDNNOp3(TestSoftmaxCUDNNOp):
def get_x_shape(self):
return [2, 3, 4, 5]
def get_axis(self):
return 0
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxCUDNNOp4(TestSoftmaxCUDNNOp):
def get_x_shape(self):
return [2, 3, 4, 5]
def get_axis(self):
return 1
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxCUDNNOp5(TestSoftmaxCUDNNOp):
......@@ -152,7 +132,7 @@ class TestSoftmaxCUDNNOp5(TestSoftmaxCUDNNOp):
return [2, 3, 4, 5]
def get_axis(self):
return 2
return 3
@unittest.skipIf(not core.is_compiled_with_cuda(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册