未验证 提交 3f4917f6 编写于 作者: iSerendipity's avatar iSerendipity 提交者: GitHub

【Hackathon No.67】remove operator.h in blas.h (#50989)

* remove operator.h from blas.h and remove paddle::framework::ExecutionContext

* remove the deps for GetBlas(exe_ctx)

* fix error
上级 b33673be
...@@ -424,10 +424,10 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -424,10 +424,10 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
T* lstm_x_data = lstm_x->mutable_data<T>(ctx.GetPlace()); T* lstm_x_data = lstm_x->mutable_data<T>(ctx.GetPlace());
T* lstm_out_data = lstm_out->mutable_data<T>(ctx.GetPlace()); T* lstm_out_data = lstm_out->mutable_data<T>(ctx.GetPlace());
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx);
// x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
auto& dev_ctx = ctx.template device_context<phi::CPUContext>(); auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
phi::funcs::FCFunctor<DeviceContext, T> fc; phi::funcs::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, fc(dev_ctx,
total_T, total_T,
......
...@@ -86,7 +86,7 @@ class CenterLossKernel : public framework::OpKernel<T> { ...@@ -86,7 +86,7 @@ class CenterLossKernel : public framework::OpKernel<T> {
int numel = centers_diffacc.numel(); int numel = centers_diffacc.numel();
std::memset(centers_diffacc_data, 0, sizeof(T) * numel); std::memset(centers_diffacc_data, 0, sizeof(T) * numel);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx); auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
int tLabel; int tLabel;
const T *x_index; const T *x_index;
......
...@@ -37,7 +37,8 @@ class FSPOpKernel : public framework::OpKernel<T> { ...@@ -37,7 +37,8 @@ class FSPOpKernel : public framework::OpKernel<T> {
auto height = x_dims[2]; auto height = x_dims[2];
auto width = x_dims[3]; auto width = x_dims[3];
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context); auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
phi::funcs::MatDescriptor x_mat_desc; phi::funcs::MatDescriptor x_mat_desc;
x_mat_desc.height_ = x_channel; x_mat_desc.height_ = x_channel;
...@@ -81,7 +82,8 @@ class FSPGradOpKernel : public framework::OpKernel<T> { ...@@ -81,7 +82,8 @@ class FSPGradOpKernel : public framework::OpKernel<T> {
int64_t h = 0; int64_t h = 0;
int64_t w = 0; int64_t w = 0;
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context); auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
phi::funcs::SetConstant<DeviceContext, T> set_zero; phi::funcs::SetConstant<DeviceContext, T> set_zero;
if (d_x != nullptr) { if (d_x != nullptr) {
d_x->mutable_data<T>(context.GetPlace()); d_x->mutable_data<T>(context.GetPlace());
......
...@@ -411,7 +411,8 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> { ...@@ -411,7 +411,8 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> {
T* xx_data = xx->mutable_data<T>(place); T* xx_data = xx->mutable_data<T>(place);
T* h_out_data = hidden_out->mutable_data<T>(place); T* h_out_data = hidden_out->mutable_data<T>(place);
T* c_out_data = cell_out->mutable_data<T>(place); T* c_out_data = cell_out->mutable_data<T>(place);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
for (int64_t i = 0; i < ids_numel; ++i) { for (int64_t i = 0; i < ids_numel; ++i) {
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
......
...@@ -197,7 +197,9 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> { ...@@ -197,7 +197,9 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
const int m = batch_size * idx_width; const int m = batch_size * idx_width;
const int n = table_width; const int n = table_width;
const int k = table_height; const int k = table_height;
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context);
auto &dev_ctx = context.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
blas.CSRMM(&transa, blas.CSRMM(&transa,
&m, &m,
&n, &n,
...@@ -316,7 +318,8 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { ...@@ -316,7 +318,8 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
padding_idx); padding_idx);
auto *d_output_data = d_output->data<T>(); auto *d_output_data = d_output->data<T>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context); auto &dev_ctx = context.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
int width = static_cast<int>(table_dim[1]); int width = static_cast<int>(table_dim[1]);
int num_seq = batch_size * idx_width; int num_seq = batch_size * idx_width;
LOG(INFO) << "num seq = " << num_seq << " width = " << width; LOG(INFO) << "num seq = " << num_seq << " width = " << width;
......
...@@ -310,9 +310,10 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -310,9 +310,10 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const T* h0_data = h0 ? h0->data<T>() : nullptr; const T* h0_data = h0 ? h0->data<T>() : nullptr;
const T* wh_state_data = wh_data + D * D2; const T* wh_state_data = wh_data + D * D2;
T* hidden_out_data = hidden_out->mutable_data<T>(place); T* hidden_out_data = hidden_out->mutable_data<T>(place);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
phi::funcs::FCFunctor<DeviceContext, T> fc; phi::funcs::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, fc(dev_ctx,
total_T, total_T,
......
...@@ -377,9 +377,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -377,9 +377,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T* xx_data = xx->mutable_data<T>(place); T* xx_data = xx->mutable_data<T>(place);
T* h_out_data = hidden_out->mutable_data<T>(place); T* h_out_data = hidden_out->mutable_data<T>(place);
T* c_out_data = cell_out->mutable_data<T>(place); T* c_out_data = cell_out->mutable_data<T>(place);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
phi::funcs::FCFunctor<DeviceContext, T> fc; phi::funcs::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data<T>()); fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data<T>());
......
...@@ -239,9 +239,9 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> { ...@@ -239,9 +239,9 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
T* out_data = out->mutable_data<T>(ctx.GetPlace()); T* out_data = out->mutable_data<T>(ctx.GetPlace());
T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace()); T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace());
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto& dev_ctx = ctx.template device_context<phi::CPUContext>(); auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
phi::funcs::FCFunctor<DeviceContext, T> fc; phi::funcs::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, fc(dev_ctx,
total_T, total_T,
......
...@@ -89,7 +89,8 @@ class GRUUnitKernel : public framework::OpKernel<T> { ...@@ -89,7 +89,8 @@ class GRUUnitKernel : public framework::OpKernel<T> {
const T* weight_data = weight->data<T>(); const T* weight_data = weight->data<T>();
T* gate_data = gate->data<T>(); T* gate_data = gate->data<T>();
T* reset_hidden_prev_data = reset_hidden_prev->data<T>(); T* reset_hidden_prev_data = reset_hidden_prev->data<T>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context); auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
blas.GEMM(false, blas.GEMM(false,
false, false,
batch_size, batch_size,
...@@ -251,7 +252,8 @@ class GRUUnitGradKernel : public framework::OpKernel<T> { ...@@ -251,7 +252,8 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
d_h * u); d_h * u);
} }
// backward for reset_hidden_prev // backward for reset_hidden_prev
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context); auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
blas.GEMM(false, blas.GEMM(false,
true, true,
batch_size, batch_size,
......
...@@ -119,7 +119,8 @@ struct IndexSelectAdd< ...@@ -119,7 +119,8 @@ struct IndexSelectAdd<
const T* src_pointer, const T* src_pointer,
const T* p_pointer, const T* p_pointer,
T* dist_pointer) { T* dist_pointer) {
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
blas.VADD(slice_size, src_pointer, p_pointer, dist_pointer); blas.VADD(slice_size, src_pointer, p_pointer, dist_pointer);
} }
}; };
......
...@@ -114,7 +114,9 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -114,7 +114,9 @@ class LookupTableKernel : public framework::OpKernel<T> {
table + id_index * row_width, table + id_index * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
} else { } else {
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context); auto &dev_ctx =
context.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
blas.VCOPY(row_width, blas.VCOPY(row_width,
table + id_index * row_width, table + id_index * row_width,
output + i * row_width); output + i * row_width);
...@@ -145,7 +147,9 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -145,7 +147,9 @@ class LookupTableKernel : public framework::OpKernel<T> {
table + id_index * row_width, table + id_index * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
} else { } else {
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context); auto &dev_ctx =
context.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
blas.VCOPY(row_width, blas.VCOPY(row_width,
table + id_index * row_width, table + id_index * row_width,
output + i * row_width); output + i * row_width);
......
...@@ -130,7 +130,8 @@ struct LookupTableV2CPUFunctor { ...@@ -130,7 +130,8 @@ struct LookupTableV2CPUFunctor {
table + id_index * row_width, table + id_index * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
} else { } else {
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context_); auto &dev_ctx = context_.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
blas.VCOPY(row_width, blas.VCOPY(row_width,
table + id_index * row_width, table + id_index * row_width,
output + i * row_width); output + i * row_width);
......
...@@ -45,9 +45,9 @@ struct LRNFunctor<phi::CPUContext, T> { ...@@ -45,9 +45,9 @@ struct LRNFunctor<phi::CPUContext, T> {
T beta, T beta,
const DataLayout data_layout) { const DataLayout data_layout) {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx);
phi::funcs::Transpose<phi::CPUContext, T, 4> transpose;
auto& dev_ctx = ctx.template device_context<phi::CPUContext>(); auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
phi::funcs::Transpose<phi::CPUContext, T, 4> transpose;
phi::DenseTensor in_transpose, mid_transpose, out_transpose; phi::DenseTensor in_transpose, mid_transpose, out_transpose;
// if channel_last, transpose to channel_first // if channel_last, transpose to channel_first
if (data_layout == DataLayout::kNHWC) { if (data_layout == DataLayout::kNHWC) {
......
...@@ -275,7 +275,8 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel<T> { ...@@ -275,7 +275,8 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel<T> {
memset( memset(
bottom_l_trans_data, 0.0, tmp->dims()[0] * tmp->dims()[1] * sizeof(T)); bottom_l_trans_data, 0.0, tmp->dims()[0] * tmp->dims()[1] * sizeof(T));
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx); auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
call_gemm(blas, call_gemm(blas,
CblasNoTrans, CblasNoTrans,
...@@ -297,7 +298,7 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel<T> { ...@@ -297,7 +298,7 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel<T> {
const auto* l_t_data = const auto* l_t_data =
bottom_l_trans_data + offset_l[b] * dim_t * dim_in + t * dim_in; bottom_l_trans_data + offset_l[b] * dim_t * dim_in + t * dim_in;
const auto* r_data = bottom_r_data + offset_r[b] * dim_in; const auto* r_data = bottom_r_data + offset_r[b] * dim_in;
auto blas_2 = phi::funcs::GetBlas<phi::CPUContext, T>(ctx); auto blas_2 = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
call_gemm_with_lda(blas_2, call_gemm_with_lda(blas_2,
CblasNoTrans, CblasNoTrans,
CblasTrans, CblasTrans,
...@@ -390,7 +391,8 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> { ...@@ -390,7 +391,8 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> {
} }
} }
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx); auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
auto* t_data = w->data<T>(); auto* t_data = w->data<T>();
auto* d_w = ctx.Output<phi::DenseTensor>(framework::GradVarName("W")); auto* d_w = ctx.Output<phi::DenseTensor>(framework::GradVarName("W"));
......
...@@ -69,7 +69,7 @@ class MatMulKernel : public framework::OpKernel<T> { ...@@ -69,7 +69,7 @@ class MatMulKernel : public framework::OpKernel<T> {
auto &dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T)); dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context); auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor( auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(
RowMatrixFromVector(x.dims()), 0, context.Attr<bool>("transpose_X")); RowMatrixFromVector(x.dims()), 0, context.Attr<bool>("transpose_X"));
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor( auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(
...@@ -237,7 +237,8 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -237,7 +237,8 @@ class MatMulGradKernel : public framework::OpKernel<T> {
bool trans_b, bool trans_b,
phi::DenseTensor *out) const { phi::DenseTensor *out) const {
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context); auto &dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b);
...@@ -376,7 +377,8 @@ class MatMulDoubleGradKernel : public framework::OpKernel<T> { ...@@ -376,7 +377,8 @@ class MatMulDoubleGradKernel : public framework::OpKernel<T> {
bool flag, bool flag,
phi::DenseTensor *out) const { phi::DenseTensor *out) const {
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context); auto &dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b);
......
...@@ -61,7 +61,8 @@ void call_gemm(const framework::ExecutionContext& ctx, ...@@ -61,7 +61,8 @@ void call_gemm(const framework::ExecutionContext& ctx,
T* C) { T* C) {
int lda = (TransA == CblasNoTrans) ? K : M; int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K; int ldb = (TransB == CblasNoTrans) ? N : K;
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx); auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N); blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N);
} }
......
...@@ -698,7 +698,8 @@ struct DeviceIndependenceTensorOperations { ...@@ -698,7 +698,8 @@ struct DeviceIndependenceTensorOperations {
private: private:
const framework::ExecutionContext& context; const framework::ExecutionContext& context;
phi::funcs::BlasT<DeviceContext, T> GetBlas() { phi::funcs::BlasT<DeviceContext, T> GetBlas() {
return phi::funcs::GetBlas<DeviceContext, T>(context); auto& dev_ctx = context.template device_context<DeviceContext>();
return phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
} }
platform::ForRange<DeviceContext> GetForRange(int numel) { platform::ForRange<DeviceContext> GetForRange(int numel) {
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
......
...@@ -326,7 +326,8 @@ class CPUVarConv2dOPKernel : public framework::OpKernel<T> { ...@@ -326,7 +326,8 @@ class CPUVarConv2dOPKernel : public framework::OpKernel<T> {
auto* w_data = w->data<T>(); auto* w_data = w->data<T>();
auto* col_data = col->data<T>(); auto* col_data = col->data<T>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx); auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
for (int b = 0; b < batch; ++b) { for (int b = 0; b < batch; ++b) {
int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel; int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel;
if (top_im_size == 0) { if (top_im_size == 0) {
...@@ -484,7 +485,8 @@ class CPUVarConv2dOPGradKernel : public framework::OpKernel<T> { ...@@ -484,7 +485,8 @@ class CPUVarConv2dOPGradKernel : public framework::OpKernel<T> {
int batch = x->lod()[0].size() - 1; int batch = x->lod()[0].size() - 1;
const auto& top_offset = out->lod()[0]; const auto& top_offset = out->lod()[0];
const auto& col_offset = col->lod()[0]; const auto& col_offset = col->lod()[0];
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx); auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
for (int b = 0; b < batch; ++b) { for (int b = 0; b < batch; ++b) {
int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel; int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel;
if (top_im_size == 0) { if (top_im_size == 0) {
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#pragma once #pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
...@@ -579,13 +578,6 @@ class BlasT : private Blas<DeviceContext> { ...@@ -579,13 +578,6 @@ class BlasT : private Blas<DeviceContext> {
} }
}; };
template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(
const paddle::framework::ExecutionContext& exe_ctx) {
return BlasT<DeviceContext, T>(
exe_ctx.template device_context<DeviceContext>());
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) { inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
return BlasT<DeviceContext, T>(dev_ctx); return BlasT<DeviceContext, T>(dev_ctx);
......
...@@ -114,21 +114,6 @@ template struct SetConstant<phi::GPUContext, bool>; ...@@ -114,21 +114,6 @@ template struct SetConstant<phi::GPUContext, bool>;
template struct SetConstant<phi::GPUContext, phi::dtype::complex<float>>; template struct SetConstant<phi::GPUContext, phi::dtype::complex<float>>;
template struct SetConstant<phi::GPUContext, phi::dtype::complex<double>>; template struct SetConstant<phi::GPUContext, phi::dtype::complex<double>>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, float16>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
bfloat16>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, float>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, double>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, uint8_t>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int16_t>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int64_t>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, bool>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
phi::dtype::complex<float>>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
phi::dtype::complex<double>>;
#define DEFINE_GPU_TRANS(RANK) \ #define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<phi::GPUContext, bool, RANK>; \ template struct Transpose<phi::GPUContext, bool, RANK>; \
template struct Transpose<phi::GPUContext, unsigned char, RANK>; \ template struct Transpose<phi::GPUContext, unsigned char, RANK>; \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册