提交 09c5786e 编写于 作者: T tensor-tang

add square jitkernel

上级 4461a458
...@@ -254,6 +254,7 @@ int main(int argc, char* argv[]) { ...@@ -254,6 +254,7 @@ int main(int argc, char* argv[]) {
// xyn // xyn
BenchXYNKernel<jit::kVRelu, T, PlaceType>(); BenchXYNKernel<jit::kVRelu, T, PlaceType>();
BenchXYNKernel<jit::kVIdentity, T, PlaceType>(); BenchXYNKernel<jit::kVIdentity, T, PlaceType>();
BenchXYNKernel<jit::kVSquare, T, PlaceType>();
BenchXYNKernel<jit::kVExp, T, PlaceType>(); BenchXYNKernel<jit::kVExp, T, PlaceType>();
BenchXYNKernel<jit::kVSigmoid, T, PlaceType>(); BenchXYNKernel<jit::kVSigmoid, T, PlaceType>();
BenchXYNKernel<jit::kVTanh, T, PlaceType>(); BenchXYNKernel<jit::kVTanh, T, PlaceType>();
......
...@@ -36,6 +36,7 @@ const char* to_string(KernelType kt) { ...@@ -36,6 +36,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kVRelu); ONE_CASE(kVRelu);
ONE_CASE(kVIdentity); ONE_CASE(kVIdentity);
ONE_CASE(kVExp); ONE_CASE(kVExp);
ONE_CASE(kVSquare);
ONE_CASE(kVSigmoid); ONE_CASE(kVSigmoid);
ONE_CASE(kVTanh); ONE_CASE(kVTanh);
ONE_CASE(kLSTMCtHt); ONE_CASE(kLSTMCtHt);
......
...@@ -30,6 +30,7 @@ typedef enum { ...@@ -30,6 +30,7 @@ typedef enum {
kVAddBias, kVAddBias,
kVRelu, kVRelu,
kVIdentity, kVIdentity,
kVSquare,
kVExp, kVExp,
kVSigmoid, kVSigmoid,
kVTanh, kVTanh,
......
...@@ -8,6 +8,7 @@ USE_JITKERNEL_MORE(kVMul, mkl) ...@@ -8,6 +8,7 @@ USE_JITKERNEL_MORE(kVMul, mkl)
USE_JITKERNEL_MORE(kVAdd, mkl) USE_JITKERNEL_MORE(kVAdd, mkl)
USE_JITKERNEL_MORE(kVScal, mkl) USE_JITKERNEL_MORE(kVScal, mkl)
USE_JITKERNEL_MORE(kVExp, mkl) USE_JITKERNEL_MORE(kVExp, mkl)
USE_JITKERNEL_MORE(kVSquare, mkl)
USE_JITKERNEL_MORE(kVSigmoid, mkl) USE_JITKERNEL_MORE(kVSigmoid, mkl)
USE_JITKERNEL_MORE(kVTanh, mkl) USE_JITKERNEL_MORE(kVTanh, mkl)
USE_JITKERNEL_MORE(kSeqPool, mkl) USE_JITKERNEL_MORE(kSeqPool, mkl)
...@@ -86,6 +86,16 @@ void VExp<double>(const double* x, double* y, int n) { ...@@ -86,6 +86,16 @@ void VExp<double>(const double* x, double* y, int n) {
platform::dynload::vdExp(n, x, y); platform::dynload::vdExp(n, x, y);
} }
template <>
void VSquare<float>(const float* x, float* y, int n) {
platform::dynload::vsSqr(n, x, y);
}
template <>
void VSquare<double>(const double* x, double* y, int n) {
platform::dynload::vdSqr(n, x, y);
}
template <> template <>
void VCopy<float>(const float* x, float* y, int n) { void VCopy<float>(const float* x, float* y, int n) {
platform::dynload::cblas_scopy(n, x, 1, y, 1); platform::dynload::cblas_scopy(n, x, 1, y, 1);
...@@ -132,6 +142,11 @@ bool VExpKernel<float>::UseMe(const int& d) const { ...@@ -132,6 +142,11 @@ bool VExpKernel<float>::UseMe(const int& d) const {
return d > 7; return d > 7;
} }
template <>
bool VSquareKernel<float>::UseMe(const int& d) const {
return d > 7;
}
template <> template <>
bool VSigmoidKernel<float>::UseMe(const int& d) const { bool VSigmoidKernel<float>::UseMe(const int& d) const {
return d > 7; return d > 7;
...@@ -165,6 +180,7 @@ AWALYS_USE_ME_WITH_DOUBLE(VScal); ...@@ -165,6 +180,7 @@ AWALYS_USE_ME_WITH_DOUBLE(VScal);
AWALYS_USE_ME_WITH_DOUBLE(VExp); AWALYS_USE_ME_WITH_DOUBLE(VExp);
AWALYS_USE_ME_WITH_DOUBLE(VSigmoid); AWALYS_USE_ME_WITH_DOUBLE(VSigmoid);
AWALYS_USE_ME_WITH_DOUBLE(VTanh); AWALYS_USE_ME_WITH_DOUBLE(VTanh);
AWALYS_USE_ME_WITH_DOUBLE(VSquare);
#undef AWALYS_USE_ME_WITH_DOUBLE #undef AWALYS_USE_ME_WITH_DOUBLE
} // namespace mkl } // namespace mkl
...@@ -184,6 +200,7 @@ REGISTER_MKL_KERNEL(kVMul, VMul); ...@@ -184,6 +200,7 @@ REGISTER_MKL_KERNEL(kVMul, VMul);
REGISTER_MKL_KERNEL(kVAdd, VAdd); REGISTER_MKL_KERNEL(kVAdd, VAdd);
REGISTER_MKL_KERNEL(kVScal, VScal); REGISTER_MKL_KERNEL(kVScal, VScal);
REGISTER_MKL_KERNEL(kVExp, VExp); REGISTER_MKL_KERNEL(kVExp, VExp);
REGISTER_MKL_KERNEL(kVSquare, VSquare);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid); REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh); REGISTER_MKL_KERNEL(kVTanh, VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool); REGISTER_MKL_KERNEL(kSeqPool, SeqPool);
......
...@@ -39,6 +39,9 @@ void VScal(const T* a, const T* x, T* y, int n); ...@@ -39,6 +39,9 @@ void VScal(const T* a, const T* x, T* y, int n);
template <typename T> template <typename T>
void VExp(const T* x, T* y, int n); void VExp(const T* x, T* y, int n);
template <typename T>
void VSquare(const T* x, T* y, int n);
template <typename T> template <typename T>
void VCopy(const T* x, T* y, int n); void VCopy(const T* x, T* y, int n);
...@@ -110,6 +113,7 @@ DECLARE_MKL_KERNEL(VScal, AXYNTuples); ...@@ -110,6 +113,7 @@ DECLARE_MKL_KERNEL(VScal, AXYNTuples);
DECLARE_MKL_KERNEL(VExp, XYNTuples); DECLARE_MKL_KERNEL(VExp, XYNTuples);
DECLARE_MKL_KERNEL(VSigmoid, XYNTuples); DECLARE_MKL_KERNEL(VSigmoid, XYNTuples);
DECLARE_MKL_KERNEL(VTanh, XYNTuples); DECLARE_MKL_KERNEL(VTanh, XYNTuples);
DECLARE_MKL_KERNEL(VSquare, XYNTuples);
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples); DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples);
......
...@@ -28,3 +28,4 @@ USE_JITKERNEL_REFER(kLayerNorm) ...@@ -28,3 +28,4 @@ USE_JITKERNEL_REFER(kLayerNorm)
USE_JITKERNEL_REFER(kNCHW16CMulNC) USE_JITKERNEL_REFER(kNCHW16CMulNC)
USE_JITKERNEL_REFER(kSeqPool) USE_JITKERNEL_REFER(kSeqPool)
USE_JITKERNEL_REFER(kMatMul) USE_JITKERNEL_REFER(kMatMul)
USE_JITKERNEL_REFER(kVSquare)
...@@ -31,6 +31,7 @@ REGISTER_REFER_KERNEL(kVAddBias, VAddBias); ...@@ -31,6 +31,7 @@ REGISTER_REFER_KERNEL(kVAddBias, VAddBias);
REGISTER_REFER_KERNEL(kVRelu, VRelu); REGISTER_REFER_KERNEL(kVRelu, VRelu);
REGISTER_REFER_KERNEL(kVIdentity, VIdentity); REGISTER_REFER_KERNEL(kVIdentity, VIdentity);
REGISTER_REFER_KERNEL(kVSquare, VSquare);
REGISTER_REFER_KERNEL(kVExp, VExp); REGISTER_REFER_KERNEL(kVExp, VExp);
REGISTER_REFER_KERNEL(kVSigmoid, VSigmoid); REGISTER_REFER_KERNEL(kVSigmoid, VSigmoid);
REGISTER_REFER_KERNEL(kVTanh, VTanh); REGISTER_REFER_KERNEL(kVTanh, VTanh);
......
...@@ -83,6 +83,13 @@ inline void VIdentity(const T* x, T* y, int n) { ...@@ -83,6 +83,13 @@ inline void VIdentity(const T* x, T* y, int n) {
} }
} }
template <typename T>
inline void VSquare(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] * x[i];
}
}
template <typename T> template <typename T>
void VExp(const T* x, T* y, int n) { void VExp(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
...@@ -394,6 +401,7 @@ DECLARE_REFER_KERNEL(VIdentity, XYNTuples); ...@@ -394,6 +401,7 @@ DECLARE_REFER_KERNEL(VIdentity, XYNTuples);
DECLARE_REFER_KERNEL(VExp, XYNTuples); DECLARE_REFER_KERNEL(VExp, XYNTuples);
DECLARE_REFER_KERNEL(VSigmoid, XYNTuples); DECLARE_REFER_KERNEL(VSigmoid, XYNTuples);
DECLARE_REFER_KERNEL(VTanh, XYNTuples); DECLARE_REFER_KERNEL(VTanh, XYNTuples);
DECLARE_REFER_KERNEL(VSquare, XYNTuples);
// lstm_t*, const lstm_attr_t* // lstm_t*, const lstm_attr_t*
DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples); DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples);
......
...@@ -604,6 +604,12 @@ TEST(JITKernel, kVIdentity) { ...@@ -604,6 +604,12 @@ TEST(JITKernel, kVIdentity) {
TestXYNKernel<jit::kVIdentity, double, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVIdentity, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, kVSquare) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::kVSquare, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVSquare, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kVExp) { TEST(JITKernel, kVExp) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestXYNKernel<jit::kVExp, float, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVExp, float, paddle::platform::CPUPlace>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册