提交 e81f0228 编写于 作者: A Adam 提交者: Tao Luo

MKL-DNN 1.0 Update (#20162)

* MKLDNN v1.0 rebase to Paddle 1.6
test=develop

* Add hacky paddle::string::to_string() implementation

* vectorize<int64-t>() -> vectorize() cleanup
test=develop

* PADDLE_ENFORCE and void_cast fixes
test=develop

* Rebase changes
test=develop

* Cosmetics
test=develop

* Delete MKL from mkldnn.cmake
test=develop

* CMake debug commands
test=develop

* Delete MKLDNN_VERBOSE and rebase fixes
test=develop

* Rebase fixes
test=develop

* Temporarily disable int8 resnet101 vgg16 and vgg19 tests
test=develop

* Add libmkldnn.so.1 to python setup
test=develop

* Add libmkldnn.so.1 to inference_lib cmake after rebase
test=develop

* Post rebase fixes + FC int8 changes
test=develop

* Fix LRN NHWC
test=develop

* Fix NHWC conv3d
test=develop

* Windows build fix + next conv3d fix
test=develop

* Fix conv2d on AVX2 machines
test=develop
上级 7f5d532a
...@@ -19,7 +19,7 @@ SET(MKLDNN_PREFIX_DIR ${THIRD_PARTY_PATH}/mkldnn) ...@@ -19,7 +19,7 @@ SET(MKLDNN_PREFIX_DIR ${THIRD_PARTY_PATH}/mkldnn)
SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn) SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn)
SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE) SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE)
SET(MKLDNN_REPOSITORY https://github.com/intel/mkl-dnn.git) SET(MKLDNN_REPOSITORY https://github.com/intel/mkl-dnn.git)
SET(MKLDNN_TAG aef88b7c233f48f8b945da310f1b973da31ad033) SET(MKLDNN_TAG 518a316a8cd6deb82dc7866bc04bd0355a25c3a4)
# Introduce variables: # Introduce variables:
# * CMAKE_INSTALL_LIBDIR # * CMAKE_INSTALL_LIBDIR
...@@ -35,13 +35,6 @@ SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/${LIBDIR ...@@ -35,13 +35,6 @@ SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/${LIBDIR
INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR}) # For MKLDNN code to include internal headers. INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR}) # For MKLDNN code to include internal headers.
IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
SET(MKLDNN_DEPENDS ${MKLML_PROJECT})
MESSAGE(STATUS "Build MKLDNN with MKLML ${MKLML_ROOT}")
ELSE()
MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN")
ENDIF()
IF(NOT WIN32) IF(NOT WIN32)
SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-error=array-bounds") SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-error=array-bounds")
SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value") SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value")
...@@ -63,7 +56,8 @@ ExternalProject_Add( ...@@ -63,7 +56,8 @@ ExternalProject_Add(
DEPENDS ${MKLDNN_DEPENDS} DEPENDS ${MKLDNN_DEPENDS}
PREFIX ${MKLDNN_PREFIX_DIR} PREFIX ${MKLDNN_PREFIX_DIR}
SOURCE_DIR ${MKLDNN_SOURCE_DIR} SOURCE_DIR ${MKLDNN_SOURCE_DIR}
UPDATE_COMMAND "" BUILD_ALWAYS 1
# UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
...@@ -77,9 +71,8 @@ ExternalProject_Add( ...@@ -77,9 +71,8 @@ ExternalProject_Add(
-DMKLROOT=${MKLML_ROOT} -DMKLROOT=${MKLML_ROOT}
-DCMAKE_C_FLAGS=${MKLDNN_CFLAG} -DCMAKE_C_FLAGS=${MKLDNN_CFLAG}
-DCMAKE_CXX_FLAGS=${MKLDNN_CXXFLAG} -DCMAKE_CXX_FLAGS=${MKLDNN_CXXFLAG}
-DWITH_TEST=OFF -DWITH_EXAMPLE=OFF -DMKLDNN_BUILD_TESTS=OFF -DMKLDNN_BUILD_EXAMPLES=OFF
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR}
-DMKLROOT:PATH=${MKLML_ROOT}
) )
if(WIN32) if(WIN32)
SET(MKLDNN_LIB "${MKLDNN_INSTALL_DIR}/${LIBDIR}/mkldnn.lib" CACHE FILEPATH "mkldnn library." FORCE) SET(MKLDNN_LIB "${MKLDNN_INSTALL_DIR}/${LIBDIR}/mkldnn.lib" CACHE FILEPATH "mkldnn library." FORCE)
...@@ -98,7 +91,7 @@ add_definitions(-DPADDLE_WITH_MKLDNN) ...@@ -98,7 +91,7 @@ add_definitions(-DPADDLE_WITH_MKLDNN)
SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/mkldnn_dummy.c) SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/mkldnn_dummy.c)
FILE(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";") FILE(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
ADD_LIBRARY(mkldnn STATIC ${dummyfile}) ADD_LIBRARY(mkldnn STATIC ${dummyfile})
TARGET_LINK_LIBRARIES(mkldnn ${MKLDNN_LIB} ${MKLML_LIB} ${MKLML_IOMP_LIB}) TARGET_LINK_LIBRARIES(mkldnn ${MKLDNN_LIB} ${MKLML_IOMP_LIB})
ADD_DEPENDENCIES(mkldnn ${MKLDNN_PROJECT}) ADD_DEPENDENCIES(mkldnn ${MKLDNN_PROJECT})
# copy the real so.0 lib to install dir # copy the real so.0 lib to install dir
...@@ -107,6 +100,9 @@ if(WIN32) ...@@ -107,6 +100,9 @@ if(WIN32)
SET(MKLDNN_SHARED_LIB ${MKLDNN_INSTALL_DIR}/bin/mkldnn.dll) SET(MKLDNN_SHARED_LIB ${MKLDNN_INSTALL_DIR}/bin/mkldnn.dll)
else(WIN32) else(WIN32)
SET(MKLDNN_SHARED_LIB ${MKLDNN_INSTALL_DIR}/libmkldnn.so.0) SET(MKLDNN_SHARED_LIB ${MKLDNN_INSTALL_DIR}/libmkldnn.so.0)
SET(MKLDNN_SHARED_LIB_1 ${MKLDNN_INSTALL_DIR}/libmkldnn.so.1)
ADD_CUSTOM_COMMAND(TARGET ${MKLDNN_PROJECT} POST_BUILD ADD_CUSTOM_COMMAND(TARGET ${MKLDNN_PROJECT} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${MKLDNN_LIB} ${MKLDNN_SHARED_LIB}) COMMAND ${CMAKE_COMMAND} -E copy ${MKLDNN_LIB} ${MKLDNN_SHARED_LIB})
ADD_CUSTOM_COMMAND(TARGET ${MKLDNN_PROJECT} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${MKLDNN_LIB} ${MKLDNN_SHARED_LIB_1})
endif(WIN32) endif(WIN32)
...@@ -84,8 +84,8 @@ function(copy_part_of_thrid_party TARGET DST) ...@@ -84,8 +84,8 @@ function(copy_part_of_thrid_party TARGET DST)
DSTS ${dst_dir} ${dst_dir}/lib ${dst_dir}/lib) DSTS ${dst_dir} ${dst_dir}/lib ${dst_dir}/lib)
else() else()
copy(${TARGET} copy(${TARGET}
SRCS ${MKLDNN_INC_DIR} ${MKLDNN_SHARED_LIB} SRCS ${MKLDNN_INC_DIR} ${MKLDNN_SHARED_LIB} ${MKLDNN_SHARED_LIB_1}
DSTS ${dst_dir} ${dst_dir}/lib) DSTS ${dst_dir} ${dst_dir}/lib ${dst_dir}/lib)
endif() endif()
endif() endif()
......
...@@ -105,8 +105,6 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) { ...@@ -105,8 +105,6 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
return platform::to_void_cast(tensor.data<int8_t>()); return platform::to_void_cast(tensor.data<int8_t>());
case mkldnn::memory::data_type::u8: case mkldnn::memory::data_type::u8:
return platform::to_void_cast(tensor.data<unsigned char>()); return platform::to_void_cast(tensor.data<unsigned char>());
case mkldnn::memory::data_type::s16:
return platform::to_void_cast(tensor.data<int16_t>());
case mkldnn::memory::data_type::s32: case mkldnn::memory::data_type::s32:
return platform::to_void_cast(tensor.data<int32_t>()); return platform::to_void_cast(tensor.data<int32_t>());
default: default:
...@@ -134,7 +132,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -134,7 +132,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
const Tensor& in, Tensor* out, const Tensor& in, Tensor* out,
platform::Place place) { platform::Place place) {
PADDLE_ENFORCE_NE(in.format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(in.format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input tensor format is invalid. Input tensor should " "Input tensor format is invalid. Input tensor should "
"have specified memory format.")); "have specified memory format."));
...@@ -151,12 +149,12 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, ...@@ -151,12 +149,12 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(pool.Get(place)); auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(pool.Get(place));
auto& cpu_engine = dev_ctx->GetEngine(); auto& cpu_engine = dev_ctx->GetEngine();
auto in_tz = paddle::framework::vectorize<int>(in.dims()); auto in_tz = paddle::framework::vectorize<int64_t>(in.dims());
auto out_tz = in_tz; auto out_tz = in_tz;
memory::data_type in_type = ToMKLDNNDataType(in.type()); memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE(in_type != memory::data_type::data_undef, PADDLE_ENFORCE_NE(in_type, memory::data_type::undef,
"Input tensor type is not supported: %s", in.type()); "Input tensor type is not supported: %s", in.type());
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format()); auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
auto out_format = auto out_format =
...@@ -167,8 +165,8 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, ...@@ -167,8 +165,8 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
if (in_format != out_format) { if (in_format != out_format) {
void* in_data = GetDataFromTensor(in, in_type); void* in_data = GetDataFromTensor(in, in_type);
const std::string key = platform::CreateKey(in_tz, in_format, out_format, const std::string key =
std::to_string(in_type)); platform::CreateKey(in_tz, in_format, out_format, in_type);
platform::ReorderMKLDNNHandler handler(in_tz, in.type(), in_type, *dev_ctx, platform::ReorderMKLDNNHandler handler(in_tz, in.type(), in_type, *dev_ctx,
cpu_engine, key); cpu_engine, key);
...@@ -179,9 +177,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, ...@@ -179,9 +177,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
auto reorder_p = auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p); handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
std::vector<mkldnn::primitive> pipeline; mkldnn::stream astream(cpu_engine);
pipeline.push_back(*reorder_p); reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); astream.wait();
} else { } else {
out->ShareDataWith(in); out->ShareDataWith(in);
} }
...@@ -193,7 +191,7 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, ...@@ -193,7 +191,7 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
} }
out->set_layout(out_layout); out->set_layout(out_layout);
// reset format since the out tensor will be feed to non-MKLDNN OPkernel // reset format since the out tensor will be feed to non-MKLDNN OPkernel
out->set_format(MKLDNNMemoryFormat::format_undef); out->set_format(MKLDNNMemoryFormat::undef);
} }
#endif #endif
......
...@@ -59,11 +59,10 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) { ...@@ -59,11 +59,10 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
{DataTypeTrait<float>::DataType(), MKLDNNDataType::f32}, {DataTypeTrait<float>::DataType(), MKLDNNDataType::f32},
{DataTypeTrait<int8_t>::DataType(), MKLDNNDataType::s8}, {DataTypeTrait<int8_t>::DataType(), MKLDNNDataType::s8},
{DataTypeTrait<uint8_t>::DataType(), MKLDNNDataType::u8}, {DataTypeTrait<uint8_t>::DataType(), MKLDNNDataType::u8},
{DataTypeTrait<int16_t>::DataType(), MKLDNNDataType::s16},
{DataTypeTrait<int32_t>::DataType(), MKLDNNDataType::s32}}; {DataTypeTrait<int32_t>::DataType(), MKLDNNDataType::s32}};
auto iter = dict.find(static_cast<int>(type)); auto iter = dict.find(static_cast<int>(type));
if (iter != dict.end()) return iter->second; if (iter != dict.end()) return iter->second;
return MKLDNNDataType::data_undef; return MKLDNNDataType::undef;
} }
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
......
...@@ -38,9 +38,9 @@ class Tensor { ...@@ -38,9 +38,9 @@ class Tensor {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
public: public:
inline mkldnn::memory::format format() const { return format_; } inline mkldnn::memory::format_tag format() const { return format_; }
inline void set_format(const mkldnn::memory::format format) { inline void set_format(const mkldnn::memory::format_tag format) {
format_ = format; format_ = format;
} }
...@@ -54,7 +54,7 @@ class Tensor { ...@@ -54,7 +54,7 @@ class Tensor {
* this field. * this field.
*/ */
mkldnn::memory::format format_ = mkldnn::memory::format::format_undef; mkldnn::memory::format_tag format_ = mkldnn::memory::format_tag::undef;
#endif #endif
public: public:
......
...@@ -248,19 +248,22 @@ if(WITH_MKLDNN) ...@@ -248,19 +248,22 @@ if(WITH_MKLDNN)
inference_analysis_api_int8_test_run(test_analyzer_int8_mobilenetv2 ${INT8_IMG_CLASS_TEST_APP} ${INT8_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH}) inference_analysis_api_int8_test_run(test_analyzer_int8_mobilenetv2 ${INT8_IMG_CLASS_TEST_APP} ${INT8_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH})
# resnet101 int8 # resnet101 int8
set(INT8_RESNET101_MODEL_DIR "${INT8_DATA_DIR}/resnet101") # TODO(grygielski) Enable after MKL-DNN 1.0 merge
download_int8_data(${INT8_RESNET101_MODEL_DIR} "Res101_int8_model.tar.gz" ) # set(INT8_RESNET101_MODEL_DIR "${INT8_DATA_DIR}/resnet101")
inference_analysis_api_int8_test_run(test_analyzer_int8_resnet101 ${INT8_IMG_CLASS_TEST_APP} ${INT8_RESNET101_MODEL_DIR} ${IMAGENET_DATA_PATH}) # download_int8_data(${INT8_RESNET101_MODEL_DIR} "Res101_int8_model.tar.gz" )
# inference_analysis_api_int8_test_run(test_analyzer_int8_resnet101 ${INT8_IMG_CLASS_TEST_APP} ${INT8_RESNET101_MODEL_DIR} ${IMAGENET_DATA_PATH})
# vgg16 int8 # vgg16 int8
set(INT8_VGG16_MODEL_DIR "${INT8_DATA_DIR}/vgg16") # TODO(grygielski) Enable after MKL-DNN 1.0 merge
download_int8_data(${INT8_VGG16_MODEL_DIR} "VGG16_int8_model.tar.gz" ) # set(INT8_VGG16_MODEL_DIR "${INT8_DATA_DIR}/vgg16")
inference_analysis_api_int8_test_run(test_analyzer_int8_vgg16 ${INT8_IMG_CLASS_TEST_APP} ${INT8_VGG16_MODEL_DIR} ${IMAGENET_DATA_PATH}) # download_int8_data(${INT8_VGG16_MODEL_DIR} "VGG16_int8_model.tar.gz" )
# inference_analysis_api_int8_test_run(test_analyzer_int8_vgg16 ${INT8_IMG_CLASS_TEST_APP} ${INT8_VGG16_MODEL_DIR} ${IMAGENET_DATA_PATH})
# vgg19 int8 # vgg19 int8
set(INT8_VGG19_MODEL_DIR "${INT8_DATA_DIR}/vgg19") # TODO(grygielski) Enable after MKL-DNN 1.0 merge
download_int8_data(${INT8_VGG19_MODEL_DIR} "VGG19_int8_model.tar.gz" ) # set(INT8_VGG19_MODEL_DIR "${INT8_DATA_DIR}/vgg19")
inference_analysis_api_int8_test_run(test_analyzer_int8_vgg19 ${INT8_IMG_CLASS_TEST_APP} ${INT8_VGG19_MODEL_DIR} ${IMAGENET_DATA_PATH}) # download_int8_data(${INT8_VGG19_MODEL_DIR} "VGG19_int8_model.tar.gz" )
# inference_analysis_api_int8_test_run(test_analyzer_int8_vgg19 ${INT8_IMG_CLASS_TEST_APP} ${INT8_VGG19_MODEL_DIR} ${IMAGENET_DATA_PATH})
# googlenet int8 # googlenet int8
set(INT8_GOOGLENET_MODEL_DIR "${INT8_DATA_DIR}/googlenet") set(INT8_GOOGLENET_MODEL_DIR "${INT8_DATA_DIR}/googlenet")
......
...@@ -31,7 +31,7 @@ class ElementwiseMulOp : public ElementwiseOp { ...@@ -31,7 +31,7 @@ class ElementwiseMulOp : public ElementwiseOp {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
static bool AreDimsAndFormatCorrect(const framework::ExecutionContext& ctx, static bool AreDimsAndFormatCorrect(const framework::ExecutionContext& ctx,
int simd_width, int simd_width,
mkldnn::memory::format x_format) { mkldnn::memory::format_tag x_format) {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using paddle::framework::vectorize; using paddle::framework::vectorize;
using mkldnn::memory; using mkldnn::memory;
...@@ -54,7 +54,7 @@ class ElementwiseMulOp : public ElementwiseOp { ...@@ -54,7 +54,7 @@ class ElementwiseMulOp : public ElementwiseOp {
if (platform::CanMKLDNNBeUsed(ctx)) { if (platform::CanMKLDNNBeUsed(ctx)) {
bool can_use_avx512_kernel = bool can_use_avx512_kernel =
platform::MayIUse(platform::avx512f) && platform::MayIUse(platform::avx512f) &&
AreDimsAndFormatCorrect(ctx, 16, memory::format::nChw16c); AreDimsAndFormatCorrect(ctx, 16, memory::format_tag::nChw16c);
if (can_use_avx512_kernel) { if (can_use_avx512_kernel) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
......
...@@ -50,12 +50,14 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -50,12 +50,14 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
auto y_dims_untrimed = y->dims(); auto y_dims_untrimed = y->dims();
auto z_dims = z->dims(); auto z_dims = z->dims();
mkldnn::stream astream(mkldnn_engine);
// Execute default elementwise_add operator when // Execute default elementwise_add operator when
// broadcast operations need to performed. // broadcast operations need to performed.
if (x_dims != y_dims_untrimed) { if (x_dims != y_dims_untrimed) {
Tensor _x; Tensor _x;
MKLDNNMemoryFormat format; MKLDNNMemoryFormat format;
std::vector<int> src_x_tz = framework::vectorize<int>(x_dims); auto src_x_tz = framework::vectorize<int64_t>(x_dims);
if ((src_x_tz.size() == 3 && if ((src_x_tz.size() == 3 &&
x->format() != (format = MKLDNNMemoryFormat::ncw)) || x->format() != (format = MKLDNNMemoryFormat::ncw)) ||
...@@ -69,8 +71,8 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -69,8 +71,8 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
auto out_format = platform::MKLDNNFormatForSize( auto out_format = platform::MKLDNNFormatForSize(
x_dims.size(), MKLDNNMemoryFormat::nchw); x_dims.size(), MKLDNNMemoryFormat::nchw);
const std::string key = platform::CreateKey( const std::string key =
src_x_tz, x->format(), out_format, std::to_string(in_type)); platform::CreateKey(src_x_tz, x->format(), out_format, in_type);
platform::ReorderMKLDNNHandler handler(src_x_tz, x->type(), in_type, platform::ReorderMKLDNNHandler handler(src_x_tz, x->type(), in_type,
dev_ctx, mkldnn_engine, key); dev_ctx, mkldnn_engine, key);
...@@ -83,9 +85,8 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -83,9 +85,8 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
auto x_reorder = handler.AcquireReorder(x_memory_p, user_x_memory_p); auto x_reorder = handler.AcquireReorder(x_memory_p, user_x_memory_p);
std::vector<primitive> pipeline; x_reorder->execute(astream, *user_x_memory_p, *x_memory_p);
pipeline.push_back(*x_reorder); astream.wait();
stream(stream::kind::eager).submit(pipeline).wait();
} else { } else {
format = x->format(); format = x->format();
_x.ShareDataWith(*x); _x.ShareDataWith(*x);
...@@ -122,19 +123,18 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -122,19 +123,18 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
} else { } else {
PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN,
"Wrong layout set for X tensor"); "Wrong layout set for X tensor");
PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for X tensor"); "Wrong format set for X tensor");
PADDLE_ENFORCE_EQ(y->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(y->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Y tensor"); "Wrong layout set for Y tensor");
PADDLE_ENFORCE_NE(y->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(y->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Y tensor"); "Wrong format set for Y tensor");
std::vector<int> src_x_tz = framework::vectorize<int>(x_dims); auto src_x_tz = framework::vectorize<int64_t>(x_dims);
std::vector<int> src_y_tz = framework::vectorize<int>(y_dims_untrimed); auto src_y_tz = framework::vectorize<int64_t>(y_dims_untrimed);
std::vector<int> dst_tz = framework::vectorize<int>(z_dims); auto dst_tz = framework::vectorize<int64_t>(z_dims);
std::vector<memory::primitive_desc> srcs_pd;
std::vector<float> scales = {1.0f, 1.0f}; std::vector<float> scales = {1.0f, 1.0f};
const std::string key = const std::string key =
...@@ -156,18 +156,17 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -156,18 +156,17 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
auto sum_pd = handler.AcquireSumPrimitiveDescriptor( auto sum_pd = handler.AcquireSumPrimitiveDescriptor(
{src_x_memory, src_y_memory}, scales, dst_md); {src_x_memory, src_y_memory}, scales, dst_md);
T* z_data = z->mutable_data<T>(ctx.GetPlace(), T* z_data =
sum_pd->dst_primitive_desc().get_size()); z->mutable_data<T>(ctx.GetPlace(), sum_pd->dst_desc().get_size());
auto dst_memory = handler.AcquireDstMemoryFromPrimitive(z_data); auto dst_memory = handler.AcquireDstMemoryFromPrimitive(z_data);
std::vector<primitive::at> inputs({*src_x_memory, *src_y_memory}); auto sum_prim = handler.AcquireSum();
auto sum_prim = handler.AcquireSum(dst_memory, &inputs);
std::vector<primitive> pipeline; sum_prim->execute(astream, {{MKLDNN_ARG_MULTIPLE_SRC, *src_x_memory},
pipeline.push_back(*sum_prim); {MKLDNN_ARG_MULTIPLE_SRC + 1, *src_y_memory},
stream(stream::kind::eager).submit(pipeline).wait(); {MKLDNN_ARG_DST, *dst_memory}});
astream.wait();
z->set_layout(DataLayout::kMKLDNN); z->set_layout(DataLayout::kMKLDNN);
z->set_format(platform::GetMKLDNNFormat(*dst_memory)); z->set_format(platform::GetMKLDNNFormat(*dst_memory));
......
...@@ -70,7 +70,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -70,7 +70,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims_untrimmed = y->dims(); auto y_dims_untrimmed = y->dims();
auto x_int_dims = paddle::framework::vectorize<int>(x_dims); auto x_int_dims = paddle::framework::vectorize<int64_t>(x_dims);
int pre, num, post, is_run_common_broadcast; int pre, num, post, is_run_common_broadcast;
get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &num, &post, get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &num, &post,
......
...@@ -35,7 +35,7 @@ class MKLDNNActivationKernel ...@@ -35,7 +35,7 @@ class MKLDNNActivationKernel
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN,
"Wrong layout set for X tensor"); "Wrong layout set for X tensor");
PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for X tensor"); "Wrong format set for X tensor");
Functor functor; Functor functor;
...@@ -51,7 +51,7 @@ class MKLDNNActivationGradKernel ...@@ -51,7 +51,7 @@ class MKLDNNActivationGradKernel
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out")); const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Input OutGrad tensor"); "Wrong layout set for Input OutGrad tensor");
PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Input OutGrad tensor"); "Wrong format set for Input OutGrad tensor");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -80,7 +80,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -80,7 +80,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4, x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4,
"Input dim must be with 2, 3 or 4"); "Input dim must be with 2, 3 or 4");
auto src_tz = framework::vectorize<int>(x->dims()); auto src_tz = framework::vectorize<int64_t>(x->dims());
auto src_format = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format(); auto src_format = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format();
...@@ -92,13 +92,12 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -92,13 +92,12 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(y); auto dst_memory_p = handler.AcquireDstMemory(y);
auto activation_p = auto activation_p = handler.AcquireForwardPrimitive();
handler.AcquireForwardPrimitive(*src_memory_p, *dst_memory_p);
// push primitive to stream and wait until it's executed mkldnn::stream astream(dev_ctx.GetEngine());
std::vector<primitive> pipeline; activation_p->execute(astream, {{MKLDNN_ARG_FROM, *src_memory_p},
pipeline.push_back(*activation_p); {MKLDNN_ARG_TO, *dst_memory_p}});
stream(stream::kind::eager).submit(pipeline).wait(); astream.wait();
y->set_layout(DataLayout::kMKLDNN); y->set_layout(DataLayout::kMKLDNN);
y->set_format(GetMKLDNNFormat(*dst_memory_p)); y->set_format(GetMKLDNNFormat(*dst_memory_p));
...@@ -116,7 +115,7 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -116,7 +115,7 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
const T alpha = ctx.HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0; const T alpha = ctx.HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
const T beta = ctx.HasAttr("beta") ? ctx.Attr<T>("beta") : 0; const T beta = ctx.HasAttr("beta") ? ctx.Attr<T>("beta") : 0;
auto diff_dst_tz = framework::vectorize<int>(diff_y->dims()); auto diff_dst_tz = framework::vectorize<int64_t>(diff_y->dims());
// diff_dst and src dims should be the same // diff_dst and src dims should be the same
auto src_format = auto src_format =
...@@ -132,13 +131,14 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -132,13 +131,14 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
auto src_memory_p = handler.AcquireBackwardSrcMemory(x); auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y); auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(diff_x); auto diff_src_memory_p = handler.AcquireDiffSrcMemory(diff_x);
auto activation_backward_p = handler.AcquireBackwardPrimitive( auto activation_backward_p = handler.AcquireBackwardPrimitive();
*src_memory_p, *diff_dst_memory_p, *diff_src_memory_p);
mkldnn::stream astream(dev_ctx.GetEngine());
// push primitive to stream and wait until it's executed activation_backward_p->execute(astream,
std::vector<primitive> pipeline; {{MKLDNN_ARG_SRC, *src_memory_p},
pipeline.push_back(*activation_backward_p); {MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p},
stream(stream::kind::eager).submit(pipeline).wait(); {MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
diff_x->set_layout(DataLayout::kMKLDNN); diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format(GetMKLDNNFormat(*diff_src_memory_p)); diff_x->set_format(GetMKLDNNFormat(*diff_src_memory_p));
......
...@@ -31,9 +31,9 @@ class BatchNormMKLDNNHandler ...@@ -31,9 +31,9 @@ class BatchNormMKLDNNHandler
: public platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward, : public platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward> { mkldnn::batch_normalization_backward> {
public: public:
BatchNormMKLDNNHandler(const std::vector<int> &dims, const float &epsilon, BatchNormMKLDNNHandler(const std::vector<int64_t> &dims, const float &epsilon,
const unsigned &flags, const bool &global_stats, const mkldnn::normalization_flags &flags,
const MKLDNNMemoryFormat fmt, const bool &global_stats, const MKLDNNMemoryFormat fmt,
const platform::MKLDNNDeviceContext &dev_ctx, const platform::MKLDNNDeviceContext &dev_ctx,
platform::Place cpu_place, platform::Place cpu_place,
const std::string &uniq_name) const std::string &uniq_name)
...@@ -48,8 +48,8 @@ class BatchNormMKLDNNHandler ...@@ -48,8 +48,8 @@ class BatchNormMKLDNNHandler
: mkldnn::prop_kind::forward_training, : mkldnn::prop_kind::forward_training,
md, epsilon, flags); md, epsilon, flags);
} }
BatchNormMKLDNNHandler(const std::vector<int> &dims, const float &epsilon, BatchNormMKLDNNHandler(const std::vector<int64_t> &dims, const float &epsilon,
const unsigned &flags, const mkldnn::normalization_flags &flags,
const MKLDNNMemoryFormat diff_fmt, const MKLDNNMemoryFormat diff_fmt,
const MKLDNNMemoryFormat src_fmt, const MKLDNNMemoryFormat src_fmt,
const platform::MKLDNNDeviceContext &dev_ctx, const platform::MKLDNNDeviceContext &dev_ctx,
...@@ -70,47 +70,44 @@ class BatchNormMKLDNNHandler ...@@ -70,47 +70,44 @@ class BatchNormMKLDNNHandler
std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(T *scaleshift_data) { std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(T *scaleshift_data) {
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->weights_primitive_desc(), scaleshift_data, this->fwd_pd_->weights_desc(), scaleshift_data, "@scaleshift_mem_p");
"@scaleshift_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireDiffScaleShiftMemory( std::shared_ptr<mkldnn::memory> AcquireDiffScaleShiftMemory(
T *diff_scaleshift_data) { T *diff_scaleshift_data) {
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(this->bwd_pd_->diff_weights_desc(),
this->bwd_pd_->diff_weights_primitive_desc(), diff_scaleshift_data, diff_scaleshift_data,
"@diff_scaleshift_mem_p"); "@diff_scaleshift_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireMeanMemory( std::shared_ptr<mkldnn::memory> AcquireMeanMemory(
const framework::Tensor *mean) { const framework::Tensor *mean) {
const T *mean_data = mean->data<T>(); const T *mean_data = mean->data<T>();
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->mean_primitive_desc(), to_void_cast<T>(mean_data), this->fwd_pd_->mean_desc(), to_void_cast<T>(mean_data), "@mean_mem_p");
"@mean_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireMeanMemory(framework::Tensor *mean) { std::shared_ptr<mkldnn::memory> AcquireMeanMemory(framework::Tensor *mean) {
T *mean_data = mean->mutable_data<T>( T *mean_data = mean->mutable_data<T>(this->place_,
this->place_, this->fwd_pd_->mean_primitive_desc().get_size()); this->fwd_pd_->mean_desc().get_size());
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
this->fwd_pd_->mean_primitive_desc(), mean_data, "@mean_mem_p"); mean_data, "@mean_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireVarianceMemory( std::shared_ptr<mkldnn::memory> AcquireVarianceMemory(
const framework::Tensor *variance) { const framework::Tensor *variance) {
const T *variance_data = variance->data<T>(); const T *variance_data = variance->data<T>();
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
this->fwd_pd_->variance_primitive_desc(), to_void_cast<T>(variance_data),
to_void_cast<T>(variance_data), "@variance_mem_p"); "@variance_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireVarianceMemory( std::shared_ptr<mkldnn::memory> AcquireVarianceMemory(
framework::Tensor *variance) { framework::Tensor *variance) {
T *variance_data = variance->mutable_data<T>( T *variance_data = variance->mutable_data<T>(
this->place_, this->fwd_pd_->variance_primitive_desc().get_size()); this->place_, this->fwd_pd_->variance_desc().get_size());
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
this->fwd_pd_->variance_primitive_desc(), variance_data, variance_data, "@variance_mem_p");
"@variance_mem_p");
} }
}; };
...@@ -140,11 +137,11 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -140,11 +137,11 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN,
"Wrong layout set for X tensor"); "Wrong layout set for X tensor");
PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for X tensor"); "Wrong format set for X tensor");
auto src_tz = paddle::framework::vectorize<int>(x->dims()); auto src_tz = paddle::framework::vectorize<int64_t>(x->dims());
auto scale_tz = paddle::framework::vectorize<int>(scale->dims()); auto scale_tz = paddle::framework::vectorize<int64_t>(scale->dims());
PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1"); PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1");
const unsigned int C = scale_tz[0]; const unsigned int C = scale_tz[0];
...@@ -156,9 +153,11 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -156,9 +153,11 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
shift->data<T>() + C); shift->data<T>() + C);
// Flags are added by bitwise OR operation // Flags are added by bitwise OR operation
unsigned flags = mkldnn::use_scale_shift; // 001 auto flags = mkldnn::normalization_flags::use_scale_shift; // 001
if (global_stats) flags |= mkldnn::use_global_stats; // 010 if (global_stats)
if (fuse_with_relu && is_test) flags |= mkldnn::fuse_bn_relu; // 100 flags |= mkldnn::normalization_flags::use_global_stats; // 010
if (fuse_with_relu && is_test)
flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100
BatchNormMKLDNNHandler<T> handler( BatchNormMKLDNNHandler<T> handler(
src_tz, epsilon, flags, global_stats, src_tz, epsilon, flags, global_stats,
...@@ -170,38 +169,35 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -170,38 +169,35 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireScaleShiftMemory(scaleshift_data.data()); handler.AcquireScaleShiftMemory(scaleshift_data.data());
auto dst_memory = handler.AcquireDstMemory(y); auto dst_memory = handler.AcquireDstMemory(y);
std::shared_ptr<mkldnn::batch_normalization_forward> batch_norm_p; auto batch_norm_p = handler.AcquireForwardPrimitive();
std::shared_ptr<memory> mean_memory;
std::shared_ptr<memory> variance_memory;
if (global_stats) { if (global_stats) {
// mean and variance are taken from input Tensor // mean and variance are taken from input Tensor
const auto *mean = ctx.Input<Tensor>("Mean"); const auto *mean = ctx.Input<Tensor>("Mean");
const auto *variance = ctx.Input<Tensor>("Variance"); const auto *variance = ctx.Input<Tensor>("Variance");
std::shared_ptr<memory> mean_memory = handler.AcquireMeanMemory(mean); mean_memory = handler.AcquireMeanMemory(mean);
std::shared_ptr<memory> variance_memory = variance_memory = handler.AcquireVarianceMemory(variance);
handler.AcquireVarianceMemory(variance);
batch_norm_p = handler.AcquireForwardPrimitive(
*src_memory, (const mkldnn::primitive::at &)*mean_memory,
(const mkldnn::primitive::at &)*variance_memory, *scaleshift_memory,
*dst_memory);
} else { } else {
// mean and variance are calculated and saved in output Tensor // mean and variance are calculated and saved in output Tensor
std::shared_ptr<memory> mean_memory = mean_memory = handler.AcquireMeanMemory(batch_mean);
handler.AcquireMeanMemory(batch_mean); variance_memory = handler.AcquireVarianceMemory(batch_variance);
std::shared_ptr<memory> variance_memory =
handler.AcquireVarianceMemory(batch_variance);
batch_norm_p = handler.AcquireForwardPrimitive(
*src_memory, *scaleshift_memory, *dst_memory, *mean_memory,
*variance_memory);
} }
y->set_layout(DataLayout::kMKLDNN); y->set_layout(DataLayout::kMKLDNN);
y->set_format(platform::GetMKLDNNFormat(*dst_memory)); y->set_format(platform::GetMKLDNNFormat(*dst_memory));
std::vector<mkldnn::primitive> pipeline; mkldnn::stream astream(dev_ctx.GetEngine());
pipeline.push_back(*batch_norm_p); batch_norm_p->execute(astream,
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); {{MKLDNN_ARG_SRC, *src_memory},
{MKLDNN_ARG_SCALE_SHIFT, *scaleshift_memory},
{MKLDNN_ARG_MEAN, *mean_memory},
{MKLDNN_ARG_VARIANCE, *variance_memory},
{MKLDNN_ARG_DST, *dst_memory}});
astream.wait();
if (!global_stats) { if (!global_stats) {
// mkldnn only compute stats for current batch // mkldnn only compute stats for current batch
...@@ -245,11 +241,11 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -245,11 +241,11 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Input diff_y tensor"); "Wrong layout set for Input diff_y tensor");
PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Input diff_y tensor"); "Wrong format set for Input diff_y tensor");
auto src_tz = paddle::framework::vectorize<int>(x->dims()); auto src_tz = paddle::framework::vectorize<int64_t>(x->dims());
auto scale_tz = paddle::framework::vectorize<int>(scale->dims()); auto scale_tz = paddle::framework::vectorize<int64_t>(scale->dims());
PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1"); PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1");
const unsigned int C = scale_tz[0]; const unsigned int C = scale_tz[0];
...@@ -261,8 +257,9 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -261,8 +257,9 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
platform::MKLDNNFormatForSize(src_tz.size(), x->format()); platform::MKLDNNFormatForSize(src_tz.size(), x->format());
BatchNormMKLDNNHandler<T> handler( BatchNormMKLDNNHandler<T> handler(
src_tz, epsilon, mkldnn::use_scale_shift, dst_format, input_format, src_tz, epsilon, mkldnn::normalization_flags::use_scale_shift,
dev_ctx, ctx.GetPlace(), ctx.InputName("SavedMean")); dst_format, input_format, dev_ctx, ctx.GetPlace(),
ctx.InputName("SavedMean"));
// MKLDNN requires a single piece of memory for scale and shift/bias data // MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * C; const size_t scaleshift_size = 2 * C;
...@@ -285,13 +282,18 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -285,13 +282,18 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data()); handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data());
// finally create batch_norm backward primitive // finally create batch_norm backward primitive
auto batch_norm_bwd_p = handler.AcquireBackwardPrimitive( auto batch_norm_bwd_p = handler.AcquireBackwardPrimitive();
*src_memory, *mean_memory, *variance_memory, *diff_dst_memory,
*scaleshift_memory, *diff_src_memory, *diff_scaleshift_memory); mkldnn::stream astream(dev_ctx.GetEngine());
batch_norm_bwd_p->execute(
std::vector<primitive> pipeline; astream, {{MKLDNN_ARG_SRC, *src_memory},
pipeline.push_back(*batch_norm_bwd_p); {MKLDNN_ARG_MEAN, *mean_memory},
stream(stream::kind::eager).submit(pipeline).wait(); {MKLDNN_ARG_VARIANCE, *variance_memory},
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory},
{MKLDNN_ARG_SCALE_SHIFT, *scaleshift_memory},
{MKLDNN_ARG_DIFF_SRC, *diff_src_memory},
{MKLDNN_ARG_DIFF_SCALE_SHIFT, *diff_scaleshift_memory}});
astream.wait();
T *diff_scale_data = diff_scale->mutable_data<T>(ctx.GetPlace()); T *diff_scale_data = diff_scale->mutable_data<T>(ctx.GetPlace());
T *diff_shift_data = diff_shift->mutable_data<T>(ctx.GetPlace()); T *diff_shift_data = diff_shift->mutable_data<T>(ctx.GetPlace());
......
...@@ -32,19 +32,17 @@ static void EnforceLayouts(const std::vector<const Tensor*> inputs) { ...@@ -32,19 +32,17 @@ static void EnforceLayouts(const std::vector<const Tensor*> inputs) {
for (auto* input : inputs) { for (auto* input : inputs) {
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Input tensor"); "Wrong layout set for Input tensor");
PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Input tensor"); "Wrong format set for Input tensor");
} }
} }
static memory::primitive_desc CreateMemPrimDesc(const Tensor& input, static memory::desc CreateMemDesc(const Tensor& input,
const mkldnn::engine& engine, const memory::data_type& dt) {
const memory::data_type& dt) { const auto dims = paddle::framework::vectorize<int64_t>(input.dims());
const auto dims = paddle::framework::vectorize<int>(input.dims());
const auto format = input.format(); const auto format = input.format();
auto description = memory::desc(dims, dt, format); auto mem_desc = memory::desc(dims, dt, format);
auto mem_prim_desc = memory::primitive_desc(description, engine); return mem_desc;
return mem_prim_desc;
} }
static platform::CPUPlace GetCpuPlace( static platform::CPUPlace GetCpuPlace(
...@@ -70,14 +68,15 @@ class ConcatPrimitiveFactory { ...@@ -70,14 +68,15 @@ class ConcatPrimitiveFactory {
const memory::data_type& dt = memory::data_type::f32) { const memory::data_type& dt = memory::data_type::f32) {
CreateSourcesDescriptors(multi_input, mkldnn_engine, dt); CreateSourcesDescriptors(multi_input, mkldnn_engine, dt);
auto dst_desc = CreateDstMemDescriptor(output, dt); auto dst_desc = CreateDstMemDescriptor(output, dt);
return concat::primitive_desc(dst_desc, concat_axis, srcs_pd); return concat::primitive_desc(dst_desc, concat_axis, srcs_d, mkldnn_engine);
} }
concat CreateConcatPrimitive(const concat::primitive_desc& concat_pd, concat CreateConcatPrimitive(const concat::primitive_desc& concat_pd,
Tensor* output, platform::CPUPlace place) { Tensor* output, platform::CPUPlace place,
CreateSourcePrimitiveAts(); const mkldnn::engine& mkldnn_engine) {
dst_mem = CreateDstMemory(concat_pd, output, place); dst_mem = mkldnn::memory(concat_pd.dst_desc(), mkldnn_engine,
return concat(concat_pd, inputs, dst_mem.get()); output->mutable_data<T>(place));
return concat(concat_pd);
} }
void SetSrcDataHandleByIndex(const std::vector<memory>& srcs, const size_t& i, void SetSrcDataHandleByIndex(const std::vector<memory>& srcs, const size_t& i,
...@@ -96,41 +95,25 @@ class ConcatPrimitiveFactory { ...@@ -96,41 +95,25 @@ class ConcatPrimitiveFactory {
private: private:
memory::desc CreateDstMemDescriptor(Tensor* output, memory::desc CreateDstMemDescriptor(Tensor* output,
const memory::data_type& dt) { const memory::data_type& dt) {
auto dst_dims = paddle::framework::vectorize<int>(output->dims()); auto dst_dims = paddle::framework::vectorize<int64_t>(output->dims());
return memory::desc(dst_dims, dt, MKLDNNMemoryFormat::any); return memory::desc(dst_dims, dt, MKLDNNMemoryFormat::any);
} }
mkldnn::memory CreateDstMemory(const concat::primitive_desc& concat_pd,
Tensor* output,
const platform::CPUPlace& place) {
return memory(concat_pd.dst_primitive_desc(),
output->mutable_data<T>(place));
}
void CreateSourcesDescriptors(const std::vector<const Tensor*> multi_input, void CreateSourcesDescriptors(const std::vector<const Tensor*> multi_input,
const mkldnn::engine& mkldnn_engine, const mkldnn::engine& mkldnn_engine,
const memory::data_type& dt) { const memory::data_type& dt) {
for (size_t i = 0; i < multi_input.size(); i++) { for (size_t i = 0; i < multi_input.size(); i++) {
auto mem_prim_desc = auto mem_desc = CreateMemDesc(*multi_input[i], dt);
CreateMemPrimDesc(*multi_input[i], mkldnn_engine, dt); srcs_d.push_back(mem_desc);
srcs_pd.push_back(mem_prim_desc); srcs.push_back(memory(mem_desc, mkldnn_engine,
srcs.push_back( to_void_cast(multi_input[i]->data<T>())));
memory(mem_prim_desc, to_void_cast(multi_input[i]->data<T>())));
}
}
void CreateSourcePrimitiveAts() {
inputs.reserve(srcs.size());
for (size_t i = 0; i < srcs.size(); i++) {
inputs.push_back(srcs[i]);
} }
} }
private: private:
std::vector<memory::primitive_desc> srcs_pd; std::vector<memory::desc> srcs_d;
std::vector<memory> srcs; std::vector<mkldnn::memory> srcs;
std::vector<primitive::at> inputs; boost::optional<mkldnn::memory> dst_mem;
boost::optional<memory> dst_mem;
}; };
template <typename T> template <typename T>
...@@ -140,7 +123,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -140,7 +123,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto multi_input = ctx.MultiInput<Tensor>("X"); auto multi_input = ctx.MultiInput<Tensor>("X");
EnforceLayouts(multi_input); EnforceLayouts(multi_input);
Tensor* output = ctx.Output<Tensor>("Out"); Tensor* output = ctx.Output<Tensor>("Out");
int64_t concat_axis = static_cast<int64_t>(ctx.Attr<int>("axis")); int concat_axis = ctx.Attr<int>("axis");
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
auto place = GetCpuPlace(ctx); auto place = GetCpuPlace(ctx);
...@@ -152,6 +135,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -152,6 +135,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::string key = platform::CreateKey( std::string key = platform::CreateKey(
paddle::framework::vectorize<int>(multi_input[0]->dims()), paddle::framework::vectorize<int>(multi_input[0]->dims()),
ctx.OutputName("Out"), dt, platform::ThreadIDasStr()); ctx.OutputName("Out"), dt, platform::ThreadIDasStr());
const std::string key_prim = key + "@concat_p"; const std::string key_prim = key + "@concat_p";
const std::string key_concat_pd = key + "@concat_pd"; const std::string key_concat_pd = key + "@concat_pd";
const std::string key_srcs = key + "@concat_srcs"; const std::string key_srcs = key + "@concat_srcs";
...@@ -162,14 +146,13 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -162,14 +146,13 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<memory> dst_mem; std::shared_ptr<memory> dst_mem;
auto concat_p = std::static_pointer_cast<concat>(dev_ctx.GetBlob(key_prim)); auto concat_p = std::static_pointer_cast<concat>(dev_ctx.GetBlob(key_prim));
const auto& mkldnn_engine = dev_ctx.GetEngine();
if (concat_p == nullptr) { if (concat_p == nullptr) {
const auto& mkldnn_engine = dev_ctx.GetEngine();
concat_pd = std::make_shared<concat::primitive_desc>( concat_pd = std::make_shared<concat::primitive_desc>(
prim_creator.CreateConcatPrimDescriptor(multi_input, output, prim_creator.CreateConcatPrimDescriptor(
static_cast<int>(concat_axis), multi_input, output, concat_axis, mkldnn_engine, dt));
mkldnn_engine, dt)); concat_p = std::make_shared<concat>(prim_creator.CreateConcatPrimitive(
concat_p = std::make_shared<concat>( *concat_pd, output, place, mkldnn_engine));
prim_creator.CreateConcatPrimitive(*concat_pd, output, place));
srcs = std::make_shared<std::vector<memory>>(prim_creator.GetSrcs()); srcs = std::make_shared<std::vector<memory>>(prim_creator.GetSrcs());
dst_mem = std::make_shared<memory>(prim_creator.GetDst()); dst_mem = std::make_shared<memory>(prim_creator.GetDst());
dev_ctx.SetBlob(key_prim, concat_p); dev_ctx.SetBlob(key_prim, concat_p);
...@@ -189,7 +172,15 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -189,7 +172,15 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
prim_creator.SetDstDataHandle(*dst_mem, output->mutable_data<T>(place)); prim_creator.SetDstDataHandle(*dst_mem, output->mutable_data<T>(place));
} }
stream(stream::kind::eager).submit({*concat_p}).wait(); mkldnn::stream astream(mkldnn_engine);
std::unordered_map<int, memory> args;
for (size_t i = 0; i < multi_input.size(); ++i) {
args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, (*srcs).at(i)});
}
args.insert({MKLDNN_ARG_DST, *dst_mem});
concat_p->execute(astream, args);
astream.wait();
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(*dst_mem)); output->set_format(platform::GetMKLDNNFormat(*dst_mem));
......
...@@ -29,8 +29,8 @@ using mkldnn::stream; ...@@ -29,8 +29,8 @@ using mkldnn::stream;
using platform::to_void_cast; using platform::to_void_cast;
using platform::GetMKLDNNFormat; using platform::GetMKLDNNFormat;
inline void GetWeightsTz(std::vector<int>& weights_tz, int groups, // NOLINT inline void GetWeightsTz(std::vector<int64_t>& weights_tz, // NOLINT
bool is_conv3d) { int groups, bool is_conv3d) {
if (groups > 1) { if (groups > 1) {
if (is_conv3d) { if (is_conv3d) {
int output = weights_tz[0]; int output = weights_tz[0];
...@@ -131,12 +131,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -131,12 +131,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Input tensor"); "Wrong layout set for Input tensor");
PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Input tensor"); "Wrong format set for Input tensor");
PADDLE_ENFORCE_EQ(filter->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(filter->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Filter tensor"); "Wrong layout set for Filter tensor");
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Filter tensor"); "Wrong format set for Filter tensor");
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
...@@ -156,16 +156,22 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -156,16 +156,22 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (bias) { if (bias) {
PADDLE_ENFORCE_EQ(bias->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(bias->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Bias tensor"); "Wrong layout set for Bias tensor");
PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Bias tensor"); "Wrong format set for Bias tensor");
PADDLE_ENFORCE_EQ(bias->dims().size(), 1, PADDLE_ENFORCE_EQ(bias->dims().size(), 1,
"Bias must only have 1 dimension, i.e. X"); "Bias must only have 1 dimension, i.e. X");
} }
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
std::vector<int64_t> dilations(begin(dilations_temp), end(dilations_temp));
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation"); std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
float fuse_alpha = ctx.Attr<float>("fuse_alpha"); float fuse_alpha = ctx.Attr<float>("fuse_alpha");
float fuse_beta = ctx.Attr<float>("fuse_beta"); float fuse_beta = ctx.Attr<float>("fuse_beta");
...@@ -180,11 +186,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -180,11 +186,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto filter_data_dims = auto filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size()); framework::slice_ddim(filter_dims, 2, filter_dims.size());
auto ksize = framework::vectorize<int>(filter_data_dims); auto ksize = framework::vectorize(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
data_dims, strides, ksize); data_dims, strides, ksize);
std::vector<primitive> pipeline;
PADDLE_ENFORCE( PADDLE_ENFORCE(
is_conv3d is_conv3d
? dilations.size() == 3 && dilations[0] == 1 && dilations[1] == 1 && ? dilations.size() == 3 && dilations[0] == 1 && dilations[1] == 1 &&
...@@ -195,18 +203,18 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -195,18 +203,18 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>(); const T* filter_data = filter->data<T>();
auto src_tz = paddle::framework::vectorize<int>(input->dims()); auto src_tz = paddle::framework::vectorize(input->dims());
auto weights_tz = paddle::framework::vectorize<int>(filter->dims()); auto weights_tz = paddle::framework::vectorize(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g, is_conv3d); GetWeightsTz(weights_tz, g, is_conv3d);
auto dst_tz = paddle::framework::vectorize<int>(output->dims());
auto dst_tz = paddle::framework::vectorize(output->dims());
// Get unique name for storing MKLDNN primitives // Get unique name for storing MKLDNN primitives
const std::string key = platform::CreateKey( const std::string key = platform::CreateKey(
src_tz, ctx.InputName("Input") + ctx.InputName("Filter")); src_tz, ctx.InputName("Input") + ctx.InputName("Filter"));
std::vector<primitive> pipeline;
auto src_format = input->format(); auto src_format = input->format();
MKLDNNMemoryFormat weights_format = MKLDNNMemoryFormat weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d); GetWeightsFormat(filter->format(), g, is_conv3d);
...@@ -242,7 +250,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -242,7 +250,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc( auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), weights_format); weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
std::vector<int> bias_tz; std::vector<int64_t> bias_tz;
auto dst_md = platform::MKLDNNMemDesc( auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
...@@ -253,7 +261,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -253,7 +261,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training; : mkldnn::prop_kind::forward_training;
if (bias) { if (bias) {
bias_tz = paddle::framework::vectorize<int>(bias->dims()); bias_tz = paddle::framework::vectorize(bias->dims());
auto bias_md = platform::MKLDNNMemDesc( auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x); bias_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
conv_pd = handler.AcquireConvolutionPrimitiveDescriptor( conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
...@@ -296,7 +304,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -296,7 +304,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto output_data = auto output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize()); output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
auto residual_data_tz = auto residual_data_tz =
paddle::framework::vectorize<int>(residual_param->dims()); paddle::framework::vectorize(residual_param->dims());
auto residual_data_type = auto residual_data_type =
paddle::framework::ToMKLDNNDataType(residual_param->type()); paddle::framework::ToMKLDNNDataType(residual_param->type());
...@@ -320,28 +328,30 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -320,28 +328,30 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data)); handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
} }
// create convolution op primitive auto conv_p = handler.AcquireConvolution();
std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::shared_ptr<mkldnn::memory> user_bias_memory_p, bias_memory_p; mkldnn::stream astream(mkldnn_engine);
if (bias) { if (bias) {
const T* bias_data = bias->data<T>(); const T* bias_data = bias->data<T>();
auto user_bias_md = platform::MKLDNNMemDesc( auto user_bias_md = platform::MKLDNNMemDesc(
{bias_tz}, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x); {bias_tz}, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
user_bias_memory_p = auto user_bias_memory_p =
handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data)); handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data));
bias_memory_p = auto bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline); handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
bias_memory_p, dst_memory_p); conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
{MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_BIAS, *bias_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}});
} else { } else {
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p, conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
dst_memory_p); {MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}});
} }
astream.wait();
// push primitive to stream and wait until it's executed
pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
...@@ -359,7 +369,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -359,7 +369,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Input tensor"); "Wrong layout set for Input tensor");
PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Input tensor"); "Wrong format set for Input tensor");
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
...@@ -376,7 +386,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -376,7 +386,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
auto src_tz = paddle::framework::vectorize<int>(input->dims()); auto src_tz = paddle::framework::vectorize(input->dims());
mkldnn::memory::data_type src_dt = mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
...@@ -385,7 +395,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -385,7 +395,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_tz, src_dt, ctx.InputName("Input") + ctx.InputName("Filter")); src_tz, src_dt, ctx.InputName("Input") + ctx.InputName("Filter"));
const std::string key_conv_pd = key + "@conv_pd"; const std::string key_conv_pd = key + "@conv_pd";
bool need_s8_to_u8 = false; bool need_s8_to_u8 = false;
std::shared_ptr<mkldnn::convolution_forward> conv_p; std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::shared_ptr<mkldnn::memory> src_memory_p; std::shared_ptr<mkldnn::memory> src_memory_p;
...@@ -407,13 +416,18 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -407,13 +416,18 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto prim_key = key + key_tid + "@conv_p"; auto prim_key = key + key_tid + "@conv_p";
auto dst_key = key + key_tid + "@dst_mem_p"; auto dst_key = key + key_tid + "@dst_mem_p";
auto src_key = key + key_tid + "@src_mem_p"; auto src_key = key + key_tid + "@src_mem_p";
auto weights_key = key + key_tid + "@weights_mem_p";
auto bias_key = key + key_tid + "@bias_mem_p";
auto user_src_key = key + key_tid + "@user_src_mem_p"; auto user_src_key = key + key_tid + "@user_src_mem_p";
auto user_residual_key = key + key_tid + "@user_residual_data_mem_p";
auto src_reorder_key = key + key_tid + "@src_mem_preorder_p"; auto src_reorder_key = key + key_tid + "@src_mem_preorder_p";
auto residual_reorder_key = key + key_tid + "@residual_data_mem_preorder_p"; auto residual_reorder_key = key + key_tid + "@residual_data_mem_preorder_p";
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>( conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
dev_ctx.GetBlob(prim_key)); dev_ctx.GetBlob(prim_key));
mkldnn::stream astream(mkldnn_engine);
if (conv_p == nullptr || !is_test) { if (conv_p == nullptr || !is_test) {
float fuse_alpha = ctx.Attr<float>("fuse_alpha"); float fuse_alpha = ctx.Attr<float>("fuse_alpha");
float fuse_beta = ctx.Attr<float>("fuse_beta"); float fuse_beta = ctx.Attr<float>("fuse_beta");
...@@ -423,7 +437,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -423,7 +437,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(filter->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(filter->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Filter tensor"); "Wrong layout set for Filter tensor");
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Filter tensor"); "Wrong format set for Filter tensor");
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
...@@ -442,16 +456,23 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -442,16 +456,23 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (bias) { if (bias) {
PADDLE_ENFORCE_EQ(bias->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(bias->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Bias tensor"); "Wrong layout set for Bias tensor");
PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Bias tensor"); "Wrong format set for Bias tensor");
PADDLE_ENFORCE_EQ(bias->dims().size(), 1, PADDLE_ENFORCE_EQ(bias->dims().size(), 1,
"Bias must only have 1 dimension, i.e. X"); "Bias must only have 1 dimension, i.e. X");
} }
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
std::vector<int64_t> dilations(begin(dilations_temp),
end(dilations_temp));
std::string padding_algorithm = std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm"); ctx.Attr<std::string>("padding_algorithm");
...@@ -466,17 +487,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -466,17 +487,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto filter_data_dims = auto filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size()); framework::slice_ddim(filter_dims, 2, filter_dims.size());
auto ksize = framework::vectorize<int>(filter_data_dims); auto ksize = framework::vectorize(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
data_dims, strides, ksize); data_dims, strides, ksize);
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
auto weights_tz = paddle::framework::vectorize<int>(filter->dims()); auto weights_tz = paddle::framework::vectorize(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g, is_conv3d); GetWeightsTz(weights_tz, g, is_conv3d);
auto dst_tz = paddle::framework::vectorize<int>(output->dims()); auto dst_tz = paddle::framework::vectorize(output->dims());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
is_conv3d is_conv3d
...@@ -526,7 +547,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -526,7 +547,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
*/ */
auto chosen_memory_format = MKLDNNMemoryFormat::any; auto chosen_memory_format = MKLDNNMemoryFormat::any;
std::vector<int> bias_tz; std::vector<int64_t> bias_tz;
auto src_md = auto src_md =
platform::MKLDNNMemDesc(src_tz, src_dt, chosen_memory_format); platform::MKLDNNMemDesc(src_tz, src_dt, chosen_memory_format);
...@@ -542,7 +563,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -542,7 +563,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
: mkldnn::prop_kind::forward_training; : mkldnn::prop_kind::forward_training;
if (bias) { if (bias) {
bias_tz = paddle::framework::vectorize<int>(bias->dims()); bias_tz = paddle::framework::vectorize(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32, auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32,
MKLDNNMemoryFormat::x); MKLDNNMemoryFormat::x);
conv_pd = handler->AcquireConvolutionPrimitiveDescriptor( conv_pd = handler->AcquireConvolutionPrimitiveDescriptor(
...@@ -582,7 +603,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -582,7 +603,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::ToMKLDNNDataType(residual_param->type()); paddle::framework::ToMKLDNNDataType(residual_param->type());
if (residual_param->format() != handler->GetDstFormat()) { if (residual_param->format() != handler->GetDstFormat()) {
auto residual_data_tz = auto residual_data_tz =
paddle::framework::vectorize<int>(residual_param->dims()); paddle::framework::vectorize(residual_param->dims());
auto user_residual_md = platform::MKLDNNMemDesc( auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_dt, residual_param->format()); residual_data_tz, residual_dt, residual_param->format());
dst_memory_p = platform::SetDstMemory<T_out>( dst_memory_p = platform::SetDstMemory<T_out>(
...@@ -601,6 +622,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -601,6 +622,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// create convolution op primitive // create convolution op primitive
auto scale_bias_key = key + "@scale_bias"; auto scale_bias_key = key + "@scale_bias";
conv_p = handler->AcquireConvolution();
if (bias) { if (bias) {
const K* bias_data = bias->data<K>(); const K* bias_data = bias->data<K>();
auto user_bias_md = platform::MKLDNNMemDesc( auto user_bias_md = platform::MKLDNNMemDesc(
...@@ -621,16 +643,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -621,16 +643,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_memory_p = handler->AcquireBiasMemoryFromPrimitive( bias_memory_p = handler->AcquireBiasMemoryFromPrimitive(
user_bias_memory_p, pipeline, is_test, true, scale_bias_data, user_bias_memory_p, pipeline, is_test, true, scale_bias_data,
mask_reorder); mask_reorder);
conv_p = handler->AcquireConvolution(src_memory_p, weights_memory_p, conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
bias_memory_p, dst_memory_p); {MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_BIAS, *bias_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}});
} else { } else {
conv_p = handler->AcquireConvolution(src_memory_p, weights_memory_p, conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
dst_memory_p); {MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}});
} }
// push primitive to stream and wait until it's executed
pipeline.push_back(*conv_p);
} else { } else {
auto src_memory_reorder_p = std::static_pointer_cast<mkldnn::memory>( auto src_memory_reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx.GetBlob(src_reorder_key)); dev_ctx.GetBlob(src_reorder_key));
src_memory_p = src_memory_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_key));
...@@ -638,10 +661,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -638,10 +661,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
user_src_memory_p = std::static_pointer_cast<mkldnn::memory>( user_src_memory_p = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(user_src_key)); dev_ctx.GetBlob(user_src_key));
user_src_memory_p->set_data_handle(to_void_cast<T>(input_data)); user_src_memory_p->set_data_handle(to_void_cast<T>(input_data));
src_memory_reorder_p->execute(astream, *user_src_memory_p,
*src_memory_p);
astream.wait();
} else if (src_memory_p) { } else if (src_memory_p) {
src_memory_p->set_data_handle(to_void_cast<T>(input_data)); src_memory_p->set_data_handle(to_void_cast<T>(input_data));
} }
auto weights_memory_p = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(weights_key));
dst_memory_p = dst_memory_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(dst_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(dst_key));
conv_pd = conv_pd =
...@@ -661,19 +688,31 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -661,19 +688,31 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
platform::SetDstMemoryHandler<T_out>(ctx, output, handler, dst_memory_p); platform::SetDstMemoryHandler<T_out>(ctx, output, handler, dst_memory_p);
if (src_memory_reorder_p) { auto residual_reorder_p = std::static_pointer_cast<mkldnn::reorder>(
pipeline.push_back(*src_memory_reorder_p);
}
auto residual_reorder_p = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(residual_reorder_key)); dev_ctx.GetBlob(residual_reorder_key));
if (residual_reorder_p) { if (residual_reorder_p) {
pipeline.push_back(*residual_reorder_p); auto user_residual_data_p = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(user_residual_key));
residual_reorder_p->execute(astream, *user_residual_data_p,
*dst_memory_p);
astream.wait();
}
auto bias_memory_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(bias_key));
if (bias_memory_p) {
conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
{MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_BIAS, *bias_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}});
} else {
conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
{MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}});
} }
pipeline.push_back(*conv_p);
} }
// push primitive to stream and wait until it's executed astream.wait();
stream(stream::kind::eager).submit(pipeline).wait();
if (need_s8_to_u8) { if (need_s8_to_u8) {
output->mutable_data<uint8_t>(ctx.GetPlace()); output->mutable_data<uint8_t>(ctx.GetPlace());
} }
...@@ -702,17 +741,17 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -702,17 +741,17 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Input tensor"); "Wrong layout set for Input tensor");
PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Input tensor"); "Wrong format set for Input tensor");
PADDLE_ENFORCE_EQ(filter->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(filter->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Filter tensor"); "Wrong layout set for Filter tensor");
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Filter tensor"); "Wrong format set for Filter tensor");
PADDLE_ENFORCE_EQ(output_grad->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(output_grad->layout(), DataLayout::kMKLDNN,
"Wrong layout set for output_grad tensor"); "Wrong layout set for output_grad tensor");
PADDLE_ENFORCE_NE(output_grad->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(output_grad->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for output_grad tensor"); "Wrong format set for output_grad tensor");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -721,10 +760,17 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -721,10 +760,17 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
if (!input_grad && !filter_grad) return; if (!input_grad && !filter_grad) return;
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
std::vector<int64_t> dilations(begin(dilations_temp), end(dilations_temp));
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm"); std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
bool is_conv3d = strides.size() == 3U; bool is_conv3d = strides.size() == 3U;
...@@ -740,16 +786,18 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -740,16 +786,18 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto filter_data_dims = auto filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size()); framework::slice_ddim(filter_dims, 2, filter_dims.size());
auto ksize = framework::vectorize<int>(filter_data_dims); auto ksize = framework::vectorize(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
data_dims, strides, ksize); data_dims, strides, ksize);
auto src_tz = paddle::framework::vectorize<int>(input->dims()); auto src_tz = paddle::framework::vectorize(input->dims());
auto weights_tz = paddle::framework::vectorize<int>(filter->dims()); auto weights_tz = paddle::framework::vectorize(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g, is_conv3d); GetWeightsTz(weights_tz, g, is_conv3d);
auto dst_tz = paddle::framework::vectorize<int>(output_grad->dims()); auto dst_tz = paddle::framework::vectorize(output_grad->dims());
auto src_format = input->format(); auto src_format = input->format();
MKLDNNMemoryFormat weights_format = MKLDNNMemoryFormat weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d); GetWeightsFormat(filter->format(), g, is_conv3d);
...@@ -803,7 +851,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -803,7 +851,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
weights_tz, platform::MKLDNNGetDataType<T>(), weights_format); weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
auto diff_dst_md = platform::MKLDNNMemDesc( auto diff_dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
// Retrieve conv_pd from device context // Retrieve conv_pd from device context
auto conv_pd = auto conv_pd =
std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>( std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
...@@ -815,18 +862,18 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -815,18 +862,18 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// create backward convolution weights primitive descriptor // create backward convolution weights primitive descriptor
auto conv_bwd_weights_desc = mkldnn::convolution_backward_weights::desc( auto conv_bwd_weights_desc = mkldnn::convolution_backward_weights::desc(
mkldnn::convolution_direct, src_md, diff_weights_md, diff_dst_md, mkldnn::algorithm::convolution_direct, src_md, diff_weights_md,
strides, mkldnn_paddings[0], mkldnn_paddings[1], diff_dst_md, strides, mkldnn_paddings[0], mkldnn_paddings[1]);
mkldnn::padding_kind::zero);
auto conv_bwd_weights_pd = auto conv_bwd_weights_pd =
std::make_shared<mkldnn::convolution_backward_weights::primitive_desc>( std::make_shared<mkldnn::convolution_backward_weights::primitive_desc>(
conv_bwd_weights_desc, mkldnn_engine, *conv_pd); conv_bwd_weights_desc, mkldnn_engine, *conv_pd);
// create backward convolution data primitive descriptor // create backward convolution data primitive descriptor
auto conv_bwd_data_desc = mkldnn::convolution_backward_data::desc( auto conv_bwd_data_desc = mkldnn::convolution_backward_data::desc(
mkldnn::convolution_direct, diff_src_md, weights_md, diff_dst_md, mkldnn::algorithm::convolution_direct, diff_src_md, weights_md,
strides, mkldnn_paddings[0], mkldnn_paddings[1], diff_dst_md, strides, mkldnn_paddings[0], mkldnn_paddings[1]);
mkldnn::padding_kind::zero);
auto conv_bwd_data_pd = auto conv_bwd_data_pd =
std::make_shared<mkldnn::convolution_backward_data::primitive_desc>( std::make_shared<mkldnn::convolution_backward_data::primitive_desc>(
conv_bwd_data_desc, mkldnn_engine, *conv_pd); conv_bwd_data_desc, mkldnn_engine, *conv_pd);
...@@ -842,8 +889,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -842,8 +889,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
user_weights_md, to_void_cast<T>(filter_data)); user_weights_md, to_void_cast<T>(filter_data));
auto user_diff_dst_memory_p = handler.AcquireDiffDstMemory( auto user_diff_dst_memory_p = handler.AcquireDiffDstMemory(
user_diff_dst_md, to_void_cast<T>(output_grad_data)); user_diff_dst_md, to_void_cast<T>(output_grad_data));
mkldnn::stream astream(mkldnn_engine);
// create backward conv primitive for weights
if (filter_grad) { if (filter_grad) {
auto src_memory_p = handler.AcquireSrcMemoryFromWeightsPrimitive( auto src_memory_p = handler.AcquireSrcMemoryFromWeightsPrimitive(
user_src_memory_p, pipeline); user_src_memory_p, pipeline);
...@@ -859,16 +905,18 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -859,16 +905,18 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireDiffWeightsMemoryFromWeightsPrimitive( handler.AcquireDiffWeightsMemoryFromWeightsPrimitive(
reinterpret_cast<void*>(filter_grad_data)); reinterpret_cast<void*>(filter_grad_data));
auto conv_bwd_weights_p = handler.AcquireConvolutionBackwardWeights( auto conv_bwd_weights_p = handler.AcquireConvolutionBackwardWeights();
src_memory_p, diff_dst_memory_4filter_p, diff_weights_memory_p);
// push primitive to stream and wait until it's executed // TODO(grygielski) why no bias_diff?
pipeline.push_back(*conv_bwd_weights_p); conv_bwd_weights_p->execute(
astream, {{MKLDNN_ARG_SRC, *src_memory_p},
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory_4filter_p},
{MKLDNN_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}});
astream.wait();
filter_grad->set_layout(DataLayout::kMKLDNN); filter_grad->set_layout(DataLayout::kMKLDNN);
filter_grad->set_format(GetMKLDNNFormat(*diff_weights_memory_p)); filter_grad->set_format(GetMKLDNNFormat(*diff_weights_memory_p));
} }
if (input_grad) { if (input_grad) {
auto weights_memory_p = handler.AcquireWeightsMemoryFromDataPrimitive( auto weights_memory_p = handler.AcquireWeightsMemoryFromDataPrimitive(
user_weights_memory_p, pipeline); user_weights_memory_p, pipeline);
...@@ -883,15 +931,17 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -883,15 +931,17 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive( auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive(
reinterpret_cast<void*>(input_grad_data)); reinterpret_cast<void*>(input_grad_data));
auto conv_bwd_data_p = handler.AcquireConvolutionBackwardData( auto conv_bwd_data_p = handler.AcquireConvolutionBackwardData();
diff_dst_memory_4data_p, weights_memory_p, diff_src_memory_p);
pipeline.push_back(*conv_bwd_data_p); conv_bwd_data_p->execute(astream,
{{MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory_4data_p},
{MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
input_grad->set_layout(DataLayout::kMKLDNN); input_grad->set_layout(DataLayout::kMKLDNN);
input_grad->set_format(GetMKLDNNFormat(*diff_src_memory_p)); input_grad->set_format(GetMKLDNNFormat(*diff_src_memory_p));
} }
stream(stream::kind::eager).submit(pipeline).wait();
} }
}; };
......
...@@ -48,12 +48,12 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -48,12 +48,12 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Input tensor"); "Wrong layout set for Input tensor");
PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Input tensor"); "Wrong format set for Input tensor");
PADDLE_ENFORCE_EQ(filter->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(filter->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Filter tensor"); "Wrong layout set for Filter tensor");
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Filter tensor"); "Wrong format set for Filter tensor");
PADDLE_ENFORCE_EQ(input->dims().size(), 4, PADDLE_ENFORCE_EQ(input->dims().size(), 4,
...@@ -64,16 +64,22 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -64,16 +64,22 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (bias) { if (bias) {
PADDLE_ENFORCE_EQ(bias->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(bias->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Bias tensor"); "Wrong layout set for Bias tensor");
PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Bias tensor"); "Wrong format set for Bias tensor");
PADDLE_ENFORCE_EQ(bias->dims().size(), 1, PADDLE_ENFORCE_EQ(bias->dims().size(), 1,
"Bias must only have 1 dimension, i.e. X"); "Bias must only have 1 dimension, i.e. X");
} }
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
std::vector<int64_t> dilations(begin(dilations_temp), end(dilations_temp));
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm"); std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
...@@ -83,7 +89,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -83,7 +89,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto filter_data_dims = auto filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size()); framework::slice_ddim(filter_dims, 2, filter_dims.size());
auto ksize = framework::vectorize<int>(filter_data_dims); auto ksize = framework::vectorize(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
data_dims, strides, ksize); data_dims, strides, ksize);
...@@ -95,8 +101,9 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -95,8 +101,9 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>(); const T* filter_data = filter->data<T>();
auto src_tz = paddle::framework::vectorize<int>(input->dims()); auto src_tz = paddle::framework::vectorize<int64_t>(input->dims());
auto iohw_weights_tz = paddle::framework::vectorize<int>(filter->dims()); auto iohw_weights_tz =
paddle::framework::vectorize<int64_t>(filter->dims());
auto weights_tz = iohw_weights_tz; auto weights_tz = iohw_weights_tz;
// IOHW -> OIHW // IOHW -> OIHW
...@@ -137,7 +144,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -137,7 +144,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
weights_tz[3] = h; weights_tz[3] = h;
weights_tz[4] = w; weights_tz[4] = w;
} }
auto dst_tz = paddle::framework::vectorize<int>(output->dims()); auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims());
// Get unique name for storing MKLDNN primitives // Get unique name for storing MKLDNN primitives
...@@ -165,7 +172,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -165,7 +172,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc( auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
std::vector<int> bias_tz; std::vector<int64_t> bias_tz;
auto dst_md = platform::MKLDNNMemDesc( auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
...@@ -177,7 +184,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -177,7 +184,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training; : mkldnn::prop_kind::forward_training;
if (bias) { if (bias) {
bias_tz = paddle::framework::vectorize<int>(bias->dims()); bias_tz = paddle::framework::vectorize<int64_t>(bias->dims());
auto bias_md = platform::MKLDNNMemDesc( auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x); bias_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor( conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor(
...@@ -203,15 +210,14 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -203,15 +210,14 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test); user_weights_memory_p, pipeline, is_test);
std::shared_ptr<mkldnn::memory> dst_memory_p;
auto output_data = auto output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize()); output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
dst_memory_p = handler.AcquireDstMemoryFromPrimitive( auto dst_memory_p = handler.AcquireDstMemoryFromPrimitive(
platform::to_void_cast<T>(output_data)); platform::to_void_cast<T>(output_data));
// create convolution op primitive auto conv_p = handler.AcquireConvolution();
std::shared_ptr<mkldnn::deconvolution_forward> conv_p;
mkldnn::stream astream(mkldnn_engine);
if (bias) { if (bias) {
const T* bias_data = bias->data<T>(); const T* bias_data = bias->data<T>();
auto user_bias_md = platform::MKLDNNMemDesc( auto user_bias_md = platform::MKLDNNMemDesc(
...@@ -221,16 +227,17 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -221,16 +227,17 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto bias_memory_p = auto bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline); handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
bias_memory_p, dst_memory_p); conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
{MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_BIAS, *bias_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}});
} else { } else {
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p, conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
dst_memory_p); {MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}});
} }
astream.wait();
// push primitive to stream and wait until it's executed
pipeline.push_back(*conv_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(*dst_memory_p)); output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
......
...@@ -46,9 +46,8 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -46,9 +46,8 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
float* output_data = output->mutable_data<float>(ctx.GetPlace()); float* output_data = output->mutable_data<float>(ctx.GetPlace());
std::vector<float> reorder_scale = {1.0f / scale_data}; std::vector<float> reorder_scale = {1.0f / scale_data};
std::vector<primitive> pipeline; auto src_tz = paddle::framework::vectorize<int64_t>(input->dims());
auto src_tz = paddle::framework::vectorize<int>(input->dims()); auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims());
auto dst_tz = paddle::framework::vectorize<int>(output->dims());
mkldnn::memory::data_type src_dt = mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
MKLDNNMemoryFormat src_fmt = input->format(); MKLDNNMemoryFormat src_fmt = input->format();
...@@ -69,23 +68,20 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -69,23 +68,20 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
attri.set_output_scales(mask, reorder_scale); attri.set_output_scales(mask, reorder_scale);
auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt); auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt);
auto src_pd = mkldnn::memory::primitive_desc(src_md, engine); src_memory = std::make_shared<mkldnn::memory>(
src_memory = src_md, engine, to_void_cast<T>(input_data));
std::make_shared<mkldnn::memory>(src_pd, to_void_cast<T>(input_data));
std::shared_ptr<primitive::at> src_memory_p = auto dst_md =
std::shared_ptr<primitive::at>(new primitive::at(*src_memory)); platform::MKLDNNMemDesc({dst_tz}, memory::data_type::f32,
platform::MKLDNNFormatForSize(
auto dst_md = platform::MKLDNNMemDesc( dst_tz.size(), MKLDNNMemoryFormat::nchw));
{dst_tz}, memory::data_type::f32,
platform::MKLDNNFormatForSize(dst_tz.size(), memory::format::nchw));
auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
dst_memory = std::make_shared<mkldnn::memory>( dst_memory = std::make_shared<mkldnn::memory>(
dst_pd, to_void_cast<float>(output_data)); dst_md, engine, to_void_cast<float>(output_data));
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>( auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(src_pd, dst_pd, attri)); new reorder::primitive_desc(*src_memory, *dst_memory, attri));
reorder_p = std::shared_ptr<reorder>( reorder_p = std::shared_ptr<reorder>(new reorder(*reorder_pd));
new reorder(*reorder_pd, *src_memory_p, *dst_memory));
dev_ctx.SetBlob(key_prim, reorder_p); dev_ctx.SetBlob(key_prim, reorder_p);
dev_ctx.SetBlob(key_src_mem, src_memory); dev_ctx.SetBlob(key_src_mem, src_memory);
dev_ctx.SetBlob(key_dst_mem, dst_memory); dev_ctx.SetBlob(key_dst_mem, dst_memory);
...@@ -99,8 +95,9 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -99,8 +95,9 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
dst_memory->set_data_handle(output->mutable_data<float>(ctx.GetPlace())); dst_memory->set_data_handle(output->mutable_data<float>(ctx.GetPlace()));
} }
pipeline.push_back(*reorder_p); mkldnn::stream astream(engine);
stream(stream::kind::eager).submit(pipeline).wait(); reorder_p->execute(astream, *src_memory, *dst_memory);
astream.wait();
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory)); output->set_format(GetMKLDNNFormat(*dst_memory));
......
...@@ -42,16 +42,16 @@ class FCPrimitiveFactory { ...@@ -42,16 +42,16 @@ class FCPrimitiveFactory {
public: public:
explicit FCPrimitiveFactory(const mkldnn::engine& engine) : engine_(engine) {} explicit FCPrimitiveFactory(const mkldnn::engine& engine) : engine_(engine) {}
inner_product_forward CreateFcPrimitive(const LoDTensor* input, void ExecuteFcPrimitive(const LoDTensor* input, const Tensor* weights,
const Tensor* weights, const Tensor* bias, LoDTensor* output,
const Tensor* bias, LoDTensor* output, const ExecutionContext& ctx) {
const ExecutionContext& ctx) {
RecomputeOutputDims(ctx, input, weights, output); RecomputeOutputDims(ctx, input, weights, output);
// If primitive has already been created and cached, don't create new one, // If primitive has already been created and cached, don't create new one,
// but update input and output data pointers and return it. // but update input and output data pointers and return it.
if (fc_) { if (fc_) {
UpdateDataPointers(ctx, output, input); UpdateDataPointers(ctx, output, input);
return *fc_; this->Execute();
return;
} }
auto src_desc = CreateMemDescriptor<T_in>(input, input->format()); auto src_desc = CreateMemDescriptor<T_in>(input, input->format());
input_ = CreateMemory<T_in>(src_desc, input); input_ = CreateMemory<T_in>(src_desc, input);
...@@ -72,7 +72,22 @@ class FCPrimitiveFactory { ...@@ -72,7 +72,22 @@ class FCPrimitiveFactory {
auto dst_desc = CreateMemDescriptor<T_out>(output, MKLDNNMemoryFormat::any); auto dst_desc = CreateMemDescriptor<T_out>(output, MKLDNNMemoryFormat::any);
fc_ = CreateFcPrimitive(*input_, *weights_, dst_desc, bias, output, ctx); fc_ = CreateFcPrimitive(*input_, *weights_, dst_desc, bias, output, ctx);
return *fc_; this->Execute();
}
void Execute() {
mkldnn::stream astream(engine_);
if (bias_) {
fc_->execute(astream, {{MKLDNN_ARG_SRC, *input_},
{MKLDNN_ARG_WEIGHTS, *weights_},
{MKLDNN_ARG_BIAS, *bias_},
{MKLDNN_ARG_DST, *output_}});
} else {
fc_->execute(astream, {{MKLDNN_ARG_SRC, *input_},
{MKLDNN_ARG_WEIGHTS, *weights_},
{MKLDNN_ARG_DST, *output_}});
}
astream.wait();
} }
private: private:
...@@ -83,7 +98,7 @@ class FCPrimitiveFactory { ...@@ -83,7 +98,7 @@ class FCPrimitiveFactory {
// If the primitive exists, but the output tensor has changed its // If the primitive exists, but the output tensor has changed its
// variable, update its format to what has been determined in first // variable, update its format to what has been determined in first
// call to CreateFcPrimitive method. // call to CreateFcPrimitive method.
if (out->format() == MKLDNNMemoryFormat::format_undef) { if (out->format() == MKLDNNMemoryFormat::undef) {
auto output_format = platform::GetMKLDNNFormat(*output_); auto output_format = platform::GetMKLDNNFormat(*output_);
out->set_format((MKLDNNMemoryFormat)output_format); out->set_format((MKLDNNMemoryFormat)output_format);
} }
...@@ -94,36 +109,37 @@ class FCPrimitiveFactory { ...@@ -94,36 +109,37 @@ class FCPrimitiveFactory {
using format = MKLDNNMemoryFormat; using format = MKLDNNMemoryFormat;
switch (fmt) { switch (fmt) {
case format::nChw16c: case format::nChw16c:
return format::oIhw16i; return format::aBcd16b;
case format::nChw8c: case format::nChw8c:
return format::oIhw8i; return format::aBcd8b;
case format::nchw: case format::nchw:
return format::oihw; return format::oihw;
case format::nhwc: case format::nhwc:
return format::hwio; return format::hwio;
default: default:
return format::format_undef; return format::undef;
} }
} }
// Convert data from one data format to another // Convert data from one data format to another
mkldnn::memory Reorder(const memory::desc& src_desc, mkldnn::memory Reorder(const memory::desc& src_desc,
const memory::desc& dst_desc, const void* src_data) { const memory::desc& dst_desc, void* src_data) {
auto src_mem = memory({src_desc, engine_}, const_cast<void*>(src_data)); auto src_mem = memory(src_desc, engine_, src_data);
auto dst_mem = memory({dst_desc, engine_}); auto dst_mem = memory(dst_desc, engine_);
auto reorder = mkldnn::reorder(src_mem, dst_mem); auto reorder = mkldnn::reorder(src_mem, dst_mem);
stream(stream::kind::eager).submit({reorder}).wait(); mkldnn::stream astream(engine_);
reorder.execute(astream, src_mem, dst_mem);
astream.wait();
return dst_mem; return dst_mem;
} }
// Convert data from one data format to another and rescale it. // Convert data from one data format to another and rescale it.
// If the desired data type is (un)signed int8, quantization occurs here. // If the desired data type is (un)signed int8, quantization occurs here.
mkldnn::memory Reorder(const memory& src_mem, mkldnn::memory Reorder(const memory& src_mem, const memory::desc& dst_md,
const memory::primitive_desc& dst_pd,
const std::vector<float>& scale_data) { const std::vector<float>& scale_data) {
mkldnn::memory dst_mem = mkldnn::memory(dst_pd); mkldnn::memory dst_mem = mkldnn::memory(dst_md, engine_);
mkldnn::primitive_attr attributes; mkldnn::primitive_attr attributes;
// According to MKL-DNN's documentation mask determines along which // According to MKL-DNN's documentation mask determines along which
// dimensions should the scale be applied. // dimensions should the scale be applied.
...@@ -133,19 +149,19 @@ class FCPrimitiveFactory { ...@@ -133,19 +149,19 @@ class FCPrimitiveFactory {
// becuase we perform per-output-channel quantization // becuase we perform per-output-channel quantization
int mask = CreateMask(0, scale_data.size() > 1); int mask = CreateMask(0, scale_data.size() > 1);
attributes.set_output_scales(mask, scale_data); attributes.set_output_scales(mask, scale_data);
auto reorder = auto reorder = mkldnn::reorder({src_mem, dst_mem, attributes});
mkldnn::reorder(mkldnn::reorder::primitive_desc(
src_mem.get_primitive_desc(), dst_pd, attributes),
src_mem, dst_mem);
stream(stream::kind::eager).submit({reorder}).wait(); mkldnn::stream astream(engine_);
reorder.execute(astream,
{{MKLDNN_ARG_FROM, src_mem}, {MKLDNN_ARG_TO, dst_mem}});
astream.wait();
return dst_mem; return dst_mem;
} }
template <typename T> template <typename T>
static mkldnn::memory::desc CreateMemDescriptor(const std::vector<int>& dims, static mkldnn::memory::desc CreateMemDescriptor(
MKLDNNMemoryFormat format) { const std::vector<int64_t>& dims, MKLDNNMemoryFormat format) {
return platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType<T>(), return platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType<T>(),
format); format);
} }
...@@ -153,28 +169,28 @@ class FCPrimitiveFactory { ...@@ -153,28 +169,28 @@ class FCPrimitiveFactory {
template <typename T> template <typename T>
static mkldnn::memory::desc CreateMemDescriptor(const Tensor* tensor, static mkldnn::memory::desc CreateMemDescriptor(const Tensor* tensor,
MKLDNNMemoryFormat format) { MKLDNNMemoryFormat format) {
auto dims = framework::vectorize<int>(tensor->dims()); auto dims = framework::vectorize(tensor->dims());
return CreateMemDescriptor<T>(dims, format); return CreateMemDescriptor<T>(dims, format);
} }
template <typename T> template <typename T>
mkldnn::memory CreateMemory(const mkldnn::memory::desc& desc, mkldnn::memory CreateMemory(const mkldnn::memory::desc& desc,
const Tensor* tensor) { const Tensor* tensor) {
return CreateMemory(desc, tensor->data<T>()); return CreateMemory(desc, platform::to_void_cast<T>(tensor->data<T>()));
} }
mkldnn::memory CreateMemory(const mkldnn::memory::desc& desc, mkldnn::memory CreateMemory(const mkldnn::memory::desc& desc, void* data) {
const void* data) { return memory(desc, engine_, data);
return memory({desc, engine_}, const_cast<void*>(data));
} }
// Transpose weights through MKL-DNN's reorder from io to oi format. // Transpose weights through MKL-DNN's reorder from io to oi format.
mkldnn::memory TransposeWeights(const Tensor* weights) { mkldnn::memory TransposeWeights(const Tensor* weights) {
auto dims = framework::vectorize<int>(weights->dims()); auto dims = framework::vectorize(weights->dims());
std::swap(dims[0], dims[1]); // Correct output dimensions std::swap(dims[0], dims[1]); // Correct output dimensions
auto src_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::io); auto src_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::io);
auto dst_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::oi); auto dst_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::oi);
return Reorder(src_desc, dst_desc, weights->data<float>()); return Reorder(src_desc, dst_desc,
platform::to_void_cast<float>(weights->data<float>()));
} }
// Compute the bias scales so that its values correspond to the // Compute the bias scales so that its values correspond to the
...@@ -232,17 +248,17 @@ class FCPrimitiveFactory { ...@@ -232,17 +248,17 @@ class FCPrimitiveFactory {
} }
void QuantizeWeights(const ExecutionContext& ctx) { void QuantizeWeights(const ExecutionContext& ctx) {
auto quantized_desc = weights_->get_primitive_desc().desc(); auto quantized_desc = weights_->get_desc();
quantized_desc.data.data_type = quantized_desc.data.data_type =
(mkldnn_data_type_t)platform::MKLDNNGetDataType<T_w>(); (mkldnn_data_type_t)platform::MKLDNNGetDataType<T_w>();
weights_ = Reorder(*weights_, {quantized_desc, engine_}, weights_ = Reorder(*weights_, quantized_desc,
ctx.Attr<std::vector<float>>("Scale_weights")); ctx.Attr<std::vector<float>>("Scale_weights"));
} }
void QuantizeBias(const inner_product_forward::primitive_desc& fc_prim_desc, void QuantizeBias(const inner_product_forward::primitive_desc& fc_prim_desc,
const ExecutionContext& ctx) { const ExecutionContext& ctx) {
auto bias_scales = ComputeBiasScales(ctx); auto bias_scales = ComputeBiasScales(ctx);
bias_ = Reorder(*bias_, fc_prim_desc.bias_primitive_desc(), bias_scales); bias_ = Reorder(*bias_, fc_prim_desc.bias_desc(), bias_scales);
} }
// Fuse relu into FC with activation type attribute has been set to 'relu' // Fuse relu into FC with activation type attribute has been set to 'relu'
...@@ -273,8 +289,8 @@ class FCPrimitiveFactory { ...@@ -273,8 +289,8 @@ class FCPrimitiveFactory {
const ExecutionContext& ctx) { const ExecutionContext& ctx) {
// Acquire descriptors needed for creation of inner_product primitive // Acquire descriptors needed for creation of inner_product primitive
// descriptor // descriptor
const auto weights_desc = weights_memory.get_primitive_desc().desc(); const auto weights_desc = weights_memory.get_desc();
const auto src_desc = src_memory.get_primitive_desc().desc(); const auto src_desc = src_memory.get_desc();
// Based on provided attributes, create attributes used by MKL-DNN to // Based on provided attributes, create attributes used by MKL-DNN to
// enable fused post-op activations such as 'relu' // enable fused post-op activations such as 'relu'
const auto attrs = CreatePostOps(ctx); const auto attrs = CreatePostOps(ctx);
...@@ -294,15 +310,12 @@ class FCPrimitiveFactory { ...@@ -294,15 +310,12 @@ class FCPrimitiveFactory {
output_ = CreateDstMemory(fc_prim_desc, ctx, output); output_ = CreateDstMemory(fc_prim_desc, ctx, output);
// Return MKL-DNN primitive ready to be fed into pipeline and executed // Return MKL-DNN primitive ready to be fed into pipeline and executed
return inner_product_forward(fc_prim_desc, src_memory, weights_memory, return inner_product_forward(fc_prim_desc);
*bias_, *output_);
} else { } else {
auto fc_prim_desc = auto fc_prim_desc =
CreateFcPrimDesc(src_desc, weights_desc, dst_desc, attrs); CreateFcPrimDesc(src_desc, weights_desc, dst_desc, attrs);
output_ = CreateDstMemory(fc_prim_desc, ctx, output); output_ = CreateDstMemory(fc_prim_desc, ctx, output);
return inner_product_forward(fc_prim_desc);
return inner_product_forward(fc_prim_desc, src_memory, weights_memory,
*output_);
} }
} }
...@@ -345,8 +358,8 @@ class FCPrimitiveFactory { ...@@ -345,8 +358,8 @@ class FCPrimitiveFactory {
// perform a converion. // perform a converion.
mkldnn::memory CreateFourDimWeightsMemory(const Tensor* input, mkldnn::memory CreateFourDimWeightsMemory(const Tensor* input,
const Tensor* weights) { const Tensor* weights) {
auto input_dims = framework::vectorize<int>(input->dims()); auto input_dims = framework::vectorize(input->dims());
auto weight_dims = framework::vectorize<int>(weights->dims()); auto weight_dims = framework::vectorize(weights->dims());
auto dims = {weight_dims[1], input_dims[1], input_dims[2], input_dims[3]}; auto dims = {weight_dims[1], input_dims[1], input_dims[2], input_dims[3]};
auto dst_format = MatchWeightFormat(input->format()); auto dst_format = MatchWeightFormat(input->format());
...@@ -361,11 +374,11 @@ class FCPrimitiveFactory { ...@@ -361,11 +374,11 @@ class FCPrimitiveFactory {
mkldnn::memory CreateDstMemory( mkldnn::memory CreateDstMemory(
const mkldnn::inner_product_forward::primitive_desc& fc_prim_desc, const mkldnn::inner_product_forward::primitive_desc& fc_prim_desc,
const ExecutionContext& ctx, Tensor* output) { const ExecutionContext& ctx, Tensor* output) {
auto dst_prim_desc = fc_prim_desc.dst_primitive_desc(); auto dst_desc = fc_prim_desc.dst_desc();
auto buffer_size = dst_prim_desc.get_size(); auto buffer_size = dst_desc.get_size();
T_out* output_data = T_out* output_data =
output->mutable_data<T_out>(ctx.GetPlace(), buffer_size); output->mutable_data<T_out>(ctx.GetPlace(), buffer_size);
memory dst_mem(dst_prim_desc, to_void_cast<T_out>(output_data)); memory dst_mem(dst_desc, engine_, to_void_cast<T_out>(output_data));
output->set_format(platform::GetMKLDNNFormat(dst_mem)); output->set_format(platform::GetMKLDNNFormat(dst_mem));
return dst_mem; return dst_mem;
} }
...@@ -421,25 +434,24 @@ GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx, ...@@ -421,25 +434,24 @@ GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx,
// Choose appropriate primitive factory implementation based on inferred // Choose appropriate primitive factory implementation based on inferred
// output type (uint8, int8 or float). // output type (uint8, int8 or float).
template <typename T_in, typename T_w> template <typename T_in, typename T_w>
static inner_product_forward GetFcPrimitive( static void ExecuteFc(const MKLDNNDeviceContext& dev_ctx,
const MKLDNNDeviceContext& dev_ctx, const ExecutionContext& ctx, const ExecutionContext& ctx, const LoDTensor* input,
const LoDTensor* input, const Tensor* w, const Tensor* bias, const Tensor* w, const Tensor* bias, LoDTensor* output,
LoDTensor* output, const mkldnn::engine& mkldnn_engine, bool fuse_relu, const mkldnn::engine& mkldnn_engine, bool fuse_relu,
bool force_fp32_output) { bool force_fp32_output) {
constexpr bool is_int8 = constexpr bool is_int8 =
std::is_same<T_in, int8_t>::value || std::is_same<T_in, uint8_t>::value; std::is_same<T_in, int8_t>::value || std::is_same<T_in, uint8_t>::value;
if (!is_int8 || force_fp32_output) { if (!is_int8 || force_fp32_output) {
return GetPrimitiveFactory<T_in, T_w, float>(dev_ctx, ctx, input, w, GetPrimitiveFactory<T_in, T_w, float>(dev_ctx, ctx, input, w, mkldnn_engine)
mkldnn_engine) ->ExecuteFcPrimitive(input, w, bias, output, ctx);
->CreateFcPrimitive(input, w, bias, output, ctx);
} else if (fuse_relu) { } else if (fuse_relu) {
return GetPrimitiveFactory<T_in, T_w, uint8_t>(dev_ctx, ctx, input, w, GetPrimitiveFactory<T_in, T_w, uint8_t>(dev_ctx, ctx, input, w,
mkldnn_engine) mkldnn_engine)
->CreateFcPrimitive(input, w, bias, output, ctx); ->ExecuteFcPrimitive(input, w, bias, output, ctx);
} else { } else {
return GetPrimitiveFactory<T_in, T_w, int8_t>(dev_ctx, ctx, input, w, GetPrimitiveFactory<T_in, T_w, int8_t>(dev_ctx, ctx, input, w,
mkldnn_engine) mkldnn_engine)
->CreateFcPrimitive(input, w, bias, output, ctx); ->ExecuteFcPrimitive(input, w, bias, output, ctx);
} }
} }
...@@ -461,10 +473,8 @@ class FCMKLDNNOpKernel : public framework::OpKernel<T_in> { ...@@ -461,10 +473,8 @@ class FCMKLDNNOpKernel : public framework::OpKernel<T_in> {
bool fuse_relu = ctx.Attr<std::string>("activation_type") == "relu"; bool fuse_relu = ctx.Attr<std::string>("activation_type") == "relu";
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output"); bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
auto fc = ExecuteFc<T_in, T_w>(dev_ctx, ctx, input, w, bias, output, mkldnn_engine,
GetFcPrimitive<T_in, T_w>(dev_ctx, ctx, input, w, bias, output, fuse_relu, force_fp32_output);
mkldnn_engine, fuse_relu, force_fp32_output);
stream(stream::kind::eager).submit({fc}).wait();
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
} }
......
...@@ -41,7 +41,7 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -41,7 +41,7 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel<T> {
} }
tensor->set_layout(DataLayout::kMKLDNN); tensor->set_layout(DataLayout::kMKLDNN);
tensor->set_format(mkldnn::memory::format::oihw); tensor->set_format(mkldnn::memory::format_tag::oihw);
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -49,7 +49,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -49,7 +49,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const float k = ctx.Attr<float>("k"); const float k = ctx.Attr<float>("k");
bool is_test = ctx.Attr<bool>("is_test"); bool is_test = ctx.Attr<bool>("is_test");
auto dims = paddle::framework::vectorize<int>(x->dims()); auto dims = paddle::framework::vectorize<int64_t>(x->dims());
platform::LRNMKLDNNHandler<T> handler(dims, n, alpha, beta, k, x->format(), platform::LRNMKLDNNHandler<T> handler(dims, n, alpha, beta, k, x->format(),
is_test, dev_ctx, ctx.GetPlace(), is_test, dev_ctx, ctx.GetPlace(),
...@@ -58,14 +58,17 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -58,14 +58,17 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto src_memory = handler.AcquireSrcMemory(x); auto src_memory = handler.AcquireSrcMemory(x);
auto dst_memory = handler.AcquireDstMemory(out); auto dst_memory = handler.AcquireDstMemory(out);
std::shared_ptr<mkldnn::memory> workspace_memory; auto lrn_p = handler.AcquireForwardPrimitive();
std::shared_ptr<mkldnn::lrn_forward> lrn_p;
if (is_test == false) { auto workspace_memory = handler.AcquireWorkspaceMemory(mid);
workspace_memory = handler.AcquireWorkspaceMemory(mid); mid->set_layout(framework::DataLayout::kMKLDNN);
mid->set_layout(framework::DataLayout::kMKLDNN);
mkldnn::stream astream(dev_ctx.GetEngine());
if (!workspace_memory->get_desc().is_zero()) {
mid->set_format(platform::GetMKLDNNFormat(*workspace_memory)); mid->set_format(platform::GetMKLDNNFormat(*workspace_memory));
lrn_p = handler.AcquireForwardPrimitive(*src_memory, *workspace_memory, lrn_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory},
*dst_memory); {MKLDNN_ARG_DST, *dst_memory},
{MKLDNN_ARG_WORKSPACE, *workspace_memory}});
} else { } else {
// mid has to be allocated and filled // mid has to be allocated and filled
// k to pass LRN unit tests // k to pass LRN unit tests
...@@ -73,11 +76,12 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -73,11 +76,12 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mid->mutable_data<T>(ctx.GetPlace()); mid->mutable_data<T>(ctx.GetPlace());
auto e_mid = framework::EigenTensor<T, 4>::From(*mid); auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid = e_mid.constant(k); e_mid = e_mid.constant(k);
lrn_p = handler.AcquireForwardPrimitive(*src_memory, *dst_memory); mid->set_format(platform::GetMKLDNNFormat(*dst_memory));
}
std::vector<mkldnn::primitive> pipeline = {*lrn_p}; lrn_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory},
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); {MKLDNN_ARG_DST, *dst_memory}});
}
astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN); out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(*dst_memory)); out->set_format(platform::GetMKLDNNFormat(*dst_memory));
...@@ -109,7 +113,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -109,7 +113,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto dims = paddle::framework::vectorize<int>(x->dims()); auto dims = paddle::framework::vectorize<int64_t>(x->dims());
platform::LRNMKLDNNHandler<T> handler(dims, n, alpha, beta, k, x->format(), platform::LRNMKLDNNHandler<T> handler(dims, n, alpha, beta, k, x->format(),
out_grad->format(), dev_ctx, out_grad->format(), dev_ctx,
...@@ -120,11 +124,14 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -120,11 +124,14 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad); auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad);
auto diff_src_memory = handler.AcquireDiffSrcMemory(x_grad); auto diff_src_memory = handler.AcquireDiffSrcMemory(x_grad);
auto lrn_bwd = handler.AcquireBackwardPrimitive( auto lrn_bwd = handler.AcquireBackwardPrimitive();
*src_memory, *diff_dst_memory, *workspace, *diff_src_memory);
std::vector<mkldnn::primitive> pipeline = {*lrn_bwd}; mkldnn::stream astream(dev_ctx.GetEngine());
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); lrn_bwd->execute(astream, {{MKLDNN_ARG_SRC, *src_memory},
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory},
{MKLDNN_ARG_DIFF_SRC, *diff_src_memory},
{MKLDNN_ARG_WORKSPACE, *workspace}});
astream.wait();
x_grad->set_layout(framework::DataLayout::kMKLDNN); x_grad->set_layout(framework::DataLayout::kMKLDNN);
x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory)); x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory));
......
...@@ -59,6 +59,7 @@ class MulPrimitiveFactory { ...@@ -59,6 +59,7 @@ class MulPrimitiveFactory {
if (mul_) { if (mul_) {
UpdateDataPointers(ctx, output, &x_matrix); UpdateDataPointers(ctx, output, &x_matrix);
Execute();
return *mul_; return *mul_;
} }
...@@ -68,9 +69,18 @@ class MulPrimitiveFactory { ...@@ -68,9 +69,18 @@ class MulPrimitiveFactory {
auto dst_desc = CreateMemDescriptor<OT>(output, MKLDNNMemoryFormat::any); auto dst_desc = CreateMemDescriptor<OT>(output, MKLDNNMemoryFormat::any);
mul_ = CreateMulPrimitive(*x_input_, *y_input_, dst_desc, output, ctx); mul_ = CreateMulPrimitive(*x_input_, *y_input_, dst_desc, output, ctx);
Execute();
return *mul_; return *mul_;
} }
void Execute() {
mkldnn::stream astream(engine_);
(*mul_).execute(astream, {{MKLDNN_ARG_SRC, *x_input_},
{MKLDNN_ARG_WEIGHTS, *y_input_},
{MKLDNN_ARG_DST, *output_}});
astream.wait();
}
protected: protected:
template <typename T> template <typename T>
Tensor UpdateDataFormat(const Tensor *data, int num_col_dims, Tensor UpdateDataFormat(const Tensor *data, int num_col_dims,
...@@ -92,7 +102,7 @@ class MulPrimitiveFactory { ...@@ -92,7 +102,7 @@ class MulPrimitiveFactory {
to_void_cast<T>(x_tmp.data<T>())); to_void_cast<T>(x_tmp.data<T>()));
x_tmp.Resize(data->dims()); x_tmp.Resize(data->dims());
x_tmp.set_format((MKLDNNMemoryFormat)dst_mdesc.data.format); x_tmp.set_format(platform::GetMKLDNNFormat(dst_mdesc));
data_matrix = framework::ReshapeToMatrix(x_tmp, num_col_dims); data_matrix = framework::ReshapeToMatrix(x_tmp, num_col_dims);
} else { } else {
data_matrix = framework::ReshapeToMatrix(*data, num_col_dims); data_matrix = framework::ReshapeToMatrix(*data, num_col_dims);
...@@ -106,7 +116,7 @@ class MulPrimitiveFactory { ...@@ -106,7 +116,7 @@ class MulPrimitiveFactory {
x_input_->set_data_handle(to_void_cast<XT>(in->data<XT>())); x_input_->set_data_handle(to_void_cast<XT>(in->data<XT>()));
output_->set_data_handle(out->mutable_data<OT>(ctx.GetPlace())); output_->set_data_handle(out->mutable_data<OT>(ctx.GetPlace()));
if (out->format() == MKLDNNMemoryFormat::format_undef) { if (out->format() == MKLDNNMemoryFormat::undef) {
auto output_format = platform::GetMKLDNNFormat(*output_); auto output_format = platform::GetMKLDNNFormat(*output_);
out->set_format((MKLDNNMemoryFormat)output_format); out->set_format((MKLDNNMemoryFormat)output_format);
} }
...@@ -116,48 +126,50 @@ class MulPrimitiveFactory { ...@@ -116,48 +126,50 @@ class MulPrimitiveFactory {
memory::desc CreateMemDescriptor( memory::desc CreateMemDescriptor(
const Tensor *tensor, MKLDNNMemoryFormat format, const Tensor *tensor, MKLDNNMemoryFormat format,
memory::data_type type = platform::MKLDNNGetDataType<T>()) { memory::data_type type = platform::MKLDNNGetDataType<T>()) {
auto dims = framework::vectorize<int>(tensor->dims()); auto dims = framework::vectorize<int64_t>(tensor->dims());
return platform::MKLDNNMemDesc(dims, type, format); return platform::MKLDNNMemDesc(dims, type, format);
} }
template <typename T> template <typename T>
memory::desc CreateMemDescriptor( memory::desc CreateMemDescriptor(
const std::vector<int> &dims, MKLDNNMemoryFormat format, const std::vector<int64_t> &dims, MKLDNNMemoryFormat format,
memory::data_type type = platform::MKLDNNGetDataType<T>()) { memory::data_type type = platform::MKLDNNGetDataType<T>()) {
return platform::MKLDNNMemDesc(dims, type, format); return platform::MKLDNNMemDesc(dims, type, format);
} }
template <typename T> template <typename T>
memory CreateMemory(const memory::desc &desc, const Tensor *tensor) { memory CreateMemory(const memory::desc &desc, const Tensor *tensor) {
return memory({desc, engine_}, to_void_cast<T>(tensor->data<T>())); return memory(desc, engine_, to_void_cast<T>(tensor->data<T>()));
} }
memory CreateDstMemory( memory CreateDstMemory(
const inner_product_forward::primitive_desc &mul_prim_desc, const inner_product_forward::primitive_desc &mul_prim_desc,
const ExecutionContext &ctx, Tensor *output) { const ExecutionContext &ctx, Tensor *output) {
auto dst_prim_desc = mul_prim_desc.dst_primitive_desc(); auto dst_desc = mul_prim_desc.dst_desc();
auto buffer_size = dst_prim_desc.get_size(); auto buffer_size = dst_desc.get_size();
OT *output_data = output->mutable_data<OT>(ctx.GetPlace(), buffer_size); OT *output_data = output->mutable_data<OT>(ctx.GetPlace(), buffer_size);
memory dst_mem(dst_prim_desc, to_void_cast<OT>(output_data)); output->set_format(paddle::platform::GetMKLDNNFormat(dst_desc));
output->set_format(platform::GetMKLDNNFormat(dst_mem)); return memory(dst_desc, engine_, to_void_cast<OT>(output_data));
return dst_mem;
} }
memory Reorder(const memory::desc &src_desc, const memory::desc &dst_desc, memory Reorder(const memory::desc &src_desc, const memory::desc &dst_desc,
void *src_data, void *dst_data = NULL) { void *src_data, void *dst_data = NULL) {
auto src_mem = memory({src_desc, engine_}, src_data); auto src_mem = memory(src_desc, engine_, src_data);
auto dst_mem = dst_data ? memory({dst_desc, engine_}, dst_data) auto dst_mem = dst_data ? memory(dst_desc, engine_, dst_data)
: memory({dst_desc, engine_}); : memory(dst_desc, engine_);
auto reorder = mkldnn::reorder(src_mem, dst_mem); auto reorder = mkldnn::reorder(src_mem, dst_mem);
stream(stream::kind::eager).submit({reorder}).wait();
mkldnn::stream astream(engine_);
reorder.execute(astream, src_mem, dst_mem);
astream.wait();
return dst_mem; return dst_mem;
} }
memory TransposeInputY(const Tensor *input_y) { memory TransposeInputY(const Tensor *input_y) {
auto dims = framework::vectorize<int>(input_y->dims()); auto dims = framework::vectorize<int64_t>(input_y->dims());
std::swap(dims[0], dims[1]); // Correct output dimensions std::swap(dims[0], dims[1]); // Correct output dimensions
auto src_desc = CreateMemDescriptor<YT>(dims, MKLDNNMemoryFormat::io); auto src_desc = CreateMemDescriptor<YT>(dims, MKLDNNMemoryFormat::io);
auto dst_desc = CreateMemDescriptor<YT>(dims, MKLDNNMemoryFormat::oi); auto dst_desc = CreateMemDescriptor<YT>(dims, MKLDNNMemoryFormat::oi);
...@@ -169,13 +181,13 @@ class MulPrimitiveFactory { ...@@ -169,13 +181,13 @@ class MulPrimitiveFactory {
const memory::desc &dst_desc, const memory::desc &dst_desc,
Tensor *output, Tensor *output,
const ExecutionContext &ctx) { const ExecutionContext &ctx) {
const auto y_desc = y_memory.get_primitive_desc().desc(); const auto y_desc = y_memory.get_desc();
const auto x_desc = x_memory.get_primitive_desc().desc(); const auto x_desc = x_memory.get_desc();
auto mul_prim_desc = CreateMulPrimDesc(x_desc, y_desc, dst_desc); auto mul_prim_desc = CreateMulPrimDesc(x_desc, y_desc, dst_desc);
output_ = CreateDstMemory(mul_prim_desc, ctx, output); output_ = CreateDstMemory(mul_prim_desc, ctx, output);
return inner_product_forward(mul_prim_desc, x_memory, y_memory, *output_); return inner_product_forward(mul_prim_desc);
} }
inner_product_forward::primitive_desc CreateMulPrimDesc( inner_product_forward::primitive_desc CreateMulPrimDesc(
...@@ -228,6 +240,7 @@ class QuantMulPrimitiveFactory : public MulPrimitiveFactory<XT, YT, OT> { ...@@ -228,6 +240,7 @@ class QuantMulPrimitiveFactory : public MulPrimitiveFactory<XT, YT, OT> {
if (this->mul_) { if (this->mul_) {
this->UpdateDataPointers(ctx, output, &x_matrix); this->UpdateDataPointers(ctx, output, &x_matrix);
this->Execute();
return *(this->mul_); return *(this->mul_);
} }
...@@ -243,6 +256,7 @@ class QuantMulPrimitiveFactory : public MulPrimitiveFactory<XT, YT, OT> { ...@@ -243,6 +256,7 @@ class QuantMulPrimitiveFactory : public MulPrimitiveFactory<XT, YT, OT> {
this->mul_ = CreateMulPrimitive(*(this->x_input_), *(this->y_input_), this->mul_ = CreateMulPrimitive(*(this->x_input_), *(this->y_input_),
dst_desc, output, ctx); dst_desc, output, ctx);
this->Execute();
return *(this->mul_); return *(this->mul_);
} }
...@@ -253,22 +267,24 @@ class QuantMulPrimitiveFactory : public MulPrimitiveFactory<XT, YT, OT> { ...@@ -253,22 +267,24 @@ class QuantMulPrimitiveFactory : public MulPrimitiveFactory<XT, YT, OT> {
mkldnn::primitive_attr attr; mkldnn::primitive_attr attr;
attr.set_output_scales(mask, scale); attr.set_output_scales(mask, scale);
auto src_mem = memory({src_desc, this->engine_}, src_data); auto src_mem = memory(src_desc, this->engine_, src_data);
auto dst_mem = memory({dst_desc, this->engine_}); auto dst_mem = memory(dst_desc, this->engine_);
auto reorder_pd = mkldnn::reorder::primitive_desc(src_mem, dst_mem, attr);
auto reorder_pd = mkldnn::reorder::primitive_desc( auto reorder = mkldnn::reorder(reorder_pd);
src_mem.get_primitive_desc(), dst_mem.get_primitive_desc(), attr);
auto reorder = mkldnn::reorder(reorder_pd, src_mem, dst_mem); mkldnn::stream astream(this->engine_);
stream(stream::kind::eager).submit({reorder}).wait(); reorder.execute(astream, src_mem, dst_mem);
astream.wait();
return dst_mem; return dst_mem;
} }
memory QuantInputY(memory input_y, const std::vector<float> &scale_y) { memory QuantInputY(memory input_y, const std::vector<float> &scale_y) {
const auto &dims = input_y.get_primitive_desc().desc().data.dims; const auto &dims = input_y.get_desc().data.dims;
auto ndims = input_y.get_primitive_desc().desc().data.ndims; auto ndims = input_y.get_desc().data.ndims;
auto y_dims = std::vector<int>(dims, dims + ndims); auto y_dims = std::vector<int64_t>(dims, dims + ndims);
auto user_y_desc = auto user_y_desc =
this->template CreateMemDescriptor<YT>(y_dims, MKLDNNMemoryFormat::oi); this->template CreateMemDescriptor<YT>(y_dims, MKLDNNMemoryFormat::oi);
...@@ -309,8 +325,8 @@ class QuantMulPrimitiveFactory : public MulPrimitiveFactory<XT, YT, OT> { ...@@ -309,8 +325,8 @@ class QuantMulPrimitiveFactory : public MulPrimitiveFactory<XT, YT, OT> {
const memory::desc &dst_desc, const memory::desc &dst_desc,
Tensor *output, Tensor *output,
const ExecutionContext &ctx) { const ExecutionContext &ctx) {
const auto x_desc = x_memory.get_primitive_desc().desc(); const auto x_desc = x_memory.get_desc();
const auto y_desc = y_memory.get_primitive_desc().desc(); const auto y_desc = y_memory.get_desc();
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output"); bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
mkldnn::primitive_attr mul_attr = CreateMulAttr(ctx, force_fp32_output); mkldnn::primitive_attr mul_attr = CreateMulAttr(ctx, force_fp32_output);
...@@ -318,8 +334,7 @@ class QuantMulPrimitiveFactory : public MulPrimitiveFactory<XT, YT, OT> { ...@@ -318,8 +334,7 @@ class QuantMulPrimitiveFactory : public MulPrimitiveFactory<XT, YT, OT> {
this->output_ = this->CreateDstMemory(mul_prim_desc, ctx, output); this->output_ = this->CreateDstMemory(mul_prim_desc, ctx, output);
return inner_product_forward(mul_prim_desc, x_memory, y_memory, return inner_product_forward(mul_prim_desc);
*(this->output_));
} }
inner_product_forward::primitive_desc CreateMulPrimDesc( inner_product_forward::primitive_desc CreateMulPrimDesc(
...@@ -340,9 +355,8 @@ std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory( ...@@ -340,9 +355,8 @@ std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory(
const Tensor *input_x, const Tensor *input_y, const Tensor *input_x, const Tensor *input_y,
const mkldnn::engine &mkldnn_engine, bool enable_quant) { const mkldnn::engine &mkldnn_engine, bool enable_quant) {
const std::string key = platform::CreateKey( const std::string key = platform::CreateKey(
input_x->type(), framework::vectorize<int>(input_x->dims()), input_x->type(), framework::vectorize(input_x->dims()), input_y->type(),
input_y->type(), framework::vectorize<int>(input_y->dims()), framework::vectorize(input_y->dims()), ctx.OutputName("Out"));
ctx.OutputName("Out"));
auto prim_creator = std::static_pointer_cast<MulPrimitiveFactory<XT, YT, OT>>( auto prim_creator = std::static_pointer_cast<MulPrimitiveFactory<XT, YT, OT>>(
dev_ctx.GetBlob(key)); dev_ctx.GetBlob(key));
...@@ -399,14 +413,12 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> { ...@@ -399,14 +413,12 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
auto mul = GetMulPrimitive<XT, YT>(dev_ctx, ctx, x, y, out, mkldnn_engine); auto mul = GetMulPrimitive<XT, YT>(dev_ctx, ctx, x, y, out, mkldnn_engine);
stream(stream::kind::eager).submit({mul}).wait();
if (out_dims.size() != 2) { if (out_dims.size() != 2) {
out->Resize(out_dims); out->Resize(out_dims);
} }
out->set_layout(DataLayout::kMKLDNN); out->set_layout(DataLayout::kMKLDNN);
out->set_format(platform::MKLDNNFormatForSize( out->set_format(platform::MKLDNNFormatForSize(out_dims.size(),
out_dims.size(), mkldnn::memory::format::nchw)); MKLDNNMemoryFormat::nchw));
} }
}; };
......
...@@ -43,13 +43,20 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -43,13 +43,20 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Input tensor"); "Wrong layout set for Input tensor");
PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Input tensor"); "Wrong format set for Input tensor");
std::string pooling_type = ctx.Attr<std::string>("pooling_type"); std::string pooling_type = ctx.Attr<std::string>("pooling_type");
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int64_t> ksize(begin(ksize_temp), end(ksize_temp));
std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
bool global_pooling = ctx.Attr<bool>("global_pooling"); bool global_pooling = ctx.Attr<bool>("global_pooling");
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm"); std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
...@@ -71,8 +78,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -71,8 +78,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
UpdatePadding(&paddings, global_pooling, 0, padding_algorithm, data_dims, UpdatePadding(&paddings, global_pooling, 0, padding_algorithm, data_dims,
strides, ksize); strides, ksize);
auto src_tz = paddle::framework::vectorize<int>(input->dims()); auto src_tz = paddle::framework::vectorize<int64_t>(input->dims());
auto dst_tz = paddle::framework::vectorize<int>(output->dims()); auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims());
auto is_test = ctx.Attr<bool>("is_test"); auto is_test = ctx.Attr<bool>("is_test");
...@@ -85,22 +92,21 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -85,22 +92,21 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto src_memory = handler.AcquireSrcMemory(input); auto src_memory = handler.AcquireSrcMemory(input);
auto dst_memory = handler.AcquireDstMemory(output); auto dst_memory = handler.AcquireDstMemory(output);
std::shared_ptr<mkldnn::pooling_forward> pool_p; auto pool_p = handler.AcquireForwardPrimitive();
std::shared_ptr<mkldnn::memory> workspace_memory;
mkldnn::stream astream(dev_ctx.GetEngine());
if ((is_test == false) && (pooling_type == "max")) { if ((is_test == false) && (pooling_type == "max")) {
// Training // Training
workspace_memory = handler.AcquireWorkspaceMemory(); auto workspace_memory = handler.AcquireWorkspaceMemory();
pool_p = handler.AcquireForwardPrimitive(*src_memory, *dst_memory, pool_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory},
*workspace_memory); {MKLDNN_ARG_DST, *dst_memory},
{MKLDNN_ARG_WORKSPACE, *workspace_memory}});
} else { } else {
// Inference // Inference
pool_p = handler.AcquireForwardPrimitive(*src_memory, *dst_memory); pool_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory},
{MKLDNN_ARG_DST, *dst_memory}});
} }
astream.wait();
// push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline{*pool_p};
stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(*dst_memory)); output->set_format(platform::GetMKLDNNFormat(*dst_memory));
...@@ -120,12 +126,12 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -120,12 +126,12 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(in_x->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(in_x->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Input tensor"); "Wrong layout set for Input tensor");
PADDLE_ENFORCE_NE(in_x->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(in_x->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Input tensor"); "Wrong format set for Input tensor");
PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Input output_grad tensor"); "Wrong layout set for Input output_grad tensor");
PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Input output_grad tensor"); "Wrong format set for Input output_grad tensor");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -133,9 +139,16 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -133,9 +139,16 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
"is_test attribute should be set to False in training phase."); "is_test attribute should be set to False in training phase.");
std::string pooling_type = ctx.Attr<std::string>("pooling_type"); std::string pooling_type = ctx.Attr<std::string>("pooling_type");
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int64_t> ksize(begin(ksize_temp), end(ksize_temp));
std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
bool global_pooling = ctx.Attr<bool>("global_pooling"); bool global_pooling = ctx.Attr<bool>("global_pooling");
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm"); std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
...@@ -155,8 +168,8 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -155,8 +168,8 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::vector<mkldnn::primitive> pipeline; std::vector<mkldnn::primitive> pipeline;
auto diff_src_tz = paddle::framework::vectorize<int>(in_x_grad->dims()); auto diff_src_tz = paddle::framework::vectorize<int64_t>(in_x_grad->dims());
auto diff_dst_tz = paddle::framework::vectorize<int>(out_grad->dims()); auto diff_dst_tz = paddle::framework::vectorize<int64_t>(out_grad->dims());
// Get an unique name from "argument" name of "Out" variable // Get an unique name from "argument" name of "Out" variable
// This name will be used as key when referring info from device context // This name will be used as key when referring info from device context
...@@ -173,22 +186,21 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -173,22 +186,21 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad); auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad);
auto diff_src_memory = handler.AcquireDiffSrcMemory(in_x_grad); auto diff_src_memory = handler.AcquireDiffSrcMemory(in_x_grad);
std::shared_ptr<mkldnn::pooling_backward> pool_bwd_p; auto pool_bwd_p = handler.AcquireBackwardPrimitive();
std::shared_ptr<mkldnn::memory> workspace_memory;
mkldnn::stream astream(dev_ctx.GetEngine());
if (pooling_type == "max") { if (pooling_type == "max") {
// Max - pooling needs Workspace // Max - pooling needs Workspace
workspace_memory = handler.AcquireWorkspaceMemory(); auto workspace_memory = handler.AcquireWorkspaceMemory();
pool_bwd_p = handler.AcquireBackwardPrimitive( pool_bwd_p->execute(astream, {{MKLDNN_ARG_DIFF_SRC, *diff_src_memory},
*diff_dst_memory, *workspace_memory, *diff_src_memory); {MKLDNN_ARG_DIFF_DST, *diff_dst_memory},
{MKLDNN_ARG_WORKSPACE, *workspace_memory}});
} else { } else {
// Average Pooling // Average Pooling
pool_bwd_p = pool_bwd_p->execute(astream, {{MKLDNN_ARG_DIFF_SRC, *diff_src_memory},
handler.AcquireBackwardPrimitive(*diff_dst_memory, *diff_src_memory); {MKLDNN_ARG_DIFF_DST, *diff_dst_memory}});
} }
astream.wait();
pipeline.push_back(*pool_bwd_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
in_x_grad->set_layout(DataLayout::kMKLDNN); in_x_grad->set_layout(DataLayout::kMKLDNN);
in_x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory)); in_x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory));
......
...@@ -42,8 +42,8 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -42,8 +42,8 @@ class QuantOpKernel : public framework::OpKernel<T> {
const auto& engine = dev_ctx.GetEngine(); const auto& engine = dev_ctx.GetEngine();
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
auto src_tz = paddle::framework::vectorize<int>(input->dims()); auto src_tz = paddle::framework::vectorize<int64_t>(input->dims());
auto dst_tz = paddle::framework::vectorize<int>(output->dims()); auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims());
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
...@@ -66,24 +66,20 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -66,24 +66,20 @@ class QuantOpKernel : public framework::OpKernel<T> {
auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32, auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32,
input->format()); input->format());
auto src_pd = mkldnn::memory::primitive_desc(src_md, engine); src_memory = std::make_shared<mkldnn::memory>(
src_memory = src_md, engine, to_void_cast<T>(input_data));
std::make_shared<mkldnn::memory>(src_pd, to_void_cast<T>(input_data));
std::shared_ptr<primitive::at> src_memory_p =
std::shared_ptr<primitive::at>(new primitive::at(*src_memory));
std::shared_ptr<mkldnn::memory::primitive_desc> dst_pd; std::shared_ptr<mkldnn::memory::desc> dst_md;
if (is_negative) { if (is_negative) {
platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine, platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine,
dst_pd, dst_memory); dst_md, dst_memory);
} else { } else {
platform::SetDstMemoryQuantized<uint8_t>(ctx, output, dst_tz, engine, platform::SetDstMemoryQuantized<uint8_t>(ctx, output, dst_tz, engine,
dst_pd, dst_memory); dst_md, dst_memory);
} }
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>( auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(src_pd, *dst_pd, attri)); new reorder::primitive_desc(*src_memory, *dst_memory, attri));
reorder_p = std::shared_ptr<reorder>( reorder_p = std::shared_ptr<reorder>(new reorder(*reorder_pd));
new reorder(*reorder_pd, *src_memory_p, *dst_memory));
dev_ctx.SetBlob(key_prim, reorder_p); dev_ctx.SetBlob(key_prim, reorder_p);
dev_ctx.SetBlob(key_src_mem, src_memory); dev_ctx.SetBlob(key_src_mem, src_memory);
...@@ -103,8 +99,10 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -103,8 +99,10 @@ class QuantOpKernel : public framework::OpKernel<T> {
} }
} }
pipeline.push_back(*reorder_p); mkldnn::stream astream(engine);
stream(stream::kind::eager).submit(pipeline).wait(); reorder_p->execute(astream, *src_memory, *dst_memory);
astream.wait();
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory)); output->set_format(GetMKLDNNFormat(*dst_memory));
} }
......
...@@ -43,8 +43,8 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -43,8 +43,8 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
const auto& engine = dev_ctx.GetEngine(); const auto& engine = dev_ctx.GetEngine();
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
auto src_tz = paddle::framework::vectorize<int>(input->dims()); auto src_tz = paddle::framework::vectorize<int64_t>(input->dims());
auto dst_tz = paddle::framework::vectorize<int>(output->dims()); auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims());
mkldnn::memory::data_type src_dt = mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
mkldnn::memory::data_type dst_dt = src_dt; mkldnn::memory::data_type dst_dt = src_dt;
...@@ -60,23 +60,21 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -60,23 +60,21 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
attri.set_output_scales(mask, {scale_shift}); attri.set_output_scales(mask, {scale_shift});
auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt); auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt);
auto src_pd = mkldnn::memory::primitive_desc(src_md, engine); auto src_memory = std::make_shared<mkldnn::memory>(
auto src_memory = src_md, engine, to_void_cast<T>(input_data));
std::make_shared<mkldnn::memory>(src_pd, to_void_cast<T>(input_data));
std::shared_ptr<primitive::at> src_memory_p =
std::shared_ptr<primitive::at>(new primitive::at(*src_memory));
auto dst_md = platform::MKLDNNMemDesc({dst_tz}, dst_dt, dst_fmt); auto dst_md = platform::MKLDNNMemDesc({dst_tz}, dst_dt, dst_fmt);
auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine); auto dst_memory =
auto dst_memory = mkldnn::memory(dst_pd, to_void_cast<T>(output_data)); mkldnn::memory(dst_md, engine, to_void_cast<T>(output_data));
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>( auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(src_pd, dst_pd, attri)); new reorder::primitive_desc(*src_memory, dst_memory, attri));
auto reorder_p = std::shared_ptr<reorder>( auto reorder_p = std::shared_ptr<reorder>(new reorder(*reorder_pd));
new reorder(*reorder_pd, *src_memory_p, dst_memory));
pipeline.push_back(*reorder_p); mkldnn::stream astream(engine);
stream(stream::kind::eager).submit(pipeline).wait(); reorder_p->execute(astream, *src_memory, dst_memory);
astream.wait();
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(dst_memory)); output->set_format(GetMKLDNNFormat(dst_memory));
......
...@@ -38,7 +38,7 @@ class SoftmaxMKLDNNHandler ...@@ -38,7 +38,7 @@ class SoftmaxMKLDNNHandler
: public platform::MKLDNNHandlerT<T, mkldnn::softmax_forward, : public platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward> { mkldnn::softmax_backward> {
public: public:
SoftmaxMKLDNNHandler(const std::vector<int>& dims, SoftmaxMKLDNNHandler(const std::vector<int64_t>& dims,
const MKLDNNMemoryFormat fmt, const int& axis, const MKLDNNMemoryFormat fmt, const int& axis,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& uniq_name) platform::Place cpu_place, const std::string& uniq_name)
...@@ -52,7 +52,7 @@ class SoftmaxMKLDNNHandler ...@@ -52,7 +52,7 @@ class SoftmaxMKLDNNHandler
axis); axis);
} }
SoftmaxMKLDNNHandler(const std::vector<int>& dims, SoftmaxMKLDNNHandler(const std::vector<int64_t>& dims,
const MKLDNNMemoryFormat fmt, const MKLDNNMemoryFormat fmt,
const MKLDNNMemoryFormat diff_fmt, const int& axis, const MKLDNNMemoryFormat diff_fmt, const int& axis,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
...@@ -87,25 +87,24 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -87,25 +87,24 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
auto dims = input->dims(); // input and output share the same shape auto dims = input->dims(); // input and output share the same shape
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size()); const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size());
auto softmax_tz = paddle::framework::vectorize<int>(dims); auto softmax_tz = paddle::framework::vectorize<int64_t>(dims);
SoftmaxMKLDNNHandler<T> handler(softmax_tz, input->format(), axis, dev_ctx, SoftmaxMKLDNNHandler<T> handler(softmax_tz, input->format(), axis, dev_ctx,
ctx.GetPlace(), ctx.OutputName("Out")); ctx.GetPlace(), ctx.OutputName("Out"));
auto softmax_src_memory_p = handler.AcquireSrcMemory(input); auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
auto softmax_dst_memory_p = handler.AcquireDstMemory(output); auto softmax_dst_memory_p = handler.AcquireDstMemory(output);
auto softmax_p = handler.AcquireForwardPrimitive(*softmax_src_memory_p, auto softmax_p = handler.AcquireForwardPrimitive();
*softmax_dst_memory_p);
std::vector<primitive> pipeline{*softmax_p}; mkldnn::stream astream(dev_ctx.GetEngine());
stream(stream::kind::eager).submit(pipeline).wait(); softmax_p->execute(astream, {{MKLDNN_ARG_SRC, *softmax_src_memory_p},
{MKLDNN_ARG_DST, *softmax_dst_memory_p}});
astream.wait();
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
if (!is_test) { if (!is_test) {
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
int size = std::accumulate(begin(softmax_tz), end(softmax_tz), 1, std::for_each(output_data, &output_data[output->numel()], [](T& val) {
std::multiplies<int>());
std::for_each(output_data, &output_data[size], [](T& val) {
val = std::max(val, static_cast<T>(exp(-64))); val = std::max(val, static_cast<T>(exp(-64)));
}); });
} }
...@@ -136,7 +135,7 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> { ...@@ -136,7 +135,7 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
auto dims = dout->dims(); // input and output share the same shape auto dims = dout->dims(); // input and output share the same shape
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size()); const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size());
std::vector<int> softmax_tz = paddle::framework::vectorize<int>(dims); auto softmax_tz = paddle::framework::vectorize<int64_t>(dims);
SoftmaxMKLDNNHandler<T> handler(softmax_tz, output->format(), SoftmaxMKLDNNHandler<T> handler(softmax_tz, output->format(),
dout->format(), axis, dev_ctx, dout->format(), axis, dev_ctx,
...@@ -146,11 +145,14 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> { ...@@ -146,11 +145,14 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout); auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx); auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto softmax_bwd_p = handler.AcquireBackwardPrimitive( auto softmax_bwd_p = handler.AcquireBackwardPrimitive();
*dst_memory_p, *diff_dst_memory_p, *diff_src_memory_p);
std::vector<primitive> pipeline{*softmax_bwd_p}; mkldnn::stream astream(dev_ctx.GetEngine());
stream(stream::kind::eager).submit(pipeline).wait(); softmax_bwd_p->execute(astream,
{{MKLDNN_ARG_DST, *dst_memory_p},
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p},
{MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN); dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(dout->format()); dx->set_format(dout->format());
......
...@@ -63,11 +63,11 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -63,11 +63,11 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
LoDTensor* output = ctx.Output<LoDTensor>("Out"); LoDTensor* output = ctx.Output<LoDTensor>("Out");
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
auto dst_tz = framework::vectorize<int>(output->dims()); auto dst_tz = framework::vectorize<int64_t>(output->dims());
auto src_tz = dst_tz; auto src_tz = dst_tz;
MKLDNNMemoryFormat output_format{MKLDNNMemoryFormat::format_undef}; MKLDNNMemoryFormat output_format{MKLDNNMemoryFormat::undef};
std::vector<float> scales; std::vector<float> scales;
std::vector<memory::primitive_desc> srcs_mpd; std::vector<memory::desc> srcs_md;
std::vector<mkldnn::memory> srcs_mem; std::vector<mkldnn::memory> srcs_mem;
PADDLE_ENFORCE_EQ(in_vars[0]->IsType<LoDTensor>(), true, PADDLE_ENFORCE_EQ(in_vars[0]->IsType<LoDTensor>(), true,
...@@ -75,7 +75,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -75,7 +75,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto& input0 = in_vars[0]->Get<LoDTensor>(); auto& input0 = in_vars[0]->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(input0.layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(input0.layout(), DataLayout::kMKLDNN,
"Wrong layout set for inputs[0] tensor"); "Wrong layout set for inputs[0] tensor");
PADDLE_ENFORCE_NE(input0.format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(input0.format(), MKLDNNMemoryFormat::undef,
"Wrong format set for inputs[0] tensor"); "Wrong format set for inputs[0] tensor");
MKLDNNMemoryFormat input_format = input0.format(); MKLDNNMemoryFormat input_format = input0.format();
...@@ -86,7 +86,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -86,7 +86,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto& input = in_vars[i]->Get<LoDTensor>(); auto& input = in_vars[i]->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(input.layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(input.layout(), DataLayout::kMKLDNN,
"Wrong layout set for inputs"); "Wrong layout set for inputs");
PADDLE_ENFORCE_NE(input.format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(input.format(), MKLDNNMemoryFormat::undef,
"Wrong format set for inputs"); "Wrong format set for inputs");
if (input.numel() == 0) { if (input.numel() == 0) {
...@@ -97,9 +97,8 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -97,9 +97,8 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto src_md = auto src_md =
memory::desc(src_tz, memory::data_type::f32, input_format); memory::desc(src_tz, memory::data_type::f32, input_format);
auto src_mpd = memory::primitive_desc(src_md, mkldnn_engine); auto src_mem = memory(src_md, mkldnn_engine, to_void_cast(input_data));
auto src_mem = memory(src_mpd, to_void_cast(input_data)); srcs_md.push_back(src_md);
srcs_mpd.push_back(src_mpd);
srcs_mem.push_back(src_mem); srcs_mem.push_back(src_mem);
scales.push_back(1.0); scales.push_back(1.0);
} }
...@@ -107,36 +106,43 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -107,36 +106,43 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_md = auto dst_md =
memory::desc(dst_tz, memory::data_type::f32, MKLDNNMemoryFormat::any); memory::desc(dst_tz, memory::data_type::f32, MKLDNNMemoryFormat::any);
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd); auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_md, mkldnn_engine);
std::shared_ptr<memory> dst_mem; std::shared_ptr<memory> dst_mem;
if (in_place) { if (in_place) {
dst_mem.reset(new memory(sum_pd.dst_primitive_desc())); dst_mem.reset(new memory(sum_pd.dst_desc(), mkldnn_engine));
} else { } else {
dst_mem.reset(new memory(sum_pd.dst_primitive_desc(), output_data)); dst_mem.reset(
} new memory(sum_pd.dst_desc(), mkldnn_engine, output_data));
std::vector<mkldnn::primitive::at> inputs;
for (size_t i = 0; i < srcs_mem.size(); ++i) {
inputs.push_back(srcs_mem[i]);
} }
auto sum_prim = mkldnn::sum(sum_pd, inputs, *dst_mem); auto sum_prim = mkldnn::sum(sum_pd);
output_format = (MKLDNNMemoryFormat)platform::GetMKLDNNFormat(sum_pd); output_format = platform::GetMKLDNNFormat(sum_pd.dst_desc());
primitive reorder_prim; std::shared_ptr<mkldnn::reorder> reorder_p;
std::shared_ptr<memory> target_mem; std::shared_ptr<memory> target_mem;
if (in_place) { if (in_place) {
output_format = input_format; output_format = input_format;
target_mem.reset(new memory( target_mem.reset(
{{{src_tz}, memory::data_type::f32, output_format}, mkldnn_engine}, new memory({{src_tz}, memory::data_type::f32, output_format},
output_data)); mkldnn_engine, output_data));
reorder_prim = reorder(*dst_mem, *target_mem); reorder_p = std::make_shared<reorder>(*dst_mem, *target_mem);
}
mkldnn::stream astream(mkldnn_engine);
std::unordered_map<int, memory> args;
for (size_t i = 0; i < srcs_mem.size(); ++i) {
args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, srcs_mem.at(i)});
} }
args.insert({MKLDNN_ARG_DST, *dst_mem});
sum_prim.execute(astream, args);
astream.wait();
std::vector<primitive> pipeline; if (in_place) {
pipeline.push_back(sum_prim); reorder_p->execute(astream, *dst_mem, *target_mem);
if (in_place) pipeline.push_back(reorder_prim); astream.wait();
stream(stream::kind::eager).submit(pipeline).wait(); }
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(output_format); output->set_format(output_format);
......
...@@ -44,7 +44,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -44,7 +44,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
return; return;
} }
auto nchw_tz = paddle::framework::vectorize<int>(input->dims()); auto nchw_tz = paddle::framework::vectorize<int64_t>(input->dims());
const std::string key = platform::CreateKey(nchw_tz, ctx.OutputName("Out")); const std::string key = platform::CreateKey(nchw_tz, ctx.OutputName("Out"));
...@@ -58,12 +58,13 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -58,12 +58,13 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p, auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
transpose_src_memory_p); transpose_src_memory_p);
std::vector<mkldnn::primitive> pipeline; mkldnn::stream astream(mkldnn_engine);
pipeline.push_back(*transpose_p); transpose_p->execute(astream, *transpose_src_memory_p,
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); *transpose_dst_memory_p);
astream.wait();
output->set_layout(DataLayout::kNCHW); output->set_layout(DataLayout::kNCHW);
output->set_format(MKLDNNMemoryFormat::format_undef); output->set_format(MKLDNNMemoryFormat::undef);
} }
}; };
...@@ -95,7 +96,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -95,7 +96,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const T* out_grad_data = out_grad->data<T>(); const T* out_grad_data = out_grad->data<T>();
x_grad->mutable_data<T>(ctx.GetPlace()); x_grad->mutable_data<T>(ctx.GetPlace());
auto nchw_tz = paddle::framework::vectorize<int>(out_grad->dims()); auto nchw_tz = paddle::framework::vectorize<int64_t>(out_grad->dims());
const std::string key = platform::CreateKey( const std::string key = platform::CreateKey(
nchw_tz, ctx.OutputName(framework::GradVarName("X"))); nchw_tz, ctx.OutputName(framework::GradVarName("X")));
...@@ -110,9 +111,10 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -110,9 +111,10 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p, auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
transpose_src_memory_p); transpose_src_memory_p);
std::vector<mkldnn::primitive> pipeline; mkldnn::stream astream(mkldnn_engine);
pipeline.push_back(*transpose_p); transpose_p->execute(astream, *transpose_src_memory_p,
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); *transpose_dst_memory_p);
astream.wait();
} }
}; };
......
...@@ -376,7 +376,9 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; } ...@@ -376,7 +376,9 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
: CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobmap_() { : CPUDeviceContext(place),
engine_(mkldnn::engine::kind::cpu, 0),
p_blobmap_() {
p_blobmap_.reset(new BlobMap()); p_blobmap_.reset(new BlobMap());
p_mutex_.reset(new std::mutex()); p_mutex_.reset(new std::mutex());
} }
......
...@@ -23,7 +23,7 @@ limitations under the License. */ ...@@ -23,7 +23,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
using MKLDNNMemoryFormat = mkldnn::memory::format; using MKLDNNMemoryFormat = mkldnn::memory::format_tag;
#endif #endif
namespace platform { namespace platform {
...@@ -71,11 +71,10 @@ tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p, ...@@ -71,11 +71,10 @@ tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p,
return tf_pd<Type>(desc, e, p); return tf_pd<Type>(desc, e, p);
} }
inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int>& dims, inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
mkldnn::memory::data_type data_type, mkldnn::memory::data_type data_type,
MKLDNNMemoryFormat format) { MKLDNNMemoryFormat format) {
mkldnn::memory::dims tz = dims; return mkldnn::memory::desc({dims}, data_type, format);
return mkldnn::memory::desc({tz}, data_type, format);
} }
inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) { inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) {
...@@ -85,7 +84,7 @@ inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) { ...@@ -85,7 +84,7 @@ inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) {
template <typename Type> template <typename Type>
mkldnn::memory::data_type MKLDNNGetDataType() { mkldnn::memory::data_type MKLDNNGetDataType() {
return mkldnn::memory::data_type::data_undef; return mkldnn::memory::data_type::undef;
} }
template <> template <>
...@@ -105,22 +104,136 @@ inline mkldnn::memory::data_type MKLDNNGetDataType<uint8_t>() { ...@@ -105,22 +104,136 @@ inline mkldnn::memory::data_type MKLDNNGetDataType<uint8_t>() {
return mkldnn::memory::data_type::u8; return mkldnn::memory::data_type::u8;
} }
inline void Reorder(const mkldnn::memory& src, const mkldnn::memory& dst) { inline void Reorder(mkldnn::memory src, mkldnn::memory dst,
const mkldnn::engine& engine) {
auto reorder_prim = mkldnn::reorder(src, dst); auto reorder_prim = mkldnn::reorder(src, dst);
std::vector<mkldnn::primitive> pipeline; mkldnn::stream astream(engine);
pipeline.push_back(reorder_prim); reorder_prim.execute(astream, src, dst);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); astream.wait();
} }
inline MKLDNNMemoryFormat GetMKLDNNFormat(const mkldnn::memory memory) { inline mkldnn::memory::format_tag GetMKLDNNFormat(
return static_cast<MKLDNNMemoryFormat>( mkldnn::memory::desc mem_desc) {
memory.get_primitive_desc().desc().data.format); auto ndims = mem_desc.data.ndims;
auto strides = mem_desc.data.format_desc.blocking.strides;
auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks;
auto inner_blks = mem_desc.data.format_desc.blocking.inner_blks;
auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs;
if (ndims == 1) {
return mkldnn::memory::format_tag::x;
} else if (ndims == 2) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1]) {
return mkldnn::memory::format_tag::nc;
} else {
return mkldnn::memory::format_tag::cn;
}
}
} else if (ndims == 3) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2]) {
return mkldnn::memory::format_tag::ncw;
} else {
return mkldnn::memory::format_tag::nwc;
}
}
} else if (ndims == 4) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3]) {
return mkldnn::memory::format_tag::nchw;
} else {
return mkldnn::memory::format_tag::nhwc;
}
} else if (inner_nblks == 1) {
if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
return mkldnn::memory::format_tag::nChw16c;
} else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
return mkldnn::memory::format_tag::nChw8c;
} else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return mkldnn::memory::format_tag::Acdb8a;
}
} else if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
return mkldnn::memory::format_tag::nChw4c;
} else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return mkldnn::memory::format_tag::Acdb16a;
}
}
} else if (inner_nblks == 2) {
if (inner_blks[0] == 16 && inner_blks[1] == 16) {
if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
return mkldnn::memory::format_tag::OIhw16i16o;
}
} else if (inner_blks[0] == 8 && inner_blks[1] == 8) {
if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
return mkldnn::memory::format_tag::OIhw8i8o;
}
}
}
} else if (ndims == 5) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::ncdhw;
} else {
return mkldnn::memory::format_tag::ndhwc;
}
} else if (inner_nblks == 1) {
if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return mkldnn::memory::format_tag::Acdeb8a;
}
} else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::aBcde8b;
}
} else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return mkldnn::memory::format_tag::Acdeb16a;
}
} else if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::aBcde16b;
}
}
}
} else if (ndims == 6) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) {
return mkldnn::memory::format_tag::abcdef;
}
}
}
// DEBUG CODE - KEEP UNTILL TENSOR.MEMORY_DESC IMPLEMENTED
// std::cout<<"@@@@@@@@@@ UNDEFINED FORMAT @@@@@@@@@@@@@@@@@@@"<<std::endl;
// std::cout<<"NDIMS: "<<ndims<<std::endl;
// std::cout<<"INNER_NBLKS: "<<inner_nblks<<std::endl;
// for (int i=0;i<ndims;++i) {
// std::cout<<"STRIDE["<<i<<"]: "<<strides[i]<<std::endl;
// }
// for (int i=0;i<inner_nblks;++i) {
// std::cout<<"INNER_BLKS["<<i<<"]: "<<inner_blks[i]<<std::endl;
// }
// for (int i=0;i<inner_nblks;++i) {
// std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl;
// }
return mkldnn::memory::format_tag::undef;
} }
inline MKLDNNMemoryFormat GetMKLDNNFormat( inline mkldnn::memory::format_tag GetMKLDNNFormat(const mkldnn::memory memory) {
const mkldnn::sum::primitive_desc& memory) { auto mem_desc = memory.get_desc();
return static_cast<MKLDNNMemoryFormat>( return GetMKLDNNFormat(mem_desc);
memory.dst_primitive_desc().desc().data.format);
} }
inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size, inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size,
...@@ -190,13 +303,37 @@ inline void AppendKey(std::string* key, const T& num) { ...@@ -190,13 +303,37 @@ inline void AppendKey(std::string* key, const T& num) {
key->append(std::to_string(num)); key->append(std::to_string(num));
} }
template <>
inline void AppendKey(std::string* key,
const mkldnn::memory::format_tag& format) {
key->append(std::to_string(static_cast<int>(format)));
}
template <>
inline void AppendKey(std::string* key,
const mkldnn::memory::data_type& data_type) {
key->append(std::to_string(static_cast<int>(data_type)));
}
template <>
inline void AppendKey(std::string* key, const mkldnn::algorithm& algorithm) {
key->append(std::to_string(static_cast<int>(algorithm)));
}
template <>
inline void AppendKey(std::string* key,
const mkldnn::normalization_flags& flags) {
key->append(std::to_string(static_cast<int>(flags)));
}
inline void AppendKey(std::string* key, const std::string& str) { inline void AppendKey(std::string* key, const std::string& str) {
key->append(str); key->append(str);
} }
inline void AppendKey(std::string* key, const char* str) { key->append(str); } inline void AppendKey(std::string* key, const char* str) { key->append(str); }
inline void AppendKey(std::string* key, const std::vector<int>& dims) { template <typename T>
inline void AppendKey(std::string* key, const std::vector<T>& dims) {
for (size_t i = 0; i < dims.size(); i++) { for (size_t i = 0; i < dims.size(); i++) {
AppendKey(key, std::to_string(dims[i])); AppendKey(key, std::to_string(dims[i]));
} }
...@@ -211,8 +348,8 @@ inline std::string CreateKey(ArgTypes&&... args) { ...@@ -211,8 +348,8 @@ inline std::string CreateKey(ArgTypes&&... args) {
return key; return key;
} }
inline std::vector<std::vector<int>> ToMkldnnPadding( inline std::vector<std::vector<int64_t>> ToMkldnnPadding(
const std::vector<int>& paddings) { const std::vector<int64_t>& paddings) {
if (paddings.size() == 6) { if (paddings.size() == 6) {
int padding_front = paddings[0]; int padding_front = paddings[0];
int padding_back = paddings[1]; int padding_back = paddings[1];
......
...@@ -49,27 +49,23 @@ class MKLDNNHandlerT { ...@@ -49,27 +49,23 @@ class MKLDNNHandlerT {
} }
} }
template <typename... Args> std::shared_ptr<TForward> AcquireForwardPrimitive() {
std::shared_ptr<TForward> AcquireForwardPrimitive(Args&&... args) {
const std::string key_p = key_ + "@forward_p"; const std::string key_p = key_ + "@forward_p";
auto forward_p = auto forward_p =
std::static_pointer_cast<TForward>(dev_ctx_.GetBlob(key_p)); std::static_pointer_cast<TForward>(dev_ctx_.GetBlob(key_p));
if (forward_p == nullptr) { if (forward_p == nullptr) {
forward_p = forward_p = std::make_shared<TForward>(*fwd_pd_);
std::make_shared<TForward>(*fwd_pd_, std::forward<Args>(args)...);
dev_ctx_.SetBlob(key_p, forward_p); dev_ctx_.SetBlob(key_p, forward_p);
} }
return forward_p; return forward_p;
} }
template <typename... Args> std::shared_ptr<TBackward> AcquireBackwardPrimitive() {
std::shared_ptr<TBackward> AcquireBackwardPrimitive(Args&&... args) {
const std::string key_p = key_ + "@backward_p"; const std::string key_p = key_ + "@backward_p";
auto backward_p = auto backward_p =
std::static_pointer_cast<TBackward>(dev_ctx_.GetBlob(key_p)); std::static_pointer_cast<TBackward>(dev_ctx_.GetBlob(key_p));
if (backward_p == nullptr) { if (backward_p == nullptr) {
backward_p = backward_p = std::make_shared<TBackward>(*bwd_pd_);
std::make_shared<TBackward>(*bwd_pd_, std::forward<Args>(args)...);
dev_ctx_.SetBlob(key_p, backward_p); dev_ctx_.SetBlob(key_p, backward_p);
} }
return backward_p; return backward_p;
...@@ -78,40 +74,36 @@ class MKLDNNHandlerT { ...@@ -78,40 +74,36 @@ class MKLDNNHandlerT {
std::shared_ptr<mkldnn::memory> AcquireSrcMemory( std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const framework::Tensor* input) { const framework::Tensor* input) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(fwd_pd_->src_primitive_desc(), return this->AcquireMemoryFromPrimitive(
to_void_cast<T>(input_data), fwd_pd_->src_desc(), to_void_cast<T>(input_data), "@src_mem_p");
"@src_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output) { std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output) {
T* ptr = output->mutable_data<T>(place_, T* ptr = output->mutable_data<T>(place_, fwd_pd_->dst_desc().get_size());
fwd_pd_->dst_primitive_desc().get_size()); return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc(), ptr,
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr,
"@dst_mem_p"); "@dst_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireDstMemory( std::shared_ptr<mkldnn::memory> AcquireDstMemory(
const framework::Tensor* output) { const framework::Tensor* output) {
const T* output_data = output->data<T>(); const T* output_data = output->data<T>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_primitive_desc(), return this->AcquireMemoryFromPrimitive(
to_void_cast<T>(output_data), bwd_pd_->dst_desc(), to_void_cast<T>(output_data), "@bwd-dst_mem_p");
"@bwd-dst_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory( std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(
const framework::Tensor* diffdst) { const framework::Tensor* diffdst) {
const T* ptr = diffdst->data<T>(); const T* ptr = diffdst->data<T>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_primitive_desc(), return this->AcquireMemoryFromPrimitive(
to_void_cast<T>(ptr), bwd_pd_->diff_dst_desc(), to_void_cast<T>(ptr), "@diff_dst_mem_p");
"@diff_dst_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory( std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(
framework::Tensor* diffsrc) { framework::Tensor* diffsrc) {
T* ptr = diffsrc->mutable_data<T>( T* ptr =
place_, bwd_pd_->diff_src_primitive_desc().get_size()); diffsrc->mutable_data<T>(place_, bwd_pd_->diff_src_desc().get_size());
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(), return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_desc(), ptr,
ptr, "@diff_src_mem_p"); "@diff_src_mem_p");
} }
protected: protected:
...@@ -156,13 +148,12 @@ class MKLDNNHandlerT { ...@@ -156,13 +148,12 @@ class MKLDNNHandlerT {
} }
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive( std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::primitive_desc mdp, void* ptr, mkldnn::memory::desc md, void* ptr, const std::string& suffix) {
const std::string& suffix) {
auto local_key = key_ + suffix; auto local_key = key_ + suffix;
auto mem_p = auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) { if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(mdp, ptr); mem_p = std::make_shared<mkldnn::memory>(md, engine_, ptr);
dev_ctx_.SetBlob(local_key, mem_p); dev_ctx_.SetBlob(local_key, mem_p);
} else { } else {
mem_p->set_data_handle(ptr); mem_p->set_data_handle(ptr);
...@@ -214,13 +205,12 @@ class MKLDNNHandler { ...@@ -214,13 +205,12 @@ class MKLDNNHandler {
} }
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive( std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::primitive_desc mdp, void* ptr, mkldnn::memory::desc md, void* ptr, const std::string& suffix) {
const std::string& suffix) {
auto local_key = key_ + suffix; auto local_key = key_ + suffix;
auto mem_p = auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) { if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(mdp, ptr); mem_p = std::make_shared<mkldnn::memory>(md, engine_, ptr);
dev_ctx_.SetBlob(local_key, mem_p); dev_ctx_.SetBlob(local_key, mem_p);
} else { } else {
mem_p->set_data_handle(ptr); mem_p->set_data_handle(ptr);
...@@ -245,8 +235,7 @@ class MKLDNNHandler { ...@@ -245,8 +235,7 @@ class MKLDNNHandler {
ptr = reinterpret_cast<void*>(reordered_data.get()); ptr = reinterpret_cast<void*>(reordered_data.get());
} }
mem_p = std::make_shared<mkldnn::memory>( mem_p = std::make_shared<mkldnn::memory>(md, engine_, ptr);
mkldnn::memory::primitive_desc{md, engine_}, ptr);
dev_ctx_.SetBlob(local_key, mem_p); dev_ctx_.SetBlob(local_key, mem_p);
} else { } else {
mem_p->set_data_handle(ptr); mem_p->set_data_handle(ptr);
...@@ -255,7 +244,7 @@ class MKLDNNHandler { ...@@ -255,7 +244,7 @@ class MKLDNNHandler {
} }
std::shared_ptr<mkldnn::memory> AcquireMemory( std::shared_ptr<mkldnn::memory> AcquireMemory(
const std::vector<int>& dims, const mkldnn::memory::data_type dtype, const std::vector<int64_t>& dims, const mkldnn::memory::data_type dtype,
const MKLDNNMemoryFormat& fmt, void* ptr, const std::string& suffix) { const MKLDNNMemoryFormat& fmt, void* ptr, const std::string& suffix) {
/*Generate key*/ /*Generate key*/
auto local_key = key_ + suffix; auto local_key = key_ + suffix;
...@@ -264,8 +253,7 @@ class MKLDNNHandler { ...@@ -264,8 +253,7 @@ class MKLDNNHandler {
if (mem_p == nullptr) { if (mem_p == nullptr) {
auto md = mkldnn::memory::desc(dims, dtype, fmt); auto md = mkldnn::memory::desc(dims, dtype, fmt);
mem_p = std::make_shared<mkldnn::memory>( mem_p = std::make_shared<mkldnn::memory>(md, engine_, ptr);
mkldnn::memory::primitive_desc{md, engine_}, ptr);
dev_ctx_.SetBlob(local_key, mem_p); dev_ctx_.SetBlob(local_key, mem_p);
} else { } else {
mem_p->set_data_handle(ptr); mem_p->set_data_handle(ptr);
...@@ -290,15 +278,18 @@ class MKLDNNHandler { ...@@ -290,15 +278,18 @@ class MKLDNNHandler {
auto reorder_p = auto reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p); std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p); dev_ctx_.SetBlob(key_reorder_p, reorder_p);
pipeline.push_back(*reorder_p); mkldnn::stream astream(engine_);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
} }
return target_memory_p; return target_memory_p;
} }
std::shared_ptr<mkldnn::memory> AcquireMemory( std::shared_ptr<mkldnn::memory> AcquireMemory(
mkldnn::memory::primitive_desc& mpd, // NOLINT mkldnn::memory::desc& md, // NOLINT
mkldnn::memory::primitive_desc& user_mpd, // NOLINT mkldnn::memory::desc& user_md, // NOLINT
const std::shared_ptr<mkldnn::memory> user_memory_p, const std::shared_ptr<mkldnn::memory> user_memory_p,
const std::string& suffix, const std::string& suffix,
std::vector<mkldnn::primitive>& pipeline, // NOLINT std::vector<mkldnn::primitive>& pipeline, // NOLINT
...@@ -310,27 +301,34 @@ class MKLDNNHandler { ...@@ -310,27 +301,34 @@ class MKLDNNHandler {
auto target_memory_p = auto target_memory_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
mkldnn::stream astream(engine_);
if (target_memory_p == nullptr) { if (target_memory_p == nullptr) {
target_memory_p = user_memory_p; target_memory_p = user_memory_p;
std::shared_ptr<mkldnn::primitive> reorder_p; if (md != user_md) {
if (mpd != user_mpd) { target_memory_p = std::make_shared<mkldnn::memory>(md, engine_);
target_memory_p = std::make_shared<mkldnn::memory>(mpd); std::shared_ptr<mkldnn::reorder::primitive_desc> reorder_pd;
std::shared_ptr<mkldnn::reorder> reorder_p;
if (is_INT8) { if (is_INT8) {
mkldnn::primitive_attr mkldnn::primitive_attr
attri; // attribute for int8 weights and bias data reorder. attri; // attribute for int8 weights and bias data reorder.
attri.set_output_scales(mask, scale_data); attri.set_output_scales(mask, scale_data);
auto reorder_pd = std::shared_ptr<mkldnn::reorder::primitive_desc>( reorder_pd = std::shared_ptr<mkldnn::reorder::primitive_desc>(
new mkldnn::reorder::primitive_desc(user_mpd, mpd, attri)); new mkldnn::reorder::primitive_desc(*user_memory_p,
reorder_p = std::shared_ptr<mkldnn::reorder>(new mkldnn::reorder( *target_memory_p, attri));
*reorder_pd, *user_memory_p, *target_memory_p));
} else { } else {
reorder_p = std::make_shared<mkldnn::reorder>(*user_memory_p, reorder_pd = std::shared_ptr<mkldnn::reorder::primitive_desc>(
*target_memory_p); new mkldnn::reorder::primitive_desc(*user_memory_p,
*target_memory_p));
} }
auto reorder_p =
std::shared_ptr<mkldnn::reorder>(new mkldnn::reorder(*reorder_pd));
dev_ctx_.SetBlob(key_reorder_p, reorder_p); dev_ctx_.SetBlob(key_reorder_p, reorder_p);
pipeline.push_back(*reorder_p);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
} }
dev_ctx_.SetBlob(local_key, target_memory_p); dev_ctx_.SetBlob(local_key, target_memory_p);
} else if (!is_persistent) { } else if (!is_persistent) {
...@@ -338,7 +336,9 @@ class MKLDNNHandler { ...@@ -338,7 +336,9 @@ class MKLDNNHandler {
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>( auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p)); dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p != nullptr) { if (reorder_p != nullptr) {
pipeline.push_back(*reorder_p); reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
} }
} }
return target_memory_p; return target_memory_p;
...@@ -366,12 +366,13 @@ class SumMKLDNNHandler : public MKLDNNHandler { ...@@ -366,12 +366,13 @@ class SumMKLDNNHandler : public MKLDNNHandler {
dev_ctx_.GetBlob(key_sum_pd)); dev_ctx_.GetBlob(key_sum_pd));
if (sum_pd_ == nullptr) { if (sum_pd_ == nullptr) {
// Get vector of inputs primitive descriptors // Get vector of inputs primitive descriptors
std::vector<mkldnn::memory::primitive_desc> src_pds; std::vector<mkldnn::memory::desc> src_ds;
for (auto& input_mem : src_mems) { for (auto& input_mem : src_mems) {
src_pds.push_back(input_mem->get_primitive_desc()); src_ds.push_back(input_mem->get_desc());
} }
sum_pd_.reset(new mkldnn::sum::primitive_desc(dst_md, scales, src_pds)); sum_pd_.reset(
new mkldnn::sum::primitive_desc(dst_md, scales, src_ds, engine_));
dev_ctx_.SetBlob(key_sum_pd, sum_pd_); dev_ctx_.SetBlob(key_sum_pd, sum_pd_);
} }
...@@ -379,7 +380,7 @@ class SumMKLDNNHandler : public MKLDNNHandler { ...@@ -379,7 +380,7 @@ class SumMKLDNNHandler : public MKLDNNHandler {
} }
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(void* ptr) { std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(void* ptr) {
return this->AcquireMemoryFromPrimitive(sum_pd_->dst_primitive_desc(), ptr, return this->AcquireMemoryFromPrimitive(sum_pd_->dst_desc(), ptr,
"@dst_mem_p"); "@dst_mem_p");
} }
...@@ -388,14 +389,12 @@ class SumMKLDNNHandler : public MKLDNNHandler { ...@@ -388,14 +389,12 @@ class SumMKLDNNHandler : public MKLDNNHandler {
return this->AcquireMemory(md, ptr, "@user_src2_mem_p"); return this->AcquireMemory(md, ptr, "@user_src2_mem_p");
} }
std::shared_ptr<mkldnn::sum> AcquireSum( std::shared_ptr<mkldnn::sum> AcquireSum() {
std::shared_ptr<mkldnn::memory> dst_memory,
std::vector<mkldnn::primitive::at>* inputs) {
auto prim_key = key_ + "@sum_p"; auto prim_key = key_ + "@sum_p";
auto sum_p = auto sum_p =
std::static_pointer_cast<mkldnn::sum>(dev_ctx_.GetBlob(prim_key)); std::static_pointer_cast<mkldnn::sum>(dev_ctx_.GetBlob(prim_key));
if (sum_p == nullptr) { if (sum_p == nullptr) {
sum_p = std::make_shared<mkldnn::sum>(*(sum_pd_), *inputs, *(dst_memory)); sum_p = std::make_shared<mkldnn::sum>(*sum_pd_);
dev_ctx_.SetBlob(prim_key, sum_p); dev_ctx_.SetBlob(prim_key, sum_p);
} }
return sum_p; return sum_p;
...@@ -410,7 +409,7 @@ class ActivationMKLDNNHandler ...@@ -410,7 +409,7 @@ class ActivationMKLDNNHandler
: public MKLDNNHandlerT<T, mkldnn::eltwise_forward, : public MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward> { mkldnn::eltwise_backward> {
public: public:
ActivationMKLDNNHandler(const std::vector<int>& dims, ActivationMKLDNNHandler(const std::vector<int64_t>& dims,
mkldnn::algorithm algorithm, float alpha, float beta, mkldnn::algorithm algorithm, float alpha, float beta,
const MKLDNNMemoryFormat fmt, bool is_test, const MKLDNNMemoryFormat fmt, bool is_test,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
...@@ -429,7 +428,7 @@ class ActivationMKLDNNHandler ...@@ -429,7 +428,7 @@ class ActivationMKLDNNHandler
algorithm, md, alpha, beta); algorithm, md, alpha, beta);
} }
ActivationMKLDNNHandler(const std::vector<int>& dims, ActivationMKLDNNHandler(const std::vector<int64_t>& dims,
mkldnn::algorithm algorithm, float alpha, float beta, mkldnn::algorithm algorithm, float alpha, float beta,
const MKLDNNMemoryFormat fmt, const MKLDNNMemoryFormat fmt,
const MKLDNNMemoryFormat diff_fmt, const MKLDNNMemoryFormat diff_fmt,
...@@ -453,7 +452,7 @@ class ActivationMKLDNNHandler ...@@ -453,7 +452,7 @@ class ActivationMKLDNNHandler
std::shared_ptr<mkldnn::memory> AcquireBackwardSrcMemory( std::shared_ptr<mkldnn::memory> AcquireBackwardSrcMemory(
const framework::Tensor* input) { const framework::Tensor* input) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->bwd_pd_->src_primitive_desc(), return this->AcquireMemoryFromPrimitive(this->bwd_pd_->src_desc(),
to_void_cast<T>(input_data), to_void_cast<T>(input_data),
"@bwd-src_mem_p"); "@bwd-src_mem_p");
} }
...@@ -463,8 +462,8 @@ template <typename T> ...@@ -463,8 +462,8 @@ template <typename T>
class LRNMKLDNNHandler class LRNMKLDNNHandler
: public MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward> { : public MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward> {
public: public:
LRNMKLDNNHandler(const std::vector<int>& dims, const int n, const float alpha, LRNMKLDNNHandler(const std::vector<int64_t>& dims, const int n,
const float beta, const float k, const float alpha, const float beta, const float k,
const MKLDNNMemoryFormat fmt, bool is_test, const MKLDNNMemoryFormat fmt, bool is_test,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& unique_name) platform::Place cpu_place, const std::string& unique_name)
...@@ -477,11 +476,11 @@ class LRNMKLDNNHandler ...@@ -477,11 +476,11 @@ class LRNMKLDNNHandler
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training, : mkldnn::prop_kind::forward_training,
mkldnn::lrn_across_channels, src_md, n, alpha, beta, k); mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
} }
LRNMKLDNNHandler(const std::vector<int>& dims, const int n, const float alpha, LRNMKLDNNHandler(const std::vector<int64_t>& dims, const int n,
const float beta, const float k, const float alpha, const float beta, const float k,
const MKLDNNMemoryFormat fmt, const MKLDNNMemoryFormat fmt,
const MKLDNNMemoryFormat diff_fmt, const MKLDNNMemoryFormat diff_fmt,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
...@@ -496,23 +495,24 @@ class LRNMKLDNNHandler ...@@ -496,23 +495,24 @@ class LRNMKLDNNHandler
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt); mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
this->AcquireBackwardPrimitiveDescriptor( this->AcquireBackwardPrimitiveDescriptor(
mkldnn::lrn_across_channels, src_md, diff_md, n, alpha, beta, k); mkldnn::algorithm::lrn_across_channels, src_md, diff_md, n, alpha, beta,
k);
} }
std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory( std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(
framework::Tensor* workspace) { framework::Tensor* workspace) {
T* ptr = workspace->mutable_data<T>( T* ptr = workspace->mutable_data<T>(
this->place_, this->fwd_pd_->workspace_primitive_desc().get_size()); this->place_, this->fwd_pd_->workspace_desc().get_size());
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(this->fwd_pd_->workspace_desc(),
this->fwd_pd_->workspace_primitive_desc(), ptr, "@wrk_mem_p"); ptr, "@wrk_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireBackwardWorkspaceMemory( std::shared_ptr<mkldnn::memory> AcquireBackwardWorkspaceMemory(
const framework::Tensor* workspace) { const framework::Tensor* workspace) {
const T* workspace_data = workspace->data<T>(); const T* workspace_data = workspace->data<T>();
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(this->fwd_pd_->workspace_desc(),
this->fwd_pd_->workspace_primitive_desc(), to_void_cast<T>(workspace_data),
to_void_cast<T>(workspace_data), "@bwd-wrk_mem_p"); "@bwd-wrk_mem_p");
} }
}; };
...@@ -521,11 +521,11 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward, ...@@ -521,11 +521,11 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
mkldnn::pooling_backward> { mkldnn::pooling_backward> {
public: public:
PoolingMKLDNNHandler( PoolingMKLDNNHandler(
const std::vector<int>& src_dims, const std::vector<int>& dst_dims, const std::vector<int64_t>& src_dims,
const std::vector<int>& ksize, const std::vector<int>& strides, const std::vector<int64_t>& dst_dims, const std::vector<int64_t>& ksize,
const std::vector<int>& paddings, const std::string& pooling_type, const std::vector<int64_t>& strides, const std::vector<int64_t>& paddings,
bool ceil_mode, const MKLDNNMemoryFormat fmt, const std::string& pooling_type, bool ceil_mode,
mkldnn::memory::data_type dt, bool is_test, const MKLDNNMemoryFormat fmt, mkldnn::memory::data_type dt, bool is_test,
const platform::MKLDNNDeviceContext& dev_ctx, platform::Place cpu_place, const platform::MKLDNNDeviceContext& dev_ctx, platform::Place cpu_place,
const std::string& unique_name, bool exclude_padding) const std::string& unique_name, bool exclude_padding)
: platform::MKLDNNHandlerT<T, mkldnn::pooling_forward, : platform::MKLDNNHandlerT<T, mkldnn::pooling_forward,
...@@ -554,17 +554,16 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward, ...@@ -554,17 +554,16 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
: (exclude_padding : (exclude_padding
? mkldnn::algorithm::pooling_avg_exclude_padding ? mkldnn::algorithm::pooling_avg_exclude_padding
: mkldnn::algorithm::pooling_avg_include_padding), : mkldnn::algorithm::pooling_avg_include_padding),
src_md, dst_md, strides, ksize, mkldnn_paddings[0], mkldnn_paddings[1], src_md, dst_md, strides, ksize, mkldnn_paddings[0], mkldnn_paddings[1]);
mkldnn::padding_kind::zero);
} }
PoolingMKLDNNHandler( PoolingMKLDNNHandler(
const std::vector<int>& diff_dst_dims, const std::vector<int64_t>& diff_dst_dims,
const std::vector<int>& diff_src_dims, const std::vector<int>& ksize, const std::vector<int64_t>& diff_src_dims,
const std::vector<int>& strides, const std::vector<int>& paddings, const std::vector<int64_t>& ksize, const std::vector<int64_t>& strides,
const std::string& pooling_type, bool ceil_mode, const std::vector<int64_t>& paddings, const std::string& pooling_type,
const MKLDNNMemoryFormat fmt, const MKLDNNMemoryFormat diff_dst_fmt, bool ceil_mode, const MKLDNNMemoryFormat fmt,
mkldnn::memory::data_type dt, const MKLDNNMemoryFormat diff_dst_fmt, mkldnn::memory::data_type dt,
const platform::MKLDNNDeviceContext& dev_ctx, platform::Place cpu_place, const platform::MKLDNNDeviceContext& dev_ctx, platform::Place cpu_place,
const std::string& unique_name, bool exclude_padding) const std::string& unique_name, bool exclude_padding)
: platform::MKLDNNHandlerT<T, mkldnn::pooling_forward, : platform::MKLDNNHandlerT<T, mkldnn::pooling_forward,
...@@ -586,12 +585,11 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward, ...@@ -586,12 +585,11 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
? mkldnn::algorithm::pooling_avg_exclude_padding ? mkldnn::algorithm::pooling_avg_exclude_padding
: mkldnn::algorithm::pooling_avg_include_padding), : mkldnn::algorithm::pooling_avg_include_padding),
diff_src_md, diff_dst_md, strides, ksize, mkldnn_paddings[0], diff_src_md, diff_dst_md, strides, ksize, mkldnn_paddings[0],
mkldnn_paddings[1], mkldnn::padding_kind::zero); mkldnn_paddings[1]);
} }
std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(void) { std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(void) {
mkldnn::memory::primitive_desc workspace_mpd = mkldnn::memory::desc workspace_md = this->fwd_pd_->workspace_desc();
this->fwd_pd_->workspace_primitive_desc();
// Pooling PD has to be passed to Grad op that // Pooling PD has to be passed to Grad op that
// may be executed by diffrent thread, hence // may be executed by diffrent thread, hence
// for that one we use key that does not contain TID // for that one we use key that does not contain TID
...@@ -605,7 +603,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward, ...@@ -605,7 +603,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
mem_p = std::static_pointer_cast<mkldnn::memory>( mem_p = std::static_pointer_cast<mkldnn::memory>(
this->dev_ctx_.GetBlob(local_key)); this->dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) { if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(workspace_mpd); mem_p = std::make_shared<mkldnn::memory>(workspace_md, this->engine_);
this->dev_ctx_.SetBlob(local_key, mem_p); this->dev_ctx_.SetBlob(local_key, mem_p);
} }
} }
...@@ -619,10 +617,10 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward, ...@@ -619,10 +617,10 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
} }
static inline void CorrectOutputSize( static inline void CorrectOutputSize(
const std::vector<int>& src_tz, const std::vector<int>& dst_tz, const std::vector<int64_t>& src_tz, const std::vector<int64_t>& dst_tz,
const std::vector<int>& kernel_size, const std::vector<int>& paddings, const std::vector<int64_t>& kernel_size,
const std::vector<int>& strides, const std::vector<int64_t>& paddings, const std::vector<int64_t>& strides,
std::vector<int>& right_bot_padding) { // NOLINT std::vector<int64_t>& right_bot_padding) { // NOLINT
for (size_t i = 0; i < right_bot_padding.size(); i++) { for (size_t i = 0; i < right_bot_padding.size(); i++) {
int desired_size = ComputeCeiledOutput(src_tz[i + 2], kernel_size[i], int desired_size = ComputeCeiledOutput(src_tz[i + 2], kernel_size[i],
paddings[i], strides[i]); paddings[i], strides[i]);
...@@ -636,8 +634,8 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward, ...@@ -636,8 +634,8 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
template <typename T> template <typename T>
class TransposeMKLDNNHandler : public MKLDNNHandler { class TransposeMKLDNNHandler : public MKLDNNHandler {
public: public:
TransposeMKLDNNHandler(std::vector<int>& dims, // NOLINT TransposeMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
std::vector<int>& axis, // NOLINT std::vector<int>& axis, // NOLINT
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key) mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key), : platform::MKLDNNHandler(dev_ctx, engine, base_key),
...@@ -657,12 +655,11 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -657,12 +655,11 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
logical_axis_[i] = i; logical_axis_[i] = i;
} }
auto src_md = fmt != mkldnn::memory::format::nchw auto src_md = fmt != MKLDNNMemoryFormat::nchw
? platform::MKLDNNMemDesc( ? platform::MKLDNNMemDesc(
dims_, platform::MKLDNNGetDataType<T>(), fmt) dims_, platform::MKLDNNGetDataType<T>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_); : Axis2MemoryDesc(dims_, logical_axis_);
mem_p = std::make_shared<mkldnn::memory>( mem_p = std::make_shared<mkldnn::memory>(src_md, engine_, ptr);
mkldnn::memory::primitive_desc{src_md, engine_}, ptr);
dev_ctx_.SetBlob(local_key, mem_p); dev_ctx_.SetBlob(local_key, mem_p);
} else { } else {
mem_p->set_data_handle(ptr); mem_p->set_data_handle(ptr);
...@@ -676,12 +673,11 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -676,12 +673,11 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
auto mem_p = auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) { if (mem_p == nullptr) {
auto dst_mdp = mkldnn::memory::primitive_desc{ auto dst_md = Axis2MemoryDesc(dims_, axis_);
Axis2MemoryDesc(dims_, axis_), engine_};
auto dst_data = output->mutable_data<T>(place, dst_mdp.get_size()); auto dst_data = output->mutable_data<T>(place, dst_md.get_size());
mem_p = std::make_shared<mkldnn::memory>(dst_mdp, dst_data); mem_p = std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data);
dev_ctx_.SetBlob(local_key, mem_p); dev_ctx_.SetBlob(local_key, mem_p);
} else { } else {
auto dst_data = output->mutable_data<T>(place); auto dst_data = output->mutable_data<T>(place);
...@@ -705,49 +701,32 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -705,49 +701,32 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
} }
protected: protected:
mkldnn_memory_desc_t Axis2MemoryDesc( mkldnn::memory::desc Axis2MemoryDesc(std::vector<int64_t>& nchw_tz, // NOLINT
const std::vector<int>& nchw_tz, // NOLINT std::vector<int>& axis // NOLINT
const std::vector<int>& axis) { ) {
mkldnn_memory_desc_t mem_fmt; size_t ndims = axis.size();
mem_fmt.primitive_kind = mkldnn_memory;
mem_fmt.ndims = axis.size();
for (unsigned int i = 0; i < nchw_tz.size(); ++i) {
mem_fmt.dims[i] = nchw_tz[i]; // logical dimensions (nchw format,
// regardless physical layout)
}
if (platform::MKLDNNGetDataType<T>() == mkldnn::memory::data_type::s8)
mem_fmt.data_type = mkldnn_s8;
else if (platform::MKLDNNGetDataType<T>() == mkldnn::memory::data_type::u8)
mem_fmt.data_type = mkldnn_u8;
else
mem_fmt.data_type = mkldnn_f32;
mem_fmt.format = mkldnn_blocked;
std::vector<int64_t> strides(ndims);
unsigned int total_stride = 1; unsigned int total_stride = 1;
for (int i = nchw_tz.size() - 1; i >= 0; --i) { for (int i = ndims - 1; i >= 0; --i) {
mem_fmt.layout_desc.blocking.padding_dims[i] = strides[axis[i]] = total_stride;
nchw_tz[i]; // logical dimensions (nchw format, regardless physical
// layout)
mem_fmt.layout_desc.blocking.block_dims[i] = 1;
mem_fmt.layout_desc.blocking.offset_padding_to_data[i] = 0; // no offset
mem_fmt.layout_desc.blocking.strides[0][axis[i]] = total_stride;
mem_fmt.layout_desc.blocking.strides[1][axis[i]] = 1;
total_stride *= nchw_tz[axis[i]]; total_stride *= nchw_tz[axis[i]];
} }
mem_fmt.layout_desc.blocking.offset_padding = 0; // no initial offset mkldnn::memory::desc mem_d(nchw_tz, platform::MKLDNNGetDataType<T>(),
return mem_fmt; strides);
return mem_d;
} }
private: private:
std::vector<int> dims_; std::vector<int64_t> dims_;
std::vector<int> axis_; std::vector<int> axis_;
std::vector<int> logical_axis_; std::vector<int> logical_axis_;
}; };
class ReorderMKLDNNHandler : public MKLDNNHandler { class ReorderMKLDNNHandler : public MKLDNNHandler {
public: public:
ReorderMKLDNNHandler(std::vector<int>& dims, // NOLINT ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
framework::proto::VarType::Type vtype, framework::proto::VarType::Type vtype,
mkldnn::memory::data_type dtype, mkldnn::memory::data_type dtype,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
...@@ -770,11 +749,10 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { ...@@ -770,11 +749,10 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) { if (mem_p == nullptr) {
auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt); auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt);
auto dst_mdp = mkldnn::memory::primitive_desc{dst_md, engine_};
auto dst_data = output->mutable_data(place, vtype_); auto dst_data = output->mutable_data(place, vtype_);
mem_p = std::make_shared<mkldnn::memory>(dst_mdp, dst_data); mem_p = std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data);
dev_ctx_.SetBlob(local_key, mem_p); dev_ctx_.SetBlob(local_key, mem_p);
} else { } else {
auto dst_data = output->mutable_data(place, vtype_); auto dst_data = output->mutable_data(place, vtype_);
...@@ -798,7 +776,7 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { ...@@ -798,7 +776,7 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
} }
private: private:
std::vector<int> dims_; std::vector<int64_t> dims_;
framework::proto::VarType::Type vtype_; framework::proto::VarType::Type vtype_;
mkldnn::memory::data_type dtype_; mkldnn::memory::data_type dtype_;
}; };
...@@ -850,28 +828,25 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -850,28 +828,25 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
key_ += "-BWD"; key_ += "-BWD";
} }
size_t GetDstMemorySize() const { size_t GetDstMemorySize() const { return conv_pd_->dst_desc().get_size(); }
return conv_pd_->dst_primitive_desc().get_size();
}
MKLDNNMemoryFormat GetDstFormat() const { MKLDNNMemoryFormat GetDstFormat() const {
return static_cast<MKLDNNMemoryFormat>( return paddle::platform::GetMKLDNNFormat(conv_pd_->dst_desc());
conv_pd_->dst_primitive_desc().desc().data.format);
} }
size_t GetDiffWeightsMemorySize() const { size_t GetDiffWeightsMemorySize() const {
return conv_bwd_weights_pd_->diff_weights_primitive_desc().get_size(); return conv_bwd_weights_pd_->diff_weights_desc().get_size();
} }
size_t GetDiffSourceMemorySize() const { size_t GetDiffSourceMemorySize() const {
return conv_bwd_data_pd_->diff_src_primitive_desc().get_size(); return conv_bwd_data_pd_->diff_src_desc().get_size();
} }
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromWeightsPrimitive( std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromWeightsPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p, const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto src_pd = conv_bwd_weights_pd_->src_primitive_desc(); auto src_pd = conv_bwd_weights_pd_->src_desc();
auto user_pd = user_memory_p->get_primitive_desc(); auto user_pd = user_memory_p->get_desc();
return this->AcquireMemory(src_pd, user_pd, user_memory_p, return this->AcquireMemory(src_pd, user_pd, user_memory_p,
"@weights-src_mem_p", pipeline); "@weights-src_mem_p", pipeline);
} }
...@@ -879,8 +854,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -879,8 +854,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemoryFromWeightsPrimitive( std::shared_ptr<mkldnn::memory> AcquireDiffDstMemoryFromWeightsPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p, const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto diff_dst_pd = conv_bwd_weights_pd_->diff_dst_primitive_desc(); auto diff_dst_pd = conv_bwd_weights_pd_->diff_dst_desc();
auto user_pd = user_memory_p->get_primitive_desc(); auto user_pd = user_memory_p->get_desc();
return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p, return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p,
"@weights-diff_dst_mem_p", pipeline); "@weights-diff_dst_mem_p", pipeline);
} }
...@@ -888,15 +863,14 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -888,15 +863,14 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemoryFromWeightsPrimitive( std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemoryFromWeightsPrimitive(
void* ptr) { void* ptr) {
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(
conv_bwd_weights_pd_->diff_weights_primitive_desc(), ptr, conv_bwd_weights_pd_->diff_weights_desc(), ptr, "@diff_weights_mem_p");
"@diff_weights_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemoryFromDataPrimitive( std::shared_ptr<mkldnn::memory> AcquireDiffDstMemoryFromDataPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p, const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto diff_dst_pd = conv_bwd_data_pd_->diff_dst_primitive_desc(); auto diff_dst_pd = conv_bwd_data_pd_->diff_dst_desc();
auto user_pd = user_memory_p->get_primitive_desc(); auto user_pd = user_memory_p->get_desc();
return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p, return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p,
"@data-diff_dst_mem_p", pipeline); "@data-diff_dst_mem_p", pipeline);
} }
...@@ -904,8 +878,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -904,8 +878,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromDataPrimitive( std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromDataPrimitive(
const std::shared_ptr<mkldnn::memory> user_weights_memory_p, const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto weights_pd = conv_bwd_data_pd_->weights_primitive_desc(); auto weights_pd = conv_bwd_data_pd_->weights_desc();
auto user_pd = user_weights_memory_p->get_primitive_desc(); auto user_pd = user_weights_memory_p->get_desc();
return this->AcquireMemory(weights_pd, user_pd, user_weights_memory_p, return this->AcquireMemory(weights_pd, user_pd, user_weights_memory_p,
"@data-weights_mem_p", pipeline); "@data-weights_mem_p", pipeline);
} }
...@@ -926,20 +900,20 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -926,20 +900,20 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemoryFromDataPrimitive( std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemoryFromDataPrimitive(
void* ptr) { void* ptr) {
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(conv_bwd_data_pd_->diff_src_desc(),
conv_bwd_data_pd_->diff_src_primitive_desc(), ptr, "@diff_src_mem_p"); ptr, "@diff_src_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(void* ptr) { std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(void* ptr) {
return this->AcquireMemoryFromPrimitive(conv_pd_->dst_primitive_desc(), ptr, return this->AcquireMemoryFromPrimitive(conv_pd_->dst_desc(), ptr,
"@dst_mem_p"); "@dst_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromPrimitive( std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p, const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto src_pd = conv_pd_->src_primitive_desc(); auto src_pd = conv_pd_->src_desc();
auto user_pd = user_memory_p->get_primitive_desc(); auto user_pd = user_memory_p->get_desc();
return this->AcquireMemory(src_pd, user_pd, user_memory_p, "@src_mem_p", return this->AcquireMemory(src_pd, user_pd, user_memory_p, "@src_mem_p",
pipeline); pipeline);
} }
...@@ -960,8 +934,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -960,8 +934,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
std::vector<mkldnn::primitive>& pipeline, // NOLINT std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false, bool is_INT8 = false, bool is_persistent = false, bool is_INT8 = false,
std::vector<float> scale_data = {1.0f}, int mask = 0) { std::vector<float> scale_data = {1.0f}, int mask = 0) {
auto user_weights_pd = user_weights_memory_p->get_primitive_desc(); auto user_weights_pd = user_weights_memory_p->get_desc();
auto weights_pd = conv_pd_->weights_primitive_desc(); auto weights_pd = conv_pd_->weights_desc();
return this->AcquireMemory( return this->AcquireMemory(
weights_pd, user_weights_pd, user_weights_memory_p, "@weights_mem_p", weights_pd, user_weights_pd, user_weights_memory_p, "@weights_mem_p",
pipeline, is_persistent, is_INT8, scale_data, mask); pipeline, is_persistent, is_INT8, scale_data, mask);
...@@ -973,8 +947,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -973,8 +947,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
bool is_persistent = false, bool is_INT8 = false, bool is_persistent = false, bool is_INT8 = false,
std::vector<float> scale_data = {1.0f}, std::vector<float> scale_data = {1.0f},
int mask = 0) { // NOLINT int mask = 0) { // NOLINT
auto user_bias_pd = user_bias_memory_p->get_primitive_desc(); auto user_bias_pd = user_bias_memory_p->get_desc();
auto bias_pd = conv_pd_->bias_primitive_desc(); auto bias_pd = conv_pd_->bias_desc();
return this->AcquireMemory(bias_pd, user_bias_pd, user_bias_memory_p, return this->AcquireMemory(bias_pd, user_bias_pd, user_bias_memory_p,
"@bias_mem_p", pipeline, is_persistent, is_INT8, "@bias_mem_p", pipeline, is_persistent, is_INT8,
scale_data, mask); scale_data, mask);
...@@ -1020,8 +994,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -1020,8 +994,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
AcquireConvolutionPrimitiveDescriptor( AcquireConvolutionPrimitiveDescriptor(
const mkldnn::memory::desc& src, const mkldnn::memory::desc& weights, const mkldnn::memory::desc& src, const mkldnn::memory::desc& weights,
boost::optional<const mkldnn::memory::desc&> bias, boost::optional<const mkldnn::memory::desc&> bias,
const mkldnn::memory::desc& dst, const std::vector<int>& strides, const mkldnn::memory::desc& dst, const std::vector<int64_t>& strides,
const std::vector<int>& paddings, const mkldnn::engine& engine, const std::vector<int64_t>& paddings, const mkldnn::engine& engine,
const std::string& fuse_activation, float fuse_alpha, float fuse_beta, const std::string& fuse_activation, float fuse_alpha, float fuse_beta,
const bool fuse_residual_conn, mkldnn::prop_kind fwd_prop_kind, const bool fuse_residual_conn, mkldnn::prop_kind fwd_prop_kind,
const std::vector<float> output_shift_scale = {}, const std::vector<float> output_shift_scale = {},
...@@ -1047,15 +1021,14 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -1047,15 +1021,14 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
auto mkldnn_paddings = ToMkldnnPadding(paddings); auto mkldnn_paddings = ToMkldnnPadding(paddings);
auto conv_desc = auto conv_desc =
bias bias ? typename forward_t::desc(
? typename forward_t::desc( fwd_prop_kind, convolutional_algorithm<forward_t>::T,
fwd_prop_kind, convolutional_algorithm<forward_t>::T, src, src, weights, *bias, dst, stride_dims,
weights, *bias, dst, stride_dims, mkldnn_paddings[0], mkldnn_paddings[0], mkldnn_paddings[1])
mkldnn_paddings[1], mkldnn::padding_kind::zero) : typename forward_t::desc(
: typename forward_t::desc( fwd_prop_kind, convolutional_algorithm<forward_t>::T,
fwd_prop_kind, convolutional_algorithm<forward_t>::T, src, src, weights, dst, stride_dims, mkldnn_paddings[0],
weights, dst, stride_dims, mkldnn_paddings[0], mkldnn_paddings[1]);
mkldnn_paddings[1], mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_activation, fuse_alpha, fuse_beta, CreatePostOps(fuse_activation, fuse_alpha, fuse_beta,
...@@ -1071,68 +1044,37 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -1071,68 +1044,37 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
return conv_pd_; return conv_pd_;
} }
std::shared_ptr<forward_t> AcquireConvolution( std::shared_ptr<forward_t> AcquireConvolution() {
std::shared_ptr<mkldnn::memory> src_memory_p,
std::shared_ptr<mkldnn::memory> weights_memory_p,
std::shared_ptr<mkldnn::memory> dst_memory_p) {
auto prim_key = key_ + "@conv_p";
auto conv_p =
std::static_pointer_cast<forward_t>(dev_ctx_.GetBlob(prim_key));
if (conv_p == nullptr) {
conv_p = std::make_shared<forward_t>(*conv_pd_, *src_memory_p,
*weights_memory_p, *dst_memory_p);
dev_ctx_.SetBlob(prim_key, conv_p);
}
return conv_p;
}
std::shared_ptr<forward_t> AcquireConvolution(
std::shared_ptr<mkldnn::memory> src_memory_p,
std::shared_ptr<mkldnn::memory> weights_memory_p,
std::shared_ptr<mkldnn::memory> bias_memory_p,
std::shared_ptr<mkldnn::memory> dst_memory_p) {
auto prim_key = key_ + "@conv_p"; auto prim_key = key_ + "@conv_p";
auto conv_p = auto conv_p =
std::static_pointer_cast<forward_t>(dev_ctx_.GetBlob(prim_key)); std::static_pointer_cast<forward_t>(dev_ctx_.GetBlob(prim_key));
if (conv_p == nullptr) { if (conv_p == nullptr) {
conv_p = std::make_shared<forward_t>(*conv_pd_, *src_memory_p, conv_p = std::make_shared<forward_t>(*conv_pd_);
*weights_memory_p, *bias_memory_p,
*dst_memory_p);
dev_ctx_.SetBlob(prim_key, conv_p); dev_ctx_.SetBlob(prim_key, conv_p);
} }
return conv_p; return conv_p;
} }
std::shared_ptr<backward_weights_t> AcquireConvolutionBackwardWeights( std::shared_ptr<backward_weights_t> AcquireConvolutionBackwardWeights() {
std::shared_ptr<mkldnn::memory> src_memory_p,
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
std::shared_ptr<mkldnn::memory> diff_weights_memory_p) {
auto prim_key = key_ + "@conv_bwd_weights_p"; auto prim_key = key_ + "@conv_bwd_weights_p";
auto conv_bwd_weights_p = std::static_pointer_cast<backward_weights_t>( auto conv_bwd_weights_p = std::static_pointer_cast<backward_weights_t>(
dev_ctx_.GetBlob(prim_key)); dev_ctx_.GetBlob(prim_key));
if (conv_bwd_weights_p == nullptr) { if (conv_bwd_weights_p == nullptr) {
// create backward conv primitive for weights // create backward conv primitive for weights
conv_bwd_weights_p = std::make_shared<backward_weights_t>( conv_bwd_weights_p =
*conv_bwd_weights_pd_, *src_memory_p, *diff_dst_memory_p, std::make_shared<backward_weights_t>(*conv_bwd_weights_pd_);
*diff_weights_memory_p);
dev_ctx_.SetBlob(prim_key, conv_bwd_weights_p); dev_ctx_.SetBlob(prim_key, conv_bwd_weights_p);
} }
return conv_bwd_weights_p; return conv_bwd_weights_p;
} }
std::shared_ptr<backward_data_t> AcquireConvolutionBackwardData( std::shared_ptr<backward_data_t> AcquireConvolutionBackwardData() {
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
std::shared_ptr<mkldnn::memory> weights_memory_p,
std::shared_ptr<mkldnn::memory> diff_src_memory_p) {
auto prim_key = key_ + "@conv_bwd_data_p"; auto prim_key = key_ + "@conv_bwd_data_p";
auto conv_bwd_data_p = auto conv_bwd_data_p =
std::static_pointer_cast<backward_data_t>(dev_ctx_.GetBlob(prim_key)); std::static_pointer_cast<backward_data_t>(dev_ctx_.GetBlob(prim_key));
if (conv_bwd_data_p == nullptr) { if (conv_bwd_data_p == nullptr) {
conv_bwd_data_p = std::make_shared<backward_data_t>( conv_bwd_data_p = std::make_shared<backward_data_t>(*conv_bwd_data_pd_);
*conv_bwd_data_pd_, *diff_dst_memory_p, *weights_memory_p,
*diff_src_memory_p);
dev_ctx_.SetBlob(prim_key, conv_bwd_data_p); dev_ctx_.SetBlob(prim_key, conv_bwd_data_p);
} }
return conv_bwd_data_p; return conv_bwd_data_p;
...@@ -1199,9 +1141,9 @@ static void SetDstMemoryHandler( ...@@ -1199,9 +1141,9 @@ static void SetDstMemoryHandler(
template <typename T> template <typename T>
static void SetDstMemoryQuantized( static void SetDstMemoryQuantized(
const framework::ExecutionContext& ctx, framework::Tensor* output, const framework::ExecutionContext& ctx, framework::Tensor* output,
std::vector<int> dst_tz, const mkldnn::engine& engine, std::vector<int64_t> dst_tz, const mkldnn::engine& engine,
std::shared_ptr<mkldnn::memory::primitive_desc>& dst_pd, // NOLINT std::shared_ptr<mkldnn::memory::desc>& dst_md, // NOLINT
std::shared_ptr<mkldnn::memory>& dst_memory) { // NOLINT std::shared_ptr<mkldnn::memory>& dst_memory) { // NOLINT
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
const size_t dst_dims = dst_tz.size(); const size_t dst_dims = dst_tz.size();
MKLDNNMemoryFormat dst_fmt; MKLDNNMemoryFormat dst_fmt;
...@@ -1209,12 +1151,13 @@ static void SetDstMemoryQuantized( ...@@ -1209,12 +1151,13 @@ static void SetDstMemoryQuantized(
"Dst memory for quantization can not have dims > 5"); "Dst memory for quantization can not have dims > 5");
dst_fmt = platform::MKLDNNFormatForSize(dst_dims, MKLDNNMemoryFormat::nhwc); dst_fmt = platform::MKLDNNFormatForSize(dst_dims, MKLDNNMemoryFormat::nhwc);
auto dst_md = platform::MKLDNNMemDesc( auto tmp_dst_md = platform::MKLDNNMemDesc(
{dst_tz}, paddle::framework::ToMKLDNNDataType( {dst_tz}, paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<T>::DataType()), framework::DataTypeTrait<T>::DataType()),
dst_fmt); dst_fmt);
dst_pd.reset(new mkldnn::memory::primitive_desc(dst_md, engine)); dst_md.reset(new mkldnn::memory::desc(tmp_dst_md));
dst_memory.reset(new mkldnn::memory(*dst_pd, to_void_cast<T>(output_data))); dst_memory.reset(
new mkldnn::memory(*dst_md, engine, to_void_cast<T>(output_data)));
} }
} // namespace platform } // namespace platform
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <type_traits>
#include <typeindex> #include <typeindex>
namespace paddle { namespace paddle {
...@@ -24,13 +25,20 @@ inline std::ostream& operator<<(std::ostream& s, const std::type_index& t) { ...@@ -24,13 +25,20 @@ inline std::ostream& operator<<(std::ostream& s, const std::type_index& t) {
return s; return s;
} }
template <typename T> template <typename T,
typename std::enable_if<!std::is_enum<T>::value, int>::type = 0>
inline std::string to_string(T v) { inline std::string to_string(T v) {
std::ostringstream sout; std::ostringstream sout;
sout << v; sout << v;
return sout.str(); return sout.str();
} }
template <typename T,
typename std::enable_if<std::is_enum<T>::value, int>::type = 0>
inline std::string to_string(T v) {
return std::to_string(static_cast<int>(v));
}
template <> template <>
inline std::string to_string(std::type_index t) { inline std::string to_string(std::type_index t) {
return t.name(); return t.name();
......
...@@ -203,14 +203,16 @@ if '${WITH_MKLDNN}' == 'ON': ...@@ -203,14 +203,16 @@ if '${WITH_MKLDNN}' == 'ON':
# TODO(typhoonzero): use install_name_tool to patch mkl libs once # TODO(typhoonzero): use install_name_tool to patch mkl libs once
# we can support mkl on mac. # we can support mkl on mac.
# #
# change rpath of libmkldnn.so.0, add $ORIGIN/ to it. # change rpath of libmkldnn.so.1, add $ORIGIN/ to it.
# The reason is that all thirdparty libraries in the same directory, # The reason is that all thirdparty libraries in the same directory,
# thus, libmkldnn.so.0 will find libmklml_intel.so and libiomp5.so. # thus, libmkldnn.so.1 will find libmklml_intel.so and libiomp5.so.
command = "patchelf --set-rpath '$ORIGIN/' ${MKLDNN_SHARED_LIB}" command = "patchelf --set-rpath '$ORIGIN/' ${MKLDNN_SHARED_LIB}"
if os.system(command) != 0: if os.system(command) != 0:
raise Exception("patch libmkldnn.so failed, command: %s" % command) raise Exception("patch libmkldnn.so failed, command: %s" % command)
package_data['paddle.libs']+=['libmkldnn.so.0' if os.name != 'nt' else ('mkldnn' + ext_name)] package_data['paddle.libs']+=['libmkldnn.so.0','libmkldnn.so.1' if os.name != 'nt' else ('mkldnn' + ext_name)]
shutil.copy('${MKLDNN_SHARED_LIB}', libs_path) shutil.copy('${MKLDNN_SHARED_LIB}', libs_path)
if os.name != 'nt':
shutil.copy('${MKLDNN_SHARED_LIB_1}', libs_path)
if '${WITH_NGRAPH}' == 'ON': if '${WITH_NGRAPH}' == 'ON':
# only change rpath in Release mode, # only change rpath in Release mode,
# since in Debug mode, nGraph lib may be too large to be changed? # since in Debug mode, nGraph lib may be too large to be changed?
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册