未验证 提交 cc3f4b81 编写于 作者: A Adam 提交者: GitHub

Add int8 GRU kernel (#27220)

* Add int8 GRU kernel with UTs

* Lint fixes

* More lint fixes
上级 255e0cf9
...@@ -19,8 +19,8 @@ SET(MKLDNN_PREFIX_DIR ${THIRD_PARTY_PATH}/mkldnn) ...@@ -19,8 +19,8 @@ SET(MKLDNN_PREFIX_DIR ${THIRD_PARTY_PATH}/mkldnn)
SET(MKLDNN_SOURCE_DIR ${THIRD_PARTY_PATH}/mkldnn/src/extern_mkldnn) SET(MKLDNN_SOURCE_DIR ${THIRD_PARTY_PATH}/mkldnn/src/extern_mkldnn)
SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn) SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn)
SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE) SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE)
SET(MKLDNN_REPOSITORY https://github.com/intel/mkl-dnn.git) SET(MKLDNN_REPOSITORY https://github.com/oneapi-src/oneDNN.git)
SET(MKLDNN_TAG 4c05c181b40cf7132f8943411fb3fab1786df0f7) SET(MKLDNN_TAG 64a48f9565aa72f6359917b3406328075a409939)
# Introduce variables: # Introduce variables:
# * CMAKE_INSTALL_LIBDIR # * CMAKE_INSTALL_LIBDIR
......
...@@ -206,6 +206,27 @@ void FusionGRUOpMaker::Make() { ...@@ -206,6 +206,27 @@ void FusionGRUOpMaker::Make() {
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "int8", "bfloat16"});
AddAttr<float>("Scale_data",
"Scale to be used for int8 input/output data."
"Only used with MKL-DNN INT8.")
.SetDefault(1.0f);
AddAttr<float>("Shift_data",
"Shift to be used for int8 input/output data."
"Only used with MKL-DNN INT8.")
.SetDefault(0.0f);
AddAttr<std::vector<float>>("Scale_weights",
"Scale_weights to be used for int8 weights data."
"Only used with MKL-DNN INT8.")
.SetDefault({1.0f});
AddAttr<bool>("force_fp32_output",
"(bool, default false) Force INT8 kernel output FP32, only "
"used in MKL-DNN INT8")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
The Fusion complete GRU Operator. The Fusion complete GRU Operator.
This operator fuse the fully-connected operator into GRU, This operator fuse the fully-connected operator into GRU,
......
...@@ -21,11 +21,12 @@ namespace operators { ...@@ -21,11 +21,12 @@ namespace operators {
using paddle::framework::LoDTensor; using paddle::framework::LoDTensor;
using paddle::framework::Tensor; using paddle::framework::Tensor;
using paddle::platform::CPUDeviceContext; using paddle::platform::CPUDeviceContext;
using paddle::platform::CreateKey;
using paddle::platform::MKLDNNGetDataType; using paddle::platform::MKLDNNGetDataType;
using paddle::platform::MKLDNNMemDesc; using paddle::platform::MKLDNNMemDesc;
using platform::to_void_cast; using platform::to_void_cast;
template <typename T> template <typename T, typename T_out = T>
class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> { class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
public: public:
GRUMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, GRUMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
...@@ -38,7 +39,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> { ...@@ -38,7 +39,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
const std::string& unique_name) const std::string& unique_name)
: platform::MKLDNNHandlerT<T, dnnl::gru_forward>( : platform::MKLDNNHandlerT<T, dnnl::gru_forward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(unique_name, Ti)), CreateKey(unique_name, MKLDNNGetDataType<T>(), Ti)),
N(N), N(N),
Ti(Ti), Ti(Ti),
IC(IC), IC(IC),
...@@ -47,9 +48,29 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> { ...@@ -47,9 +48,29 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
// do not depend on Ti size but primitive and input/output memory do // do not depend on Ti size but primitive and input/output memory do
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() !=
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) {
memory_key_ = unique_name; memory_key_ = CreateKey(unique_name, MKLDNNGetDataType<T>());
} else { } else {
memory_key_ = unique_name + "-t:" + platform::ThreadIDasStr(); memory_key_ = CreateKey(unique_name, MKLDNNGetDataType<T>(), "-t:",
platform::ThreadIDasStr());
}
// Is it int8 kernel
const bool is_INT8 = std::is_same<T, uint8_t>::value;
if (is_INT8) {
// Int8 attributes
const float scale_data = ctx.Attr<float>("Scale_data");
const float shift_data = ctx.Attr<float>("Shift_data");
const auto scale_weights = ctx.Attr<std::vector<float>>("Scale_weights");
const int weights_scale_mask =
0 +
(1 << 3) // bit, indicating the unique scales for `g` dim in `ldigo`
+
(1 << 4); // bit, indicating the unique scales for `o` dim in `ldigo`
attr_.set_rnn_data_qparams(scale_data, shift_data);
attr_.set_rnn_weights_qparams(weights_scale_mask, scale_weights);
} }
if (!this->isCached()) { if (!this->isCached()) {
...@@ -63,6 +84,10 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> { ...@@ -63,6 +84,10 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
platform::errors::Unimplemented( platform::errors::Unimplemented(
"oneDNN fusion_gru supports only tanh as an activation.")); "oneDNN fusion_gru supports only tanh as an activation."));
// Weights for int8 kernel are of a type s8
const auto weights_dt =
is_INT8 ? dnnl::memory::data_type::s8 : dnnl::memory::data_type::f32;
// oneDNN RNN dimensions // oneDNN RNN dimensions
const int64_t D = 1; // Directions const int64_t D = 1; // Directions
const int64_t L = 1; // Layers (PP supports only 1 stacked layer) const int64_t L = 1; // Layers (PP supports only 1 stacked layer)
...@@ -71,19 +96,16 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> { ...@@ -71,19 +96,16 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
// Create memory descriptors // Create memory descriptors
auto input_md = MKLDNNMemDesc({Ti, N, IC}, MKLDNNGetDataType<T>(), auto input_md = MKLDNNMemDesc({Ti, N, IC}, MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any); MKLDNNMemoryFormat::any);
auto weight_x_md = MKLDNNMemDesc( auto weight_x_md =
{L, D, IC, G, OC}, MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any); MKLDNNMemDesc({L, D, IC, G, OC}, weights_dt, MKLDNNMemoryFormat::any);
auto weight_h_md = MKLDNNMemDesc( auto weight_h_md =
{L, D, OC, G, OC}, MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any); MKLDNNMemDesc({L, D, OC, G, OC}, weights_dt, MKLDNNMemoryFormat::any);
auto bias_md = MKLDNNMemDesc({L, D, G, OC}, MKLDNNGetDataType<float>(), auto bias_md = MKLDNNMemDesc({L, D, G, OC}, MKLDNNGetDataType<float>(),
MKLDNNMemoryFormat::ldgo); MKLDNNMemoryFormat::ldgo);
auto hidden_md = MKLDNNMemDesc({Ti, N, OC}, MKLDNNGetDataType<T>(), auto hidden_md = MKLDNNMemDesc({Ti, N, OC}, MKLDNNGetDataType<T_out>(),
MKLDNNMemoryFormat::any); MKLDNNMemoryFormat::any);
auto h0_md = dnnl::memory::desc(); auto h0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType<T>(),
if (h0) { MKLDNNMemoryFormat::ldnc);
h0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::ldnc);
}
// Create GRU oneDNN primitive // Create GRU oneDNN primitive
const auto direction = const auto direction =
...@@ -91,7 +113,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> { ...@@ -91,7 +113,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
: dnnl::rnn_direction::unidirectional_left2right; : dnnl::rnn_direction::unidirectional_left2right;
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_inference, direction, input_md, h0_md, attr_, dnnl::prop_kind::forward_inference, direction, input_md, h0_md,
weight_x_md, weight_h_md, bias_md, hidden_md, dnnl::memory::desc()); weight_x_md, weight_h_md, bias_md, hidden_md, dnnl::memory::desc());
} }
} }
...@@ -101,29 +123,31 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> { ...@@ -101,29 +123,31 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
dnnl::memory::format_tag::ntc); dnnl::memory::format_tag::ntc);
} }
void reorderRNNdata(const T* input_data, T* output_data, void reorderRNNdata(void* input_data, void* output_data,
std::vector<size_t> lod, const bool is_reverse, std::vector<size_t> lod, const bool is_reverse,
platform::RNNReorderType reorder_type) { platform::RNNReorderType reorder_type) {
switch (reorder_type) { switch (reorder_type) {
// Reorder input memory [WORDS, C] + LoD -> [N, T, C] // Reorder input memory [WORDS, C] + LoD -> [N, T, C]
case platform::RNNReorderType::PP_NTC: { case platform::RNNReorderType::PP_NTC: {
auto* input_data_iter = input_data; auto* input_data_iter = reinterpret_cast<T*>(input_data);
auto* output_data_iter = reinterpret_cast<T*>(output_data);
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
const auto num_elements = (lod[n + 1] - lod[n]) * IC; const auto num_elements = (lod[n + 1] - lod[n]) * IC;
const auto offset = is_reverse ? (Ti * IC - num_elements) : 0; const auto offset = is_reverse ? (Ti * IC - num_elements) : 0;
memcpy(output_data + n * Ti * IC + offset, input_data_iter, memcpy(output_data_iter + n * Ti * IC + offset, input_data_iter,
sizeof(T) * num_elements); sizeof(T) * num_elements);
input_data_iter += num_elements; input_data_iter += num_elements;
} }
} break; } break;
// Reorder input memory [WORDS, C] + LoD -> [T, N, C] // Reorder input memory [WORDS, C] + LoD -> [T, N, C]
case platform::RNNReorderType::PP_TNC: { case platform::RNNReorderType::PP_TNC: {
auto* input_data_iter = input_data; auto* input_data_iter = reinterpret_cast<T*>(input_data);
auto* output_data_iter = reinterpret_cast<T*>(output_data);
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
const auto num_elements = (lod[n + 1] - lod[n]); const auto num_elements = (lod[n + 1] - lod[n]);
const auto offset = is_reverse ? (Ti - num_elements) : 0; const auto offset = is_reverse ? (Ti - num_elements) : 0;
for (size_t t = 0; t < num_elements; ++t) { for (size_t t = 0; t < num_elements; ++t) {
memcpy(output_data + (t + offset) * N * IC + n * IC, memcpy(output_data_iter + (t + offset) * N * IC + n * IC,
input_data_iter, sizeof(T) * IC); input_data_iter, sizeof(T) * IC);
input_data_iter += IC; input_data_iter += IC;
} }
...@@ -131,24 +155,27 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> { ...@@ -131,24 +155,27 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
} break; } break;
// Reorder output values to PP format [N, T, C] -> [WORDS, C] // Reorder output values to PP format [N, T, C] -> [WORDS, C]
case platform::RNNReorderType::NTC_PP: { case platform::RNNReorderType::NTC_PP: {
auto* output_data_iter = output_data; auto* input_data_iter = reinterpret_cast<T_out*>(input_data);
auto* output_data_iter = reinterpret_cast<T_out*>(output_data);
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
const auto num_elements = (lod[n + 1] - lod[n]) * OC; const auto num_elements = (lod[n + 1] - lod[n]) * OC;
const auto offset = is_reverse ? (Ti * OC - num_elements) : 0; const auto offset = is_reverse ? (Ti * OC - num_elements) : 0;
memcpy(output_data_iter, input_data + n * Ti * OC + offset, memcpy(output_data_iter, input_data_iter + n * Ti * OC + offset,
sizeof(T) * num_elements); sizeof(T_out) * num_elements);
output_data_iter += num_elements; output_data_iter += num_elements;
} }
} break; } break;
// Reorder output values to PP format [T, N, C] -> [WORDS, C] // Reorder output values to PP format [T, N, C] -> [WORDS, C]
case platform::RNNReorderType::TNC_PP: { case platform::RNNReorderType::TNC_PP: {
auto* output_data_iter = output_data; auto* input_data_iter = reinterpret_cast<T_out*>(input_data);
auto* output_data_iter = reinterpret_cast<T_out*>(output_data);
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
const auto num_elements = lod[n + 1] - lod[n]; const auto num_elements = lod[n + 1] - lod[n];
const auto offset = is_reverse ? (Ti - num_elements) : 0; const auto offset = is_reverse ? (Ti - num_elements) : 0;
for (size_t t = 0; t < num_elements; ++t) { for (size_t t = 0; t < num_elements; ++t) {
memcpy(output_data_iter, memcpy(output_data_iter,
input_data + (t + offset) * N * OC + n * OC, sizeof(T) * OC); input_data_iter + (t + offset) * N * OC + n * OC,
sizeof(T_out) * OC);
output_data_iter += OC; output_data_iter += OC;
} }
} }
...@@ -169,9 +196,9 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> { ...@@ -169,9 +196,9 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
} }
const auto& input_lod = input->lod()[0]; const auto& input_lod = input->lod()[0];
auto* x_data = input->data<T>(); auto* x_data = to_void_cast(input->data<T>());
auto* x_onednn_data = reinterpret_cast<T*>(memory_p->get_data_handle()); auto* x_onednn_data = memory_p->get_data_handle();
memset(x_onednn_data, 0, sizeof(T) * N * Ti * IC); memset(x_onednn_data, 0, sizeof(T) * N * Ti * IC);
if (platform::GetMKLDNNFormat(this->fwd_pd_->src_desc()) == if (platform::GetMKLDNNFormat(this->fwd_pd_->src_desc()) ==
...@@ -198,19 +225,35 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> { ...@@ -198,19 +225,35 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
return memory_p; return memory_p;
} }
// TODO(grygielski) H0 is for now persistable
std::shared_ptr<dnnl::memory> AcquireH0Memory(const Tensor* h0) { std::shared_ptr<dnnl::memory> AcquireH0Memory(const Tensor* h0) {
const std::string h0_key = memory_key_ + "@h0"; const std::string h0_key = memory_key_ + "@h0";
auto memory_p = auto memory_p =
std::static_pointer_cast<dnnl::memory>(this->dev_ctx_.GetBlob(h0_key)); std::static_pointer_cast<dnnl::memory>(this->dev_ctx_.GetBlob(h0_key));
auto* h0_data = to_void_cast(h0->data<T>());
if (!memory_p) { if (!memory_p) {
memory_p = std::make_shared<dnnl::memory>( auto user_h0_memory = dnnl::memory();
this->fwd_pd_->weights_layer_desc(), this->engine_, h0_data); if (h0) {
user_h0_memory =
dnnl::memory({{1, 1, N, OC},
MKLDNNGetDataType<float>(),
MKLDNNMemoryFormat::ldnc},
this->engine_, to_void_cast(h0->data<float>()));
} else {
user_h0_memory = dnnl::memory({{1, 1, N, OC},
MKLDNNGetDataType<float>(),
MKLDNNMemoryFormat::ldnc},
this->engine_);
memset(user_h0_memory.get_data_handle(), 0, sizeof(float) * N * OC);
}
memory_p = std::make_shared<dnnl::memory>(this->fwd_pd_->src_iter_desc(),
this->engine_);
dnnl::stream astream(this->engine_);
dnnl::reorder(user_h0_memory, *memory_p, attr_)
.execute(astream, user_h0_memory, *memory_p);
this->dev_ctx_.SetBlob(h0_key, memory_p); this->dev_ctx_.SetBlob(h0_key, memory_p);
} else {
memory_p->set_data_handle(h0_data);
} }
return memory_p; return memory_p;
} }
...@@ -245,7 +288,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> { ...@@ -245,7 +288,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
this->fwd_pd_->weights_layer_desc(), this->engine_); this->fwd_pd_->weights_layer_desc(), this->engine_);
dnnl::stream astream(this->engine_); dnnl::stream astream(this->engine_);
dnnl::reorder(user_memory, *memory_p) dnnl::reorder(user_memory, *memory_p, attr_)
.execute(astream, user_memory, *memory_p); .execute(astream, user_memory, *memory_p);
this->dev_ctx_.SetBlob(wx_key, memory_p); this->dev_ctx_.SetBlob(wx_key, memory_p);
...@@ -298,7 +341,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> { ...@@ -298,7 +341,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
this->fwd_pd_->weights_iter_desc(), this->engine_); this->fwd_pd_->weights_iter_desc(), this->engine_);
dnnl::stream astream(this->engine_); dnnl::stream astream(this->engine_);
dnnl::reorder(user_memory, *memory_p) dnnl::reorder(user_memory, *memory_p, attr_)
.execute(astream, user_memory, *memory_p); .execute(astream, user_memory, *memory_p);
this->dev_ctx_.SetBlob(wh_key, memory_p); this->dev_ctx_.SetBlob(wh_key, memory_p);
...@@ -347,12 +390,26 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> { ...@@ -347,12 +390,26 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
// Memory size of weights, bias and h0 does not depend // Memory size of weights, bias and h0 does not depend
// on Ti size, thus we need another key to cache them // on Ti size, thus we need another key to cache them
std::string memory_key_; std::string memory_key_;
dnnl::primitive_attr attr_;
}; };
template <typename T> template <typename T>
class FusionGRUMKLDNNKernel : public framework::OpKernel<T> { class FusionGRUMKLDNNKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const bool is_INT8 = std::is_same<T, uint8_t>::value;
const bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
// TODO(grygielski) Add option for bfloat
if (!is_INT8 || force_fp32_output) {
RunKernel<float>(ctx);
} else {
RunKernel<uint8_t>(ctx);
}
}
template <typename Tout = T>
void RunKernel(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
...@@ -390,12 +447,14 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> { ...@@ -390,12 +447,14 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> {
const int64_t IC = x_mat_dims_vec[1]; // Input channels const int64_t IC = x_mat_dims_vec[1]; // Input channels
const int64_t OC = weight_h_dims[0]; // Output channels const int64_t OC = weight_h_dims[0]; // Output channels
GRUMKLDNNHandler<T> handler(ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), GRUMKLDNNHandler<T, Tout> handler(
input, weight_h, h0, is_reverse, N, Ti, IC, OC, ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, weight_h, h0,
ctx.InputName("X") + ctx.InputName("WeightH")); is_reverse, N, Ti, IC, OC,
ctx.InputName("X") + ctx.InputName("WeightH"));
auto input_memory_p = auto input_memory_p =
handler.AcquireInputMemoryWithReorder(input, is_reverse); handler.AcquireInputMemoryWithReorder(input, is_reverse);
auto h0_memory_p = handler.AcquireH0Memory(h0);
auto weight_x_memory_p = auto weight_x_memory_p =
handler.AcquireWeightXMemory(weight_x, origin_mode); handler.AcquireWeightXMemory(weight_x, origin_mode);
auto weight_h_memory_p = auto weight_h_memory_p =
...@@ -405,25 +464,21 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> { ...@@ -405,25 +464,21 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> {
std::unordered_map<int, dnnl::memory> gru_args = { std::unordered_map<int, dnnl::memory> gru_args = {
{DNNL_ARG_SRC_LAYER, *input_memory_p}, {DNNL_ARG_SRC_LAYER, *input_memory_p},
{DNNL_ARG_SRC_ITER, *h0_memory_p},
{DNNL_ARG_WEIGHTS_LAYER, *weight_x_memory_p}, {DNNL_ARG_WEIGHTS_LAYER, *weight_x_memory_p},
{DNNL_ARG_WEIGHTS_ITER, *weight_h_memory_p}, {DNNL_ARG_WEIGHTS_ITER, *weight_h_memory_p},
{DNNL_ARG_BIAS, *bias_memory_p}, {DNNL_ARG_BIAS, *bias_memory_p},
{DNNL_ARG_DST_LAYER, *hidden_onednn_memory_p}}; {DNNL_ARG_DST_LAYER, *hidden_onednn_memory_p}};
if (h0) {
auto h0_memory_p = handler.AcquireH0Memory(h0);
gru_args.insert({DNNL_ARG_SRC_ITER, *h0_memory_p});
}
auto gru_forward_p = handler.AcquireForwardPrimitive(); auto gru_forward_p = handler.AcquireForwardPrimitive();
dnnl::stream astream(mkldnn_engine); dnnl::stream astream(mkldnn_engine);
gru_forward_p->execute(astream, gru_args); gru_forward_p->execute(astream, gru_args);
astream.wait(); astream.wait();
auto* hidden_onednn_data = auto* hidden_onednn_data = hidden_onednn_memory_p->get_data_handle();
reinterpret_cast<T*>(hidden_onednn_memory_p->get_data_handle()); auto* hidden_data =
auto* hidden_data = hidden->mutable_data<T>(ctx.GetPlace()); to_void_cast(hidden->mutable_data<Tout>(ctx.GetPlace()));
if (handler.is_NTC()) { if (handler.is_NTC()) {
handler.reorderRNNdata(hidden_onednn_data, hidden_data, input_lod, handler.reorderRNNdata(hidden_onednn_data, hidden_data, input_lod,
is_reverse, platform::RNNReorderType::NTC_PP); is_reverse, platform::RNNReorderType::NTC_PP);
...@@ -439,4 +494,5 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> { ...@@ -439,4 +494,5 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(fusion_gru, MKLDNN, paddle::platform::CPUPlace, REGISTER_OP_KERNEL(fusion_gru, MKLDNN, paddle::platform::CPUPlace,
ops::FusionGRUMKLDNNKernel<float>); ops::FusionGRUMKLDNNKernel<float>,
ops::FusionGRUMKLDNNKernel<uint8_t>);
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.test_fusion_gru_op import fusion_gru
from paddle.fluid.tests.unittests.test_fusion_lstm_op import fc, ACTIVATION
class TestFusionGRUINT8MKLDNNOp(OpTest):
def set_confs(self):
pass
def setUp(self):
self.op_type = "fusion_gru"
self.lod = [[2, 4, 3]]
self.IC = 3
self.OC = 5
self.is_reverse = False
self.with_h0 = False
self.with_bias = True
self.act_state = 'tanh'
self.act_gate = 'sigmoid'
self.origin_mode = True
self.use_mkldnn = True
self.force_fp32_output = True
self.error_margin = 1e-5
self.set_confs()
# RNN dimensions
T = sum(self.lod[0])
N = len(self.lod[0])
# Input data
x_f32 = np.random.rand(T, self.IC).astype('float32') * 2 - 1
scale_data = 63
shift_data = 64
x_u8 = (x_f32 * scale_data + shift_data).astype(np.uint8)
# WeightX/WeightH data
wx = np.random.rand(self.IC, 3 * self.OC).astype('float32') * 2 - 1
wh = np.random.rand(self.OC, 3 * self.OC).astype('float32') * 2 - 1
# Calculating weight scales
# scales = 63 / max(abs(channel_wise(weightsX + weightsH)))
# WeightX data shape in PP: [IC, 3 * OC]
# WeightH data shape in PP: [OC, 2 * OC] + [OC, OC]
# Scales shape in oneDNN: [3, OC]
scale_ur = 63 / np.max(np.abs(
np.concatenate(
[
wx[:, :2 * self.OC], wh.flatten()[:2 * self.OC * self.OC]
.reshape(self.OC, 2 * self.OC)
],
axis=0)),
axis=0)
scale_o = 63 / np.max(np.abs(
np.concatenate(
[
wx[:, 2 * self.OC:], wh.flatten()[2 * self.OC * self.OC:]
.reshape(self.OC, self.OC)
],
axis=0)),
axis=0)
scale_weights = np.concatenate([scale_ur, scale_o]).astype('float')
bias = np.random.rand(
1, 3 * self.OC).astype('float32') if self.with_bias else np.zeros(
(1, 3 * self.OC), dtype='float32')
h0 = np.random.rand(
N, self.OC).astype('float32') if self.with_h0 else np.zeros(
(N, self.OC), dtype='float32')
_, _, _, hidden_f32 = fusion_gru(x_f32, self.lod, h0, wx, wh, bias,
self.is_reverse, self.origin_mode,
ACTIVATION[self.act_state],
ACTIVATION[self.act_gate])
self.inputs = {'X': (x_u8, self.lod), 'WeightX': wx, 'WeightH': wh}
if self.with_bias:
self.inputs['Bias'] = bias
if self.with_h0:
self.inputs['H0'] = h0
if self.force_fp32_output:
self.error_margin = 1e-1
self.outputs = {'Hidden': (hidden_f32, self.lod)}
else:
self.error_margin = 1
hidden_u8 = (hidden_f32 * scale_data + shift_data).astype(np.uint8)
self.outputs = {'Hidden': (hidden_u8, self.lod)}
self.attrs = {
'activation': self.act_state,
'gate_activation': self.act_gate,
'is_reverse': self.is_reverse,
'origin_mode': self.origin_mode,
'use_mkldnn': self.use_mkldnn,
'force_fp32_output': self.force_fp32_output,
'Scale_data': scale_data,
'Shift_data': shift_data,
'Scale_weights': scale_weights
}
def test_check_output(self):
self.check_output(check_dygraph=False, atol=self.error_margin)
class TestFusionGRUINT8MKLDNNOp2(TestFusionGRUINT8MKLDNNOp):
def set_confs(self):
self.force_fp32_output = False
class TestFusionGRUINT8MKLDNNOp3(TestFusionGRUINT8MKLDNNOp):
def set_confs(self):
self.origin_mode = False
class TestFusionGRUINT8MKLDNNOp4(TestFusionGRUINT8MKLDNNOp):
def set_confs(self):
self.with_bias = False
class TestFusionGRUINT8MKLDNNOp5(TestFusionGRUINT8MKLDNNOp):
def set_confs(self):
self.with_h0 = False
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册