提交 9455be0b 编写于 作者: M Michal Gallus

EltwiseMul: Extract StringToFormat to MKLDNN helper

test=develop
上级 726f2cef
...@@ -19,36 +19,21 @@ limitations under the License. */ ...@@ -19,36 +19,21 @@ limitations under the License. */
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
#include "xbyak.h" #include "xbyak/xbyak.h"
#include "xbyak_util.h" #include "xbyak/xbyak_util.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::DataLayout; using framework::DataLayout;
using mkldnn::memory; using mkldnn::memory;
using platform::StringToMKLDNNFormat;
static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) {
std::transform(format.begin(), format.end(), format.begin(), ::tolower);
if (!format.compare("nchw")) {
return memory::format::nchw;
} else if (!format.compare("nchw16c")) {
return memory::format::nChw16c;
} else if (!format.compare("nchw8c")) {
return memory::format::nChw8c;
} else if (!format.compare("nhwc")) {
return memory::format::nhwc;
} else {
return memory::format::any;
}
}
static void UpdateDataFormat(const framework::ExecutionContext& ctx, static void UpdateDataFormat(const framework::ExecutionContext& ctx,
framework::Tensor* tensor, const char* attribute) { framework::Tensor* tensor, const char* attribute) {
if (ctx.op().HasAttr(attribute)) { if (ctx.op().HasAttr(attribute)) {
auto format_as_string = ctx.Attr<std::string>(attribute); auto format_as_string = ctx.Attr<std::string>(attribute);
auto format = StringToMKLDNNFormat(format_as_string); auto format = StringToMKLDNNFormat(&format_as_string);
if (format != memory::format::any) { if (format != memory::format::any) {
tensor->set_format(format); tensor->set_format(format);
} }
...@@ -93,8 +78,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -93,8 +78,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
auto y_dims_untrimmed = y->dims(); auto y_dims_untrimmed = y->dims();
auto x_int_dims = paddle::framework::vectorize2int(x_dims); auto x_int_dims = paddle::framework::vectorize2int(x_dims);
UpdateDataFormat(ctx, (Tensor*)x, "x_data_format"); UpdateDataFormat(ctx, const_cast<Tensor*>(x), "x_data_format");
UpdateDataFormat(ctx, (Tensor*)y, "y_data_format"); UpdateDataFormat(ctx, const_cast<Tensor*>(y), "y_data_format");
Xbyak::util::Cpu cpu; Xbyak::util::Cpu cpu;
const bool is_avx512_enabled = cpu.has(Xbyak::util::Cpu::tAVX512F); const bool is_avx512_enabled = cpu.has(Xbyak::util::Cpu::tAVX512F);
...@@ -156,10 +141,10 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -156,10 +141,10 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
if (!(is_x_nchw || is_x_nc)) if (!(is_x_nchw || is_x_nc))
ReorderInput<T>((Tensor*)x, ctx.GetPlace(), mkldnn_engine, ReorderInput<T>(const_cast<Tensor*>(x), ctx.GetPlace(), mkldnn_engine,
x->dims().size() == 4); x->dims().size() == 4);
if (!(is_y_nchw || is_y_nc)) if (!(is_y_nchw || is_y_nc))
ReorderInput<T>((Tensor*)y, ctx.GetPlace(), mkldnn_engine, ReorderInput<T>(const_cast<Tensor*>(y), ctx.GetPlace(), mkldnn_engine,
y->dims().size() == 4); y->dims().size() == 4);
} }
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <mkldnn.h> #include <mkldnn.h>
#include <algorithm>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -292,5 +293,21 @@ inline mkldnn::memory::format data_format_to_memory_format( ...@@ -292,5 +293,21 @@ inline mkldnn::memory::format data_format_to_memory_format(
} }
} }
inline mkldnn::memory::format StringToMKLDNNFormat(std::string* format) {
std::transform(format->begin(), format->end(), format->begin(), ::tolower);
if (!format->compare("nchw")) {
return mkldnn::memory::format::nchw;
} else if (!format->compare("nchw16c")) {
return mkldnn::memory::format::nChw16c;
} else if (!format->compare("nchw8c")) {
return mkldnn::memory::format::nChw8c;
} else if (!format->compare("nhwc")) {
return mkldnn::memory::format::nhwc;
} else {
return mkldnn::memory::format::any;
}
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册