提交 ae179269 编写于 作者: T tensor-tang

enable jitkernel mkl vmul, vadd and vscal

上级 77907a35
......@@ -45,6 +45,8 @@ PaddlePaddle/Paddle/paddle/fluid/
-`KernelType` 中添加 `your_key` .
- 实现Reference 的逻辑,每个jitkernel的Reference 实现是必须的。不要依赖任何第三方库。并在`refer/CmakeLists.txt``USE_JITKERNEL_REFER(your_key)`.
- (optional) 实现更多的算法在`more`目录下,可以依赖mkl,openblas,或者mkldnn等第三方库。
- (optional) 实现基于Xbyak的生成code,在`gen`目下。
- 必要时可以添加新的`KernelTuples`,可以参考`XYZNTuples`,新加的Attr类型需要特例化`JitCodeKey`方法。
- 添加unit test,需要测试float和double
- 添加benchmark确保get得到的速度是最快。
......@@ -4,3 +4,5 @@ set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE
# use mkl kernels by name and type
USE_JITKERNEL_MORE(vmul, mkl)
USE_JITKERNEL_MORE(vadd, mkl)
USE_JITKERNEL_MORE(vscal, mkl)
......@@ -13,7 +13,9 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/more/mkl/mkl.h"
#include "paddle/fluid/operators/jit/refer/refer.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/dynload/mklml.h"
namespace paddle {
......@@ -32,6 +34,61 @@ void VMul<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdMul(n, x, y, z);
}
template <>
void VAdd<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsAdd(n, x, y, z);
}
template <>
void VAdd<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdAdd(n, x, y, z);
}
template <>
void VScal<float>(const float* a, const float* x, float* y, int n) {
if (x == y) {
platform::dynload::cblas_sscal(n, *a, y, 1);
} else {
refer::VScal<float>(a, x, y, n);
}
}
template <>
void VScal<double>(const double* a, const double* x, double* y, int n) {
if (x == y) {
platform::dynload::cblas_dscal(n, *a, y, 1);
} else {
refer::VScal<double>(a, x, y, n);
}
}
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <>
bool VMulKernel<float>::UseMe(int d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool VAddKernel<float>::UseMe(int d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool VScalKernel<float>::UseMe(int d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \
bool func##Kernel<double>::UseMe(int d) const { \
return true; \
}
AWALYS_USE_ME_WITH_DOUBLE(VMul);
AWALYS_USE_ME_WITH_DOUBLE(VAdd);
AWALYS_USE_ME_WITH_DOUBLE(VScal);
#undef AWALYS_USE_ME_WITH_DOUBLE
} // namespace mkl
} // namespace more
} // namespace jit
......@@ -40,5 +97,12 @@ void VMul<double>(const double* x, const double* y, double* z, int n) {
namespace mkl = paddle::operators::jit::more::mkl;
REGISTER_JITKERNEL_MORE(vmul, mkl, mkl::VMulKernel<float>,
mkl::VMulKernel<double>);
#define REGISTER_MKL_KERNEL(key, func) \
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \
mkl::func##Kernel<double>)
REGISTER_MKL_KERNEL(vmul, VMul);
REGISTER_MKL_KERNEL(vadd, VAdd);
REGISTER_MKL_KERNEL(vscal, VScal);
#undef REGISTER_MKL_KERNEL
......@@ -16,7 +16,6 @@
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
......@@ -28,17 +27,27 @@ template <typename T>
void VMul(const T* x, const T* y, T* z, int n);
template <typename T>
class VMulKernel : public KernelImpl<XYZNTuples<T>> {
public:
VMulKernel() { this->func = VMul<T>; }
bool UseMe(int d) const override {
if (std::is_same<T, float>::value) {
return platform::MayIUse(platform::avx512f) && d > 512;
} else {
return true;
}
void VAdd(const T* x, const T* y, T* z, int n);
template <typename T>
void VScal(const T* a, const T* x, T* y, int n);
#define DECLARE_MKL_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public KernelImpl<tuples<T>> { \
public: \
name##Kernel() { this->func = name<T>; } \
bool UseMe(typename tuples<T>::attr_type) const override; \
}
};
// XYZN
DECLARE_MKL_KERNEL(VMul, XYZNTuples);
DECLARE_MKL_KERNEL(VAdd, XYZNTuples);
// AXYN
DECLARE_MKL_KERNEL(VScal, AXYNTuples);
#undef DECLARE_MKL_KERNEL
} // namespace mkl
} // namespace more
......
......@@ -31,56 +31,6 @@ namespace operators {
namespace math {
namespace jitkernel {
#ifdef PADDLE_WITH_MKLML
template <typename T>
void VMulMKL(const T* x, const T* y, T* z, int n);
template <>
void VMulMKL<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsMul(n, x, y, z);
}
template <>
void VMulMKL<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdMul(n, x, y, z);
}
template <typename T>
void VAddMKL(const T* x, const T* y, T* z, int n);
template <>
void VAddMKL<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsAdd(n, x, y, z);
}
template <>
void VAddMKL<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdAdd(n, x, y, z);
}
template <typename T>
void VScalMKL(const T* a, const T* x, T* y, int n);
template <>
void VScalMKL<float>(const float* a, const float* x, float* y, int n) {
if (x == y) {
platform::dynload::cblas_sscal(n, *a, y, 1);
} else {
refer::VScal<float>(a, x, y, n);
}
}
template <>
void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
if (x == y) {
platform::dynload::cblas_dscal(n, *a, y, 1);
} else {
refer::VScal<double>(a, x, y, n);
}
}
#endif
/* VMUL JitKernel */
template <typename T>
class VMulKernelImpl : public VMulKernel<T> {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册