未验证 提交 b9d91531 编写于 作者: Y YuanRisheng 提交者: GitHub

Revert "【Hackathon No.67】remove operator.h in blas.h (#50989)" (#51467)

This reverts commit 3f4917f6.
上级 8d40e02f
......@@ -424,10 +424,10 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
T* lstm_x_data = lstm_x->mutable_data<T>(ctx.GetPlace());
T* lstm_out_data = lstm_out->mutable_data<T>(ctx.GetPlace());
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx);
// x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
phi::funcs::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx,
total_T,
......
......@@ -86,7 +86,7 @@ class CenterLossKernel : public framework::OpKernel<T> {
int numel = centers_diffacc.numel();
std::memset(centers_diffacc_data, 0, sizeof(T) * numel);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
int tLabel;
const T *x_index;
......
......@@ -37,8 +37,7 @@ class FSPOpKernel : public framework::OpKernel<T> {
auto height = x_dims[2];
auto width = x_dims[3];
auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
phi::funcs::MatDescriptor x_mat_desc;
x_mat_desc.height_ = x_channel;
......@@ -82,8 +81,7 @@ class FSPGradOpKernel : public framework::OpKernel<T> {
int64_t h = 0;
int64_t w = 0;
auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
phi::funcs::SetConstant<DeviceContext, T> set_zero;
if (d_x != nullptr) {
d_x->mutable_data<T>(context.GetPlace());
......
......@@ -411,8 +411,7 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> {
T* xx_data = xx->mutable_data<T>(place);
T* h_out_data = hidden_out->mutable_data<T>(place);
T* c_out_data = cell_out->mutable_data<T>(place);
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
for (int64_t i = 0; i < ids_numel; ++i) {
PADDLE_ENFORCE_LT(
......
......@@ -197,9 +197,7 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
const int m = batch_size * idx_width;
const int n = table_width;
const int k = table_height;
auto &dev_ctx = context.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context);
blas.CSRMM(&transa,
&m,
&n,
......@@ -318,8 +316,7 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
padding_idx);
auto *d_output_data = d_output->data<T>();
auto &dev_ctx = context.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context);
int width = static_cast<int>(table_dim[1]);
int num_seq = batch_size * idx_width;
LOG(INFO) << "num seq = " << num_seq << " width = " << width;
......
......@@ -310,10 +310,9 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const T* h0_data = h0 ? h0->data<T>() : nullptr;
const T* wh_state_data = wh_data + D * D2;
T* hidden_out_data = hidden_out->mutable_data<T>(place);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
phi::funcs::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx,
total_T,
......
......@@ -377,9 +377,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T* xx_data = xx->mutable_data<T>(place);
T* h_out_data = hidden_out->mutable_data<T>(place);
T* c_out_data = cell_out->mutable_data<T>(place);
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto& dev_ctx = ctx.template device_context<DeviceContext>();
phi::funcs::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data<T>());
......
......@@ -239,9 +239,9 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
T* out_data = out->mutable_data<T>(ctx.GetPlace());
T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
phi::funcs::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx,
total_T,
......
......@@ -89,8 +89,7 @@ class GRUUnitKernel : public framework::OpKernel<T> {
const T* weight_data = weight->data<T>();
T* gate_data = gate->data<T>();
T* reset_hidden_prev_data = reset_hidden_prev->data<T>();
auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
blas.GEMM(false,
false,
batch_size,
......@@ -252,8 +251,7 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
d_h * u);
}
// backward for reset_hidden_prev
auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
blas.GEMM(false,
true,
batch_size,
......
......@@ -119,8 +119,7 @@ struct IndexSelectAdd<
const T* src_pointer,
const T* p_pointer,
T* dist_pointer) {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
blas.VADD(slice_size, src_pointer, p_pointer, dist_pointer);
}
};
......
......@@ -114,9 +114,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
table + id_index * row_width,
row_width * sizeof(T));
} else {
auto &dev_ctx =
context.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context);
blas.VCOPY(row_width,
table + id_index * row_width,
output + i * row_width);
......@@ -147,9 +145,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
table + id_index * row_width,
row_width * sizeof(T));
} else {
auto &dev_ctx =
context.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context);
blas.VCOPY(row_width,
table + id_index * row_width,
output + i * row_width);
......
......@@ -130,8 +130,7 @@ struct LookupTableV2CPUFunctor {
table + id_index * row_width,
row_width * sizeof(T));
} else {
auto &dev_ctx = context_.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context_);
blas.VCOPY(row_width,
table + id_index * row_width,
output + i * row_width);
......
......@@ -45,9 +45,9 @@ struct LRNFunctor<phi::CPUContext, T> {
T beta,
const DataLayout data_layout) {
auto place = ctx.GetPlace();
auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx);
phi::funcs::Transpose<phi::CPUContext, T, 4> transpose;
auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
phi::DenseTensor in_transpose, mid_transpose, out_transpose;
// if channel_last, transpose to channel_first
if (data_layout == DataLayout::kNHWC) {
......
......@@ -275,8 +275,7 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel<T> {
memset(
bottom_l_trans_data, 0.0, tmp->dims()[0] * tmp->dims()[1] * sizeof(T));
auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx);
call_gemm(blas,
CblasNoTrans,
......@@ -298,7 +297,7 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel<T> {
const auto* l_t_data =
bottom_l_trans_data + offset_l[b] * dim_t * dim_in + t * dim_in;
const auto* r_data = bottom_r_data + offset_r[b] * dim_in;
auto blas_2 = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
auto blas_2 = phi::funcs::GetBlas<phi::CPUContext, T>(ctx);
call_gemm_with_lda(blas_2,
CblasNoTrans,
CblasTrans,
......@@ -391,8 +390,7 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> {
}
}
auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx);
auto* t_data = w->data<T>();
auto* d_w = ctx.Output<phi::DenseTensor>(framework::GradVarName("W"));
......
......@@ -69,7 +69,7 @@ class MatMulKernel : public framework::OpKernel<T> {
auto &dev_ctx = context.template device_context<DeviceContext>();
dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(
RowMatrixFromVector(x.dims()), 0, context.Attr<bool>("transpose_X"));
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(
......@@ -237,8 +237,7 @@ class MatMulGradKernel : public framework::OpKernel<T> {
bool trans_b,
phi::DenseTensor *out) const {
out->mutable_data<T>(context.GetPlace());
auto &dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b);
......@@ -377,8 +376,7 @@ class MatMulDoubleGradKernel : public framework::OpKernel<T> {
bool flag,
phi::DenseTensor *out) const {
out->mutable_data<T>(context.GetPlace());
auto &dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b);
......
......@@ -61,8 +61,7 @@ void call_gemm(const framework::ExecutionContext& ctx,
T* C) {
int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K;
auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx);
blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N);
}
......
......@@ -698,8 +698,7 @@ struct DeviceIndependenceTensorOperations {
private:
const framework::ExecutionContext& context;
phi::funcs::BlasT<DeviceContext, T> GetBlas() {
auto& dev_ctx = context.template device_context<DeviceContext>();
return phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
return phi::funcs::GetBlas<DeviceContext, T>(context);
}
platform::ForRange<DeviceContext> GetForRange(int numel) {
auto& dev_ctx = context.template device_context<DeviceContext>();
......
......@@ -326,8 +326,7 @@ class CPUVarConv2dOPKernel : public framework::OpKernel<T> {
auto* w_data = w->data<T>();
auto* col_data = col->data<T>();
auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx);
for (int b = 0; b < batch; ++b) {
int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel;
if (top_im_size == 0) {
......@@ -485,8 +484,7 @@ class CPUVarConv2dOPGradKernel : public framework::OpKernel<T> {
int batch = x->lod()[0].size() - 1;
const auto& top_offset = out->lod()[0];
const auto& col_offset = col->lod()[0];
auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx);
for (int b = 0; b < batch; ++b) {
int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel;
if (top_im_size == 0) {
......
......@@ -14,6 +14,7 @@
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/dense_tensor.h"
#ifdef PADDLE_WITH_MKLML
......@@ -578,6 +579,13 @@ class BlasT : private Blas<DeviceContext> {
}
};
template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(
const paddle::framework::ExecutionContext& exe_ctx) {
return BlasT<DeviceContext, T>(
exe_ctx.template device_context<DeviceContext>());
}
template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
return BlasT<DeviceContext, T>(dev_ctx);
......
......@@ -114,6 +114,21 @@ template struct SetConstant<phi::GPUContext, bool>;
template struct SetConstant<phi::GPUContext, phi::dtype::complex<float>>;
template struct SetConstant<phi::GPUContext, phi::dtype::complex<double>>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, float16>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
bfloat16>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, float>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, double>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, uint8_t>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int16_t>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int64_t>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, bool>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
phi::dtype::complex<float>>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
phi::dtype::complex<double>>;
#define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<phi::GPUContext, bool, RANK>; \
template struct Transpose<phi::GPUContext, unsigned char, RANK>; \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册