diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc index a04c1c1263fba659e2d3f623b607e9f476bb40ed..120b2ab440156f6020fd6005dd64a48e9a6918ec 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fusion_gru_op.cc @@ -16,10 +16,9 @@ limitations under the License. */ #include // for memcpy #include #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" +#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/sequence2batch.h" -#include "paddle/fluid/platform/cpu_info.h" namespace paddle { namespace operators { @@ -174,58 +173,44 @@ class FusionGRUKernel : public framework::OpKernel { } } -#define INIT_VEC_FUNC \ - std::function act_gate, act_state; \ - std::function cross; \ - auto& act_gate_str = ctx.Attr("gate_activation"); \ - auto& act_state_str = ctx.Attr("activation"); \ - if (platform::jit::MayIUse(platform::jit::avx)) { \ - math::VecActivations act_functor; \ - act_gate = act_functor(act_gate_str); \ - act_state = act_functor(act_state_str); \ - cross = math::vec_cross; \ - } else { \ - math::VecActivations act_functor; \ - act_gate = act_functor(act_gate_str); \ - act_state = act_functor(act_state_str); \ - cross = math::vec_cross; \ - } - -#define INIT_BASE_INPUT_OUTPUT \ - auto* h0 = ctx.Input("H0"); \ - auto* wx = ctx.Input("WeightX"); \ - auto* wh = ctx.Input("WeightH"); \ - auto* bias = ctx.Input("Bias"); \ - auto* xx = ctx.Output("XX"); \ - auto* hidden_out = ctx.Output("Hidden"); \ - bool is_reverse = ctx.Attr("is_reverse"); - -#define INIT_BASE_SIZES \ - auto x_dims = x->dims(); /* T x M*/ \ - auto wh_dims = wh->dims(); /* D x 3D*/ \ - const int total_T = x_dims[0]; \ - const int M = x_dims[1]; \ - const int D = wh_dims[0]; \ - const int D3 = wh_dims[1]; \ - const int D2 = D * 2; +#define INIT_BASE_DEFINES \ + auto* x = ctx.Input("X"); \ + auto* wh = ctx.Input("WeightH"); \ + auto* xx = ctx.Output("XX"); \ + auto x_lod = x->lod(); \ + auto x_dims = x->dims(); /* T x M*/ \ + auto wh_dims = wh->dims(); /* D x 3D*/ \ + const int total_T = x_dims[0]; \ + const int D3 = wh_dims[1] + +#define INIT_OTHER_DEFINES \ + auto* h0 = ctx.Input("H0"); \ + auto* wx = ctx.Input("WeightX"); \ + auto* bias = ctx.Input("Bias"); \ + auto* hidden_out = ctx.Output("Hidden"); \ + bool is_reverse = ctx.Attr("is_reverse"); \ + const int M = x_dims[1]; \ + const int D = wh_dims[0]; \ + const int D2 = D * 2; \ + const auto& ker = math::jitkernel::KernelPool::Instance() \ + .template Get, \ + const std::string&, const std::string&>( \ + ctx.Attr("gate_activation"), \ + ctx.Attr("activation"), D); \ + const T* x_data = x->data(); \ + const T* wx_data = wx->data(); \ + const T* wh_data = wh->data(); \ + auto place = ctx.GetPlace(); \ + T* xx_data = xx->mutable_data(place) void SeqCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = paddle::platform::CPUDeviceContext; - auto* x = ctx.Input("X"); - INIT_BASE_INPUT_OUTPUT - INIT_BASE_SIZES - INIT_VEC_FUNC - - auto x_lod = x->lod(); + INIT_BASE_DEFINES; + INIT_OTHER_DEFINES; const int N = x_lod[0].size() - 1; - const T* x_data = x->data(); const T* h0_data = h0 ? h0->data() : nullptr; - const T* wx_data = wx->data(); - const T* wh_data = wh->data(); const T* wh_state_data = wh_data + D * D2; - T* xx_data = xx->mutable_data(ctx.GetPlace()); - T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); - + T* hidden_out_data = hidden_out->mutable_data(place); auto blas = math::GetBlas(ctx); math::FCCompute(blas, total_T, D3, M, x_data, wx_data, xx_data, @@ -252,14 +237,7 @@ class FusionGRUKernel : public framework::OpKernel { if (h0_data) { prev_hidden_data = h0_data + bid * D; } else { - // W: {W_update, W_reset; W_state} - // update gate - act_gate(D, xx_data, xx_data); - // state gate - act_state(D, xx_data + D2, xx_data + D2); - // out = a*b - blas.VMUL(D, xx_data, xx_data + D2, hidden_out_data); - // save prev + ker->ComputeH1(xx_data, hidden_out_data); prev_hidden_data = hidden_out_data; tstart = 1; move_step(); @@ -269,17 +247,12 @@ class FusionGRUKernel : public framework::OpKernel { blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast(1), prev_hidden_data, D, wh_data, D2, static_cast(1), xx_data, D3); - act_gate(D2, xx_data, xx_data); - // rt = rt*ht_1 inplace result - blas.VMUL(D, prev_hidden_data, xx_data + D, hidden_out_data); - + ker->ComputeHtPart1(xx_data, prev_hidden_data, hidden_out_data); // gemm rt * Ws blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast(1), hidden_out_data, D, wh_state_data, D, static_cast(1), xx_data + D2, D3); - act_state(D, xx_data + D2, xx_data + D2); - // out = zt*ht~ + (1-zt)*ht_1 - cross(D, xx_data, xx_data + D2, prev_hidden_data, hidden_out_data); + ker->ComputeHtPart2(xx_data, prev_hidden_data, hidden_out_data); // save prev prev_hidden_data = hidden_out_data; move_step(); @@ -289,28 +262,19 @@ class FusionGRUKernel : public framework::OpKernel { void BatchCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = paddle::platform::CPUDeviceContext; - auto* x = ctx.Input("X"); - INIT_BASE_INPUT_OUTPUT - INIT_BASE_SIZES - if (x->lod()[0].size() == 2) { + INIT_BASE_DEFINES; + if (x_lod[0].size() == 2) { xx->Resize({total_T, D3}); SeqCompute(ctx); return; } - INIT_VEC_FUNC - + INIT_OTHER_DEFINES; auto* reordered_h0 = ctx.Output("ReorderedH0"); auto* batched_input = ctx.Output("BatchedInput"); auto* batched_out = ctx.Output("BatchedOut"); - - const T* x_data = x->data(); - const T* wx_data = wx->data(); - const T* wh_data = wh->data(); - T* xx_data = xx->mutable_data(ctx.GetPlace()); - T* batched_input_data = batched_input->mutable_data(ctx.GetPlace()); - T* batched_out_data = batched_out->mutable_data(ctx.GetPlace()); - hidden_out->mutable_data(ctx.GetPlace()); - + T* batched_input_data = batched_input->mutable_data(place); + T* batched_out_data = batched_out->mutable_data(place); + hidden_out->mutable_data(place); auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); math::LoDTensor2BatchFunctor to_batch; @@ -336,7 +300,7 @@ class FusionGRUKernel : public framework::OpKernel { T* prev_hidden_data = nullptr; if (h0) { // reorder h0 - T* reordered_h0_data = reordered_h0->mutable_data(ctx.GetPlace()); + T* reordered_h0_data = reordered_h0->mutable_data(place); const T* h0_data = h0->data(); prev_hidden_data = reordered_h0_data; size_t sz = sizeof(T) * D; @@ -350,12 +314,7 @@ class FusionGRUKernel : public framework::OpKernel { T* cur_out_data = batched_out_data; // W: {W_update, W_reset; W_state} for (int i = 0; i < max_bs; ++i) { - // update gate - act_gate(D, cur_in_data, cur_in_data); - // state gate - act_state(D, cur_in_data + D2, cur_in_data + D2); - // out = a*b - blas.VMUL(D, cur_in_data, cur_in_data + D2, cur_out_data); + ker->ComputeH1(cur_in_data, cur_out_data); // add offset cur_in_data += D3; cur_out_data += D; @@ -380,10 +339,8 @@ class FusionGRUKernel : public framework::OpKernel { T* cur_out_data = batched_out_data; T* cur_prev_hidden_data = prev_hidden_data; for (int i = 0; i < cur_bs; ++i) { - act_gate(D2, cur_batched_data, cur_batched_data); - // rt = rt*ht_1 inplace result - blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, cur_out_data); - + ker->ComputeHtPart1(cur_batched_data, cur_prev_hidden_data, + cur_out_data); cur_batched_data += D3; cur_prev_hidden_data += D; cur_out_data += D; @@ -397,12 +354,8 @@ class FusionGRUKernel : public framework::OpKernel { cur_prev_hidden_data = prev_hidden_data; for (int i = 0; i < cur_bs; ++i) { - // ht~ = act_state(...) - act_state(D, cur_batched_data + D2, cur_batched_data + D2); - // out = zt*ht~ + (1-zt)*ht_1 - cross(D, cur_batched_data, cur_batched_data + D2, cur_prev_hidden_data, - cur_out_data); - + ker->ComputeHtPart2(cur_batched_data, cur_prev_hidden_data, + cur_out_data); cur_batched_data += D3; cur_prev_hidden_data += D; cur_out_data += D; @@ -416,9 +369,8 @@ class FusionGRUKernel : public framework::OpKernel { batched_out->set_lod(batched_lod); to_seq(dev_ctx, *batched_out, hidden_out); } -#undef INIT_VEC_FUNC -#undef INIT_BASE_SIZES -#undef INIT_BASE_INPUT_OUTPUT +#undef INIT_OTHER_DEFINES +#undef INIT_BASE_DEFINES }; } // namespace operators diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 5d0c0b4228d8e2890c8b8d8bd10e0df080251350..e5b438cece3cd517c237d2d580d319aa808b7776 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -75,6 +75,6 @@ endif() cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_library(jit_kernel - SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc + SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc DEPS cpu_info cblas) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index e91e4e8e5adfdfff8163efe7fc1451bc602504e0..9088d0c7a6307c3fbd9707c719ec9e6f6c85fbdb 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -142,6 +142,15 @@ class LSTMKernel : public Kernel { const T *wp_data = nullptr) const = 0; }; +template +class GRUKernel : public Kernel { + public: + // compute h1 without h0 + virtual void ComputeH1(T *gates, T *ht) const = 0; + virtual void ComputeHtPart1(T *gates, const T *ht_1, T *ht) const = 0; + virtual void ComputeHtPart2(T *gates, const T *ht_1, T *ht) const = 0; +}; + } // namespace jitkernel } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/jit_kernel_lstm.cc b/paddle/fluid/operators/math/jit_kernel_rnn.cc similarity index 65% rename from paddle/fluid/operators/math/jit_kernel_lstm.cc rename to paddle/fluid/operators/math/jit_kernel_rnn.cc index 26bd26e2e171feea569fbd646a9caf03bebbaa46..c0847f0bee415393255983468349ad314716bcb3 100644 --- a/paddle/fluid/operators/math/jit_kernel_lstm.cc +++ b/paddle/fluid/operators/math/jit_kernel_rnn.cc @@ -136,6 +136,21 @@ static std::shared_ptr> GetActKernel( return nullptr; } +template +static std::unique_ptr GetAVXAct(const std::string& type) { + if (type == "sigmoid") { + return std::unique_ptr(new AVXActImpl()); + } else if (type == "relu") { + return std::unique_ptr(new AVXActImpl()); + } else if (type == "tanh") { + return std::unique_ptr(new AVXActImpl()); + } else if (type == "identity" || type == "") { + return std::unique_ptr(new AVXActImpl()); + } + PADDLE_THROW("Not support type: %s", type); + return nullptr; +} + /* LSTM JitKernel */ template class LSTMKernelImpl : public LSTMKernel { @@ -192,61 +207,49 @@ class LSTMKernelImpl : public LSTMKernel { #endif }; -#define INTRI8_FLOAT(isa) \ - template <> \ - LSTMKernelImpl::LSTMKernelImpl( \ - const std::string& act_gate, const std::string& act_cand, \ - const std::string& act_cell, int d) \ - : LSTMKernel() { \ - auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr { \ - if (type == "sigmoid") { \ - return std::unique_ptr(new AVXActImpl()); \ - } else if (type == "relu") { \ - return std::unique_ptr(new AVXActImpl()); \ - } else if (type == "tanh") { \ - return std::unique_ptr(new AVXActImpl()); \ - } else if (type == "identity" || type == "") { \ - return std::unique_ptr(new AVXActImpl()); \ - } \ - PADDLE_THROW("Not support type: %s", type); \ - }; \ - avx_act_gate_ = GetAVXAct(act_gate); \ - avx_act_cand_ = GetAVXAct(act_cand); \ - avx_act_cell_ = GetAVXAct(act_cell); \ - } \ - template <> \ - void LSTMKernelImpl::ComputeCtHt( \ - float* gates, const float* ct_1, float* ct, float* ht, \ - const float* wp_data, float* checked) const { \ - /* gates: W_ch, W_ih, W_fh, W_oh */ \ - __m256 c, i, f, o; \ - c = _mm256_loadu_ps(gates); \ - i = _mm256_loadu_ps(gates + 8); \ - f = _mm256_loadu_ps(gates + 16); \ - o = _mm256_loadu_ps(gates + 24); \ - /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ - c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \ - i = _mm256_loadu_ps(ct_1); \ - f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \ - f = _mm256_add_ps(c, f); \ - _mm256_storeu_ps(ct, f); \ - /* H_t = act_cell(C_t) * ogated */ \ - o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \ - _mm256_storeu_ps(ht, o); \ - } \ - template <> \ - void LSTMKernelImpl::ComputeC1H1( \ - float* gates, float* ct, float* ht, const float* wp_data) const { \ - __m256 c, i, o; \ - c = _mm256_loadu_ps(gates); \ - i = _mm256_loadu_ps(gates + 8); \ - o = _mm256_loadu_ps(gates + 24); \ - /* C_t = igated * cgated*/ \ - c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \ - _mm256_storeu_ps(ct, c); \ - /* H_t = act_cell(C_t) * ogated */ \ - o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \ - _mm256_storeu_ps(ht, o); \ +#define INTRI8_FLOAT(isa) \ + template <> \ + LSTMKernelImpl::LSTMKernelImpl( \ + const std::string& act_gate, const std::string& act_cand, \ + const std::string& act_cell, int d) \ + : LSTMKernel() { \ + avx_act_gate_ = GetAVXAct(act_gate); \ + avx_act_cand_ = GetAVXAct(act_cand); \ + avx_act_cell_ = GetAVXAct(act_cell); \ + } \ + template <> \ + void LSTMKernelImpl::ComputeCtHt( \ + float* gates, const float* ct_1, float* ct, float* ht, \ + const float* wp_data, float* checked) const { \ + /* gates: W_ch, W_ih, W_fh, W_oh */ \ + __m256 c, i, f, o; \ + c = _mm256_loadu_ps(gates); \ + i = _mm256_loadu_ps(gates + 8); \ + f = _mm256_loadu_ps(gates + 16); \ + o = _mm256_loadu_ps(gates + 24); \ + /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ + c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \ + i = _mm256_loadu_ps(ct_1); \ + f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \ + f = _mm256_add_ps(c, f); \ + _mm256_storeu_ps(ct, f); \ + /* H_t = act_cell(C_t) * ogated */ \ + o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \ + _mm256_storeu_ps(ht, o); \ + } \ + template <> \ + void LSTMKernelImpl::ComputeC1H1( \ + float* gates, float* ct, float* ht, const float* wp_data) const { \ + __m256 c, i, o; \ + c = _mm256_loadu_ps(gates); \ + i = _mm256_loadu_ps(gates + 8); \ + o = _mm256_loadu_ps(gates + 24); \ + /* C_t = igated * cgated*/ \ + c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \ + _mm256_storeu_ps(ct, c); \ + /* H_t = act_cell(C_t) * ogated */ \ + o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \ + _mm256_storeu_ps(ht, o); \ } // TODO(TJ): optimize keq16 @@ -354,6 +357,126 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, #undef JITKERNEL_DECLARE_LSTM #undef JITKERNEL_KEY_LSTM #undef JITKERNEL_NEW_LSTM_IMPL + +/* GRU JitKernel */ +template +class GRUKernelImpl : public GRUKernel { + public: + explicit GRUKernelImpl(const std::string& act_gate, + const std::string& act_state, int d) + : GRUKernel() { + d_ = d; + d2_ = d * 2; + act_gate_d2_ = GetActKernel(act_gate, d2_); + act_gate_d_ = GetActKernel(act_gate, d); + act_state_d_ = GetActKernel(act_state, d); + vmul_d_ = KernelPool::Instance().template Get>(d); + } + + void ComputeH1(T* gates, T* ht) const override { + act_gate_d_->Compute(gates, gates); + act_state_d_->Compute(gates + d2_, gates + d2_); + vmul_d_->Compute(gates, gates + d2_, ht); + } + + void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override { + // W: {W_update, W_reset; W_state} + act_gate_d2_->Compute(gates, gates); + vmul_d_->Compute(ht_1, gates + d_, ht); + } + + void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override { + T* y = gates + d2_; + act_state_d_->Compute(y, y); + // out = zt*ht~ + (1-zt)*ht_1 + for (int i = 0; i < d_; ++i) { + ht[i] = gates[i] * y[i] + (static_cast(1) - gates[i]) * ht_1[i]; + } + } + + private: + int d_, d2_; + std::shared_ptr> act_gate_d2_, act_gate_d_, act_state_d_; + std::shared_ptr> vmul_d_; +#ifdef __AVX__ + std::unique_ptr avx_act_gate_, avx_act_state_; +#endif +}; + +#define INTRI8_FLOAT(isa) \ + template <> \ + GRUKernelImpl::GRUKernelImpl( \ + const std::string& act_gate, const std::string& act_state, int d) \ + : GRUKernel() { \ + avx_act_gate_ = GetAVXAct(act_gate); \ + avx_act_state_ = GetAVXAct(act_state); \ + } \ + template <> \ + void GRUKernelImpl::ComputeH1(float* gates, float* ht) \ + const { \ + __m256 u, s; \ + /* W: {W_update, W_reset; W_state} */ \ + u = _mm256_loadu_ps(gates); \ + s = _mm256_loadu_ps(gates + 16); \ + s = _mm256_mul_ps(avx_act_gate_->Compute(u), avx_act_state_->Compute(s)); \ + _mm256_storeu_ps(ht, s); \ + } \ + template <> \ + void GRUKernelImpl::ComputeHtPart1( \ + float* gates, const float* ht_1, float* ht) const { \ + /* not exactly equal the any implementation */ \ + __m256 r, ht0; \ + r = _mm256_loadu_ps(gates + 8); \ + ht0 = _mm256_loadu_ps(ht_1); \ + r = _mm256_mul_ps(avx_act_gate_->Compute(r), ht0); \ + _mm256_storeu_ps(ht, r); \ + } \ + template <> \ + void GRUKernelImpl::ComputeHtPart2( \ + float* gates, const float* ht_1, float* ht) const { \ + /* not exactly equal the any implementation */ \ + __m256 u, s, ht0; \ + u = _mm256_loadu_ps(gates); \ + s = _mm256_loadu_ps(gates + 16); \ + ht0 = _mm256_loadu_ps(ht_1); \ + u = avx_act_gate_->Compute(u); \ + s = _mm256_mul_ps(u, avx_act_state_->Compute(s)); \ + u = _mm256_sub_ps(_mm256_set1_ps(1.f), u); \ + u = _mm256_mul_ps(u, ht0); \ + u = _mm256_add_ps(s, u); \ + _mm256_storeu_ps(ht, u); \ + } + +#ifdef __AVX__ +INTRI8_FLOAT(jit::avx); +#endif +#ifdef __AVX2__ +INTRI8_FLOAT(jit::avx2); +#endif +#ifdef __AVX512F__ +INTRI8_FLOAT(jit::avx512f); +#endif + +#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \ + template <> \ + std::shared_ptr> KernelPool::Get< \ + GRUKernel, const std::string&, const std::string&, int>( \ + const std::string& act_gate, const std::string& act_state, int d) + +#define JITKERNEL_KEY_GRU(ker_key, dtype_key) \ + #ker_key #dtype_key + std::to_string(d) + act_gate + act_state + +#define JITKERNEL_NEW_GRU_IMPL(ker, dtype, isa, k) \ + p = std::dynamic_pointer_cast>( \ + std::make_shared>(act_gate, act_state, d)); + +REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DECLARE_GRU, + JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL); + +#undef INTRI8_FLOAT +#undef JITKERNEL_NEW_GRU_IMPL +#undef JITKERNEL_KEY_GRU +#undef JITKERNEL_DECLARE_GRU } // namespace jitkernel } // namespace math } // namespace operators diff --git a/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py b/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py index 36ebc8fb6ea9efdcd1807f5c8917ab1428b3381e..377454e7802e40f90c371987adfe50cce922c764 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py @@ -125,6 +125,12 @@ class TestFusionGRUOpMD2(TestFusionGRUOp): self.D = 8 +class TestFusionGRUOpMD3(TestFusionGRUOp): + def set_confs(self): + self.M = 17 + self.D = 15 + + class TestFusionGRUOpBS1(TestFusionGRUOp): def set_confs(self): self.lod = [[3]]