提交 09c5786e 编写于 作者: T tensor-tang

add square jitkernel

上级 4461a458
......@@ -254,6 +254,7 @@ int main(int argc, char* argv[]) {
// xyn
BenchXYNKernel<jit::kVRelu, T, PlaceType>();
BenchXYNKernel<jit::kVIdentity, T, PlaceType>();
BenchXYNKernel<jit::kVSquare, T, PlaceType>();
BenchXYNKernel<jit::kVExp, T, PlaceType>();
BenchXYNKernel<jit::kVSigmoid, T, PlaceType>();
BenchXYNKernel<jit::kVTanh, T, PlaceType>();
......
......@@ -36,6 +36,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kVRelu);
ONE_CASE(kVIdentity);
ONE_CASE(kVExp);
ONE_CASE(kVSquare);
ONE_CASE(kVSigmoid);
ONE_CASE(kVTanh);
ONE_CASE(kLSTMCtHt);
......
......@@ -30,6 +30,7 @@ typedef enum {
kVAddBias,
kVRelu,
kVIdentity,
kVSquare,
kVExp,
kVSigmoid,
kVTanh,
......
......@@ -8,6 +8,7 @@ USE_JITKERNEL_MORE(kVMul, mkl)
USE_JITKERNEL_MORE(kVAdd, mkl)
USE_JITKERNEL_MORE(kVScal, mkl)
USE_JITKERNEL_MORE(kVExp, mkl)
USE_JITKERNEL_MORE(kVSquare, mkl)
USE_JITKERNEL_MORE(kVSigmoid, mkl)
USE_JITKERNEL_MORE(kVTanh, mkl)
USE_JITKERNEL_MORE(kSeqPool, mkl)
......@@ -86,6 +86,16 @@ void VExp<double>(const double* x, double* y, int n) {
platform::dynload::vdExp(n, x, y);
}
template <>
void VSquare<float>(const float* x, float* y, int n) {
platform::dynload::vsSqr(n, x, y);
}
template <>
void VSquare<double>(const double* x, double* y, int n) {
platform::dynload::vdSqr(n, x, y);
}
template <>
void VCopy<float>(const float* x, float* y, int n) {
platform::dynload::cblas_scopy(n, x, 1, y, 1);
......@@ -132,6 +142,11 @@ bool VExpKernel<float>::UseMe(const int& d) const {
return d > 7;
}
template <>
bool VSquareKernel<float>::UseMe(const int& d) const {
return d > 7;
}
template <>
bool VSigmoidKernel<float>::UseMe(const int& d) const {
return d > 7;
......@@ -165,6 +180,7 @@ AWALYS_USE_ME_WITH_DOUBLE(VScal);
AWALYS_USE_ME_WITH_DOUBLE(VExp);
AWALYS_USE_ME_WITH_DOUBLE(VSigmoid);
AWALYS_USE_ME_WITH_DOUBLE(VTanh);
AWALYS_USE_ME_WITH_DOUBLE(VSquare);
#undef AWALYS_USE_ME_WITH_DOUBLE
} // namespace mkl
......@@ -184,6 +200,7 @@ REGISTER_MKL_KERNEL(kVMul, VMul);
REGISTER_MKL_KERNEL(kVAdd, VAdd);
REGISTER_MKL_KERNEL(kVScal, VScal);
REGISTER_MKL_KERNEL(kVExp, VExp);
REGISTER_MKL_KERNEL(kVSquare, VSquare);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool);
......
......@@ -39,6 +39,9 @@ void VScal(const T* a, const T* x, T* y, int n);
template <typename T>
void VExp(const T* x, T* y, int n);
template <typename T>
void VSquare(const T* x, T* y, int n);
template <typename T>
void VCopy(const T* x, T* y, int n);
......@@ -110,6 +113,7 @@ DECLARE_MKL_KERNEL(VScal, AXYNTuples);
DECLARE_MKL_KERNEL(VExp, XYNTuples);
DECLARE_MKL_KERNEL(VSigmoid, XYNTuples);
DECLARE_MKL_KERNEL(VTanh, XYNTuples);
DECLARE_MKL_KERNEL(VSquare, XYNTuples);
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples);
......
......@@ -28,3 +28,4 @@ USE_JITKERNEL_REFER(kLayerNorm)
USE_JITKERNEL_REFER(kNCHW16CMulNC)
USE_JITKERNEL_REFER(kSeqPool)
USE_JITKERNEL_REFER(kMatMul)
USE_JITKERNEL_REFER(kVSquare)
......@@ -31,6 +31,7 @@ REGISTER_REFER_KERNEL(kVAddBias, VAddBias);
REGISTER_REFER_KERNEL(kVRelu, VRelu);
REGISTER_REFER_KERNEL(kVIdentity, VIdentity);
REGISTER_REFER_KERNEL(kVSquare, VSquare);
REGISTER_REFER_KERNEL(kVExp, VExp);
REGISTER_REFER_KERNEL(kVSigmoid, VSigmoid);
REGISTER_REFER_KERNEL(kVTanh, VTanh);
......
......@@ -83,6 +83,13 @@ inline void VIdentity(const T* x, T* y, int n) {
}
}
template <typename T>
inline void VSquare(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] * x[i];
}
}
template <typename T>
void VExp(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
......@@ -394,6 +401,7 @@ DECLARE_REFER_KERNEL(VIdentity, XYNTuples);
DECLARE_REFER_KERNEL(VExp, XYNTuples);
DECLARE_REFER_KERNEL(VSigmoid, XYNTuples);
DECLARE_REFER_KERNEL(VTanh, XYNTuples);
DECLARE_REFER_KERNEL(VSquare, XYNTuples);
// lstm_t*, const lstm_attr_t*
DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples);
......
......@@ -604,6 +604,12 @@ TEST(JITKernel, kVIdentity) {
TestXYNKernel<jit::kVIdentity, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kVSquare) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::kVSquare, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVSquare, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kVExp) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::kVExp, float, paddle::platform::CPUPlace>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册