From 3f4917f69ad4ad91afa7d8451992e305f78dd2b5 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Fri, 10 Mar 2023 10:38:59 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=20No.67=E3=80=91remove=20ope?= =?UTF-8?q?rator.h=20in=20blas.h=20(#50989)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * remove operator.h from blas.h and remove paddle::framework::ExecutionContext * remove the deps for GetBlas(exe_ctx) * fix error --- paddle/fluid/operators/attention_lstm_op.cc | 4 ++-- paddle/fluid/operators/center_loss_op.h | 2 +- paddle/fluid/operators/fsp_op.h | 6 ++++-- .../operators/fused/fused_embedding_fc_lstm_op.cc | 3 ++- .../operators/fused/fused_embedding_seq_pool_op.h | 7 +++++-- paddle/fluid/operators/fused/fusion_gru_op.cc | 3 ++- paddle/fluid/operators/fused/fusion_lstm_op.cc | 4 ++-- .../fused/fusion_seqexpand_concat_fc_op.cc | 4 ++-- paddle/fluid/operators/gru_unit_op.h | 6 ++++-- paddle/fluid/operators/index_select_op.h | 3 ++- paddle/fluid/operators/lookup_table_op.h | 8 ++++++-- paddle/fluid/operators/lookup_table_v2_op.h | 3 ++- paddle/fluid/operators/lrn_op.cc | 4 ++-- paddle/fluid/operators/match_matrix_tensor_op.cc | 8 +++++--- paddle/fluid/operators/matmul_op.cc | 8 +++++--- paddle/fluid/operators/search_compute.h | 3 ++- paddle/fluid/operators/svd_helper.h | 3 ++- paddle/fluid/operators/var_conv_2d_op.cc | 6 ++++-- paddle/phi/kernels/funcs/blas/blas.h | 8 -------- paddle/phi/kernels/funcs/math_function.cu | 15 --------------- 20 files changed, 54 insertions(+), 54 deletions(-) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index c4617138553..5ee8b2c7efb 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -424,10 +424,10 @@ class AttentionLSTMKernel : public framework::OpKernel { T* lstm_x_data = lstm_x->mutable_data(ctx.GetPlace()); T* lstm_out_data = lstm_out->mutable_data(ctx.GetPlace()); - auto blas = phi::funcs::GetBlas(ctx); - // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); + phi::funcs::FCFunctor fc; fc(dev_ctx, total_T, diff --git a/paddle/fluid/operators/center_loss_op.h b/paddle/fluid/operators/center_loss_op.h index 7632482c97b..2db3202b41a 100644 --- a/paddle/fluid/operators/center_loss_op.h +++ b/paddle/fluid/operators/center_loss_op.h @@ -86,7 +86,7 @@ class CenterLossKernel : public framework::OpKernel { int numel = centers_diffacc.numel(); std::memset(centers_diffacc_data, 0, sizeof(T) * numel); - auto blas = phi::funcs::GetBlas(ctx); + auto blas = phi::funcs::GetBlas(dev_ctx); int tLabel; const T *x_index; diff --git a/paddle/fluid/operators/fsp_op.h b/paddle/fluid/operators/fsp_op.h index c5b903559a0..2136bc19336 100644 --- a/paddle/fluid/operators/fsp_op.h +++ b/paddle/fluid/operators/fsp_op.h @@ -37,7 +37,8 @@ class FSPOpKernel : public framework::OpKernel { auto height = x_dims[2]; auto width = x_dims[3]; - auto blas = phi::funcs::GetBlas(context); + auto& dev_ctx = context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); phi::funcs::MatDescriptor x_mat_desc; x_mat_desc.height_ = x_channel; @@ -81,7 +82,8 @@ class FSPGradOpKernel : public framework::OpKernel { int64_t h = 0; int64_t w = 0; - auto blas = phi::funcs::GetBlas(context); + auto& dev_ctx = context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); phi::funcs::SetConstant set_zero; if (d_x != nullptr) { d_x->mutable_data(context.GetPlace()); diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc index bec18220e9a..07207b6e028 100644 --- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc @@ -411,7 +411,8 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel { T* xx_data = xx->mutable_data(place); T* h_out_data = hidden_out->mutable_data(place); T* c_out_data = cell_out->mutable_data(place); - auto blas = phi::funcs::GetBlas(ctx); + auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); for (int64_t i = 0; i < ids_numel; ++i) { PADDLE_ENFORCE_LT( diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h index 7fdfe706ef6..8c86f9d5f47 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h @@ -197,7 +197,9 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel { const int m = batch_size * idx_width; const int n = table_width; const int k = table_height; - auto blas = phi::funcs::GetBlas(context); + + auto &dev_ctx = context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); blas.CSRMM(&transa, &m, &n, @@ -316,7 +318,8 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel { padding_idx); auto *d_output_data = d_output->data(); - auto blas = phi::funcs::GetBlas(context); + auto &dev_ctx = context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); int width = static_cast(table_dim[1]); int num_seq = batch_size * idx_width; LOG(INFO) << "num seq = " << num_seq << " width = " << width; diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index 932a777d3e1..55f60c500b9 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -310,9 +310,10 @@ class FusionGRUKernel : public framework::OpKernel { const T* h0_data = h0 ? h0->data() : nullptr; const T* wh_state_data = wh_data + D * D2; T* hidden_out_data = hidden_out->mutable_data(place); - auto blas = phi::funcs::GetBlas(ctx); auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); + phi::funcs::FCFunctor fc; fc(dev_ctx, total_T, diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index a1a0d490dd7..112a65a24ea 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -377,9 +377,9 @@ class FuisonLSTMKernel : public framework::OpKernel { T* xx_data = xx->mutable_data(place); T* h_out_data = hidden_out->mutable_data(place); T* c_out_data = cell_out->mutable_data(place); - auto blas = phi::funcs::GetBlas(ctx); - auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); + phi::funcs::FCFunctor fc; fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data()); diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc index 86eb7053f88..08469276836 100644 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc @@ -239,9 +239,9 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel { T* out_data = out->mutable_data(ctx.GetPlace()); T* fc_out_data = fc_out->mutable_data(ctx.GetPlace()); - auto blas = phi::funcs::GetBlas(ctx); - auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); + phi::funcs::FCFunctor fc; fc(dev_ctx, total_T, diff --git a/paddle/fluid/operators/gru_unit_op.h b/paddle/fluid/operators/gru_unit_op.h index d46e6cf429f..c22a82c5ae8 100644 --- a/paddle/fluid/operators/gru_unit_op.h +++ b/paddle/fluid/operators/gru_unit_op.h @@ -89,7 +89,8 @@ class GRUUnitKernel : public framework::OpKernel { const T* weight_data = weight->data(); T* gate_data = gate->data(); T* reset_hidden_prev_data = reset_hidden_prev->data(); - auto blas = phi::funcs::GetBlas(context); + auto& dev_ctx = context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); blas.GEMM(false, false, batch_size, @@ -251,7 +252,8 @@ class GRUUnitGradKernel : public framework::OpKernel { d_h * u); } // backward for reset_hidden_prev - auto blas = phi::funcs::GetBlas(context); + auto& dev_ctx = context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); blas.GEMM(false, true, batch_size, diff --git a/paddle/fluid/operators/index_select_op.h b/paddle/fluid/operators/index_select_op.h index ad1542666fd..b13d83a57ee 100644 --- a/paddle/fluid/operators/index_select_op.h +++ b/paddle/fluid/operators/index_select_op.h @@ -119,7 +119,8 @@ struct IndexSelectAdd< const T* src_pointer, const T* p_pointer, T* dist_pointer) { - auto blas = phi::funcs::GetBlas(ctx); + auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); blas.VADD(slice_size, src_pointer, p_pointer, dist_pointer); } }; diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 04153eecc39..b467428eeaf 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -114,7 +114,9 @@ class LookupTableKernel : public framework::OpKernel { table + id_index * row_width, row_width * sizeof(T)); } else { - auto blas = phi::funcs::GetBlas(context); + auto &dev_ctx = + context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); blas.VCOPY(row_width, table + id_index * row_width, output + i * row_width); @@ -145,7 +147,9 @@ class LookupTableKernel : public framework::OpKernel { table + id_index * row_width, row_width * sizeof(T)); } else { - auto blas = phi::funcs::GetBlas(context); + auto &dev_ctx = + context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); blas.VCOPY(row_width, table + id_index * row_width, output + i * row_width); diff --git a/paddle/fluid/operators/lookup_table_v2_op.h b/paddle/fluid/operators/lookup_table_v2_op.h index f43fccb19e0..52c93f26b7e 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.h +++ b/paddle/fluid/operators/lookup_table_v2_op.h @@ -130,7 +130,8 @@ struct LookupTableV2CPUFunctor { table + id_index * row_width, row_width * sizeof(T)); } else { - auto blas = phi::funcs::GetBlas(context_); + auto &dev_ctx = context_.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); blas.VCOPY(row_width, table + id_index * row_width, output + i * row_width); diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 5a6ed730477..96d5a115991 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -45,9 +45,9 @@ struct LRNFunctor { T beta, const DataLayout data_layout) { auto place = ctx.GetPlace(); - auto blas = phi::funcs::GetBlas(ctx); - phi::funcs::Transpose transpose; auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); + phi::funcs::Transpose transpose; phi::DenseTensor in_transpose, mid_transpose, out_transpose; // if channel_last, transpose to channel_first if (data_layout == DataLayout::kNHWC) { diff --git a/paddle/fluid/operators/match_matrix_tensor_op.cc b/paddle/fluid/operators/match_matrix_tensor_op.cc index 3473a051b73..773d9f223f8 100644 --- a/paddle/fluid/operators/match_matrix_tensor_op.cc +++ b/paddle/fluid/operators/match_matrix_tensor_op.cc @@ -275,7 +275,8 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel { memset( bottom_l_trans_data, 0.0, tmp->dims()[0] * tmp->dims()[1] * sizeof(T)); - auto blas = phi::funcs::GetBlas(ctx); + auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); call_gemm(blas, CblasNoTrans, @@ -297,7 +298,7 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel { const auto* l_t_data = bottom_l_trans_data + offset_l[b] * dim_t * dim_in + t * dim_in; const auto* r_data = bottom_r_data + offset_r[b] * dim_in; - auto blas_2 = phi::funcs::GetBlas(ctx); + auto blas_2 = phi::funcs::GetBlas(dev_ctx); call_gemm_with_lda(blas_2, CblasNoTrans, CblasTrans, @@ -390,7 +391,8 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel { } } - auto blas = phi::funcs::GetBlas(ctx); + auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); auto* t_data = w->data(); auto* d_w = ctx.Output(framework::GradVarName("W")); diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index bc5e5aa6ea1..3036cdc5615 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -69,7 +69,7 @@ class MatMulKernel : public framework::OpKernel { auto &dev_ctx = context.template device_context(); dev_ctx.template Alloc(out, out->numel() * sizeof(T)); - auto blas = phi::funcs::GetBlas(context); + auto blas = phi::funcs::GetBlas(dev_ctx); auto mat_dim_a = phi::funcs::CreateMatrixDescriptor( RowMatrixFromVector(x.dims()), 0, context.Attr("transpose_X")); auto mat_dim_b = phi::funcs::CreateMatrixDescriptor( @@ -237,7 +237,8 @@ class MatMulGradKernel : public framework::OpKernel { bool trans_b, phi::DenseTensor *out) const { out->mutable_data(context.GetPlace()); - auto blas = phi::funcs::GetBlas(context); + auto &dev_ctx = context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); @@ -376,7 +377,8 @@ class MatMulDoubleGradKernel : public framework::OpKernel { bool flag, phi::DenseTensor *out) const { out->mutable_data(context.GetPlace()); - auto blas = phi::funcs::GetBlas(context); + auto &dev_ctx = context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); diff --git a/paddle/fluid/operators/search_compute.h b/paddle/fluid/operators/search_compute.h index 15f87803f5a..e2156483320 100644 --- a/paddle/fluid/operators/search_compute.h +++ b/paddle/fluid/operators/search_compute.h @@ -61,7 +61,8 @@ void call_gemm(const framework::ExecutionContext& ctx, T* C) { int lda = (TransA == CblasNoTrans) ? K : M; int ldb = (TransB == CblasNoTrans) ? N : K; - auto blas = phi::funcs::GetBlas(ctx); + auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N); } diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index 6358722e943..86ef05df0e8 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -698,7 +698,8 @@ struct DeviceIndependenceTensorOperations { private: const framework::ExecutionContext& context; phi::funcs::BlasT GetBlas() { - return phi::funcs::GetBlas(context); + auto& dev_ctx = context.template device_context(); + return phi::funcs::GetBlas(dev_ctx); } platform::ForRange GetForRange(int numel) { auto& dev_ctx = context.template device_context(); diff --git a/paddle/fluid/operators/var_conv_2d_op.cc b/paddle/fluid/operators/var_conv_2d_op.cc index b470874f260..f60190f00cb 100644 --- a/paddle/fluid/operators/var_conv_2d_op.cc +++ b/paddle/fluid/operators/var_conv_2d_op.cc @@ -326,7 +326,8 @@ class CPUVarConv2dOPKernel : public framework::OpKernel { auto* w_data = w->data(); auto* col_data = col->data(); - auto blas = phi::funcs::GetBlas(ctx); + auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); for (int b = 0; b < batch; ++b) { int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel; if (top_im_size == 0) { @@ -484,7 +485,8 @@ class CPUVarConv2dOPGradKernel : public framework::OpKernel { int batch = x->lod()[0].size() - 1; const auto& top_offset = out->lod()[0]; const auto& col_offset = col->lod()[0]; - auto blas = phi::funcs::GetBlas(ctx); + auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); for (int b = 0; b < batch; ++b) { int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel; if (top_im_size == 0) { diff --git a/paddle/phi/kernels/funcs/blas/blas.h b/paddle/phi/kernels/funcs/blas/blas.h index a44c24e971a..9e970cf1b54 100644 --- a/paddle/phi/kernels/funcs/blas/blas.h +++ b/paddle/phi/kernels/funcs/blas/blas.h @@ -14,7 +14,6 @@ #pragma once -#include "paddle/fluid/framework/operator.h" #include "paddle/phi/core/dense_tensor.h" #ifdef PADDLE_WITH_MKLML @@ -579,13 +578,6 @@ class BlasT : private Blas { } }; -template -inline BlasT GetBlas( - const paddle::framework::ExecutionContext& exe_ctx) { - return BlasT( - exe_ctx.template device_context()); -} - template inline BlasT GetBlas(const DeviceContext& dev_ctx) { return BlasT(dev_ctx); diff --git a/paddle/phi/kernels/funcs/math_function.cu b/paddle/phi/kernels/funcs/math_function.cu index 9225c39d617..71c902a8dc8 100644 --- a/paddle/phi/kernels/funcs/math_function.cu +++ b/paddle/phi/kernels/funcs/math_function.cu @@ -114,21 +114,6 @@ template struct SetConstant; template struct SetConstant>; template struct SetConstant>; -template struct SetConstant; -template struct SetConstant; -template struct SetConstant; -template struct SetConstant; -template struct SetConstant; -template struct SetConstant; -template struct SetConstant; -template struct SetConstant; -template struct SetConstant; -template struct SetConstant>; -template struct SetConstant>; - #define DEFINE_GPU_TRANS(RANK) \ template struct Transpose; \ template struct Transpose; \ -- GitLab