diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 757da1c829a9c67fadae60320d884141683553dc..0343ff3cc292d97dcc77108735baa69c804468af 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -197,7 +197,7 @@ function(op_library TARGET) "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" -"skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" +"skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op" "fused_bn_add_activation_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 466e016d99db560a8ba790a48bb51f66cbbffd73..95ae807c6ae0444e42a2db69b223df4cc0b899ef 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -14,11 +14,15 @@ register_operators(EXCLUDES fused_embedding_eltwise_layernorm_op fusion_group_op fusion_gru_op + fusion_lstm_op fused_bn_add_activation_op) # fusion_gru_op does not have CUDA kernel op_library(fusion_gru_op) +op_library(fusion_lstm_op) file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(fusion_gru);\n") +file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(fusion_lstm);\n") + if (WITH_GPU) # fused_bn_activation_op needs cudnn 7.4.1 above diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index 65cf4c170ac91823bfef2d3a202f4893a46dba3c..f14a05142512aeb9dead4d406804312e5148a32f 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -18,6 +18,9 @@ limitations under the License. */ #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/math/sequence2batch.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace operators { @@ -145,8 +148,16 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } +#endif + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); } void FusionLSTMOpMaker::Make() { @@ -235,6 +246,9 @@ void FusionLSTMOpMaker::Make() { "`tanh` by default.") .SetDefault("tanh") .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddComment(R"DOC( Fusion Long-Short Term Memory (LSTM) Operator. This operator fuse the X into LSTM, more details can refer to LSTM op. diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc index da811faa41bc765cb65442e2372c30b40458bcfe..a3b59419b7f4c920e0ac15016bb1f2a46d53c483 100644 --- a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc +++ b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/fused/fusion_gru_op.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/operators/fused/mkldnn/fusion_rnn_mkldnn.h" namespace paddle { namespace operators { @@ -27,7 +27,7 @@ using paddle::platform::MKLDNNMemDesc; using platform::to_void_cast; template -class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { +class GRUMKLDNNHandler : public RNNMKLDNNHandler { public: GRUMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, const platform::MKLDNNDeviceContext& dev_ctx, @@ -37,37 +37,12 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { const bool is_reverse, const int64_t N, const int64_t Ti, const int64_t IC, const int64_t OC, const std::string& unique_name) - : platform::MKLDNNHandlerT( - dev_ctx, dev_ctx.GetEngine(), cpu_place, - CreateKey(dev_ctx, unique_name, MKLDNNGetDataType(), Ti)), - N(N), - Ti(Ti), - IC(IC), - OC(OC) { - // Create memory key without Ti because weights, bias and h0 memories - // do not depend on Ti size but primitive and input/output memory do - memory_key_ = platform::ExtendKeyWithThreadInfoIfNeeded( - dev_ctx, CreateKey(dev_ctx, unique_name, MKLDNNGetDataType())); - - // Is it int8 kernel + : RNNMKLDNNHandler( + ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, weight_h, h0, + is_reverse, N, Ti, IC, OC, 3, + ctx.InputName("X") + ctx.InputName("WeightH")) { const bool is_INT8 = std::is_same::value; - if (is_INT8) { - // Int8 attributes - const float scale_data = ctx.Attr("Scale_data"); - const float shift_data = ctx.Attr("Shift_data"); - const auto scale_weights = ctx.Attr>("Scale_weights"); - - const int weights_scale_mask = - 0 + - (1 << 3) // bit, indicating the unique scales for `g` dim in `ldigo` - + - (1 << 4); // bit, indicating the unique scales for `o` dim in `ldigo` - - attr_.set_rnn_data_qparams(scale_data, shift_data); - attr_.set_rnn_weights_qparams(weights_scale_mask, scale_weights); - } - if (!this->isCached()) { // oneDNN kernel has hardcoded activation functions PADDLE_ENFORCE_EQ( @@ -108,176 +83,35 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { : dnnl::rnn_direction::unidirectional_left2right; this->AcquireForwardPrimitiveDescriptor( - attr_, dnnl::prop_kind::forward_inference, direction, input_md, h0_md, - weight_x_md, weight_h_md, bias_md, hidden_md, dnnl::memory::desc()); - } - } - - bool is_NTC() { - return (platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc()) == - dnnl::memory::format_tag::ntc); - } - - void reorderRNNdata(void* input_data, void* output_data, - std::vector lod, const bool is_reverse, - platform::RNNReorderType reorder_type) { - switch (reorder_type) { - // Reorder input memory [WORDS, C] + LoD -> [N, T, C] - case platform::RNNReorderType::PP_NTC: { - auto* input_data_iter = reinterpret_cast(input_data); - auto* output_data_iter = reinterpret_cast(output_data); - for (int n = 0; n < N; ++n) { - const auto num_elements = (lod[n + 1] - lod[n]) * IC; - const auto offset = is_reverse ? (Ti * IC - num_elements) : 0; - memcpy(output_data_iter + n * Ti * IC + offset, input_data_iter, - sizeof(T) * num_elements); - input_data_iter += num_elements; - } - } break; - // Reorder input memory [WORDS, C] + LoD -> [T, N, C] - case platform::RNNReorderType::PP_TNC: { - auto* input_data_iter = reinterpret_cast(input_data); - auto* output_data_iter = reinterpret_cast(output_data); - for (int n = 0; n < N; ++n) { - const auto num_elements = (lod[n + 1] - lod[n]); - const auto offset = is_reverse ? (Ti - num_elements) : 0; - for (size_t t = 0; t < num_elements; ++t) { - memcpy(output_data_iter + (t + offset) * N * IC + n * IC, - input_data_iter, sizeof(T) * IC); - input_data_iter += IC; - } - } - } break; - // Reorder output values to PP format [N, T, C] -> [WORDS, C] - case platform::RNNReorderType::NTC_PP: { - auto* input_data_iter = reinterpret_cast(input_data); - auto* output_data_iter = reinterpret_cast(output_data); - for (int n = 0; n < N; ++n) { - const auto num_elements = (lod[n + 1] - lod[n]) * OC; - const auto offset = is_reverse ? (Ti * OC - num_elements) : 0; - memcpy(output_data_iter, input_data_iter + n * Ti * OC + offset, - sizeof(T_out) * num_elements); - output_data_iter += num_elements; - } - } break; - // Reorder output values to PP format [T, N, C] -> [WORDS, C] - case platform::RNNReorderType::TNC_PP: { - auto* input_data_iter = reinterpret_cast(input_data); - auto* output_data_iter = reinterpret_cast(output_data); - for (int n = 0; n < N; ++n) { - const auto num_elements = lod[n + 1] - lod[n]; - const auto offset = is_reverse ? (Ti - num_elements) : 0; - for (size_t t = 0; t < num_elements; ++t) { - memcpy(output_data_iter, - input_data_iter + (t + offset) * N * OC + n * OC, - sizeof(T_out) * OC); - output_data_iter += OC; - } - } - } break; + this->attr_, dnnl::prop_kind::forward_inference, direction, input_md, + h0_md, weight_x_md, weight_h_md, bias_md, hidden_md, + dnnl::memory::desc()); } } - std::shared_ptr AcquireInputMemoryWithReorder( - const LoDTensor* input, const bool is_reverse) { - const auto name = this->key_ + "@input_mem"; - auto memory_p = - std::static_pointer_cast(this->dev_ctx_.GetBlob(name)); - - if (!memory_p) { - memory_p = std::make_shared(this->fwd_pd_->src_desc(), - this->engine_); - this->dev_ctx_.SetBlob(name, memory_p); - } - - const auto& input_lod = input->lod()[0]; - auto* x_data = to_void_cast(input->data()); - - auto* x_onednn_data = memory_p->get_data_handle(); - memset(x_onednn_data, 0, sizeof(T) * N * Ti * IC); - - if (platform::GetMKLDNNFormat(this->fwd_pd_->src_desc()) == - dnnl::memory::format_tag::ntc) { - reorderRNNdata(x_data, x_onednn_data, input_lod, is_reverse, - platform::RNNReorderType::PP_NTC); - } else { - reorderRNNdata(x_data, x_onednn_data, input_lod, is_reverse, - platform::RNNReorderType::PP_TNC); - } - return memory_p; - } - - std::shared_ptr AcquireOutputMemory() { - const auto name = this->key_ + "@output_mem"; - auto memory_p = - std::static_pointer_cast(this->dev_ctx_.GetBlob(name)); - - if (!memory_p) { - memory_p = std::make_shared(this->fwd_pd_->dst_desc(), - this->engine_); - this->dev_ctx_.SetBlob(name, memory_p); - } - return memory_p; - } - - // TODO(grygielski) H0 is for now persistable - // TODO(jczaja) H0 should be updated each iter and of T type (Fusion pass does - // not support in yet) - std::shared_ptr AcquireH0Memory(const Tensor* h0) { - const std::string h0_key = memory_key_ + "@h0"; - auto memory_p = - std::static_pointer_cast(this->dev_ctx_.GetBlob(h0_key)); - - if (!memory_p) { - auto user_h0_memory = dnnl::memory(); - if (h0) { - user_h0_memory = - dnnl::memory({{1, 1, N, OC}, - MKLDNNGetDataType(), - MKLDNNMemoryFormat::ldnc}, - this->engine_, to_void_cast(h0->data())); - } else { - user_h0_memory = dnnl::memory({{1, 1, N, OC}, - MKLDNNGetDataType(), - MKLDNNMemoryFormat::ldnc}, - this->engine_); - memset(user_h0_memory.get_data_handle(), 0, sizeof(float) * N * OC); - } - memory_p = std::make_shared(this->fwd_pd_->src_iter_desc(), - this->engine_); - - auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); - dnnl::reorder(user_h0_memory, *memory_p, attr_) - .execute(astream, user_h0_memory, *memory_p); - - this->dev_ctx_.SetBlob(h0_key, memory_p); - } - return memory_p; - } - std::shared_ptr AcquireWeightXMemory(const Tensor* weight_x, const bool origin_mode) { - const std::string wx_key = memory_key_ + "@weight_x"; + const std::string wx_key = this->memory_key_ + "@weight_x"; auto memory_p = std::static_pointer_cast(this->dev_ctx_.GetBlob(wx_key)); if (!memory_p) { auto user_md = - MKLDNNMemDesc({1, 1, IC, 3, OC}, MKLDNNGetDataType(), - MKLDNNMemoryFormat::ldigo); + MKLDNNMemDesc({1, 1, this->IC, this->G, this->OC}, + MKLDNNGetDataType(), MKLDNNMemoryFormat::ldigo); auto user_memory = dnnl::memory(user_md, this->engine_); auto* weight_x_data = reinterpret_cast(user_memory.get_data_handle()); memcpy(weight_x_data, weight_x->data(), - sizeof(float) * IC * 3 * OC); + sizeof(float) * this->IC * this->G * this->OC); if (origin_mode == false) { - for (int64_t i = 0; i < IC; ++i) { - for (int64_t j = 0; j < OC; ++j) { + for (int64_t i = 0; i < this->IC; ++i) { + for (int64_t j = 0; j < this->OC; ++j) { weight_x_data[j] *= -1; } - weight_x_data += 3 * OC; + weight_x_data += 3 * this->OC; } } @@ -285,7 +119,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { this->fwd_pd_->weights_layer_desc(), this->engine_); auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); - dnnl::reorder(user_memory, *memory_p, attr_) + dnnl::reorder(user_memory, *memory_p, this->attr_) .execute(astream, user_memory, *memory_p); this->dev_ctx_.SetBlob(wx_key, memory_p); @@ -295,14 +129,14 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { std::shared_ptr AcquireWeightHMemory(const Tensor* weight_h, const bool origin_mode) { - const std::string wh_key = memory_key_ + "@weight_h"; + const std::string wh_key = this->memory_key_ + "@weight_h"; auto memory_p = std::static_pointer_cast(this->dev_ctx_.GetBlob(wh_key)); if (!memory_p) { auto user_md = - MKLDNNMemDesc({1, 1, OC, 3, OC}, MKLDNNGetDataType(), - MKLDNNMemoryFormat::ldigo); + MKLDNNMemDesc({1, 1, this->OC, this->G, this->OC}, + MKLDNNGetDataType(), MKLDNNMemoryFormat::ldigo); auto user_memory = dnnl::memory(user_md, this->engine_); // Reorder weights_h from PP format [OC, 2OC] + [OC, OC] to @@ -312,25 +146,26 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { auto* user_weight_h_data = weight_h->data(); auto src1_iter = user_weight_h_data; - auto src2_iter = user_weight_h_data + 2 * OC * OC; + auto src2_iter = user_weight_h_data + 2 * this->OC * this->OC; - for (int64_t c = 0; c < OC; ++c) { - memcpy(weight_h_data, src1_iter, 2 * OC * sizeof(float)); - memcpy(weight_h_data + 2 * OC, src2_iter, OC * sizeof(float)); + for (int64_t c = 0; c < this->OC; ++c) { + memcpy(weight_h_data, src1_iter, 2 * this->OC * sizeof(float)); + memcpy(weight_h_data + 2 * this->OC, src2_iter, + this->OC * sizeof(float)); - src1_iter += 2 * OC; - src2_iter += OC; - weight_h_data += 3 * OC; + src1_iter += 2 * this->OC; + src2_iter += this->OC; + weight_h_data += 3 * this->OC; } weight_h_data = reinterpret_cast(user_memory.get_data_handle()); if (origin_mode == false) { - for (int64_t i = 0; i < OC; ++i) { - for (int64_t j = 0; j < OC; ++j) { + for (int64_t i = 0; i < this->OC; ++i) { + for (int64_t j = 0; j < this->OC; ++j) { weight_h_data[j] *= -1; } - weight_h_data += 3 * OC; + weight_h_data += 3 * this->OC; } } @@ -338,7 +173,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { this->fwd_pd_->weights_iter_desc(), this->engine_); auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); - dnnl::reorder(user_memory, *memory_p, attr_) + dnnl::reorder(user_memory, *memory_p, this->attr_) .execute(astream, user_memory, *memory_p); this->dev_ctx_.SetBlob(wh_key, memory_p); @@ -348,7 +183,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { std::shared_ptr AcquireBiasMemory(const Tensor* bias, const bool origin_mode) { - const std::string bias_key = memory_key_ + "@bias"; + const std::string bias_key = this->memory_key_ + "@bias"; auto memory_p = std::static_pointer_cast( this->dev_ctx_.GetBlob(bias_key)); @@ -359,15 +194,15 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { if (bias) { const float* user_bias_data = bias->data(); // Bias in oneDNN is always float - memcpy(bias_data, user_bias_data, sizeof(float) * 3 * OC); + memcpy(bias_data, user_bias_data, sizeof(float) * this->G * this->OC); } else { // oneDNN always need bias memory, if it's not provided in PP, let // oneDNN allocate memory and set it to 0 - memset(bias_data, 0, sizeof(float) * 3 * OC); + memset(bias_data, 0, sizeof(float) * this->G * this->OC); } if (origin_mode == false && bias) { - for (int64_t i = 0; i < OC; ++i) { + for (int64_t i = 0; i < this->OC; ++i) { bias_data[i] *= -1; } } @@ -375,19 +210,6 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { } return memory_p; } - - private: - // RNN dimensions - // N - Batch Size - // Ti - Max sentence length - // IC - Input Channels - // OC - Output Channels - const int64_t N, Ti, IC, OC; - - // Memory size of weights, bias and h0 does not depend - // on Ti size, thus we need another key to cache them - std::string memory_key_; - dnnl::primitive_attr attr_; }; template diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_lstm_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/fusion_lstm_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f5ad0644c6aeda6080a155794c99ebf12caf2b1f --- /dev/null +++ b/paddle/fluid/operators/fused/mkldnn/fusion_lstm_mkldnn_op.cc @@ -0,0 +1,377 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fusion_lstm_op.h" +#include "paddle/fluid/operators/fused/mkldnn/fusion_rnn_mkldnn.h" + +namespace paddle { +namespace operators { + +using paddle::framework::LoDTensor; +using paddle::framework::Tensor; +using paddle::platform::CPUDeviceContext; +using paddle::platform::CreateKey; +using paddle::platform::MKLDNNGetDataType; +using paddle::platform::MKLDNNMemDesc; +using platform::to_void_cast; + +template +class LSTMMKLDNNHandler + : public RNNMKLDNNHandler { + public: + LSTMMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, + const platform::MKLDNNDeviceContext& dev_ctx, + const mkldnn::engine mkldnn_engine, + platform::Place cpu_place, const LoDTensor* input, + const Tensor* weight_h, const Tensor* h0, const Tensor* c0, + const bool is_reverse, const int64_t N, const int64_t Ti, + const int64_t IC, const int64_t OC, + const std::string& unique_name) + : RNNMKLDNNHandler( + ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, weight_h, h0, + is_reverse, N, Ti, IC, OC, 4, + ctx.InputName("X") + ctx.InputName("WeightH")) { + if (!this->isCached()) { + const bool is_INT8 = std::is_same::value; + const bool use_peepholes = ctx.Attr("use_peepholes"); + // oneDNN kernel has hardcoded activation functions + PADDLE_ENFORCE_EQ( + ctx.Attr("gate_activation"), "sigmoid", + platform::errors::Unimplemented("oneDNN fusion_lstm supports only " + "sigmoid as a gate activation.")); + PADDLE_ENFORCE_EQ( + ctx.Attr("cell_activation"), "tanh", + platform::errors::Unimplemented( + "oneDNN fusion_lstm supports only tanh as a cell activation.")); + PADDLE_ENFORCE_EQ( + ctx.Attr("candidate_activation"), "tanh", + platform::errors::Unimplemented( + "oneDNN fusion_lstm supports only tanh a candidate activation.")); + + // Weights for int8 kernel are of a type s8 + const auto weights_dt = + is_INT8 ? dnnl::memory::data_type::s8 : MKLDNNGetDataType(); + + // oneDNN RNN dimensions + const int64_t D = 1; // Directions + const int64_t L = 1; // Layers (PP supports only 1 stacked layer) + const int64_t G = 4; // Number of Gates, 4 for LSTM + + // Create memory descriptors + auto input_md = MKLDNNMemDesc({Ti, N, IC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::tnc); + auto weight_x_md = + MKLDNNMemDesc({L, D, IC, G, OC}, weights_dt, MKLDNNMemoryFormat::any); + auto weight_h_md = + MKLDNNMemDesc({L, D, OC, G, OC}, weights_dt, MKLDNNMemoryFormat::any); + auto bias_md = MKLDNNMemDesc({L, D, G, OC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldgo); + auto hidden_md = MKLDNNMemDesc({Ti, N, OC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::tnc); + auto h0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldnc); + auto c0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldnc); + + // Create LSTM oneDNN primitive + const auto direction = + is_reverse ? dnnl::rnn_direction::unidirectional_right2left + : dnnl::rnn_direction::unidirectional_left2right; + if (!use_peepholes) { + this->AcquireForwardPrimitiveDescriptor( + this->attr_, dnnl::prop_kind::forward_inference, direction, + input_md, h0_md, c0_md, weight_x_md, weight_h_md, bias_md, + hidden_md, dnnl::memory::desc(), dnnl::memory::desc()); + } else { + auto weight_peephole_md = + MKLDNNMemDesc({L, D, 3, OC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldgo); + this->AcquireForwardPrimitiveDescriptor( + this->attr_, dnnl::prop_kind::forward_inference, direction, + input_md, h0_md, c0_md, weight_x_md, weight_h_md, + weight_peephole_md, bias_md, hidden_md, dnnl::memory::desc(), + dnnl::memory::desc()); + } + } + } + + // PaddlePaddle has different order of weights than oneDNN, so a reorder is + // needed + // PaddlePaddle: {c, i, f, o} + // oneDNN: {i, f, c, o} + void ReorderGates(float* weights, int64_t I) { + size_t inner_block_size = this->OC; + size_t block_size = inner_block_size * this->G; + for (size_t i = 0; i < (size_t)I; ++i) { + size_t offset = i * block_size; + + float* base_pos = weights + offset; + std::swap_ranges(base_pos, base_pos + inner_block_size, + base_pos + inner_block_size); // c <-> i + std::swap_ranges(base_pos + inner_block_size, + base_pos + 2 * inner_block_size, + base_pos + 2 * inner_block_size); // c <-> f + } + } + + std::shared_ptr AcquireWeightXMemory(const Tensor* weight_x) { + const std::string wx_key = this->memory_key_ + "@weight_x"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(wx_key)); + + if (!memory_p) { + auto user_md = + MKLDNNMemDesc({1, 1, this->IC, this->G, this->OC}, + MKLDNNGetDataType(), MKLDNNMemoryFormat::ldigo); + auto user_memory = dnnl::memory(user_md, this->engine_); + + auto* weight_x_data = + reinterpret_cast(user_memory.get_data_handle()); + memcpy(weight_x_data, weight_x->data(), + sizeof(float) * this->IC * this->G * this->OC); + + ReorderGates(weight_x_data, this->IC); + + memory_p = std::make_shared( + this->fwd_pd_->weights_layer_desc(), this->engine_); + + auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); + dnnl::reorder(user_memory, *memory_p, this->attr_) + .execute(astream, user_memory, *memory_p); + + this->dev_ctx_.SetBlob(wx_key, memory_p); + } + return memory_p; + } + + std::shared_ptr AcquireWeightHMemory(const Tensor* weight_h) { + const std::string wh_key = this->memory_key_ + "@weight_h"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(wh_key)); + + if (!memory_p) { + auto user_md = + MKLDNNMemDesc({1, 1, this->OC, this->G, this->OC}, + MKLDNNGetDataType(), MKLDNNMemoryFormat::ldigo); + auto user_memory = dnnl::memory(user_md, this->engine_); + + auto* weight_h_data = + reinterpret_cast(user_memory.get_data_handle()); + memcpy(weight_h_data, weight_h->data(), + sizeof(float) * this->OC * this->G * this->OC); + + ReorderGates(weight_h_data, this->OC); + + memory_p = std::make_shared( + this->fwd_pd_->weights_iter_desc(), this->engine_); + + auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); + dnnl::reorder(user_memory, *memory_p, this->attr_) + .execute(astream, user_memory, *memory_p); + + this->dev_ctx_.SetBlob(wh_key, memory_p); + } + return memory_p; + } + + std::shared_ptr AcquireBiasMemory(const Tensor* bias) { + const std::string bias_key = this->memory_key_ + "@bias"; + auto memory_p = std::static_pointer_cast( + this->dev_ctx_.GetBlob(bias_key)); + + if (!memory_p) { + memory_p = std::make_shared(this->fwd_pd_->bias_desc(), + this->engine_); + auto* bias_data = reinterpret_cast(memory_p->get_data_handle()); + if (bias) { + const float* user_bias_data = + bias->data(); // Bias in oneDNN is always float + + memcpy(bias_data, user_bias_data, sizeof(float) * this->G * this->OC); + + ReorderGates(bias_data, 1); + } else { + // oneDNN always need bias memory, if it's not provided in PP, let + // oneDNN allocate memory and set it to 0 + memset(bias_data, 0, sizeof(float) * this->G * this->OC); + } + + this->dev_ctx_.SetBlob(bias_key, memory_p); + } + return memory_p; + } + + std::shared_ptr AcquirePeepholeWeights(const Tensor* bias) { + const std::string peepholes_key = this->memory_key_ + "@peepholes_weights"; + auto memory_p = std::static_pointer_cast( + this->dev_ctx_.GetBlob(peepholes_key)); + + if (!memory_p) { + auto user_md = + MKLDNNMemDesc({1, 1, 3, this->OC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldgo); + auto user_memory = dnnl::memory(user_md, this->engine_); + memory_p = std::make_shared( + this->fwd_pd_->weights_peephole_desc(), this->engine_); + auto* peephole_weights_data = + reinterpret_cast(memory_p->get_data_handle()); + + const float* user_bias_data = + bias->data(); // Bias in oneDNN is always float + memcpy(peephole_weights_data, user_bias_data + 4 * this->OC, + sizeof(float) * 3 * this->OC); + + this->dev_ctx_.SetBlob(peepholes_key, memory_p); + } + return memory_p; + } + + std::shared_ptr AcquireC0Memory(const Tensor* c0) { + const std::string c0_key = this->memory_key_ + "@c0"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(c0_key)); + + if (!memory_p) { + auto user_c0_memory = dnnl::memory(); + if (c0) { + user_c0_memory = + dnnl::memory({{1, 1, this->N, this->OC}, + MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldnc}, + this->engine_, to_void_cast(c0->data())); + } else { + user_c0_memory = dnnl::memory({{1, 1, this->N, this->OC}, + MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldnc}, + this->engine_); + memset(user_c0_memory.get_data_handle(), 0, + sizeof(float) * this->N * this->OC); + } + memory_p = std::make_shared(this->fwd_pd_->src_iter_desc(), + this->engine_); + + auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); + dnnl::reorder(user_c0_memory, *memory_p, this->attr_) + .execute(astream, user_c0_memory, *memory_p); + + this->dev_ctx_.SetBlob(c0_key, memory_p); + } + return memory_p; + } +}; + +template +class FusionLSTMMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + RunKernel(ctx); + } + + template + void RunKernel(const framework::ExecutionContext& ctx) const { + auto& dev_ctx = + ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + // Get Tensors + const auto* input = ctx.Input("X"); + const auto* h0 = ctx.Input("H0"); + const auto* c0 = ctx.Input("C0"); + const auto* weight_x = ctx.Input("WeightX"); + const auto* weight_h = ctx.Input("WeightH"); + const auto* bias = ctx.Input("Bias"); + auto* hidden = ctx.Output("Hidden"); + auto* cell = ctx.Output("Cell"); + cell = cell; + auto x_dims = input->dims(); + auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) + ? framework::flatten_to_2d(x_dims, 1) + : x_dims; + // Get attributes + const bool is_reverse = ctx.Attr("is_reverse"); + const bool use_peepholes = ctx.Attr("use_peepholes"); + + // Get tensor dimensions + const auto x_mat_dims_vec = framework::vectorize(x_mat_dims); + const auto weight_h_dims = framework::vectorize(weight_h->dims()); + const auto& input_lod = input->lod()[0]; + + // Calculate RNN dimensions + const int64_t N = input_lod.size() - 1; // Number of sentences (batches) + const int64_t Ti = // Max length of the sentence in a batch + [&input_lod]() { + size_t res = 0; + for (size_t i = 0; i < (input_lod.size() - 1); ++i) { + res = std::max(res, input_lod[i + 1] - input_lod[i]); + } + return res; + }(); + const int64_t IC = x_mat_dims_vec[1]; // Input channels + const int64_t OC = weight_h_dims[0]; // Output channels + + LSTMMKLDNNHandler handler( + ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, weight_h, h0, c0, + is_reverse, N, Ti, IC, OC, + ctx.InputName("X") + ctx.InputName("WeightH")); + + auto input_memory_p = + handler.AcquireInputMemoryWithReorder(input, is_reverse); + auto h0_memory_p = handler.AcquireH0Memory(h0); + auto c0_memory_p = handler.AcquireC0Memory(c0); + auto weight_x_memory_p = handler.AcquireWeightXMemory(weight_x); + auto weight_h_memory_p = handler.AcquireWeightHMemory(weight_h); + auto bias_memory_p = handler.AcquireBiasMemory(bias); + auto hidden_onednn_memory_p = handler.AcquireOutputMemory(); + + std::unordered_map lstm_args = { + {DNNL_ARG_SRC_LAYER, *input_memory_p}, + {DNNL_ARG_SRC_ITER, *h0_memory_p}, + {DNNL_ARG_SRC_ITER_C, *c0_memory_p}, + {DNNL_ARG_WEIGHTS_LAYER, *weight_x_memory_p}, + {DNNL_ARG_WEIGHTS_ITER, *weight_h_memory_p}, + {DNNL_ARG_BIAS, *bias_memory_p}, + {DNNL_ARG_DST_LAYER, *hidden_onednn_memory_p}}; + + if (use_peepholes) { + auto peephole_weight_p = handler.AcquirePeepholeWeights(bias); + std::pair peepholes_weights(DNNL_ARG_WEIGHTS_PEEPHOLE, + *peephole_weight_p); + lstm_args.insert(peepholes_weights); + } + + auto lstm_forward_p = handler.AcquireForwardPrimitive(); + + auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); + lstm_forward_p->execute(astream, lstm_args); + astream.wait(); + + auto* hidden_onednn_data = hidden_onednn_memory_p->get_data_handle(); + auto* hidden_data = + to_void_cast(hidden->mutable_data(ctx.GetPlace())); + if (handler.is_NTC()) { + handler.reorderRNNdata(hidden_onednn_data, hidden_data, input_lod, + is_reverse, platform::RNNReorderType::NTC_PP); + } else { + handler.reorderRNNdata(hidden_onednn_data, hidden_data, input_lod, + is_reverse, platform::RNNReorderType::TNC_PP); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(fusion_lstm, MKLDNN, paddle::platform::CPUPlace, + ops::FusionLSTMMKLDNNKernel); diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_rnn_mkldnn.h b/paddle/fluid/operators/fused/mkldnn/fusion_rnn_mkldnn.h new file mode 100644 index 0000000000000000000000000000000000000000..f102c535fdf56a076766b36b758deacfe4455266 --- /dev/null +++ b/paddle/fluid/operators/fused/mkldnn/fusion_rnn_mkldnn.h @@ -0,0 +1,229 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace paddle { +namespace operators { + +using paddle::framework::LoDTensor; +using paddle::framework::Tensor; +using paddle::platform::CPUDeviceContext; +using paddle::platform::CreateKey; +using paddle::platform::MKLDNNGetDataType; +using paddle::platform::MKLDNNMemDesc; +using platform::to_void_cast; + +template +class RNNMKLDNNHandler : public platform::MKLDNNHandlerT { + public: + RNNMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, + const platform::MKLDNNDeviceContext& dev_ctx, + const mkldnn::engine mkldnn_engine, + platform::Place cpu_place, const LoDTensor* input, + const Tensor* weight_h, const Tensor* h0, + const bool is_reverse, const int64_t N, const int64_t Ti, + const int64_t IC, const int64_t OC, const int64_t G, + const std::string& unique_name) + : platform::MKLDNNHandlerT( + dev_ctx, dev_ctx.GetEngine(), cpu_place, + CreateKey(dev_ctx, unique_name, MKLDNNGetDataType(), Ti)), + N(N), + Ti(Ti), + IC(IC), + OC(OC), + G(G) { + // Create memory key without Ti because weights, bias and h0 memories + // do not depend on Ti size but primitive and input/output memory do + memory_key_ = platform::ExtendKeyWithThreadInfoIfNeeded( + dev_ctx, CreateKey(dev_ctx, unique_name, MKLDNNGetDataType())); + + // Is it int8 kernel + const bool is_INT8 = std::is_same::value; + + if (is_INT8) { + // Int8 attributes + const float scale_data = ctx.Attr("Scale_data"); + const float shift_data = ctx.Attr("Shift_data"); + const auto scale_weights = ctx.Attr>("Scale_weights"); + + const int weights_scale_mask = + 0 + + (1 << 3) // bit, indicating the unique scales for `g` dim in `ldigo` + + + (1 << 4); // bit, indicating the unique scales for `o` dim in `ldigo` + + attr_.set_rnn_data_qparams(scale_data, shift_data); + attr_.set_rnn_weights_qparams(weights_scale_mask, scale_weights); + } + } + + bool is_NTC() { + return (platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc()) == + dnnl::memory::format_tag::ntc); + } + + void reorderRNNdata(void* input_data, void* output_data, + std::vector lod, const bool is_reverse, + platform::RNNReorderType reorder_type) { + switch (reorder_type) { + // Reorder input memory [WORDS, C] + LoD -> [N, T, C] + case platform::RNNReorderType::PP_NTC: { + auto* input_data_iter = reinterpret_cast(input_data); + auto* output_data_iter = reinterpret_cast(output_data); + for (int n = 0; n < N; ++n) { + const auto num_elements = (lod[n + 1] - lod[n]) * IC; + const auto offset = is_reverse ? (Ti * IC - num_elements) : 0; + memcpy(output_data_iter + n * Ti * IC + offset, input_data_iter, + sizeof(T) * num_elements); + input_data_iter += num_elements; + } + } break; + // Reorder input memory [WORDS, C] + LoD -> [T, N, C] + case platform::RNNReorderType::PP_TNC: { + auto* input_data_iter = reinterpret_cast(input_data); + auto* output_data_iter = reinterpret_cast(output_data); + for (int n = 0; n < N; ++n) { + const auto num_elements = (lod[n + 1] - lod[n]); + const auto offset = is_reverse ? (Ti - num_elements) : 0; + for (size_t t = 0; t < num_elements; ++t) { + memcpy(output_data_iter + (t + offset) * N * IC + n * IC, + input_data_iter, sizeof(T) * IC); + input_data_iter += IC; + } + } + } break; + // Reorder output values to PP format [N, T, C] -> [WORDS, C] + case platform::RNNReorderType::NTC_PP: { + auto* input_data_iter = reinterpret_cast(input_data); + auto* output_data_iter = reinterpret_cast(output_data); + for (int n = 0; n < N; ++n) { + const auto num_elements = (lod[n + 1] - lod[n]) * OC; + const auto offset = is_reverse ? (Ti * OC - num_elements) : 0; + memcpy(output_data_iter, input_data_iter + n * Ti * OC + offset, + sizeof(T_out) * num_elements); + output_data_iter += num_elements; + } + } break; + // Reorder output values to PP format [T, N, C] -> [WORDS, C] + case platform::RNNReorderType::TNC_PP: { + auto* input_data_iter = reinterpret_cast(input_data); + auto* output_data_iter = reinterpret_cast(output_data); + for (int n = 0; n < N; ++n) { + const auto num_elements = lod[n + 1] - lod[n]; + const auto offset = is_reverse ? (Ti - num_elements) : 0; + for (size_t t = 0; t < num_elements; ++t) { + memcpy(output_data_iter, + input_data_iter + (t + offset) * N * OC + n * OC, + sizeof(T_out) * OC); + output_data_iter += OC; + } + } + } break; + } + } + + std::shared_ptr AcquireInputMemoryWithReorder( + const LoDTensor* input, const bool is_reverse) { + const auto name = this->key_ + "@input_mem"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(name)); + + if (!memory_p) { + memory_p = std::make_shared(this->fwd_pd_->src_desc(), + this->engine_); + this->dev_ctx_.SetBlob(name, memory_p); + } + + const auto& input_lod = input->lod()[0]; + auto* x_data = to_void_cast(input->data()); + + auto* x_onednn_data = memory_p->get_data_handle(); + memset(x_onednn_data, 0, sizeof(T) * N * Ti * IC); + + if (platform::GetMKLDNNFormat(this->fwd_pd_->src_desc()) == + dnnl::memory::format_tag::ntc) { + reorderRNNdata(x_data, x_onednn_data, input_lod, is_reverse, + platform::RNNReorderType::PP_NTC); + } else { + reorderRNNdata(x_data, x_onednn_data, input_lod, is_reverse, + platform::RNNReorderType::PP_TNC); + } + return memory_p; + } + + std::shared_ptr AcquireOutputMemory() { + const auto name = this->key_ + "@output_mem"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(name)); + + if (!memory_p) { + memory_p = std::make_shared(this->fwd_pd_->dst_desc(), + this->engine_); + this->dev_ctx_.SetBlob(name, memory_p); + } + return memory_p; + } + + // TODO(grygielski) H0 is for now persistable + // TODO(jczaja) H0 should be updated each iter and of T type (Fusion pass does + // not support in yet) + std::shared_ptr AcquireH0Memory(const Tensor* h0) { + const std::string h0_key = memory_key_ + "@h0"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(h0_key)); + + if (!memory_p) { + auto user_h0_memory = dnnl::memory(); + if (h0) { + user_h0_memory = + dnnl::memory({{1, 1, N, OC}, + MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldnc}, + this->engine_, to_void_cast(h0->data())); + } else { + user_h0_memory = dnnl::memory({{1, 1, N, OC}, + MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldnc}, + this->engine_); + memset(user_h0_memory.get_data_handle(), 0, sizeof(float) * N * OC); + } + memory_p = std::make_shared(this->fwd_pd_->src_iter_desc(), + this->engine_); + + auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); + dnnl::reorder(user_h0_memory, *memory_p, attr_) + .execute(astream, user_h0_memory, *memory_p); + + this->dev_ctx_.SetBlob(h0_key, memory_p); + } + return memory_p; + } + + protected: + // RNN dimensions + // N - Batch Size + // Ti - Max sentence length + // IC - Input Channels + // OC - Output Channels + // G - Number of gates + const int64_t N, Ti, IC, OC, G; + + // Memory size of weights, bias and h0 does not depend + // on Ti size, thus we need another key to cache them + std::string memory_key_; + dnnl::primitive_attr attr_; +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_mkldnn_op.py index cfbbf7de22087d13aed1f8293d362aead5ae03b3..3c70380493d9a079755d3553c5b2b3eb2445c02e 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_mkldnn_op.py @@ -75,4 +75,6 @@ class TestFusionGRUMKLDNNOpBS1(TestFusionGRUOp): if __name__ == "__main__": + from paddle import enable_static + enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_lstm_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_lstm_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..9988a033a7d898551984e825999aade0013fcd16 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_lstm_mkldnn_op.py @@ -0,0 +1,81 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.test_fusion_lstm_op import TestFusionLSTMOp + + +class TestFusionLSTMONEDNNOp(TestFusionLSTMOp): + def set_conf(self): + self.use_mkldnn = True + + def test_check_output(self): + for use_seq in {True, False}: + self.attrs['use_seq'] = use_seq + self.check_output(check_dygraph=False, no_check_set=["Cell"]) + + +class TestFusionLSTMONEDNNOpReverse(TestFusionLSTMONEDNNOp): + def set_conf(self): + self.is_reverse = True + self.use_mkldnn = True + + +class TestFusionLSTMONEDNNOpInitReverse(TestFusionLSTMONEDNNOp): + def set_conf(self): + self.has_initial_state = True + self.is_reverse = True + self.use_mkldnn = True + + +class TestFusionLSTMONEDNNOpMD1(TestFusionLSTMONEDNNOp): + def set_conf(self): + self.M = 36 + self.D = 8 + self.use_mkldnn = True + + +class TestFusionLSTMONEDNNOpMD2(TestFusionLSTMONEDNNOp): + def set_conf(self): + self.M = 8 + self.D = 8 + self.use_mkldnn = True + + +class TestFusionLSTMONEDNNOpMD3(TestFusionLSTMONEDNNOp): + def set_conf(self): + self.M = 15 + self.D = 3 + self.use_mkldnn = True + + +class TestFusionLSTMONEDNNOpBS1(TestFusionLSTMONEDNNOp): + def set_conf(self): + self.lod = [[3]] + self.D = 16 + self.use_mkldnn = True + + +class TestFusionLSTMONEDNNOpPeepholesInit(TestFusionLSTMONEDNNOp): + def set_conf(self): + self.use_peepholes = True + self.has_initial_state = True + self.use_mkldnn = True + + +if __name__ == '__main__': + from paddle import enable_static + enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py b/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py index d8a5816a42a2fd03ecfaa11f22b602f89a422cda..1e25b8034da0a3449453a759da9fe501c07944c1 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py @@ -144,4 +144,6 @@ class TestFusionGRUOpBS1(TestFusionGRUOp): if __name__ == "__main__": + from paddle import enable_static + enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py index e829797ddbbdbb77d6b23e78cdbbb3816b8cce92..3928b6fa034efd123599149651f41a863cb6263e 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py @@ -58,6 +58,7 @@ class TestFusionLSTMOp(OpTest): self.act_gate = 'sigmoid' self.act_cell = 'tanh' self.act_cand = 'tanh' + self.use_mkldnn = False self.set_conf() T = sum(self.lod[0]) @@ -110,7 +111,8 @@ class TestFusionLSTMOp(OpTest): 'is_reverse': self.is_reverse, 'gate_activation': self.act_gate, 'cell_activation': self.act_cell, - 'candidate_activation': self.act_cand + 'candidate_activation': self.act_cand, + 'use_mkldnn': self.use_mkldnn } def test_check_output(self): @@ -191,4 +193,6 @@ class TestFusionLSTMOpPeepholesBS1(TestFusionLSTMOp): if __name__ == '__main__': + from paddle import enable_static + enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py index 24c89408b55fe38142bf26f6a784f59a55c640e4..f81011717040a35375bdb5bed87392c997f5ab29 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py @@ -29,4 +29,5 @@ no_check_set_white_list = [ 'update_loss_scaling', 'cudnn_lstm', 'rnn', + 'fusion_lstm', ] diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index ba510d49a8c3bb3e60e0168923f223cd3a7d207c..958aad3cfbaa1b086a2ec6e24ee692ffe89d08e0 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -601,6 +601,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_bilinear_interp_mkldnn_op', 'test_fusion_gru_int8_mkldnn_op', 'test_fusion_gru_mkldnn_op', + 'test_fusion_lstm_mkldnn_op', 'test_gaussian_random_mkldnn_op', 'test_lrn_mkldnn_op', 'test_matmul_mkldnn_op',