提交 084893a9 编写于 作者: T tensor-tang

add vadd kernel

上级 eeff268a
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <iostream>
#include <string>
namespace paddle {
......@@ -27,29 +28,35 @@ KernelPool& KernelPool::Instance() {
return g_jit_kernels;
}
template <>
const std::shared_ptr<VMulKernel<float>> KernelPool::Get<VMulKernel<float>>(
int d) {
std::string key = "f" + std::to_string(d);
const std::shared_ptr<Kernel> KernelPool::Get(const std::string& key) const {
if (kers_.find(key) == kers_.end()) {
auto p = std::make_shared<VMulKernel<float>>(d);
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)});
return p;
return nullptr;
}
return std::dynamic_pointer_cast<VMulKernel<float>>(kers_.at(key));
return kers_.at(key);
}
template <>
const std::shared_ptr<VMulKernel<double>> KernelPool::Get<VMulKernel<double>>(
int d) {
std::string key = "d" + std::to_string(d);
if (kers_.find(key) == kers_.end()) {
auto p = std::make_shared<VMulKernel<double>>(d);
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)});
return p;
#define DEFINE_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key) \
template <> \
const std::shared_ptr<ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>>(int d) { \
std::string key = #ker_key #dtype_key + std::to_string(d); \
if (kers_.find(key) == kers_.end()) { \
auto p = std::make_shared<ker_class<ker_dtype>>(d); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<ker_class<ker_dtype>>(kers_.at(key)); \
}
return std::dynamic_pointer_cast<VMulKernel<double>>(kers_.at(key));
}
#define REGISTER_BLAS_JITKERNEL(ker_key, ker_class) \
DEFINE_WITH_DTYPE(ker_key, ker_class, float, f); \
DEFINE_WITH_DTYPE(ker_key, ker_class, double, d)
REGISTER_BLAS_JITKERNEL(vmul, VMulKernel);
REGISTER_BLAS_JITKERNEL(vadd, VAddKernel);
#undef REGISTER_BLAS_JITKERNEL
#undef DEFINE_WITH_DTYPE
template <>
const std::shared_ptr<LSTMKernel<float>>
......@@ -57,7 +64,8 @@ KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&,
const std::string&>(int d, const std::string& act_gate,
const std::string& act_cand,
const std::string& act_cell) {
std::string key = "f" + std::to_string(d) + act_gate + act_cand + act_cell;
std::string key =
"lstmf" + std::to_string(d) + act_gate + act_cand + act_cell;
if (kers_.find(key) == kers_.end()) {
auto p =
std::make_shared<LSTMKernel<float>>(d, act_gate, act_cand, act_cell);
......
......@@ -54,6 +54,8 @@ class KernelPool {
template <typename Ker, typename... ARGS>
const std::shared_ptr<Ker> Get(ARGS... args);
const std::shared_ptr<Kernel> Get(const std::string &key) const;
private:
KernelPool() = default;
std::unordered_map<std::string, std::shared_ptr<Kernel>> kers_;
......@@ -68,6 +70,13 @@ class VMulKernel : public Kernel {
void (*Compute)(const int n, const T *, const T *, T *);
};
template <typename T>
class VAddKernel : public Kernel {
public:
explicit VAddKernel(int n);
void (*Compute)(const int n, const T *, const T *, T *);
};
template <typename T>
class LSTMKernel : public Kernel {
public:
......
......@@ -74,15 +74,22 @@ namespace jit = platform::jit;
FOR_EACH_ALL_BLOCK(macro_, jit::avx) \
FOR_EACH_ALL_BLOCK(macro_, jit::any)
/* VMUL JitKernel */
#define VMUL_ANY \
for (int i = 0; i < n; ++i) { \
z[i] = x[i] * y[i]; \
#define BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, ker_dtype) \
template <> \
ker_class<ker_dtype>::ker_class(int d) { \
SEARCH_ISA_BLOCK(ker_func, ker_dtype); \
}
#define BIND_KERNEL(ker_class, ker_func) \
BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, float); \
BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, double)
/* VMUL JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
static void VMulCompute(const int n, const T* x, const T* y, T* z) {
VMUL_ANY
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
}
#ifdef PADDLE_USE_MKLML
......@@ -107,6 +114,8 @@ FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE)
/// lt8
#ifdef PADDLE_USE_MKLML
VMUL_MKL_FLOAT(jit::avx, kLT8)
VMUL_MKL_FLOAT(jit::avx2, kLT8)
VMUL_MKL_FLOAT(jit::avx512f, kLT8)
#endif
/// eq8
......@@ -143,20 +152,93 @@ VMUL_MKL_FLOAT(jit::avx2, kEQ16)
VMUL_MKL_FLOAT(jit::avx512f, kEQ16)
#endif
#define USE_VMUL_KERNEL(T, func) \
template <> \
VMulKernel<T>::VMulKernel(int d) { \
SEARCH_ISA_BLOCK(func, T); \
}
USE_VMUL_KERNEL(float, VMulCompute);
USE_VMUL_KERNEL(double, VMulCompute);
#undef VMUL_ANY
#undef VMUL_INTRI8_FLOAT
#undef VMUL_MKL_FLOAT
#undef VMUL_MKL_DOUBLE
#undef USE_VMUL_KERNEL
/* VADD */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
static void VAddCompute(const int n, const T* x, const T* y, T* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
#ifdef PADDLE_USE_MKLML
#define VADD_MKL_FLOAT(isa, block) \
template <> \
void VAddCompute<float, isa, block>(const int n, const float* x, \
const float* y, float* z) { \
platform::dynload::vsAdd(n, x, y, z); \
}
#define VADD_MKL_DOUBLE(isa, block) \
template <> \
void VAddCompute<double, isa, block>(const int n, const double* x, \
const double* y, float* z) { \
platform::dynload::vdAdd(n, x, y, z); \
}
FOR_EACH_ISA_COMMON_BLOCK(VADD_MKL_FLOAT)
FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE)
#endif
/// lt8
#ifdef PADDLE_USE_MKLML
VADD_MKL_FLOAT(jit::avx, kLT8)
VADD_MKL_FLOAT(jit::avx2, kLT8)
VADD_MKL_FLOAT(jit::avx512f, kLT8)
#endif
/// eq8
#define VADD_INTRI8_FLOAT(isa) \
template <> \
void VAddCompute<float, isa, kEQ8>(const int n, const float* x, \
const float* y, float* z) { \
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_add_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
}
// mkl > avx > for, ">" means better
#ifdef PADDLE_USE_MKLML
VADD_MKL_FLOAT(jit::avx, kEQ8)
#elif defined __AVX__
VADD_INTRI8_FLOAT(jit::avx)
#endif
// avx2 > mkl > for
#ifdef __AVX2__
VADD_INTRI8_FLOAT(jit::avx2)
#elif defined PADDLE_USE_MKLML
VADD_MKL_FLOAT(jit::avx2, kEQ8)
#endif
// TODO(TJ): test and complete avx512
/// eq16
#ifdef PADDLE_USE_MKLML
// TODO(TJ): test and complete me
VADD_MKL_FLOAT(jit::avx, kEQ16)
VADD_MKL_FLOAT(jit::avx2, kEQ16)
VADD_MKL_FLOAT(jit::avx512f, kEQ16)
#endif
#undef VADD_INTRI8_FLOAT
#undef VADD_MKL_FLOAT
#undef VADD_MKL_DOUBLE
BIND_KERNEL(VMulKernel, VMulCompute);
BIND_KERNEL(VAddKernel, VAddCompute);
#undef BIND_KERNEL
#undef BIND_KERNEL_WITH_DTYPE
#undef FOR_EACH_ISA_ALL_BLOCK
#undef FOR_EACH_ALL_BLOCK
#undef FOR_EACH_ISA_COMMON_BLOCK
#undef FOR_EACH_COMMON_BLOCK
#undef SEARCH_ISA_BLOCK
#undef SEARCH_BLOCK
} // namespace jitkernel
} // namespace math
......
......@@ -23,25 +23,30 @@ TEST(JitKernel, pool) {
namespace jit = paddle::operators::math::jitkernel;
const int frame_size = 4;
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
const auto& p1 =
const auto& plstm1 =
jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, int, const std::string&,
const std::string&, const std::string&>(
frame_size, act_gate, act_cand, act_cell);
const auto& p2 =
const auto& plstm2 =
jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, int, const std::string&,
const std::string&, const std::string&>(
frame_size, act_gate, act_cand, act_cell);
EXPECT_EQ(p1, p2);
EXPECT_EQ(plstm1, plstm2);
const auto& p3 =
const auto& pvmul_f =
jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(4);
EXPECT_TRUE(std::dynamic_pointer_cast<jit::Kernel>(p2) !=
std::dynamic_pointer_cast<jit::Kernel>(p3));
EXPECT_TRUE(std::dynamic_pointer_cast<jit::Kernel>(plstm2) !=
std::dynamic_pointer_cast<jit::Kernel>(pvmul_f));
const auto& p4 =
const auto& pvmul_d =
jit::KernelPool::Instance().template Get<jit::VMulKernel<double>>(4);
EXPECT_TRUE(std::dynamic_pointer_cast<jit::Kernel>(p3) !=
std::dynamic_pointer_cast<jit::Kernel>(p4));
EXPECT_TRUE(std::dynamic_pointer_cast<jit::Kernel>(pvmul_f) !=
std::dynamic_pointer_cast<jit::Kernel>(pvmul_d));
const auto& pvmul_from_key = jit::KernelPool::Instance().Get("vmulf4");
EXPECT_TRUE(pvmul_f == pvmul_from_key);
const auto& pvmul_from_key2 = jit::KernelPool::Instance().Get("vmulf5");
EXPECT_TRUE(pvmul_from_key2 == nullptr);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册