diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index c4617138553917f7644006fb27cfcd7ddb00111b..5ee8b2c7efbb26db49a24c1d150af8ec42a78a7a 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -424,10 +424,10 @@ class AttentionLSTMKernel : public framework::OpKernel { T* lstm_x_data = lstm_x->mutable_data(ctx.GetPlace()); T* lstm_out_data = lstm_out->mutable_data(ctx.GetPlace()); - auto blas = phi::funcs::GetBlas(ctx); - // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); + phi::funcs::FCFunctor fc; fc(dev_ctx, total_T, diff --git a/paddle/fluid/operators/center_loss_op.h b/paddle/fluid/operators/center_loss_op.h index 7632482c97b3f8d02ca16064cba6eefcad102de3..2db3202b41abbf424e20bf3a0fa842fdd2e97d22 100644 --- a/paddle/fluid/operators/center_loss_op.h +++ b/paddle/fluid/operators/center_loss_op.h @@ -86,7 +86,7 @@ class CenterLossKernel : public framework::OpKernel { int numel = centers_diffacc.numel(); std::memset(centers_diffacc_data, 0, sizeof(T) * numel); - auto blas = phi::funcs::GetBlas(ctx); + auto blas = phi::funcs::GetBlas(dev_ctx); int tLabel; const T *x_index; diff --git a/paddle/fluid/operators/fsp_op.h b/paddle/fluid/operators/fsp_op.h index c5b903559a07b60c7f03cc3b36f53c771d2cf0fb..2136bc1933692e6e72a35b1099ce81339e4872d3 100644 --- a/paddle/fluid/operators/fsp_op.h +++ b/paddle/fluid/operators/fsp_op.h @@ -37,7 +37,8 @@ class FSPOpKernel : public framework::OpKernel { auto height = x_dims[2]; auto width = x_dims[3]; - auto blas = phi::funcs::GetBlas(context); + auto& dev_ctx = context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); phi::funcs::MatDescriptor x_mat_desc; x_mat_desc.height_ = x_channel; @@ -81,7 +82,8 @@ class FSPGradOpKernel : public framework::OpKernel { int64_t h = 0; int64_t w = 0; - auto blas = phi::funcs::GetBlas(context); + auto& dev_ctx = context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); phi::funcs::SetConstant set_zero; if (d_x != nullptr) { d_x->mutable_data(context.GetPlace()); diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc index bec18220e9afdb949769e2dd7f025409b1f6857c..07207b6e028696a81bdb17b4774bb5e9be18768a 100644 --- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc @@ -411,7 +411,8 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel { T* xx_data = xx->mutable_data(place); T* h_out_data = hidden_out->mutable_data(place); T* c_out_data = cell_out->mutable_data(place); - auto blas = phi::funcs::GetBlas(ctx); + auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); for (int64_t i = 0; i < ids_numel; ++i) { PADDLE_ENFORCE_LT( diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h index 7fdfe706ef638774d8964d31a68d9d08b6e1a233..8c86f9d5f471fde6d72a1ce08fa688488ca0848e 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h @@ -197,7 +197,9 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel { const int m = batch_size * idx_width; const int n = table_width; const int k = table_height; - auto blas = phi::funcs::GetBlas(context); + + auto &dev_ctx = context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); blas.CSRMM(&transa, &m, &n, @@ -316,7 +318,8 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel { padding_idx); auto *d_output_data = d_output->data(); - auto blas = phi::funcs::GetBlas(context); + auto &dev_ctx = context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); int width = static_cast(table_dim[1]); int num_seq = batch_size * idx_width; LOG(INFO) << "num seq = " << num_seq << " width = " << width; diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index 932a777d3e1ae0d96f6ad34c6fdbd7ab193848fc..55f60c500b97dc29331cb293a31c61f557227db0 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -310,9 +310,10 @@ class FusionGRUKernel : public framework::OpKernel { const T* h0_data = h0 ? h0->data() : nullptr; const T* wh_state_data = wh_data + D * D2; T* hidden_out_data = hidden_out->mutable_data(place); - auto blas = phi::funcs::GetBlas(ctx); auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); + phi::funcs::FCFunctor fc; fc(dev_ctx, total_T, diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index a1a0d490dd756ecd96fb68f0c708326fc9e50e5f..112a65a24eadd01b32739d25d2c38fb437d0532a 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -377,9 +377,9 @@ class FuisonLSTMKernel : public framework::OpKernel { T* xx_data = xx->mutable_data(place); T* h_out_data = hidden_out->mutable_data(place); T* c_out_data = cell_out->mutable_data(place); - auto blas = phi::funcs::GetBlas(ctx); - auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); + phi::funcs::FCFunctor fc; fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data()); diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc index 86eb7053f88e89a9996af6d844641294df1b638f..084692768367827273bded882622a93636cc0f01 100644 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc @@ -239,9 +239,9 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel { T* out_data = out->mutable_data(ctx.GetPlace()); T* fc_out_data = fc_out->mutable_data(ctx.GetPlace()); - auto blas = phi::funcs::GetBlas(ctx); - auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); + phi::funcs::FCFunctor fc; fc(dev_ctx, total_T, diff --git a/paddle/fluid/operators/gru_unit_op.h b/paddle/fluid/operators/gru_unit_op.h index d46e6cf429f6fcf1f1cb4cbb070736c0ab810fca..c22a82c5ae8fd62bb080780f9810cd9ef03a5531 100644 --- a/paddle/fluid/operators/gru_unit_op.h +++ b/paddle/fluid/operators/gru_unit_op.h @@ -89,7 +89,8 @@ class GRUUnitKernel : public framework::OpKernel { const T* weight_data = weight->data(); T* gate_data = gate->data(); T* reset_hidden_prev_data = reset_hidden_prev->data(); - auto blas = phi::funcs::GetBlas(context); + auto& dev_ctx = context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); blas.GEMM(false, false, batch_size, @@ -251,7 +252,8 @@ class GRUUnitGradKernel : public framework::OpKernel { d_h * u); } // backward for reset_hidden_prev - auto blas = phi::funcs::GetBlas(context); + auto& dev_ctx = context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); blas.GEMM(false, true, batch_size, diff --git a/paddle/fluid/operators/index_select_op.h b/paddle/fluid/operators/index_select_op.h index ad1542666fd39d8c853ef8e941974386c53d4bd3..b13d83a57ee97473d853be3b4e37d60ffbe659c9 100644 --- a/paddle/fluid/operators/index_select_op.h +++ b/paddle/fluid/operators/index_select_op.h @@ -119,7 +119,8 @@ struct IndexSelectAdd< const T* src_pointer, const T* p_pointer, T* dist_pointer) { - auto blas = phi::funcs::GetBlas(ctx); + auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); blas.VADD(slice_size, src_pointer, p_pointer, dist_pointer); } }; diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 04153eecc392779bb17ff78f706de854195a44f8..b467428eeafd3e41405473d91cfc6b3c655dac6d 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -114,7 +114,9 @@ class LookupTableKernel : public framework::OpKernel { table + id_index * row_width, row_width * sizeof(T)); } else { - auto blas = phi::funcs::GetBlas(context); + auto &dev_ctx = + context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); blas.VCOPY(row_width, table + id_index * row_width, output + i * row_width); @@ -145,7 +147,9 @@ class LookupTableKernel : public framework::OpKernel { table + id_index * row_width, row_width * sizeof(T)); } else { - auto blas = phi::funcs::GetBlas(context); + auto &dev_ctx = + context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); blas.VCOPY(row_width, table + id_index * row_width, output + i * row_width); diff --git a/paddle/fluid/operators/lookup_table_v2_op.h b/paddle/fluid/operators/lookup_table_v2_op.h index f43fccb19e0b6f0d58dd75ba9eb108710164bff9..52c93f26b7e8a88603176368ea3a2b55819e3935 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.h +++ b/paddle/fluid/operators/lookup_table_v2_op.h @@ -130,7 +130,8 @@ struct LookupTableV2CPUFunctor { table + id_index * row_width, row_width * sizeof(T)); } else { - auto blas = phi::funcs::GetBlas(context_); + auto &dev_ctx = context_.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); blas.VCOPY(row_width, table + id_index * row_width, output + i * row_width); diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 5a6ed730477a1845ae7674c4de4aa0ed3082533e..96d5a115991b01f0aab214423ceeb0985b40e01f 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -45,9 +45,9 @@ struct LRNFunctor { T beta, const DataLayout data_layout) { auto place = ctx.GetPlace(); - auto blas = phi::funcs::GetBlas(ctx); - phi::funcs::Transpose transpose; auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); + phi::funcs::Transpose transpose; phi::DenseTensor in_transpose, mid_transpose, out_transpose; // if channel_last, transpose to channel_first if (data_layout == DataLayout::kNHWC) { diff --git a/paddle/fluid/operators/match_matrix_tensor_op.cc b/paddle/fluid/operators/match_matrix_tensor_op.cc index 3473a051b7324c36fec8b72b4316c835e892f57a..773d9f223f83442c0293ae89607ebf013a51780e 100644 --- a/paddle/fluid/operators/match_matrix_tensor_op.cc +++ b/paddle/fluid/operators/match_matrix_tensor_op.cc @@ -275,7 +275,8 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel { memset( bottom_l_trans_data, 0.0, tmp->dims()[0] * tmp->dims()[1] * sizeof(T)); - auto blas = phi::funcs::GetBlas(ctx); + auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); call_gemm(blas, CblasNoTrans, @@ -297,7 +298,7 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel { const auto* l_t_data = bottom_l_trans_data + offset_l[b] * dim_t * dim_in + t * dim_in; const auto* r_data = bottom_r_data + offset_r[b] * dim_in; - auto blas_2 = phi::funcs::GetBlas(ctx); + auto blas_2 = phi::funcs::GetBlas(dev_ctx); call_gemm_with_lda(blas_2, CblasNoTrans, CblasTrans, @@ -390,7 +391,8 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel { } } - auto blas = phi::funcs::GetBlas(ctx); + auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); auto* t_data = w->data(); auto* d_w = ctx.Output(framework::GradVarName("W")); diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index bc5e5aa6ea1514b3c8fb50acd542179532bc5832..3036cdc5615ade227cc0d59e26b6d5d1be2defa9 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -69,7 +69,7 @@ class MatMulKernel : public framework::OpKernel { auto &dev_ctx = context.template device_context(); dev_ctx.template Alloc(out, out->numel() * sizeof(T)); - auto blas = phi::funcs::GetBlas(context); + auto blas = phi::funcs::GetBlas(dev_ctx); auto mat_dim_a = phi::funcs::CreateMatrixDescriptor( RowMatrixFromVector(x.dims()), 0, context.Attr("transpose_X")); auto mat_dim_b = phi::funcs::CreateMatrixDescriptor( @@ -237,7 +237,8 @@ class MatMulGradKernel : public framework::OpKernel { bool trans_b, phi::DenseTensor *out) const { out->mutable_data(context.GetPlace()); - auto blas = phi::funcs::GetBlas(context); + auto &dev_ctx = context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); @@ -376,7 +377,8 @@ class MatMulDoubleGradKernel : public framework::OpKernel { bool flag, phi::DenseTensor *out) const { out->mutable_data(context.GetPlace()); - auto blas = phi::funcs::GetBlas(context); + auto &dev_ctx = context.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); diff --git a/paddle/fluid/operators/search_compute.h b/paddle/fluid/operators/search_compute.h index 15f87803f5ab8e7ccec5de9996a39b0f27384a19..e2156483320107742a3bb30eda221a24346a053c 100644 --- a/paddle/fluid/operators/search_compute.h +++ b/paddle/fluid/operators/search_compute.h @@ -61,7 +61,8 @@ void call_gemm(const framework::ExecutionContext& ctx, T* C) { int lda = (TransA == CblasNoTrans) ? K : M; int ldb = (TransB == CblasNoTrans) ? N : K; - auto blas = phi::funcs::GetBlas(ctx); + auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N); } diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index 6358722e94390521a8919c0b69b4a7d867c1684f..86ef05df0e8f3ea83209eac116a8af30602e888d 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -698,7 +698,8 @@ struct DeviceIndependenceTensorOperations { private: const framework::ExecutionContext& context; phi::funcs::BlasT GetBlas() { - return phi::funcs::GetBlas(context); + auto& dev_ctx = context.template device_context(); + return phi::funcs::GetBlas(dev_ctx); } platform::ForRange GetForRange(int numel) { auto& dev_ctx = context.template device_context(); diff --git a/paddle/fluid/operators/var_conv_2d_op.cc b/paddle/fluid/operators/var_conv_2d_op.cc index b470874f26083682d637a7cb37e171ec7575de66..f60190f00cb5561ac2a2ad2cd1b36735bb382c66 100644 --- a/paddle/fluid/operators/var_conv_2d_op.cc +++ b/paddle/fluid/operators/var_conv_2d_op.cc @@ -326,7 +326,8 @@ class CPUVarConv2dOPKernel : public framework::OpKernel { auto* w_data = w->data(); auto* col_data = col->data(); - auto blas = phi::funcs::GetBlas(ctx); + auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); for (int b = 0; b < batch; ++b) { int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel; if (top_im_size == 0) { @@ -484,7 +485,8 @@ class CPUVarConv2dOPGradKernel : public framework::OpKernel { int batch = x->lod()[0].size() - 1; const auto& top_offset = out->lod()[0]; const auto& col_offset = col->lod()[0]; - auto blas = phi::funcs::GetBlas(ctx); + auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(dev_ctx); for (int b = 0; b < batch; ++b) { int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel; if (top_im_size == 0) { diff --git a/paddle/phi/kernels/funcs/blas/blas.h b/paddle/phi/kernels/funcs/blas/blas.h index a44c24e971a47ae7c8712262a1d5b93bfc2cefe9..9e970cf1b549a902697e968180d5d0d68ec20e6f 100644 --- a/paddle/phi/kernels/funcs/blas/blas.h +++ b/paddle/phi/kernels/funcs/blas/blas.h @@ -14,7 +14,6 @@ #pragma once -#include "paddle/fluid/framework/operator.h" #include "paddle/phi/core/dense_tensor.h" #ifdef PADDLE_WITH_MKLML @@ -579,13 +578,6 @@ class BlasT : private Blas { } }; -template -inline BlasT GetBlas( - const paddle::framework::ExecutionContext& exe_ctx) { - return BlasT( - exe_ctx.template device_context()); -} - template inline BlasT GetBlas(const DeviceContext& dev_ctx) { return BlasT(dev_ctx); diff --git a/paddle/phi/kernels/funcs/math_function.cu b/paddle/phi/kernels/funcs/math_function.cu index 9225c39d617ad88fb4948104ef065897a1a9771b..71c902a8dc88ff8565f90c242f158331af27e76b 100644 --- a/paddle/phi/kernels/funcs/math_function.cu +++ b/paddle/phi/kernels/funcs/math_function.cu @@ -114,21 +114,6 @@ template struct SetConstant; template struct SetConstant>; template struct SetConstant>; -template struct SetConstant; -template struct SetConstant; -template struct SetConstant; -template struct SetConstant; -template struct SetConstant; -template struct SetConstant; -template struct SetConstant; -template struct SetConstant; -template struct SetConstant; -template struct SetConstant>; -template struct SetConstant>; - #define DEFINE_GPU_TRANS(RANK) \ template struct Transpose; \ template struct Transpose; \