diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc index 10290a4aeff6b6a023fb28961d12728aff891e83..c600d1e3d76f7a989dd61e72caf4967aa5923c6f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc @@ -19,36 +19,21 @@ limitations under the License. */ #include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/operators/math/jit_kernel.h" -#include "xbyak.h" -#include "xbyak_util.h" +#include "xbyak/xbyak.h" +#include "xbyak/xbyak_util.h" namespace paddle { namespace operators { using framework::DataLayout; using mkldnn::memory; - -static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) { - std::transform(format.begin(), format.end(), format.begin(), ::tolower); - - if (!format.compare("nchw")) { - return memory::format::nchw; - } else if (!format.compare("nchw16c")) { - return memory::format::nChw16c; - } else if (!format.compare("nchw8c")) { - return memory::format::nChw8c; - } else if (!format.compare("nhwc")) { - return memory::format::nhwc; - } else { - return memory::format::any; - } -} +using platform::StringToMKLDNNFormat; static void UpdateDataFormat(const framework::ExecutionContext& ctx, framework::Tensor* tensor, const char* attribute) { if (ctx.op().HasAttr(attribute)) { auto format_as_string = ctx.Attr(attribute); - auto format = StringToMKLDNNFormat(format_as_string); + auto format = StringToMKLDNNFormat(&format_as_string); if (format != memory::format::any) { tensor->set_format(format); } @@ -93,8 +78,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { auto y_dims_untrimmed = y->dims(); auto x_int_dims = paddle::framework::vectorize2int(x_dims); - UpdateDataFormat(ctx, (Tensor*)x, "x_data_format"); - UpdateDataFormat(ctx, (Tensor*)y, "y_data_format"); + UpdateDataFormat(ctx, const_cast(x), "x_data_format"); + UpdateDataFormat(ctx, const_cast(y), "y_data_format"); Xbyak::util::Cpu cpu; const bool is_avx512_enabled = cpu.has(Xbyak::util::Cpu::tAVX512F); @@ -156,10 +141,10 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); if (!(is_x_nchw || is_x_nc)) - ReorderInput((Tensor*)x, ctx.GetPlace(), mkldnn_engine, + ReorderInput(const_cast(x), ctx.GetPlace(), mkldnn_engine, x->dims().size() == 4); if (!(is_y_nchw || is_y_nc)) - ReorderInput((Tensor*)y, ctx.GetPlace(), mkldnn_engine, + ReorderInput(const_cast(y), ctx.GetPlace(), mkldnn_engine, y->dims().size() == 4); } diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 814012e6c1fad414d10f5a64af283bed57e11fe3..761a9815e098098cb4c4080bd8605dde7f6870a4 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include #include #include #include "paddle/fluid/framework/operator.h" @@ -292,5 +293,21 @@ inline mkldnn::memory::format data_format_to_memory_format( } } +inline mkldnn::memory::format StringToMKLDNNFormat(std::string* format) { + std::transform(format->begin(), format->end(), format->begin(), ::tolower); + + if (!format->compare("nchw")) { + return mkldnn::memory::format::nchw; + } else if (!format->compare("nchw16c")) { + return mkldnn::memory::format::nChw16c; + } else if (!format->compare("nchw8c")) { + return mkldnn::memory::format::nChw8c; + } else if (!format->compare("nhwc")) { + return mkldnn::memory::format::nhwc; + } else { + return mkldnn::memory::format::any; + } +} + } // namespace platform } // namespace paddle