diff --git a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc index 6371c9f83937cfc8b8969ca4a8f0dbe1751b1702..216c7ed9c6671767b4292b9b011ce5ebee7ed3d5 100644 --- a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc @@ -136,10 +136,13 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { UpdateDataFormat(ctx, (Tensor*)x, "x_data_format"); UpdateDataFormat(ctx, (Tensor*)y, "y_data_format"); + Xbyak::util::Cpu cpu; + const bool is_avx512_enabled = cpu.has(Xbyak::util::Cpu::tAVX512F); const bool are_dims_divisable = !(x_int_dims[1] % 16); const bool is_x_format_correct = x->format() == memory::format::nChw16c; const bool is_y_format_correct = y->format() == memory::format::nc; - if (is_x_format_correct && is_y_format_correct && are_dims_divisable) { + if (is_x_format_correct && is_y_format_correct && are_dims_divisable && + is_avx512_enabled) { int pre, n, post; get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post);