diff --git a/paddle/fluid/operators/bilinear_tensor_product_op.h b/paddle/fluid/operators/bilinear_tensor_product_op.h index 7191711a731676298219d2f4bd95fdece27250da..f23336f7b98d6d71d155373cff3515a8463aecbe 100644 --- a/paddle/fluid/operators/bilinear_tensor_product_op.h +++ b/paddle/fluid/operators/bilinear_tensor_product_op.h @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index d6f86a5c88e37970379da0afe2a1d46e18b653f4..c51898abb422663a6731a17e0717c62ebf0701f8 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -17,9 +17,9 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/im2col.h" -#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/vol2col.h" namespace paddle { @@ -161,6 +161,7 @@ class GemmConvKernel : public framework::OpKernel { math::Im2ColFunctor im2col; auto& dev_ctx = context.template device_context(); + auto blas = math::GetBlas(dev_ctx); for (int i = 0; i < batch_size; i++) { Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); @@ -186,8 +187,7 @@ class GemmConvKernel : public framework::OpKernel { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(dev_ctx, filter_slice, false, col_matrix, - false, T(1.0), &out_slice, T(0.0)); + blas.MatMul(filter_slice, col_matrix, &out_slice); } } } @@ -274,6 +274,7 @@ class GemmConvGradKernel : public framework::OpKernel { math::SetConstant set_zero; auto& dev_ctx = context.template device_context(); + auto blas = math::GetBlas(dev_ctx); if (input_grad) { input_grad->mutable_data(context.GetPlace()); @@ -303,9 +304,7 @@ class GemmConvGradKernel : public framework::OpKernel { col_matrix.ShareDataWith(in_grad_slice); col_matrix.Resize(col_matrix_shape); } - math::matmul(dev_ctx, filter_slice, true, - out_grad_slice, false, T(1.0), - &col_matrix, T(0.0)); + blas.MatMul(filter_slice, true, out_grad_slice, false, &col_matrix); if (is_expand && data_dim == 2U) { col2im(dev_ctx, col, dilations, strides, @@ -352,9 +351,8 @@ class GemmConvGradKernel : public framework::OpKernel { // gemm Tensor filter_grad_slice = filter_grad_.Slice(g * out_step, (g + 1) * out_step); - math::matmul(dev_ctx, out_grad_slice, false, - col_matrix, true, T(1.0), - &filter_grad_slice, T(1.0)); + blas.MatMul(out_grad_slice, false, col_matrix, true, + &filter_grad_slice); } } } diff --git a/paddle/fluid/operators/conv_transpose_op.h b/paddle/fluid/operators/conv_transpose_op.h index bfc0177c2a0da1627fbca532764fdae8167b6b2a..9276e5bfef71a58741c2dfa25b31c2bd07c309b8 100644 --- a/paddle/fluid/operators/conv_transpose_op.h +++ b/paddle/fluid/operators/conv_transpose_op.h @@ -16,8 +16,8 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/im2col.h" -#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/vol2col.h" namespace paddle { @@ -118,6 +118,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { output->mutable_data(context.GetPlace()); math::SetConstant set_zero; auto& dev_ctx = context.template device_context(); + auto blas = math::GetBlas(dev_ctx); set_zero(dev_ctx, output, static_cast(0)); math::Col2ImFunctor col2im; @@ -134,9 +135,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { // col_matrix = filter * input_batch // of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) - math::matmul(dev_ctx, filter, true, input_batch, false, - static_cast(1.0), &col_matrix, - static_cast(0.0)); + blas.MatMul(filter, true, input_batch, false, &col_matrix); if (data_dim == 2U) { // col2im: col_matrix -> dy @@ -213,6 +212,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // im2col + gemm (similar to conv-forward) // input need to compute gradient auto& dev_ctx = context.template device_context(); + auto blas = math::GetBlas(dev_ctx); if (input_grad || filter_grad) { Tensor col; col.mutable_data(col_shape, context.GetPlace()); @@ -267,9 +267,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // or // (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m, // d, h, w) - math::matmul( - dev_ctx, filter, false, col_matrix, false, static_cast(1.0), - &input_grad_batch, static_cast(0.0)); + blas.MatMul(filter, false, col_matrix, false, &input_grad_batch); } if (filter_grad) { // input batch @@ -279,9 +277,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // or // (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d * // k_h * k_w) - math::matmul(dev_ctx, in_batch, false, col_matrix, - true, static_cast(1.0), - &filter_grad_, static_cast(1.0)); + blas.MatMul(in_batch, false, col_matrix, true, &filter_grad_); } } } diff --git a/paddle/fluid/operators/gru_unit_op.h b/paddle/fluid/operators/gru_unit_op.h index 49e657a272cdcf919f4ae88c159d7ef6bdae9c93..2d9faed648aef78da60706e13db3862080c96514 100644 --- a/paddle/fluid/operators/gru_unit_op.h +++ b/paddle/fluid/operators/gru_unit_op.h @@ -14,11 +14,10 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/operators/math/math_function.h" - #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/layer_norm_op.h b/paddle/fluid/operators/layer_norm_op.h index 7b84ba0a7daf10e9e636f62eea6bd759ebec9541..2e54bb497dec11eaeda03a1aa6acfd4cc261dbfe 100644 --- a/paddle/fluid/operators/layer_norm_op.h +++ b/paddle/fluid/operators/layer_norm_op.h @@ -15,8 +15,8 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" - #include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" namespace paddle { @@ -46,9 +46,9 @@ class RowwiseMean2D { } void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, framework::Tensor* out) { - math::gemv( - context, false, left_, right_, 1., input.data(), divisor_.data(), - 0., out->data()); + math::GetBlas(context).GEMV( + false, left_, right_, 1., input.data(), divisor_.data(), 0., + out->data()); } private: @@ -93,9 +93,9 @@ class ColwiseSum2D { void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, framework::Tensor* out) { - math::gemv( - context, true, left_, right_, 1., input.data(), divisor_.data(), - 0., out->data()); + math::GetBlas(context).GEMV( + true, left_, right_, 1., input.data(), divisor_.data(), 0., + out->data()); } private: diff --git a/paddle/fluid/operators/lstm_op.h b/paddle/fluid/operators/lstm_op.h index 0707aded8c9aa37d6be92373c274b59b7d6b34b6..7d62d2d020ec2e3a29ad8720a8f04fead3a90a63 100644 --- a/paddle/fluid/operators/lstm_op.h +++ b/paddle/fluid/operators/lstm_op.h @@ -15,9 +15,9 @@ limitations under the License. */ #pragma once #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/lstm_compute.h" -#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/sequence2batch.h" namespace paddle { @@ -114,6 +114,7 @@ class LSTMKernel : public framework::OpKernel { auto cand_act = math::detail::GetActivationType( ctx.Attr("candidate_activation")); + auto blas = math::GetBlas(device_ctx); for (size_t n = 0; n < num_batch; n++) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); @@ -129,9 +130,8 @@ class LSTMKernel : public framework::OpKernel { int pre_h_start = static_cast(batch_starts[n - 1]); int pre_h_end = pre_h_start + cur_batch_size; auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end); - math::matmul(device_ctx, pre_hidden_t, false, *weight, - false, static_cast(1.0), &gate_t, - static_cast(1.0)); + blas.MatMul(pre_hidden_t, false, *weight, false, static_cast(1.0), + &gate_t, static_cast(1.0)); } else if (hidden_t0) { // If n == 0 and there is no initialized hidden state, that is to say // the H0 is zeros, the calculation W_h * H0 will be skiped. @@ -143,9 +143,8 @@ class LSTMKernel : public framework::OpKernel { Tensor ordered_h0; ReorderInitState(device_ctx, *hidden_t0, order, &ordered_h0, true); - math::matmul(device_ctx, ordered_h0, false, *weight, - false, static_cast(1.0), &gate_t, - static_cast(1.0)); + blas.MatMul(ordered_h0, false, *weight, false, static_cast(1.0), + &gate_t, static_cast(1.0)); } lstm_value.gate_value = gate_t.data(); @@ -282,6 +281,7 @@ class LSTMGradKernel : public framework::OpKernel { auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; + auto blas = math::GetBlas(device_ctx); for (int n = static_cast(num_batch) - 1; n >= 0; n--) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); @@ -320,29 +320,25 @@ class LSTMGradKernel : public framework::OpKernel { int pre_h_start = static_cast(batch_starts[n - 1]); int pre_h_end = pre_h_start + cur_batch_size; auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end); - math::matmul(device_ctx, gate_g, false, *weight, true, - static_cast(1.0), &pre_hidden_g, - static_cast(1.0)); + blas.MatMul(gate_g, false, *weight, true, static_cast(1.0), + &pre_hidden_g, static_cast(1.0)); if (weight_g) { /* backward weight */ auto pre_hidden = batch_hidden.Slice(pre_h_start, pre_h_end); - math::matmul(device_ctx, pre_hidden, true, gate_g, - false, static_cast(1.0), weight_g, - static_cast(1.0)); + blas.MatMul(pre_hidden, true, gate_g, false, static_cast(1.0), + weight_g, static_cast(1.0)); } } else { if (h0 && weight_g) { ReorderInitState(device_ctx, *h0, order, &ordered_h0, true); - math::matmul(device_ctx, ordered_h0, true, gate_g, - false, static_cast(1.0), weight_g, - static_cast(1.0)); + blas.MatMul(ordered_h0, true, gate_g, false, static_cast(1.0), + weight_g, static_cast(1.0)); } if (h0 && h0_g) { ordered_h0_g.mutable_data(h0_g->dims(), ctx.GetPlace()); - math::matmul(device_ctx, gate_g, false, *weight, - true, static_cast(1.0), - &ordered_h0_g, static_cast(0.0)); + blas.MatMul(gate_g, false, *weight, true, static_cast(1.0), + &ordered_h0_g, static_cast(0.0)); } } } diff --git a/paddle/fluid/operators/lstmp_op.h b/paddle/fluid/operators/lstmp_op.h index 628936a3105b95577bef080f05b0bd556b514918..370dd04d1449a8e211febf9a4f9e90e6f5008e20 100644 --- a/paddle/fluid/operators/lstmp_op.h +++ b/paddle/fluid/operators/lstmp_op.h @@ -14,15 +14,14 @@ limitations under the License. */ #pragma once #include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/lstm_compute.h" -#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/sequence2batch.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" - namespace paddle { namespace operators { @@ -143,7 +142,7 @@ class LSTMPKernel : public framework::OpKernel { auto proj_act = math::detail::GetActivationType( ctx.Attr("proj_activation")); auto& place = *ctx.template device_context().eigen_device(); - + auto blas = math::GetBlas(device_ctx); for (size_t n = 0; n < num_batch; n++) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); @@ -160,9 +159,8 @@ class LSTMPKernel : public framework::OpKernel { int pre_h_start = static_cast(batch_starts[n - 1]); int pre_h_end = pre_h_start + cur_batch_size; auto pre_proj_t = batch_proj.Slice(pre_h_start, pre_h_end); - math::matmul(device_ctx, pre_proj_t, false, *weight, - false, static_cast(1.0), &gate_t, - static_cast(1.0)); + blas.MatMul(pre_proj_t, false, *weight, false, static_cast(1.0), + &gate_t, static_cast(1.0)); } else if (hidden_t0) { // If n == 0 and there is no initialized hidden state, that is to say // the H0 is zeros, the calculation W_h * H0 will be skiped. @@ -176,16 +174,14 @@ class LSTMPKernel : public framework::OpKernel { ordered_proj0->mutable_data(ctx.GetPlace()); ReorderInitState(device_ctx, *hidden_t0, order, &ordered_h0, true); - math::matmul(device_ctx, ordered_h0, false, - *proj_weight, false, static_cast(1.0), - ordered_proj0, static_cast(0.0)); + blas.MatMul(ordered_h0, false, *proj_weight, false, static_cast(1.0), + ordered_proj0, static_cast(0.0)); if (proj_act != math::detail::ActivationType::kIdentity) { auto proj0_dev = EigenMatrix::From(*ordered_proj0); ActCompute(cell_act, place, proj0_dev, proj0_dev); } - math::matmul(device_ctx, *ordered_proj0, false, - *weight, false, static_cast(1.0), - &gate_t, static_cast(1.0)); + blas.MatMul(*ordered_proj0, false, *weight, false, static_cast(1.0), + &gate_t, static_cast(1.0)); } lstmp_value.gate_value = gate_t.data(); @@ -196,9 +192,8 @@ class LSTMPKernel : public framework::OpKernel { device_ctx, lstmp_value, frame_size, cur_batch_size, gate_act, cell_act, cand_act); lstmp_value.prev_state_value = lstmp_value.state_value; - math::matmul(device_ctx, hidden_t, false, *proj_weight, - false, static_cast(1.0), &proj_t, - static_cast(0.0)); + blas.MatMul(hidden_t, false, *proj_weight, false, static_cast(1.0), + &proj_t, static_cast(0.0)); if (proj_act != math::detail::ActivationType::kIdentity) { auto proj_t_dev = EigenMatrix::From(proj_t); ActCompute(cell_act, place, proj_t_dev, proj_t_dev); @@ -361,6 +356,7 @@ class LSTMPGradKernel : public framework::OpKernel { auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; + auto blas = math::GetBlas(device_ctx); for (int n = static_cast(num_batch) - 1; n >= 0; n--) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); @@ -375,15 +371,13 @@ class LSTMPGradKernel : public framework::OpKernel { } /* hidden state backwarad */ Tensor out_g = batch_hidden_g.Slice(bstart, bend); - math::matmul(device_ctx, proj_g, false, *proj_weight, - true, static_cast(1.0), &out_g, - static_cast(0.0)); + blas.MatMul(proj_g, false, *proj_weight, true, static_cast(1.0), + &out_g, static_cast(0.0)); /* projection weight backward*/ if (proj_weight_g) { Tensor hidden_t = batch_hidden->Slice(bstart, bend); - math::matmul(device_ctx, hidden_t, true, proj_g, - false, static_cast(1.0), - proj_weight_g, static_cast(1.0)); + blas.MatMul(hidden_t, true, proj_g, false, static_cast(1.0), + proj_weight_g, static_cast(1.0)); } Tensor gate = batch_gate->Slice(bstart, bend); @@ -419,24 +413,21 @@ class LSTMPGradKernel : public framework::OpKernel { int pre_h_start = static_cast(batch_starts[n - 1]); int pre_h_end = pre_h_start + cur_batch_size; auto pre_proj_g = batch_proj_g.Slice(pre_h_start, pre_h_end); - math::matmul(device_ctx, gate_g, false, *weight, true, - static_cast(1.0), &pre_proj_g, - static_cast(1.0)); + blas.MatMul(gate_g, false, *weight, true, static_cast(1.0), + &pre_proj_g, static_cast(1.0)); if (weight_g) { /* weight backward*/ auto pre_proj = batch_proj.Slice(pre_h_start, pre_h_end); - math::matmul(device_ctx, pre_proj, true, gate_g, - false, static_cast(1.0), weight_g, - static_cast(1.0)); + blas.MatMul(pre_proj, true, gate_g, false, static_cast(1.0), + weight_g, static_cast(1.0)); } } else { if (h0 && weight_g) { ReorderInitState(device_ctx, *h0, order, &ordered_h0, true); if (weight_g) { - math::matmul(device_ctx, *ordered_proj0, true, - gate_g, false, static_cast(1.0), - weight_g, static_cast(1.0)); + blas.MatMul(*ordered_proj0, true, gate_g, false, + static_cast(1.0), weight_g, static_cast(1.0)); } } if (h0 && (h0_g || proj_weight_g)) { @@ -444,9 +435,8 @@ class LSTMPGradKernel : public framework::OpKernel { Tensor proj0_g; proj0_g.Resize({in_dims[0], proj_weight->dims()[1]}); proj0_g.mutable_data(ctx.GetPlace()); - math::matmul(device_ctx, gate_g, false, *weight, - true, static_cast(1.0), &proj0_g, - static_cast(0.0)); + blas.MatMul(gate_g, false, *weight, true, static_cast(1.0), + &proj0_g, static_cast(0.0)); if (proj_act != math::detail::ActivationType::kIdentity) { auto proj0_dev = EigenMatrix::From(*ordered_proj0); auto proj0_g_dev = EigenMatrix::From(proj0_g); @@ -454,14 +444,12 @@ class LSTMPGradKernel : public framework::OpKernel { proj0_g_dev); } if (h0_g) { - math::matmul( - device_ctx, proj0_g, false, *proj_weight, true, - static_cast(1.0), &ordered_h0_g, static_cast(0.0)); + blas.MatMul(proj0_g, false, *proj_weight, true, static_cast(1.0), + &ordered_h0_g, static_cast(0.0)); } if (proj_weight_g) { - math::matmul(device_ctx, ordered_h0, true, - proj0_g, false, static_cast(1.0), - proj_weight_g, static_cast(1.0)); + blas.MatMul(ordered_h0, true, proj0_g, false, static_cast(1.0), + proj_weight_g, static_cast(1.0)); } } } diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index ee0e91132bce52998e9c45b37335618e4354e1cd..f36e9444dfb6dce2a7ea9eba153cda174f1ed6f1 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -41,7 +41,8 @@ math_library(depthwise_conv) math_library(gru_compute DEPS activation_functions math_function) math_library(im2col) math_library(lstm_compute DEPS activation_functions) -math_library(math_function DEPS cblas) +cc_library(blas SRCS blas.cc DEPS cblas framework_proto) +math_library(math_function DEPS blas) math_library(maxouting) math_library(pooling) math_library(selected_rows_functor DEPS selected_rows math_function) diff --git a/paddle/fluid/operators/math/blas.cc b/paddle/fluid/operators/math/blas.cc new file mode 100644 index 0000000000000000000000000000000000000000..3eeb77546b97a0337b46216d837a4f4cff12c89f --- /dev/null +++ b/paddle/fluid/operators/math/blas.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/math/blas.h" +namespace paddle { +namespace operators { +namespace math { +// Do nothing. Blas is a header only library. +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h new file mode 100644 index 0000000000000000000000000000000000000000..5cd2f855d1135e6dd8343efdaa9855d2526a3520 --- /dev/null +++ b/paddle/fluid/operators/math/blas.h @@ -0,0 +1,152 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" + +#ifdef PADDLE_WITH_MKLML +#include +#include +#include +#endif + +#ifdef PADDLE_USE_OPENBLAS +#include +#include +#endif + +#ifndef LAPACK_FOUND +extern "C" { +#include // NOLINT +int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda, + int* ipiv); +int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda, + int* ipiv); +int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda, + const int* ipiv); +int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, + const int* ipiv); +} +#endif + +namespace paddle { +namespace operators { +namespace math { + +template +class Blas { + public: + explicit Blas(const DeviceContext& context) : context_(context) {} + + template + void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + T alpha, const T* A, const T* B, T beta, T* C) const; + + template + void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A, + int lda, const T* B, int ldb, T beta, T* C, int ldc) const; + + template + void MatMul(const framework::Tensor& mat_a, bool trans_a, + const framework::Tensor& mat_b, bool trans_b, T alpha, + framework::Tensor* mat_out, T beta) const; + + template + void MatMul(const framework::Tensor& mat_a, bool trans_a, + const framework::Tensor& mat_b, bool trans_b, + framework::Tensor* mat_out) const { + MatMul(mat_a, trans_a, mat_b, trans_b, static_cast(1.0), mat_out, + static_cast(0.0)); + } + + template + void MatMul(const framework::Tensor& mat_a, const framework::Tensor& mat_b, + framework::Tensor* mat_out) const { + this->template MatMul(mat_a, false, mat_b, false, mat_out); + } + + template + void AXPY(int n, T alpha, const T* x, T* y) const; + + template + void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta, + T* C) const; + + template + void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, + int K, T alpha, const T* A, const T* B, T beta, T* C, + int batchCount, int64_t strideA, int64_t strideB) const; + + private: + const DeviceContext& context_; +}; + +template +class BlasT : private Blas { + public: + using Blas::Blas; + + template + void GEMM(ARGS... args) const { + Base()->template GEMM(args...); + } + + template + void MatMul(ARGS... args) const { + Base()->template MatMul(args...); + } + + template + void AXPY(ARGS... args) const { + Base()->template AXPY(args...); + } + + template + void GEMV(ARGS... args) const { + Base()->template GEMV(args...); + } + + template + void BatchedGEMM(ARGS... args) const { + Base()->template BatchedGEMM(args...); + } + + private: + const Blas* Base() const { + return static_cast*>(this); + } +}; + +template +inline BlasT GetBlas( + const framework::ExecutionContext& exe_ctx) { + return BlasT( + exe_ctx.template device_context()); +} + +template +inline BlasT GetBlas(const DeviceContext& dev_ctx) { + return BlasT(dev_ctx); +} + +} // namespace math +} // namespace operators +} // namespace paddle + +#include "paddle/fluid/operators/math/blas_impl.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/operators/math/blas_impl.cu.h" +#endif diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index 89935829ab35a52dd85bcaf906b53e41d576cf3f..c76fc17d78cce514b5e35ce8e5ca890d7cec1e98 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -30,6 +30,25 @@ struct CUBlas { static void GEMM(ARGS... args) { PADDLE_ENFORCE(platform::dynload::cublasSgemm(args...)); } + + template + static void AXPY(ARGS... args) { + PADDLE_ENFORCE(platform::dynload::cublasSaxpy(args...)); + } + + template + static void GEMV(ARGS... args) { + PADDLE_ENFORCE(platform::dynload::cublasSgemv(args...)); + } + + template + static void GEMM_BATCH(ARGS... args) { +#if CUDA_VERSION >= 8000 + PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(args...)); +#else + PADDLE_THROW("SgemmStridedBatched is not supported on cuda <= 7.5"); +#endif + } }; template <> @@ -38,6 +57,25 @@ struct CUBlas { static void GEMM(ARGS... args) { PADDLE_ENFORCE(platform::dynload::cublasDgemm(args...)); } + + template + static void AXPY(ARGS... args) { + PADDLE_ENFORCE(platform::dynload::cublasDaxpy(args...)); + } + + template + static void GEMV(ARGS... args) { + PADDLE_ENFORCE(platform::dynload::cublasDgemv(args...)); + } + + template + static void GEMM_BATCH(ARGS... args) { +#if CUDA_VERSION >= 8000 + PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(args...)); +#else + PADDLE_THROW("DgemmStridedBatched is not supported on cuda <= 7.5"); +#endif + } }; template <> @@ -57,16 +95,23 @@ struct CUBlas { reinterpret_cast(beta), reinterpret_cast<__half *>(C), ldc)); } + + template + static void GEMM_BATCH(ARGS... args) { +#if CUDA_VERSION >= 8000 + PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(args...)); +#else + PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5"); +#endif + } }; template <> template -void Blas::GEMM(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, - const int M, const int N, - const int K, const T alpha, - const T *A, const T *B, - const T beta, T *C) const { +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, int M, + int N, int K, T alpha, const T *A, + const T *B, T beta, T *C) const { // Note that cublas follows fortran order, so the order is different from // the cblas convention. int lda = (transA == CblasNoTrans) ? K : M; @@ -83,10 +128,10 @@ void Blas::GEMM(const CBLAS_TRANSPOSE transA, template <> template <> inline void Blas::GEMM( - const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, - const int N, const int K, const platform::float16 alpha, - const platform::float16 *A, const platform::float16 *B, - const platform::float16 beta, platform::float16 *C) const { + CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + platform::float16 alpha, const platform::float16 *A, + const platform::float16 *B, platform::float16 beta, + platform::float16 *C) const { // Note that cublas follows fortran order, so the order is different from // the cblas convention. int lda = (transA == CblasNoTrans) ? K : M; @@ -134,18 +179,58 @@ inline void Blas::GEMM( template <> template -void Blas::GEMM( - const bool transA, const bool transB, const int M, const int N, const int K, - const T alpha, const T *A, const int lda, const T *B, const int ldb, - const T beta, T *C, const int ldc) const { +void Blas::GEMM(bool transA, bool transB, int M, + int N, int K, T alpha, const T *A, + int lda, const T *B, int ldb, + T beta, T *C, int ldc) const { // Note that cublas follows fortran order, so the order is different from // the cblas convention. - cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc); } +template <> +template +void Blas::AXPY(int n, T alpha, const T *x, + T *y) const { + CUBlas::AXPY(context_.cublas_handle(), n, &alpha, x, 1, y, 1); +} + +template <> +template +void Blas::GEMV(bool trans_a, int M, int N, + T alpha, const T *A, const T *B, + T beta, T *C) const { + cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; + + CUBlas::GEMV(context_.cublas_handle(), cuTransA, N, M, &alpha, A, N, B, 1, + &beta, C, 1); +} + +template <> +template +void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + T alpha, const T *A, const T *B, T beta, T *C, int batchCount, + int64_t strideA, int64_t strideB) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const int64_t strideC = M * N; + + CUBlas::GEMM_BATCH(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, + &alpha, B, ldb, strideB, A, lda, strideA, &beta, C, ldc, + strideC, batchCount); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index f6d6669765865386116532c5c65c689aa170eaa6..7360cc0a90da499c372c6fb3f8d40a26f9093dd8 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once - +#include #include "paddle/fluid/operators/math/math_function.h" namespace paddle { @@ -28,6 +28,23 @@ struct CBlas { static void GEMM(ARGS... args) { cblas_sgemm(args...); } + + template + static void AXPY(ARGS... args) { + cblas_saxpy(args...); + } + + template + static void GEMV(ARGS... args) { + cblas_sgemv(args...); + } + +#ifdef PADDLE_WITH_MKLML + template + static void GEMM_BATCH(ARGS... args) { + cblas_sgemm_batch(args...); + } +#endif }; template <> @@ -36,21 +53,41 @@ struct CBlas { static void GEMM(ARGS... args) { cblas_dgemm(args...); } + + template + static void AXPY(ARGS... args) { + cblas_daxpy(args...); + } + + template + static void GEMV(ARGS... args) { + cblas_dgemv(args...); + } + +#ifdef PADDLE_WITH_MKLML + template + static void GEMM_BATCH(ARGS... args) { + cblas_dgemm_batch(args...); + } +#endif }; template <> struct CBlas { static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); } +#ifdef PADDLE_WITH_MKLML + static void GEMM_BATCH(...) { + PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); + } +#endif }; template <> template -void Blas::GEMM(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, - const int M, const int N, - const int K, const T alpha, - const T *A, const T *B, - const T beta, T *C) const { +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, int M, + int N, int K, T alpha, const T *A, + const T *B, T beta, T *C) const { int lda = (transA == CblasNoTrans) ? K : M; int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; @@ -60,15 +97,89 @@ void Blas::GEMM(const CBLAS_TRANSPOSE transA, template <> template -void Blas::GEMM( - const bool transA, const bool transB, const int M, const int N, const int K, - const T alpha, const T *A, const int lda, const T *B, const int ldb, - const T beta, T *C, const int ldc) const { +void Blas::GEMM(bool transA, bool transB, int M, + int N, int K, T alpha, const T *A, + int lda, const T *B, int ldb, + T beta, T *C, int ldc) const { CBlas::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } +template +template +void Blas::MatMul(const framework::Tensor &mat_a, bool trans_a, + const framework::Tensor &mat_b, bool trans_b, + T alpha, framework::Tensor *mat_out, + T beta) const { + auto dim_a = mat_a.dims(); + auto dim_b = mat_b.dims(); + auto dim_out = mat_out->dims(); + PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + PADDLE_ENFORCE( + mat_a.place() == mat_b.place() && mat_a.place() == mat_out->place(), + "The places of matrices must be same"); + + int M = dim_out[0]; + int N = dim_out[1]; + int K = !trans_a ? dim_a[1] : dim_a[0]; + + CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = !trans_b ? CblasNoTrans : CblasTrans; + + this->GEMM(transA, transB, M, N, K, alpha, mat_a.data(), mat_b.data(), + beta, mat_out->data()); +} + +template <> +template +void Blas::AXPY(int n, T alpha, const T *x, + T *y) const { + CBlas::AXPY(n, alpha, x, 1, y, 1); +} + +template <> +template +void Blas::GEMV(bool trans_a, int M, int N, T alpha, + const T *A, const T *B, T beta, + T *C) const { + CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; + CBlas::GEMV(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); +} + +template <> +template +void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + T alpha, const T *A, const T *B, T beta, T *C, int batchCount, + int64_t strideA, int64_t strideB) const { +#ifdef PADDLE_WITH_MKLML + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + auto a_array = std::vector(batchCount); + auto b_array = std::vector(batchCount); + auto c_array = std::vector(batchCount); + for (int k = 0; k < batchCount; ++k) { + a_array[k] = &A[k * strideA]; + b_array[k] = &B[k * strideB]; + c_array[k] = &C[k * M * N]; + } + + CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, + a_array.data(), &lda, b_array.data(), &ldb, &beta, + c_array.data(), &ldc, 1 /* group_count */, &batchCount); +#else + for (int k = 0; k < batchCount; ++k) { + const float *Ak = &A[k * strideA]; + const float *Bk = &B[k * strideB]; + float *Ck = &C[k * M * N]; + this->template GEMM(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck); + } +#endif +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/context_project.h b/paddle/fluid/operators/math/context_project.h index 027a019a284cac097eea50553e7d0dad5b09a218..bc0df3f3551c7a100d5d285cab585bb81c07fc5e 100644 --- a/paddle/fluid/operators/math/context_project.h +++ b/paddle/fluid/operators/math/context_project.h @@ -17,8 +17,8 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/im2col.h" -#include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { @@ -211,6 +211,7 @@ class ContextProjectGradFunctor { int input_row_begin, input_row_end; int sequence_height, sequence_width; sequence_width = in.dims()[1]; + auto blas = math::GetBlas(context); if (input_grad) { for (int i = 0; i < static_cast(lod_level_0.size()) - 1; ++i) { @@ -262,8 +263,8 @@ class ContextProjectGradFunctor { Tensor out_t_sub = out_t.Slice(k * context_length, k * context_length + padding_size); Tensor w_sub = padding_data->Slice(k, k + padding_size); - axpy(context, w_sub.numel(), static_cast(1), - out_t_sub.data(), w_sub.data()); + blas.AXPY(w_sub.numel(), static_cast(1), out_t_sub.data(), + w_sub.data()); } } if (down_pad > 0) { @@ -294,8 +295,8 @@ class ContextProjectGradFunctor { (down_pad_begin_row + t) * context_length); Tensor w_sub = padding_data->Slice( up_pad + padding_idx, up_pad + padding_idx + padding_size); - axpy(context, w_sub.numel(), static_cast(1), - out_t_sub.data(), w_sub.data()); + blas.AXPY(w_sub.numel(), static_cast(1), out_t_sub.data(), + w_sub.data()); } } out_t.Resize({sequence_height, context_length * sequence_width}); diff --git a/paddle/fluid/operators/math/gru_compute.cc b/paddle/fluid/operators/math/gru_compute.cc index d786250271231179b46ae704c9bd013efe26d910..0e15b81deef43a932d4b2d3f545393b0ad9e080c 100644 --- a/paddle/fluid/operators/math/gru_compute.cc +++ b/paddle/fluid/operators/math/gru_compute.cc @@ -10,9 +10,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/gru_compute.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" #include "paddle/fluid/operators/math/detail/gru_kernel.h" -#include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/gru_compute.cu b/paddle/fluid/operators/math/gru_compute.cu index f26bec41095789c197841f4d8362a229b07a2af0..1327d914952d57aab6e5d17090d0ea976a6d4755 100644 --- a/paddle/fluid/operators/math/gru_compute.cu +++ b/paddle/fluid/operators/math/gru_compute.cu @@ -10,10 +10,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/detail/gru_gpu_kernel.h" #include "paddle/fluid/operators/math/detail/gru_kernel.h" #include "paddle/fluid/operators/math/gru_compute.h" -#include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index b63676f961bcd488797aca887c281a7d351cfca0..d62ea387cc55c7399973b6f35bace491a49666dc 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -24,200 +24,6 @@ namespace math { using float16 = paddle::platform::float16; -template <> -void matmul( - const platform::CPUDeviceContext& context, - const framework::Tensor& matrix_a, bool trans_a, - const framework::Tensor& matrix_b, bool trans_b, float16 alpha, - framework::Tensor* matrix_out, float16 beta) { - PADDLE_THROW("float16 matmul not supported on CPU"); -} - -template <> -void matmul( - const platform::CPUDeviceContext& context, - const framework::Tensor& matrix_a, bool trans_a, - const framework::Tensor& matrix_b, bool trans_b, float alpha, - framework::Tensor* matrix_out, float beta) { - auto dim_a = matrix_a.dims(); - auto dim_b = matrix_b.dims(); - auto dim_out = matrix_out->dims(); - PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, - "The input and output of matmul be matrix"); - - PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) && - platform::is_cpu_place(matrix_b.place()) && - platform::is_cpu_place(matrix_out->place()), - "Matrix must all be in CPUPlace"); - - int M = dim_out[0]; - int N = dim_out[1]; - int K = (trans_a == false) ? dim_a[1] : dim_a[0]; - - CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; - - Blas(context).GEMM( - transA, transB, M, N, K, alpha, matrix_a.data(), - matrix_b.data(), beta, matrix_out->data()); -} - -template <> -void matmul( - const platform::CPUDeviceContext& context, - const framework::Tensor& matrix_a, bool trans_a, - const framework::Tensor& matrix_b, bool trans_b, double alpha, - framework::Tensor* matrix_out, double beta) { - auto dim_a = matrix_a.dims(); - auto dim_b = matrix_b.dims(); - auto dim_out = matrix_out->dims(); - PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, - "The input and output of matmul be matrix"); - - PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) && - platform::is_cpu_place(matrix_b.place()) && - platform::is_cpu_place(matrix_out->place()), - "Matrix must all be in CPUPlace"); - - int M = dim_out[0]; - int N = dim_out[1]; - int K = (trans_a == false) ? dim_a[1] : dim_a[0]; - - CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; - - Blas(context).GEMM( - transA, transB, M, N, K, alpha, matrix_a.data(), - matrix_b.data(), beta, matrix_out->data()); -} - -template <> -void batched_gemm( - const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, - const float16 alpha, const float16* A, const float16* B, const float16 beta, - float16* C, const int batchCount, const int64_t strideA, - const int64_t strideB) { - PADDLE_THROW("float16 batched_gemm not supported on CPU"); -} - -#ifdef PADDLE_WITH_MKLML -// Use cblas_{s,d}gemm_batched if available: Run with 1 group of size batchSize. -template <> -void batched_gemm( - const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, - const float alpha, const float* A, const float* B, const float beta, - float* C, const int batchCount, const int64_t strideA, - const int64_t strideB) { - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - auto a_array = std::vector(batchCount); - auto b_array = std::vector(batchCount); - auto c_array = std::vector(batchCount); - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA]; - b_array[k] = &B[k * strideB]; - c_array[k] = &C[k * M * N]; - } - cblas_sgemm_batch(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, - a_array.data(), &lda, b_array.data(), &ldb, &beta, - c_array.data(), &ldc, 1 /* group_count */, &batchCount); -} - -template <> -void batched_gemm( - const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, - const double alpha, const double* A, const double* B, const double beta, - double* C, const int batchCount, const int64_t strideA, - const int64_t strideB) { - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - auto a_array = std::vector(batchCount); - auto b_array = std::vector(batchCount); - auto c_array = std::vector(batchCount); - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA]; - b_array[k] = &B[k * strideB]; - c_array[k] = &C[k * M * N]; - } - cblas_dgemm_batch(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, - a_array.data(), &lda, b_array.data(), &ldb, &beta, - c_array.data(), &ldc, 1 /* group_count */, &batchCount); -} -#else -// The below is a naive but correct serial implementation that just loops -// over the batch dimension. This is a fallback for when the batched gemm -// functions of Intel MKL are not available. In the future, this computation -// should be parallelized. -template <> -void batched_gemm( - const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, - const float alpha, const float* A, const float* B, const float beta, - float* C, const int batchCount, const int64_t strideA, - const int64_t strideB) { - for (int k = 0; k < batchCount; ++k) { - const float* Ak = &A[k * strideA]; - const float* Bk = &B[k * strideB]; - float* Ck = &C[k * M * N]; - Blas(context).GEMM(transA, transB, M, N, K, - alpha, Ak, Bk, beta, Ck); - } -} - -template <> -void batched_gemm( - const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, - const double alpha, const double* A, const double* B, const double beta, - double* C, const int batchCount, const int64_t strideA, - const int64_t strideB) { - for (int k = 0; k < batchCount; ++k) { - const double* Ak = &A[k * strideA]; - const double* Bk = &B[k * strideB]; - double* Ck = &C[k * M * N]; - Blas(context).GEMM(transA, transB, M, N, K, - alpha, Ak, Bk, beta, Ck); - } -} -#endif - -template <> -void gemv( - const platform::CPUDeviceContext& context, const bool trans_a, const int M, - const int N, const float alpha, const float* A, const float* B, - const float beta, float* C) { - CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; - cblas_sgemv(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); -} - -template <> -void gemv( - const platform::CPUDeviceContext& context, const bool trans_a, const int M, - const int N, const double alpha, const double* A, const double* B, - const double beta, double* C) { - CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; - cblas_dgemv(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); -} - -template <> -void axpy( - const platform::CPUDeviceContext& context, const int n, const float alpha, - const float* x, float* y) { - cblas_saxpy(n, alpha, x, 1, y, 1); -} - -template <> -void axpy( - const platform::CPUDeviceContext& context, const int n, const double alpha, - const double* x, double* y) { - cblas_daxpy(n, alpha, x, 1, y, 1); -} - template struct SetConstant; template struct SetConstant; template struct SetConstant; diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index 7bf816ac190a9b848b12ea07e655449802a26bc3..b5bf84e5178c143de35ec6dcb16b1bde5577c166 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -15,6 +15,7 @@ limitations under the License. */ #define EIGEN_USE_GPU #include #include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function_impl.h" #include "paddle/fluid/platform/float16.h" @@ -25,223 +26,6 @@ namespace math { using float16 = paddle::platform::float16; -template <> -void matmul( - const platform::CUDADeviceContext& context, - const framework::Tensor& matrix_a, bool trans_a, - const framework::Tensor& matrix_b, bool trans_b, float16 alpha, - framework::Tensor* matrix_out, float16 beta) { - auto dim_a = matrix_a.dims(); - auto dim_b = matrix_b.dims(); - auto dim_out = matrix_out->dims(); - PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, - "The input and output of matmul be matrix"); - - PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) && - platform::is_gpu_place(matrix_b.place()) && - platform::is_gpu_place(matrix_out->place()), - "Matrix must all be in CUDAPlace"); - - int M = dim_out[0]; - int N = dim_out[1]; - int K = (trans_a == false) ? dim_a[1] : dim_a[0]; - - CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; - - Blas(context).GEMM( - transA, transB, M, N, K, alpha, matrix_a.data(), - matrix_b.data(), beta, matrix_out->data()); -} - -template <> -void matmul( - const platform::CUDADeviceContext& context, - const framework::Tensor& matrix_a, bool trans_a, - const framework::Tensor& matrix_b, bool trans_b, float alpha, - framework::Tensor* matrix_out, float beta) { - auto dim_a = matrix_a.dims(); - auto dim_b = matrix_b.dims(); - auto dim_out = matrix_out->dims(); - PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, - "The input and output of matmul be matrix"); - - PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) && - platform::is_gpu_place(matrix_b.place()) && - platform::is_gpu_place(matrix_out->place()), - "Matrix must all be in CUDAPlace"); - - int M = dim_out[0]; - int N = dim_out[1]; - int K = (trans_a == false) ? dim_a[1] : dim_a[0]; - - CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; - - Blas(context).GEMM( - transA, transB, M, N, K, alpha, matrix_a.data(), - matrix_b.data(), beta, matrix_out->data()); -} - -template <> -void matmul( - const platform::CUDADeviceContext& context, - const framework::Tensor& matrix_a, bool trans_a, - const framework::Tensor& matrix_b, bool trans_b, double alpha, - framework::Tensor* matrix_out, double beta) { - auto dim_a = matrix_a.dims(); - auto dim_b = matrix_b.dims(); - auto dim_out = matrix_out->dims(); - PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, - "The input and output of matmul be matrix"); - - PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) && - platform::is_gpu_place(matrix_b.place()) && - platform::is_gpu_place(matrix_out->place()), - "Matrix must all be in CUDAPlace"); - - int M = dim_out[0]; - int N = dim_out[1]; - int K = (trans_a == false) ? dim_a[1] : dim_a[0]; - - CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; - - Blas(context).GEMM( - transA, transB, M, N, K, alpha, matrix_a.data(), - matrix_b.data(), beta, matrix_out->data()); -} - -template <> -void batched_gemm( - const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, - const float16 alpha, const float16* A, const float16* B, const float16 beta, - float16* C, const int batchCount, const int64_t strideA, - const int64_t strideB) { -#if CUDA_VERSION >= 8000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int64_t strideC = M * N; - - const half h_alpha = static_cast(alpha); - const half h_beta = static_cast(beta); - const half* h_A = reinterpret_cast(A); - const half* h_B = reinterpret_cast(B); - half* h_C = reinterpret_cast(C); - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53, - "cublas Hgemm requires GPU compute capability >= 53"); - - PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched( - context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb, - strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount)); -#else - PADDLE_ENFORCE(false, "HgemmStridedBatched is not supported on cuda <= 7.5"); -#endif -} - -template <> -void batched_gemm( - const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, - const float alpha, const float* A, const float* B, const float beta, - float* C, const int batchCount, const int64_t strideA, - const int64_t strideB) { -#if CUDA_VERSION >= 8000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int64_t strideC = M * N; - - PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched( - context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, - strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount)); -#else - PADDLE_ENFORCE(false, "SgemmStridedBatched is not supported on cuda <= 7.5"); -#endif -} - -template <> -void batched_gemm( - const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, - const double alpha, const double* A, const double* B, const double beta, - double* C, const int batchCount, const int64_t strideA, - const int64_t strideB) { -#if CUDA_VERSION >= 8000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int64_t strideC = M * N; - - PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched( - context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, - strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount)); -#else - PADDLE_ENFORCE(false, "DgemmStridedBatched is not supported on cuda <= 7.5"); -#endif -} - -template <> -void gemv( - const platform::CUDADeviceContext& context, const bool trans_a, const int M, - const int N, const float alpha, const float* A, const float* B, - const float beta, float* C) { - cublasOperation_t cuTransA = (trans_a == false) ? CUBLAS_OP_T : CUBLAS_OP_N; - - PADDLE_ENFORCE(platform::dynload::cublasSgemv(context.cublas_handle(), - cuTransA, N, M, &alpha, A, N, B, - 1, &beta, C, 1)); -} - -template <> -void gemv( - const platform::CUDADeviceContext& context, const bool trans_a, const int M, - const int N, const double alpha, const double* A, const double* B, - const double beta, double* C) { - cublasOperation_t cuTransA = (trans_a == false) ? CUBLAS_OP_T : CUBLAS_OP_N; - PADDLE_ENFORCE(platform::dynload::cublasDgemv(context.cublas_handle(), - cuTransA, N, M, &alpha, A, N, B, - 1, &beta, C, 1)); -} - -template <> -void axpy( - const platform::CUDADeviceContext& context, const int n, const float alpha, - const float* x, float* y) { - PADDLE_ENFORCE(platform::dynload::cublasSaxpy(context.cublas_handle(), n, - &alpha, x, 1, y, 1)); -} - -template <> -void axpy( - const platform::CUDADeviceContext& context, const int n, const double alpha, - const double* x, double* y) { - PADDLE_ENFORCE(platform::dynload::cublasDaxpy(context.cublas_handle(), n, - &alpha, x, 1, y, 1)); -} - template struct SetConstant; template struct SetConstant; template struct SetConstant; @@ -333,10 +117,9 @@ void ColwiseSum::operator()( one.mutable_data({in_dims[0]}, context.GetPlace()); SetConstant set; set(context, &one, static_cast(1.0)); - gemv( - context, true, static_cast(in_dims[0]), static_cast(in_dims[1]), - 1.0, input.data(), one.data(), 0.0, - vector->data()); + GetBlas(context).GEMV( + true, static_cast(in_dims[0]), static_cast(in_dims[1]), 1.0, + input.data(), one.data(), 0.0, vector->data()); } template struct RowwiseSum; @@ -355,10 +138,9 @@ void RowwiseSum::operator()( one.mutable_data({size}, context.GetPlace()); SetConstant set; set(context, &one, static_cast(1.0)); - gemv( - context, true, static_cast(in_dims[1]), static_cast(in_dims[0]), - 1.0, one.data(), input.data(), 0.0, - vector->data()); + GetBlas(context).GEMV( + true, static_cast(in_dims[1]), static_cast(in_dims[0]), 1.0, + one.data(), input.data(), 0.0, vector->data()); } template struct RowwiseMean; diff --git a/paddle/fluid/operators/math/math_function.h b/paddle/fluid/operators/math/math_function.h index 9950c09ea618d6c4250d66beb480d6f707813b54..d4b0e17ed44da61e2633b9bd97faeb62f9967c3c 100644 --- a/paddle/fluid/operators/math/math_function.h +++ b/paddle/fluid/operators/math/math_function.h @@ -51,78 +51,6 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, namespace paddle { namespace operators { namespace math { - -// Support continuous memory now -// If transA = N, and transB = N -// Then matrixA: M * K, matrixB: K * N, matrixC : M * N -// For more detailed info, please refer to -// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html - -template -class Blas { - public: - explicit Blas(const DeviceContext& context) : context_(context) {} - - template - void GEMM(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, - const int M, const int N, const int K, const T alpha, const T* A, - const T* B, const T beta, T* C) const; - - template - void GEMM(const bool transA, const bool transB, const int M, const int N, - const int K, const T alpha, const T* A, const int lda, const T* B, - const int ldb, const T beta, T* C, const int ldc) const; - - private: - const DeviceContext& context_; -}; - -template -class BlasT : private Blas { - public: - using Blas::Blas; - - template - void GEMM(ARGS... args) const { - static_cast*>(this)->template GEMM(args...); - } -}; - -template -inline BlasT GetBlas( - const framework::ExecutionContext& exe_ctx) { - return BlasT( - exe_ctx.template device_context()); -} - -template -inline BlasT GetBlas(const DeviceContext& dev_ctx) { - return BlasT(dev_ctx); -} - -// matrix multiply with continuous memory -template -void matmul(const DeviceContext& context, const framework::Tensor& matrix_a, - bool trans_a, const framework::Tensor& matrix_b, bool trans_b, - T alpha, framework::Tensor* matrix_out, T beta); - -// Batched gemm -template -void batched_gemm(const DeviceContext& context, const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, const int M, const int N, - const int K, const T alpha, const T* A, const T* B, - const T beta, T* C, const int batchCount, - const int64_t strideA, const int64_t strideB); - -template -void gemv(const DeviceContext& context, const bool trans_a, const int M, - const int N, const T alpha, const T* A, const T* B, const T beta, - T* C); - -template -void axpy(const DeviceContext& context, const int n, const T alpha, const T* x, - T* y); - template struct Transpose { void operator()(const DeviceContext& context, const framework::Tensor& in, @@ -169,8 +97,3 @@ struct RowwiseMean { } // namespace math } // namespace operators } // namespace paddle - -#include "paddle/fluid/operators/math/blas_impl.h" -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/operators/math/blas_impl.cu.h" -#endif diff --git a/paddle/fluid/operators/math/math_function_test.cc b/paddle/fluid/operators/math/math_function_test.cc index 6d11dc8c76799a72bd144e4103a6c65d5c94a649..3719a264e90ea7d1a99eb9589ce4fd0d8e074781 100644 --- a/paddle/fluid/operators/math/math_function_test.cc +++ b/paddle/fluid/operators/math/math_function_test.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/operators/math/math_function.h" #include "gtest/gtest.h" +#include "paddle/fluid/operators/math/blas.h" template inline paddle::operators::math::BlasT @@ -129,9 +130,8 @@ void GemvTest(int m, int n, bool trans) { } paddle::platform::CPUDeviceContext context(*cpu_place); - paddle::operators::math::gemv( - context, trans, static_cast(m), static_cast(n), 1., data_a, - data_b, 0., data_c); + GetBlas(context).GEMV(trans, static_cast(m), static_cast(n), 1., + data_a, data_b, 0., data_c); if (!trans) { for (int i = 0; i < m; ++i) { diff --git a/paddle/fluid/operators/math/math_function_test.cu b/paddle/fluid/operators/math/math_function_test.cu index 3d03981b9f8a5ee5c302acce3d31157a16d8b67b..bcbb4a8274f149240b9f0990f38d9f38bdd0e5b1 100644 --- a/paddle/fluid/operators/math/math_function_test.cu +++ b/paddle/fluid/operators/math/math_function_test.cu @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "gtest/gtest.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/device_context.h" @@ -23,6 +24,13 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size, } } +template +inline paddle::operators::math::BlasT +GetBlas(const paddle::platform::CUDADeviceContext& context) { + return paddle::operators::math::GetBlas(context); +} + TEST(math_function, notrans_mul_trans_fp32) { paddle::framework::Tensor input1; paddle::framework::Tensor input1_gpu; @@ -42,9 +50,8 @@ TEST(math_function, notrans_mul_trans_fp32) { paddle::framework::TensorCopySync(input1, gpu_place, &input2_gpu); out_gpu.mutable_data({2, 2}, gpu_place); - - paddle::operators::math::matmul( - context, input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0); + GetBlas(context).MatMul(input1_gpu, false, input2_gpu, true, 1, + &out_gpu, 0); paddle::framework::TensorCopySync(out_gpu, cpu_place, &out); @@ -81,10 +88,9 @@ TEST(math_function, notrans_mul_trans_fp16) { out_gpu.mutable_data({2, 2}, gpu_place); - paddle::operators::math::matmul( - context, input1_gpu, false, input2_gpu, true, - paddle::platform::float16(1), &out_gpu, paddle::platform::float16(0)); + GetBlas(context).MatMul( + input1_gpu, false, input2_gpu, true, paddle::platform::float16(1), + &out_gpu, paddle::platform::float16(0)); paddle::framework::TensorCopySync(out_gpu, cpu_place, &out); @@ -116,8 +122,8 @@ TEST(math_function, trans_mul_notrans_fp32) { out_gpu.mutable_data({3, 3}, gpu_place); - paddle::operators::math::matmul( - context, input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0); + GetBlas(context).MatMul(input1_gpu, true, input2_gpu, false, 1, + &out_gpu, 0); paddle::framework::TensorCopySync(out_gpu, cpu_place, &out); @@ -159,10 +165,9 @@ TEST(math_function, trans_mul_notrans_fp16) { out_gpu.mutable_data({3, 3}, gpu_place); - paddle::operators::math::matmul( - context, input1_gpu, true, input2_gpu, false, - paddle::platform::float16(1), &out_gpu, paddle::platform::float16(0)); + GetBlas(context).MatMul( + input1_gpu, true, input2_gpu, false, paddle::platform::float16(1), + &out_gpu, paddle::platform::float16(0)); paddle::framework::TensorCopySync(out_gpu, cpu_place, &out); @@ -179,13 +184,6 @@ TEST(math_function, trans_mul_notrans_fp16) { EXPECT_EQ(static_cast(out_ptr[8]), 29); } -template -inline paddle::operators::math::BlasT -GetBlas(const paddle::platform::CUDADeviceContext& context) { - return paddle::operators::math::GetBlas(context); -} - TEST(math_function, gemm_notrans_cublas_fp32) { paddle::framework::Tensor input1; paddle::framework::Tensor input2; @@ -437,9 +435,8 @@ void GemvTest(int m, int n, bool trans) { paddle::framework::TensorCopySync(mat_a, gpu_place, &g_mat_a); paddle::framework::TensorCopySync(vec_b, gpu_place, &g_vec_b); - paddle::operators::math::gemv( - context, trans, static_cast(m), static_cast(n), 1., g_data_a, - g_data_b, 0., g_data_c); + GetBlas(context).GEMV(trans, static_cast(m), static_cast(n), 1., + g_data_a, g_data_b, 0., g_data_c); paddle::framework::TensorCopySync(g_vec_c, cpu_place, &vec_c); diff --git a/paddle/fluid/operators/math/matmul.h b/paddle/fluid/operators/math/matmul.h index 67efd1be5322b633e5dbc804e6b0a3db6519f497..87fd38a324e007bcc939c31b6ae8e5d38c3e658c 100644 --- a/paddle/fluid/operators/math/matmul.h +++ b/paddle/fluid/operators/math/matmul.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include #include -#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { @@ -129,16 +129,17 @@ class MatMulFunctor { CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; + auto blas = GetBlas(context); + if (!batchCount) { // regular matrix multiplication - Blas(context).GEMM(transA, transB, M, N, kA, alpha, - a.data(), b.data(), beta, - out->data()); + blas.GEMM(transA, transB, M, N, kA, alpha, a.data(), b.data(), beta, + out->data()); } else { // batched matrix multiplication - batched_gemm( - context, transA, transB, M, N, kA, alpha, a.data(), b.data(), - beta, out->data(), batchCount, strideA, strideB); + blas.BatchedGEMM(transA, transB, M, N, kA, alpha, a.data(), + b.data(), beta, out->data(), batchCount, strideA, + strideB); } } }; diff --git a/paddle/fluid/operators/mul_op.h b/paddle/fluid/operators/mul_op.h index b1260d36ebe11f65529ac274c959479dcb38ee5f..15dd975e3bbf80b2e616e6628555e812d025f70a 100644 --- a/paddle/fluid/operators/mul_op.h +++ b/paddle/fluid/operators/mul_op.h @@ -14,9 +14,9 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/math/math_function.h" - #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { @@ -46,9 +46,10 @@ class MulKernel : public framework::OpKernel { if (z_dim.size() != 2) { z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); } - math::matmul( - context.template device_context(), x_matrix, false, - y_matrix, false, static_cast(1), z, static_cast(0)); + + auto blas = math::GetBlas(context); + + blas.MatMul(x_matrix, y_matrix, z); if (z_dim.size() != 2) { z->Resize(z_dim); } @@ -79,6 +80,7 @@ class MulGradKernel : public framework::OpKernel { Tensor* dx = ctx.Output(framework::GradVarName("X")); Tensor* dy = ctx.Output(framework::GradVarName("Y")); auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); if (dx) { dx->mutable_data(ctx.GetPlace()); Tensor dx_matrix = dx->dims().size() > 2 @@ -86,8 +88,7 @@ class MulGradKernel : public framework::OpKernel { : *dx; // dx = dout * y'. dx: M x K, dout : M x N, y : K x N - math::matmul(dev_ctx, dout_mat, false, y_matrix, true, - 1, &dx_matrix, 0); + blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix); } if (dy) { dy->mutable_data(ctx.GetPlace()); @@ -95,8 +96,7 @@ class MulGradKernel : public framework::OpKernel { ? framework::ReshapeToMatrix(*dy, y_num_col_dims) : *dy; // dy = x' * dout. dy K x N, dout : M x N, x : M x K - math::matmul(dev_ctx, x_matrix, true, dout_mat, false, - 1, &dy_matrix, 0); + blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix); } } }; diff --git a/paddle/fluid/operators/sequence_conv_op.h b/paddle/fluid/operators/sequence_conv_op.h index 3916cdbb6a69c5a18f7a21ec60bad2732b4c3e58..ee70281d51673b94a1451f636e607fad3404863b 100644 --- a/paddle/fluid/operators/sequence_conv_op.h +++ b/paddle/fluid/operators/sequence_conv_op.h @@ -58,17 +58,15 @@ class SequenceConvKernel : public framework::OpKernel { // Because if padding_trainable is false, padding data should be zeros. math::SetConstant set_zero; auto& dev_ctx = context.template device_context(); + auto blas = math::GetBlas(dev_ctx); set_zero(dev_ctx, &col, static_cast(0)); - math::ContextProjectFunctor seq_project_functor; seq_project_functor(dev_ctx, *in, *padding_data, padding_trainable, context_start, context_length, context_stride, up_pad, down_pad, &col); - math::matmul(dev_ctx, col, false, filter, false, - static_cast(1.0), out, - static_cast(0.0)); + blas.MatMul(col, filter, out); } }; @@ -99,6 +97,7 @@ class SequenceConvGradKernel : public framework::OpKernel { math::SetConstant set_zero; auto& dev_ctx = context.template device_context(); + auto blas = math::GetBlas(dev_ctx); // use col_shape in the im2col calculation framework::DDim col_shape = {in->dims()[0], sequence_width * context_length}; @@ -108,8 +107,7 @@ class SequenceConvGradKernel : public framework::OpKernel { col.mutable_data(col_shape, context.GetPlace()); // Because if padding_trainable is false, padding data should be zeros. set_zero(dev_ctx, &col, static_cast(0)); - math::matmul(dev_ctx, *out_g, false, *filter, true, - T(1.0), &col, T(1.0)); + blas.MatMul(*out_g, false, *filter, true, &col); } math::ContextProjectFunctor seq_project_functor; math::ContextProjectGradFunctor seq_project_grad_functor; @@ -150,8 +148,7 @@ class SequenceConvGradKernel : public framework::OpKernel { context_start, context_length, context_stride, up_pad, down_pad, &col); - math::matmul(dev_ctx, col, true, out_grad, false, - T(1.0), &filter_grad, T(1.0)); + blas.MatMul(col, true, out_grad, false, &filter_grad); } } };