提交 700bcbf7 编写于 作者: T Tomasz Patejko 提交者: Michal Gallus

MKLDNN elementwise_mul: h and w loops implemented in xbyak

上级 ad09faca
......@@ -30,16 +30,42 @@ struct vector_mul : public Xbyak::CodeGenerator {
// RDI is ptr X
// RSI is ptr Y
// RDX is ptr Z
// RCX is h
// r8 is w
vmovups(zmm2, ptr[rdi]);
push(rbx);
xor_(rax, rax);
xor_(r10, r10);
vmovups(zmm3, ptr[rsi]);
vmulps(zmm1, zmm2, zmm3);
vmovups(ptr[rdx], zmm1);
L("h_loop");
xor_(rbx, rbx);
L("w_loop");
vmovups(zmm2, ptr[rdi + rax]);
vmulps(zmm1, zmm2, zmm3);
vmovups(ptr[rdx + rax], zmm1);
add(rax, 64);
inc(rbx);
cmp(r8, rbx);
jnz("w_loop");
inc(r10);
cmp(r10, rcx);
jnz("h_loop");
pop(rbx);
ret();
}
};
void check(const float* x, const float* y, float* z, int w) {
for (int wi = 0; wi < w; wi++) {
for (int i = 0; i < 16; i++) {
z[wi * 16 + i] = x[wi * 16 + i] * y[i];
}
}
}
template <typename T>
class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
public:
......@@ -65,7 +91,6 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
PADDLE_THROW("Not implemented when post is 1");
} else {
// Just check whether it works for RE-Resnext.
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions");
int n = x_dims[0];
......@@ -81,26 +106,21 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
vector_mul mul;
using mul_func_t = void (*)(const float*, const float*, float*);
using mul_func_t =
void (*)(const float*, const float*, float*, int, int);
mul_func_t mul_func = (mul_func_t)mul.getCode();
for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) {
for (int hi = 0; hi < h; hi++) {
for (int wi = 0; wi < w; wi++) {
auto ptr_x = x_data + ni * C * h * w * simd_width +
ci * h * w * simd_width + hi * w * simd_width +
wi * simd_width;
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
auto ptr_z = z_data + ni * C * h * w * simd_width +
ci * h * w * simd_width + hi * w * simd_width +
wi * simd_width;
mul_func(ptr_x, ptr_y, ptr_z);
}
}
auto ptr_x =
x_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
auto ptr_z =
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
mul_func(ptr_x, ptr_y, ptr_z, h, w);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册