提交 92bf6782 编写于 作者: W Wojciech Uss

use oneDNN GRU INT8 optimized version

use faster data format
上级 1cab4dcd
...@@ -20,7 +20,7 @@ SET(MKLDNN_SOURCE_DIR ${THIRD_PARTY_PATH}/mkldnn/src/extern_mkldnn) ...@@ -20,7 +20,7 @@ SET(MKLDNN_SOURCE_DIR ${THIRD_PARTY_PATH}/mkldnn/src/extern_mkldnn)
SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn) SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn)
SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE) SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE)
SET(MKLDNN_REPOSITORY https://github.com/oneapi-src/oneDNN.git) SET(MKLDNN_REPOSITORY https://github.com/oneapi-src/oneDNN.git)
SET(MKLDNN_TAG 64a48f9565aa72f6359917b3406328075a409939) SET(MKLDNN_TAG 361725600224f41b7347a1c6bee9b04d1e6c14d7)
# Introduce variables: # Introduce variables:
# * CMAKE_INSTALL_LIBDIR # * CMAKE_INSTALL_LIBDIR
......
...@@ -95,7 +95,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> { ...@@ -95,7 +95,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
// Create memory descriptors // Create memory descriptors
auto input_md = MKLDNNMemDesc({Ti, N, IC}, MKLDNNGetDataType<T>(), auto input_md = MKLDNNMemDesc({Ti, N, IC}, MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any); MKLDNNMemoryFormat::ntc);
auto weight_x_md = auto weight_x_md =
MKLDNNMemDesc({L, D, IC, G, OC}, weights_dt, MKLDNNMemoryFormat::any); MKLDNNMemDesc({L, D, IC, G, OC}, weights_dt, MKLDNNMemoryFormat::any);
auto weight_h_md = auto weight_h_md =
...@@ -103,7 +103,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> { ...@@ -103,7 +103,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
auto bias_md = MKLDNNMemDesc({L, D, G, OC}, MKLDNNGetDataType<float>(), auto bias_md = MKLDNNMemDesc({L, D, G, OC}, MKLDNNGetDataType<float>(),
MKLDNNMemoryFormat::ldgo); MKLDNNMemoryFormat::ldgo);
auto hidden_md = MKLDNNMemDesc({Ti, N, OC}, MKLDNNGetDataType<T_out>(), auto hidden_md = MKLDNNMemDesc({Ti, N, OC}, MKLDNNGetDataType<T_out>(),
MKLDNNMemoryFormat::any); MKLDNNMemoryFormat::ntc);
auto h0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType<T>(), auto h0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::ldnc); MKLDNNMemoryFormat::ldnc);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册