提交 785066eb 编写于 作者: M Michal Gallus

MKLDNN elementwise_mul: Check if AVX512 is available

test=develop
上级 08f63c4d
...@@ -136,10 +136,13 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -136,10 +136,13 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
UpdateDataFormat(ctx, (Tensor*)x, "x_data_format"); UpdateDataFormat(ctx, (Tensor*)x, "x_data_format");
UpdateDataFormat(ctx, (Tensor*)y, "y_data_format"); UpdateDataFormat(ctx, (Tensor*)y, "y_data_format");
Xbyak::util::Cpu cpu;
const bool is_avx512_enabled = cpu.has(Xbyak::util::Cpu::tAVX512F);
const bool are_dims_divisable = !(x_int_dims[1] % 16); const bool are_dims_divisable = !(x_int_dims[1] % 16);
const bool is_x_format_correct = x->format() == memory::format::nChw16c; const bool is_x_format_correct = x->format() == memory::format::nChw16c;
const bool is_y_format_correct = y->format() == memory::format::nc; const bool is_y_format_correct = y->format() == memory::format::nc;
if (is_x_format_correct && is_y_format_correct && are_dims_divisable) { if (is_x_format_correct && is_y_format_correct && are_dims_divisable &&
is_avx512_enabled) {
int pre, n, post; int pre, n, post;
get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post); get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册