diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index fb856d97403a4d2d982c4f37537ef6d28d89f6b2..a3ba1085fb910db6b6c153ce6d45a503b883f7eb 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -65,17 +65,13 @@ class MatMulFactory { public: void CreateAndExecute(const ExecutionContext& ctx) { SetDNNLEngine(ctx); - if (IsInitialized()) { - UpdateDataPointers(ctx); - Execute(); - SetOutputFormat(ctx); - return; + if (!IsInitialized()) { + CreateMemories(ctx); + CreatePrimitive(ctx); + SetInitialized(); } - CreateMemories(ctx); - CreatePrimitive(ctx); - Execute(); + Execute(ctx); SetOutputFormat(ctx); - SetInitialized(); } private: @@ -181,41 +177,63 @@ class MatMulFactory { } MatMulDims GetMatmulDims(const ExecutionContext& ctx) { - math::MatDescriptor mat_dim_x; + math::MatDescriptor x_mat_dims; memory::dims strides_x; - std::tie(mat_dim_x, strides_x) = GetInputDimsAndStrides(ctx, "X"); - math::MatDescriptor mat_dim_y; + std::tie(x_mat_dims, strides_x) = GetInputDimsAndStrides(ctx, "X"); + math::MatDescriptor y_mat_dims; memory::dims strides_y; - std::tie(mat_dim_y, strides_y) = GetInputDimsAndStrides(ctx, "Y"); + std::tie(y_mat_dims, strides_y) = GetInputDimsAndStrides(ctx, "Y"); + + auto x_mat_bs = x_mat_dims.batch_size_; + auto y_mat_bs = y_mat_dims.batch_size_; + PADDLE_ENFORCE_EQ(x_mat_bs > 0 && y_mat_bs > 0 && x_mat_bs != y_mat_bs, + false, platform::errors::InvalidArgument( + "If batch sizes of X and Y are positive," + "they have to be equal.")); + memory::dim out_mat_bs = + x_mat_bs || y_mat_bs ? std::max(x_mat_bs, y_mat_bs) : 1; + + const memory::dim M = x_mat_dims.height_; + const memory::dim N = y_mat_dims.width_; + const memory::dim K = x_mat_dims.width_; + + // Find total batch size of the data + const memory::dim x_bs = (x_mat_bs) ? x_mat_bs : 1; + const memory::dim y_bs = (y_mat_bs) ? y_mat_bs : 1; + const memory::dim total_bs = std::max(x_bs, y_bs); + + // Find batch size for oneDNN primitive + memory::dim onednn_bs = std::min(x_bs, y_bs); + + // Find the number of times the oneDNN primitive has to be executed + execute_loop_steps_ = total_bs / onednn_bs; + if (execute_loop_steps_ > 1) { + x_mat_bs /= execute_loop_steps_; + y_mat_bs /= execute_loop_steps_; + out_mat_bs /= execute_loop_steps_; + } - const auto x_bs = mat_dim_x.batch_size_; - const auto y_bs = mat_dim_y.batch_size_; - PADDLE_ENFORCE_EQ(x_bs > 0 && y_bs > 0 && x_bs != y_bs, false, - platform::errors::InvalidArgument( - "If batch sizes of X and Y are positive," - "they have to be equal.")); - - // Store 1 if both batches are zero, otherwise save the nonzero batch - const memory::dim BS = x_bs || y_bs ? std::max(x_bs, y_bs) : 1; - const memory::dim M = mat_dim_x.height_; - const memory::dim N = mat_dim_y.width_; - const memory::dim K = mat_dim_x.width_; - - batch_size_ = 1; - auto b = BS; - if (BS > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) { - auto& x_dims = ctx.Input("X")->dims(); - auto& y_dims = ctx.Input("Y")->dims(); - batch_size_ = x_bs > y_bs ? x_dims[0] : y_dims[0]; - b = BS / batch_size_; + // Take original format batch size into account + if (out_mat_bs > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) { + auto x_orig_bs = ctx.Input("X")->dims()[0]; + auto y_orig_bs = ctx.Input("Y")->dims()[0]; + auto orig_bs = x_mat_bs > y_mat_bs ? x_orig_bs : y_orig_bs; + execute_loop_steps_ *= orig_bs; + onednn_bs /= orig_bs; + x_mat_bs /= orig_bs; + y_mat_bs /= orig_bs; + out_mat_bs /= orig_bs; } - memory::dims x_dims = {b, M, K}; - memory::dims y_dims = {b, K, N}; - memory::dims out_dims = {b, M, N}; - x_offset_ = b * M * K * sizeof(XT); - y_offset_ = b * K * N * sizeof(YT); - out_offset_ = b * M * N * sizeof(OT); + // Set dimensions for the oneDNN memories + memory::dims x_dims = {onednn_bs, M, K}; + memory::dims y_dims = {onednn_bs, K, N}; + memory::dims out_dims = {onednn_bs, M, N}; + + // Find data offsets for each oneDNN primitive execution step + x_offset_ = x_mat_bs * M * K * sizeof(XT); + y_offset_ = y_mat_bs * K * N * sizeof(YT); + out_offset_ = out_mat_bs * M * N * sizeof(OT); // Translate transA and transB if (strides_x.empty()) @@ -226,7 +244,7 @@ class MatMulFactory { : memory::dims{N * K, 1, K}; memory::dims out_strides = memory::dims{M * N, N, 1}; - CorrectStridesWhenFloatOutputFused(ctx, N, b, &out_strides); + CorrectStridesWhenFloatOutputFused(ctx, N, out_mat_bs, &out_strides); return {x_dims, y_dims, out_dims, strides_x, strides_y, out_strides}; } @@ -266,13 +284,15 @@ class MatMulFactory { matmul_prim_ = dnnl::matmul(matmul_pd); } - void Execute() { + void Execute(const ExecutionContext& ctx) { dnnl::stream stream(engine_); - - void* x_ptr = x_mem_.get_data_handle(); - void* y_ptr = y_mem_.get_data_handle(); - void* out_ptr = out_mem_.get_data_handle(); - for (uint16_t i = 0; i < batch_size_; i++) { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Output("Out"); + void* x_ptr = to_void_cast(x->data()); + void* y_ptr = to_void_cast(y->data()); + void* out_ptr = to_void_cast(out->mutable_data(ctx.GetPlace())); + for (uint16_t i = 0; i < execute_loop_steps_; i++) { x_mem_.set_data_handle(x_ptr); y_mem_.set_data_handle(y_ptr); out_mem_.set_data_handle(out_ptr); @@ -297,15 +317,6 @@ class MatMulFactory { out->set_layout(DataLayout::kMKLDNN); } - void UpdateDataPointers(const ExecutionContext& ctx) { - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* out = ctx.Output("Out"); - x_mem_.set_data_handle(to_void_cast(x->data())); - y_mem_.set_data_handle(to_void_cast(y->data())); - out_mem_.set_data_handle(out->mutable_data(ctx.GetPlace())); - } - // If initialized, x memory should've been already initialized bool IsInitialized() { return initialized_; } @@ -326,7 +337,7 @@ class MatMulFactory { uint32_t x_offset_; uint32_t y_offset_; uint32_t out_offset_; - uint16_t batch_size_; + uint16_t execute_loop_steps_; bool initialized_ = false; }; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py index 9a5443eed1af70e3070d1291a198684ed9d3c15d..2f557f0bf145eec64b88bc67c7192116bfb771c9 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py @@ -48,6 +48,20 @@ class TestDnnlMatMulOp(OpTest): self.check_output() +class TestDnnlMatMulOpMixedDims1(TestDnnlMatMulOp): + def generate_data(self): + self.x = np.random.random((17, 2, 3)).astype("float32") + self.y = np.random.random((3, 4)).astype("float32") + self.out = np.matmul(self.x, self.y) + + +class TestDnnlMatMulOpMixedDims2(TestDnnlMatMulOp): + def generate_data(self): + self.x = np.random.random((2, 3)).astype("float32") + self.y = np.random.random((17, 3, 4)).astype("float32") + self.out = np.matmul(self.x, self.y) + + class TestDnnlMatMulOpAlpha(TestDnnlMatMulOp): def generate_data(self): self.x = np.random.random((17, 2, 3)).astype("float32") @@ -396,10 +410,10 @@ class TestMatMulOpTransposeReshapeBasicFloat( TestMatMulOpTransposeReshapeEmptyFloat): def generate_data(self): self.bs = 8 - self.x = np.random.random( - [self.bs, 12, 128, 128]).astype(self.data_type_) - self.y = np.random.random( - [self.bs, 12, 128, 64]).astype(self.data_type_) + self.x = np.random.random([self.bs, 12, 128, + 128]).astype(self.data_type_) + self.y = np.random.random([self.bs, 12, 128, + 64]).astype(self.data_type_) def init_params_and_out(self): self.transpose_out = [0, 2, 1, 3]