提交 74292f41 编写于 作者: T tensor-tang

enable eltwise nchw16c mul nc

上级 720b55cb
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "xbyak/xbyak.h" #include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h" #include "xbyak/xbyak_util.h"
...@@ -108,10 +108,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -108,10 +108,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr int simd_width = 16; constexpr int simd_width = 16;
int C = c / simd_width; int C = c / simd_width;
const auto& multiply = auto multiply = jit::Get<jit::nchw16cmulnc, jit::NCHW16CMulNCTuples,
math::jitkernel::KernelPool::Instance() platform::CPUPlace>(0);
.template Get<math::jitkernel::EltwiseMulnChw16cNCKernel<T>>(n);
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) { for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) { for (int ci = 0; ci < C; ci++) {
...@@ -122,7 +120,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -122,7 +120,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
auto ptr_z = auto ptr_z =
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
multiply->Compute(ptr_x, ptr_y, ptr_z, h, w); multiply(ptr_x, ptr_y, ptr_z, h, w);
} }
} }
} }
......
...@@ -25,3 +25,4 @@ USE_JITKERNEL_GEN(lstmc1h1) ...@@ -25,3 +25,4 @@ USE_JITKERNEL_GEN(lstmc1h1)
USE_JITKERNEL_GEN(gruh1) USE_JITKERNEL_GEN(gruh1)
USE_JITKERNEL_GEN(gruhtpart1) USE_JITKERNEL_GEN(gruhtpart1)
USE_JITKERNEL_GEN(gruhtpart2) USE_JITKERNEL_GEN(gruhtpart2)
USE_JITKERNEL_GEN(nchw16cmulnc)
...@@ -104,6 +104,48 @@ void VXXJitCode::genCode() { ...@@ -104,6 +104,48 @@ void VXXJitCode::genCode() {
ret(); ret();
} }
void NCHW16CMulNCJitCode::genCode() {
// RDI is ptr x_input
// RSI is ptr y_input
// RDX is ptr output
// RCX is height
// r8 is width
push(rbx);
xor_(rax, rax);
xor_(r10, r10);
vmovups(zmm3, ptr[rsi]);
L("h_loop");
xor_(rbx, rbx);
L("w_loop");
vmovups(zmm2, ptr[rdi + rax]);
vmulps(zmm1, zmm2, zmm3);
vmovups(ptr[rdx + rax], zmm1);
add(rax, 64);
inc(rbx);
cmp(r8, rbx);
jnz("w_loop");
inc(r10);
cmp(r10, rcx);
jnz("h_loop");
pop(rbx);
ret();
}
class NCHW16CMulNCCreator : public JitCodeCreator<int> {
public:
bool UseMe(const int& attr) const override {
return platform::MayIUse(platform::avx512f);
}
size_t CodeSize(const int& d) const override { return 256 * 1024; }
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override {
return make_unique<NCHW16CMulNCJitCode>(attr, CodeSize(attr));
}
};
#define DECLARE_BLAS_CREATOR(name) \ #define DECLARE_BLAS_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \ class name##Creator : public JitCodeCreator<int> { \
public: \ public: \
...@@ -141,3 +183,4 @@ REGISTER_JITKERNEL_GEN(vadd, gen::VAddCreator); ...@@ -141,3 +183,4 @@ REGISTER_JITKERNEL_GEN(vadd, gen::VAddCreator);
REGISTER_JITKERNEL_GEN(vaddrelu, gen::VAddReluCreator); REGISTER_JITKERNEL_GEN(vaddrelu, gen::VAddReluCreator);
REGISTER_JITKERNEL_GEN(vscal, gen::VScalCreator); REGISTER_JITKERNEL_GEN(vscal, gen::VScalCreator);
REGISTER_JITKERNEL_GEN(vaddbias, gen::VAddBiasCreator); REGISTER_JITKERNEL_GEN(vaddbias, gen::VAddBiasCreator);
REGISTER_JITKERNEL_GEN(nchw16cmulnc, gen::NCHW16CMulNCCreator);
...@@ -99,6 +99,18 @@ DECLARE_BLAS_JITCODE(VAddBias, operand_type::add, 1, false); ...@@ -99,6 +99,18 @@ DECLARE_BLAS_JITCODE(VAddBias, operand_type::add, 1, false);
#undef DECLARE_BLAS_JITCODE #undef DECLARE_BLAS_JITCODE
// nChw16c = nChw16c .* NC
class NCHW16CMulNCJitCode : public JitCode {
public:
DECLARE_JIT_CODE(NCHW16CMulNCJitCode);
explicit NCHW16CMulNCJitCode(int d /*unused*/, size_t code_size,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr) {
this->genCode();
}
void genCode() override;
};
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace operators
......
...@@ -44,8 +44,9 @@ const char* to_string(KernelType kt) { ...@@ -44,8 +44,9 @@ const char* to_string(KernelType kt) {
ONE_CASE(gruhtpart2); ONE_CASE(gruhtpart2);
ONE_CASE(crfdecoding); ONE_CASE(crfdecoding);
ONE_CASE(layernorm); ONE_CASE(layernorm);
ONE_CASE(nchw16cmulnc);
default: default:
PADDLE_THROW("Not support type: %d", kt); PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
return "NOT JITKernel"; return "NOT JITKernel";
} }
return nullptr; return nullptr;
......
...@@ -93,6 +93,7 @@ inline typename KernelTuples::func_type GetRefer() { ...@@ -93,6 +93,7 @@ inline typename KernelTuples::func_type GetRefer() {
template <KernelType KT, typename KernelTuples, template <KernelType KT, typename KernelTuples,
typename PlaceType = platform::CPUPlace> typename PlaceType = platform::CPUPlace>
// TODO(TJ): const & attr
typename KernelTuples::func_type Get(typename KernelTuples::attr_type attr) { typename KernelTuples::func_type Get(typename KernelTuples::attr_type attr) {
auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr); auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitfunc) { if (jitfunc) {
......
...@@ -39,7 +39,8 @@ typedef enum { ...@@ -39,7 +39,8 @@ typedef enum {
gruhtpart1, gruhtpart1,
gruhtpart2, gruhtpart2,
crfdecoding, crfdecoding,
layernorm layernorm,
nchw16cmulnc,
} KernelType; } KernelType;
template <typename T> template <typename T>
...@@ -126,6 +127,14 @@ struct LayerNormTuples { ...@@ -126,6 +127,14 @@ struct LayerNormTuples {
const float, int); const float, int);
}; };
// nChw16c = nChw16c .* NC
template <typename T>
struct NCHW16CMulNCTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int, int);
};
// Just for adding to kernel pool without template // Just for adding to kernel pool without template
class Kernel { class Kernel {
public: public:
......
...@@ -25,3 +25,4 @@ USE_JITKERNEL_REFER(gruhtpart1) ...@@ -25,3 +25,4 @@ USE_JITKERNEL_REFER(gruhtpart1)
USE_JITKERNEL_REFER(gruhtpart2) USE_JITKERNEL_REFER(gruhtpart2)
USE_JITKERNEL_REFER(crfdecoding) USE_JITKERNEL_REFER(crfdecoding)
USE_JITKERNEL_REFER(layernorm) USE_JITKERNEL_REFER(layernorm)
USE_JITKERNEL_REFER(nchw16cmulnc)
...@@ -45,4 +45,6 @@ REGISTER_REFER_KERNEL(gruhtpart2, GRUHtPart2); ...@@ -45,4 +45,6 @@ REGISTER_REFER_KERNEL(gruhtpart2, GRUHtPart2);
REGISTER_REFER_KERNEL(crfdecoding, CRFDecoding); REGISTER_REFER_KERNEL(crfdecoding, CRFDecoding);
REGISTER_REFER_KERNEL(layernorm, LayerNorm); REGISTER_REFER_KERNEL(layernorm, LayerNorm);
REGISTER_REFER_KERNEL(nchw16cmulnc, NCHW16CMulNC);
#undef REGISTER_REFER_KERNEL #undef REGISTER_REFER_KERNEL
...@@ -319,6 +319,19 @@ void LayerNorm(T* x, T* out, T* mean, T* var, const T* scale, const T* bias, ...@@ -319,6 +319,19 @@ void LayerNorm(T* x, T* out, T* mean, T* var, const T* scale, const T* bias,
} }
} }
template <typename T>
void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) {
int offset = 0;
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
for (int i = 0; i < 16; ++i) {
z[i + offset] = y[i] * x[i + offset];
}
offset += ZMM_FLOAT_BLOCK;
}
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \ #define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \ template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \ class name##Kernel : public ReferKernel<tuples<T>> { \
...@@ -355,6 +368,8 @@ DECLARE_REFER_KERNEL(GRUHtPart2, GRUTuples); ...@@ -355,6 +368,8 @@ DECLARE_REFER_KERNEL(GRUHtPart2, GRUTuples);
DECLARE_REFER_KERNEL(CRFDecoding, CRFDecodingTuples); DECLARE_REFER_KERNEL(CRFDecoding, CRFDecodingTuples);
DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples); DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples);
DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples);
#undef DECLARE_REFER_KERNEL #undef DECLARE_REFER_KERNEL
} // namespace refer } // namespace refer
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
template <typename T> template <typename T>
...@@ -414,6 +415,59 @@ void TestGRUKernel() { ...@@ -414,6 +415,59 @@ void TestGRUKernel() {
} }
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestNCHW16CMulNCKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
const int n = 3, c = 16 * 4, h = 10, w = 10;
auto ref = jit::GetRefer<KT, jit::NCHW16CMulNCTuples<T>>();
EXPECT_TRUE(ref != nullptr);
int sz = n * c * h * w;
std::vector<T> x(sz), y(n * c), zref(sz);
std::vector<T> ztgt(sz), zjit(sz);
RandomVec<T>(sz, x.data(), -2.f, 2.f);
RandomVec<T>(n * c, y.data(), -2.f, 2.f);
const T* x_data = x.data();
const T* y_data = y.data();
T* zref_data = zref.data();
T* ztgt_data = ztgt.data();
T* zjit_data = zjit.data();
constexpr int simd_width = ZMM_FLOAT_BLOCK;
int C = c / simd_width;
auto tgt = jit::Get<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>(0);
auto jitcode = jit::GetJitCode<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>(0);
EXPECT_TRUE(tgt != nullptr);
if (std::is_same<T, float>::value &&
paddle::platform::MayIUse(paddle::platform::avx512f)) {
EXPECT_TRUE(jitcode != nullptr);
}
for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) {
auto ptr_x =
x_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
auto ptr_zref =
zref_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
auto ptr_ztgt =
ztgt_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
ref(ptr_x, ptr_y, ptr_zref, h, w);
tgt(ptr_x, ptr_y, ptr_ztgt, h, w);
if (jitcode) {
auto ptr_zjit =
zjit_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
jitcode(ptr_x, ptr_y, ptr_zjit, h, w);
}
}
}
ExpectEQ<T>(ztgt_data, zref_data, sz);
if (jitcode) {
ExpectEQ<T>(zjit_data, zref_data, sz);
}
}
// XYZNTuple // XYZNTuple
TEST(JITKernel, vmul) { TEST(JITKernel, vmul) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
...@@ -515,6 +569,14 @@ TEST(JITKernel, gruhtpart2) { ...@@ -515,6 +569,14 @@ TEST(JITKernel, gruhtpart2) {
TestGRUKernel<jit::gruhtpart2, double, paddle::platform::CPUPlace>(); TestGRUKernel<jit::gruhtpart2, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, nchw16cmulnc) {
namespace jit = paddle::operators::jit;
TestNCHW16CMulNCKernel<jit::nchw16cmulnc, float,
paddle::platform::CPUPlace>();
TestNCHW16CMulNCKernel<jit::nchw16cmulnc, double,
paddle::platform::CPUPlace>();
}
// TODO(yihua/TJ): add crf decoding and layer norm unit tests // TODO(yihua/TJ): add crf decoding and layer norm unit tests
TEST(JITKernel, pool) { TEST(JITKernel, pool) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册