提交 c69c4160 编写于 作者: M Michal Gallus

MKLDNN elementwise_mul: Move Kernel to KernelPool to avoid segfaults

test=develop
上级 99e3e36a
......@@ -13,13 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <mkldnn/include/mkldnn.hpp>
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "xbyak.h"
#include "xbyak_util.h"
namespace paddle {
namespace operators {
......@@ -27,47 +28,6 @@ namespace operators {
using framework::DataLayout;
using mkldnn::memory;
struct vector_mul : public Xbyak::CodeGenerator {
vector_mul() {
// RDI is ptr X
// RSI is ptr Y
// RDX is ptr Z
// RCX is h
// r8 is w
push(rbx);
xor_(rax, rax);
xor_(r10, r10);
vmovups(zmm3, ptr[rsi]);
L("h_loop");
xor_(rbx, rbx);
L("w_loop");
vmovups(zmm2, ptr[rdi + rax]);
vmulps(zmm1, zmm2, zmm3);
vmovups(ptr[rdx + rax], zmm1);
add(rax, 64);
inc(rbx);
cmp(r8, rbx);
jnz("w_loop");
inc(r10);
cmp(r10, rcx);
jnz("h_loop");
pop(rbx);
ret();
}
};
void check(const float* x, const float* y, float* z, int w) {
for (int wi = 0; wi < w; wi++) {
for (int i = 0; i < 16; i++) {
z[wi * 16 + i] = x[wi * 16 + i] * y[i];
}
}
}
static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) {
std::transform(format.begin(), format.end(), format.begin(), ::tolower);
......@@ -163,12 +123,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr int simd_width = 16;
int C = c / simd_width;
vector_mul mul;
using mul_func_t =
void (*)(const float*, const float*, float*, int, int);
mul_func_t mul_func = (mul_func_t)mul.getCode();
const auto& multiply =
math::jitkernel::KernelPool::Instance()
.template Get<math::jitkernel::EltwiseMulnChw16cNCKernel<T>>(n);
#pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) {
......@@ -180,7 +137,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
auto ptr_z =
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
mul_func(ptr_x, ptr_y, ptr_z, h, w);
multiply->Compute(ptr_x, ptr_y, ptr_z, h, w);
}
}
}
......
......@@ -156,6 +156,42 @@ class VActJitCode : public JitCode {
ymm_t ymm_dst = ymm_t(1);
};
#ifdef PADDLE_WITH_MKLDNN
struct EltwiseMulnChw16cNC : public Xbyak::CodeGenerator {
explicit EltwiseMulnChw16cNC(size_t code_size = 256 * 1024)
: Xbyak::CodeGenerator(code_size) {
// RDI is ptr x_input
// RSI is ptr y_input
// RDX is ptr output
// RCX is height
// r8 is width
push(rbx);
xor_(rax, rax);
xor_(r10, r10);
vmovups(zmm3, ptr[rsi]);
L("h_loop");
xor_(rbx, rbx);
L("w_loop");
vmovups(zmm2, ptr[rdi + rax]);
vmulps(zmm1, zmm2, zmm3);
vmovups(ptr[rdx + rax], zmm1);
add(rax, 64);
inc(rbx);
cmp(r8, rbx);
jnz("w_loop");
inc(r10);
cmp(r10, rcx);
jnz("h_loop");
pop(rbx);
ret();
}
};
#endif
} // namespace gen
} // namespace jitkernel
} // namespace math
......
......@@ -94,6 +94,15 @@ class VAddBiasKernel : public Kernel {
void (*Compute)(const T *, const T *, T *, int);
};
#ifdef PADDLE_WITH_MKLDNN
template <typename T>
class EltwiseMulnChw16cNCKernel : public Kernel {
public:
// nChw16c = nChw16c .* NC
void (*Compute)(const float *, const float *, float *, int, int);
};
#endif
template <typename T>
class VActKernel : public Kernel {
public:
......
......@@ -226,6 +226,44 @@ bool VAddKernelImpl<double>::useMKL(int d) {
}
#endif
#ifdef PADDLE_WITH_MKLDNN
/* EltwiseMul for nChw16c & NC inputs JitKernel */
template <typename T>
class EltwiseMulnChw16cNCKernelImpl
: public math::jitkernel::EltwiseMulnChw16cNCKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit EltwiseMulnChw16cNCKernelImpl(int d)
: EltwiseMulnChw16cNCKernel<T>() {
using mul_func_t = void (*)(const float*, const float*, float*, int, int);
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
// roughly estimate the size of code
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
sz = sz > 4096 ? sz : 4096;
jitcode_.reset(new gen::EltwiseMulnChw16cNC(sz));
this->Compute = (mul_func_t)jitcode_->getCode();
return;
}
#endif
PADDLE_THROW(
"This kernel shouldn't be used in Non-Xbyak, Non-MKL-DNN "
"environemnt");
}
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::EltwiseMulnChw16cNC> jitcode_{nullptr};
};
template <>
bool EltwiseMulnChw16cNCKernelImpl<float>::useJIT(int d) {
return true;
}
#endif
#endif
/* VAddRelu JitKernel */
template <typename T>
class VAddReluKernelImpl : public VAddReluKernel<T> {
......@@ -394,6 +432,9 @@ REGISTER_JITKERNEL(vscal, VScalKernel);
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
REGISTER_JITKERNEL(vrelu, VReluKernel);
REGISTER_JITKERNEL(videntity, VIdentityKernel);
#ifdef PADDLE_WITH_MKLDNN
REGISTER_JITKERNEL(eltwise_mul_nchw16c, EltwiseMulnChw16cNCKernel);
#endif
} // namespace jitkernel
} // namespace math
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册