未验证 提交 427712df 编写于 作者: iSerendipity's avatar iSerendipity 提交者: GitHub

[PHI] remove operator.h in blas.h (rebase to latest codebase) (#51472)

* Revert "Revert "【Hackathon No.67】remove operator.h in blas.h (#50989)" (#51467)"

This reverts commit b9d91531.

* remove cout

* add header

* fix missing header

* fix refer fluid error

* fix missing header

* 更新 repeat_interleave_grad_kernel_impl.h

Change to phi style datatype.

* 更新 repeat_interleave_grad_kernel_impl.h

Fix missing header

* datatype fluid -> phi

* paddle::experimental -> phi

* fix reference error

* fix reference error

* fix reference error

* fix errors

* fix missing FLAGS

* fix missing headers

* fix missing headers

* fix missing headers

* fix missing headers

* fix missing header

* fix missing header

* fix errors
上级 effe2c11
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/mkldnn/layer_norm_onednn_optimization_pass.h" #include "paddle/fluid/framework/ir/mkldnn/layer_norm_onednn_optimization_pass.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/utils/string/pretty_log.h" #include "paddle/utils/string/pretty_log.h"
......
...@@ -424,10 +424,10 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -424,10 +424,10 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
T* lstm_x_data = lstm_x->mutable_data<T>(ctx.GetPlace()); T* lstm_x_data = lstm_x->mutable_data<T>(ctx.GetPlace());
T* lstm_out_data = lstm_out->mutable_data<T>(ctx.GetPlace()); T* lstm_out_data = lstm_out->mutable_data<T>(ctx.GetPlace());
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx);
// x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
auto& dev_ctx = ctx.template device_context<phi::CPUContext>(); auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
phi::funcs::FCFunctor<DeviceContext, T> fc; phi::funcs::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, fc(dev_ctx,
total_T, total_T,
......
...@@ -86,7 +86,7 @@ class CenterLossKernel : public framework::OpKernel<T> { ...@@ -86,7 +86,7 @@ class CenterLossKernel : public framework::OpKernel<T> {
int numel = centers_diffacc.numel(); int numel = centers_diffacc.numel();
std::memset(centers_diffacc_data, 0, sizeof(T) * numel); std::memset(centers_diffacc_data, 0, sizeof(T) * numel);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx); auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
int tLabel; int tLabel;
const T *x_index; const T *x_index;
......
...@@ -37,7 +37,8 @@ class FSPOpKernel : public framework::OpKernel<T> { ...@@ -37,7 +37,8 @@ class FSPOpKernel : public framework::OpKernel<T> {
auto height = x_dims[2]; auto height = x_dims[2];
auto width = x_dims[3]; auto width = x_dims[3];
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context); auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
phi::funcs::MatDescriptor x_mat_desc; phi::funcs::MatDescriptor x_mat_desc;
x_mat_desc.height_ = x_channel; x_mat_desc.height_ = x_channel;
...@@ -81,7 +82,8 @@ class FSPGradOpKernel : public framework::OpKernel<T> { ...@@ -81,7 +82,8 @@ class FSPGradOpKernel : public framework::OpKernel<T> {
int64_t h = 0; int64_t h = 0;
int64_t w = 0; int64_t w = 0;
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context); auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
phi::funcs::SetConstant<DeviceContext, T> set_zero; phi::funcs::SetConstant<DeviceContext, T> set_zero;
if (d_x != nullptr) { if (d_x != nullptr) {
d_x->mutable_data<T>(context.GetPlace()); d_x->mutable_data<T>(context.GetPlace());
......
...@@ -411,7 +411,8 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> { ...@@ -411,7 +411,8 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> {
T* xx_data = xx->mutable_data<T>(place); T* xx_data = xx->mutable_data<T>(place);
T* h_out_data = hidden_out->mutable_data<T>(place); T* h_out_data = hidden_out->mutable_data<T>(place);
T* c_out_data = cell_out->mutable_data<T>(place); T* c_out_data = cell_out->mutable_data<T>(place);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
for (int64_t i = 0; i < ids_numel; ++i) { for (int64_t i = 0; i < ids_numel; ++i) {
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
......
...@@ -197,7 +197,9 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> { ...@@ -197,7 +197,9 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
const int m = batch_size * idx_width; const int m = batch_size * idx_width;
const int n = table_width; const int n = table_width;
const int k = table_height; const int k = table_height;
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context);
auto &dev_ctx = context.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
blas.CSRMM(&transa, blas.CSRMM(&transa,
&m, &m,
&n, &n,
...@@ -316,7 +318,8 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { ...@@ -316,7 +318,8 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
padding_idx); padding_idx);
auto *d_output_data = d_output->data<T>(); auto *d_output_data = d_output->data<T>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context); auto &dev_ctx = context.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
int width = static_cast<int>(table_dim[1]); int width = static_cast<int>(table_dim[1]);
int num_seq = batch_size * idx_width; int num_seq = batch_size * idx_width;
LOG(INFO) << "num seq = " << num_seq << " width = " << width; LOG(INFO) << "num seq = " << num_seq << " width = " << width;
......
...@@ -310,9 +310,10 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -310,9 +310,10 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const T* h0_data = h0 ? h0->data<T>() : nullptr; const T* h0_data = h0 ? h0->data<T>() : nullptr;
const T* wh_state_data = wh_data + D * D2; const T* wh_state_data = wh_data + D * D2;
T* hidden_out_data = hidden_out->mutable_data<T>(place); T* hidden_out_data = hidden_out->mutable_data<T>(place);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
phi::funcs::FCFunctor<DeviceContext, T> fc; phi::funcs::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, fc(dev_ctx,
total_T, total_T,
......
...@@ -377,9 +377,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -377,9 +377,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T* xx_data = xx->mutable_data<T>(place); T* xx_data = xx->mutable_data<T>(place);
T* h_out_data = hidden_out->mutable_data<T>(place); T* h_out_data = hidden_out->mutable_data<T>(place);
T* c_out_data = cell_out->mutable_data<T>(place); T* c_out_data = cell_out->mutable_data<T>(place);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
phi::funcs::FCFunctor<DeviceContext, T> fc; phi::funcs::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data<T>()); fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data<T>());
......
...@@ -239,9 +239,9 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> { ...@@ -239,9 +239,9 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
T* out_data = out->mutable_data<T>(ctx.GetPlace()); T* out_data = out->mutable_data<T>(ctx.GetPlace());
T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace()); T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace());
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto& dev_ctx = ctx.template device_context<phi::CPUContext>(); auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
phi::funcs::FCFunctor<DeviceContext, T> fc; phi::funcs::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, fc(dev_ctx,
total_T, total_T,
......
...@@ -89,7 +89,8 @@ class GRUUnitKernel : public framework::OpKernel<T> { ...@@ -89,7 +89,8 @@ class GRUUnitKernel : public framework::OpKernel<T> {
const T* weight_data = weight->data<T>(); const T* weight_data = weight->data<T>();
T* gate_data = gate->data<T>(); T* gate_data = gate->data<T>();
T* reset_hidden_prev_data = reset_hidden_prev->data<T>(); T* reset_hidden_prev_data = reset_hidden_prev->data<T>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context); auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
blas.GEMM(false, blas.GEMM(false,
false, false,
batch_size, batch_size,
...@@ -251,7 +252,8 @@ class GRUUnitGradKernel : public framework::OpKernel<T> { ...@@ -251,7 +252,8 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
d_h * u); d_h * u);
} }
// backward for reset_hidden_prev // backward for reset_hidden_prev
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context); auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
blas.GEMM(false, blas.GEMM(false,
true, true,
batch_size, batch_size,
......
...@@ -119,7 +119,8 @@ struct IndexSelectAdd< ...@@ -119,7 +119,8 @@ struct IndexSelectAdd<
const T* src_pointer, const T* src_pointer,
const T* p_pointer, const T* p_pointer,
T* dist_pointer) { T* dist_pointer) {
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
blas.VADD(slice_size, src_pointer, p_pointer, dist_pointer); blas.VADD(slice_size, src_pointer, p_pointer, dist_pointer);
} }
}; };
......
...@@ -114,7 +114,9 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -114,7 +114,9 @@ class LookupTableKernel : public framework::OpKernel<T> {
table + id_index * row_width, table + id_index * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
} else { } else {
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context); auto &dev_ctx =
context.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
blas.VCOPY(row_width, blas.VCOPY(row_width,
table + id_index * row_width, table + id_index * row_width,
output + i * row_width); output + i * row_width);
...@@ -145,7 +147,9 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -145,7 +147,9 @@ class LookupTableKernel : public framework::OpKernel<T> {
table + id_index * row_width, table + id_index * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
} else { } else {
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context); auto &dev_ctx =
context.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
blas.VCOPY(row_width, blas.VCOPY(row_width,
table + id_index * row_width, table + id_index * row_width,
output + i * row_width); output + i * row_width);
......
...@@ -130,7 +130,8 @@ struct LookupTableV2CPUFunctor { ...@@ -130,7 +130,8 @@ struct LookupTableV2CPUFunctor {
table + id_index * row_width, table + id_index * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
} else { } else {
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context_); auto &dev_ctx = context_.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
blas.VCOPY(row_width, blas.VCOPY(row_width,
table + id_index * row_width, table + id_index * row_width,
output + i * row_width); output + i * row_width);
......
...@@ -45,9 +45,9 @@ struct LRNFunctor<phi::CPUContext, T> { ...@@ -45,9 +45,9 @@ struct LRNFunctor<phi::CPUContext, T> {
T beta, T beta,
const DataLayout data_layout) { const DataLayout data_layout) {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx);
phi::funcs::Transpose<phi::CPUContext, T, 4> transpose;
auto& dev_ctx = ctx.template device_context<phi::CPUContext>(); auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
phi::funcs::Transpose<phi::CPUContext, T, 4> transpose;
phi::DenseTensor in_transpose, mid_transpose, out_transpose; phi::DenseTensor in_transpose, mid_transpose, out_transpose;
// if channel_last, transpose to channel_first // if channel_last, transpose to channel_first
if (data_layout == DataLayout::kNHWC) { if (data_layout == DataLayout::kNHWC) {
......
...@@ -275,7 +275,8 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel<T> { ...@@ -275,7 +275,8 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel<T> {
memset( memset(
bottom_l_trans_data, 0.0, tmp->dims()[0] * tmp->dims()[1] * sizeof(T)); bottom_l_trans_data, 0.0, tmp->dims()[0] * tmp->dims()[1] * sizeof(T));
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx); auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
call_gemm(blas, call_gemm(blas,
CblasNoTrans, CblasNoTrans,
...@@ -297,7 +298,7 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel<T> { ...@@ -297,7 +298,7 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel<T> {
const auto* l_t_data = const auto* l_t_data =
bottom_l_trans_data + offset_l[b] * dim_t * dim_in + t * dim_in; bottom_l_trans_data + offset_l[b] * dim_t * dim_in + t * dim_in;
const auto* r_data = bottom_r_data + offset_r[b] * dim_in; const auto* r_data = bottom_r_data + offset_r[b] * dim_in;
auto blas_2 = phi::funcs::GetBlas<phi::CPUContext, T>(ctx); auto blas_2 = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
call_gemm_with_lda(blas_2, call_gemm_with_lda(blas_2,
CblasNoTrans, CblasNoTrans,
CblasTrans, CblasTrans,
...@@ -390,7 +391,8 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> { ...@@ -390,7 +391,8 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> {
} }
} }
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx); auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
auto* t_data = w->data<T>(); auto* t_data = w->data<T>();
auto* d_w = ctx.Output<phi::DenseTensor>(framework::GradVarName("W")); auto* d_w = ctx.Output<phi::DenseTensor>(framework::GradVarName("W"));
......
...@@ -15,13 +15,18 @@ limitations under the License. */ ...@@ -15,13 +15,18 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/selected_rows_functor.h" #include "paddle/phi/kernels/funcs/selected_rows_functor.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
TEST(selected_rows_functor, gpu_add) { TEST(selected_rows_functor, gpu_add) {
paddle::platform::CUDAPlace gpu_place(0); phi::GPUPlace gpu_place(0);
paddle::platform::CPUPlace cpu_place; phi::CPUPlace cpu_place;
phi::GPUContext& ctx = *reinterpret_cast<phi::GPUContext*>( phi::GPUContext& ctx = *reinterpret_cast<phi::GPUContext*>(
paddle::platform::DeviceContextPool::Instance().Get(gpu_place)); phi::DeviceContextPool::Instance().Get(gpu_place));
phi::funcs::SetConstant<phi::GPUContext, float> functor; phi::funcs::SetConstant<phi::GPUContext, float> functor;
int64_t height = 10; int64_t height = 10;
int64_t row_numel = 10; int64_t row_numel = 10;
...@@ -37,12 +42,12 @@ TEST(selected_rows_functor, gpu_add) { ...@@ -37,12 +42,12 @@ TEST(selected_rows_functor, gpu_add) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_EQ(hipDeviceSynchronize(), PADDLE_ENFORCE_EQ(hipDeviceSynchronize(),
0, 0,
paddle::platform::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"The all synchronization on the cuda is error!")); "The all synchronization on the cuda is error!"));
#else #else
PADDLE_ENFORCE_EQ(cudaDeviceSynchronize(), PADDLE_ENFORCE_EQ(cudaDeviceSynchronize(),
0, 0,
paddle::platform::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"The all synchronization on the cuda is error!")); "The all synchronization on the cuda is error!"));
#endif #endif
...@@ -80,8 +85,7 @@ TEST(selected_rows_functor, gpu_add) { ...@@ -80,8 +85,7 @@ TEST(selected_rows_functor, gpu_add) {
EXPECT_EQ(out_rows[6], 9); EXPECT_EQ(out_rows[6], 9);
phi::DenseTensor out_cpu; phi::DenseTensor out_cpu;
paddle::framework::TensorCopy(*out_value, cpu_place, ctx, &out_cpu); phi::Copy(ctx, *out_value, cpu_place, true, &out_cpu);
ctx.Wait();
auto* out_cpu_data = out_cpu.data<float>(); auto* out_cpu_data = out_cpu.data<float>();
// input1 value // input1 value
...@@ -107,8 +111,7 @@ TEST(selected_rows_functor, gpu_add) { ...@@ -107,8 +111,7 @@ TEST(selected_rows_functor, gpu_add) {
add_tensor_functor(ctx, *output, *tensor1, tensor2.get()); add_tensor_functor(ctx, *output, *tensor1, tensor2.get());
phi::DenseTensor tensor2_cpu; phi::DenseTensor tensor2_cpu;
paddle::framework::TensorCopy(*tensor2, cpu_place, ctx, &tensor2_cpu); phi::Copy(ctx, *tensor2, cpu_place, true, &tensor2_cpu);
ctx.Wait();
auto* tensor2_cpu_data = tensor2_cpu.data<float>(); auto* tensor2_cpu_data = tensor2_cpu.data<float>();
// row0: 1.0 + 2.0 + 3.0 // row0: 1.0 + 2.0 + 3.0
...@@ -128,10 +131,10 @@ TEST(selected_rows_functor, gpu_add) { ...@@ -128,10 +131,10 @@ TEST(selected_rows_functor, gpu_add) {
} }
TEST(selected_rows_functor, gpu_add_to) { TEST(selected_rows_functor, gpu_add_to) {
paddle::platform::CUDAPlace gpu_place(0); phi::GPUPlace gpu_place(0);
paddle::platform::CPUPlace cpu_place; phi::CPUPlace cpu_place;
phi::GPUContext& ctx = *reinterpret_cast<phi::GPUContext*>( phi::GPUContext& ctx = *reinterpret_cast<phi::GPUContext*>(
paddle::platform::DeviceContextPool::Instance().Get(gpu_place)); phi::DeviceContextPool::Instance().Get(gpu_place));
phi::funcs::SetConstant<phi::GPUContext, float> functor; phi::funcs::SetConstant<phi::GPUContext, float> functor;
int64_t height = 10; int64_t height = 10;
int64_t row_numel = 10; int64_t row_numel = 10;
...@@ -181,8 +184,7 @@ TEST(selected_rows_functor, gpu_add_to) { ...@@ -181,8 +184,7 @@ TEST(selected_rows_functor, gpu_add_to) {
EXPECT_EQ(out_rows[6], 9); EXPECT_EQ(out_rows[6], 9);
phi::DenseTensor out_cpu; phi::DenseTensor out_cpu;
paddle::framework::TensorCopy(*out_value, cpu_place, ctx, &out_cpu); phi::Copy(ctx, *out_value, cpu_place, true, &out_cpu);
ctx.Wait();
auto* out_cpu_data = out_cpu.data<float>(); auto* out_cpu_data = out_cpu.data<float>();
// input1 value // input1 value
...@@ -206,8 +208,7 @@ TEST(selected_rows_functor, gpu_add_to) { ...@@ -206,8 +208,7 @@ TEST(selected_rows_functor, gpu_add_to) {
add_to_tensor_functor(ctx, *output, tensor1.get()); add_to_tensor_functor(ctx, *output, tensor1.get());
phi::DenseTensor tensor1_cpu; phi::DenseTensor tensor1_cpu;
paddle::framework::TensorCopy(*tensor1, cpu_place, ctx, &tensor1_cpu); phi::Copy(ctx, *tensor1, cpu_place, true, &tensor1_cpu);
ctx.Wait();
auto* tensor1_cpu_data = tensor1_cpu.data<float>(); auto* tensor1_cpu_data = tensor1_cpu.data<float>();
// row0: 1.0 + 2.0 + 3.0 // row0: 1.0 + 2.0 + 3.0
...@@ -227,10 +228,10 @@ TEST(selected_rows_functor, gpu_add_to) { ...@@ -227,10 +228,10 @@ TEST(selected_rows_functor, gpu_add_to) {
} }
TEST(selected_rows_functor, gpu_merge_add) { TEST(selected_rows_functor, gpu_merge_add) {
paddle::platform::CUDAPlace gpu_place(0); phi::GPUPlace gpu_place(0);
paddle::platform::CPUPlace cpu_place; phi::CPUPlace cpu_place;
phi::GPUContext& ctx = *reinterpret_cast<phi::GPUContext*>( phi::GPUContext& ctx = *reinterpret_cast<phi::GPUContext*>(
paddle::platform::DeviceContextPool::Instance().Get(gpu_place)); phi::DeviceContextPool::Instance().Get(gpu_place));
phi::funcs::SetConstant<phi::GPUContext, float> set_const; phi::funcs::SetConstant<phi::GPUContext, float> set_const;
int64_t height = 10; int64_t height = 10;
...@@ -264,8 +265,7 @@ TEST(selected_rows_functor, gpu_merge_add) { ...@@ -264,8 +265,7 @@ TEST(selected_rows_functor, gpu_merge_add) {
merge_add_functor(ctx, inputs, output.get()); merge_add_functor(ctx, inputs, output.get());
phi::DenseTensor output_cpu; phi::DenseTensor output_cpu;
paddle::framework::TensorCopy(output->value(), cpu_place, ctx, &output_cpu); phi::Copy(ctx, output->value(), cpu_place, true, &output_cpu);
ctx.Wait();
EXPECT_EQ(output->height(), height); EXPECT_EQ(output->height(), height);
EXPECT_EQ(output->value().dims(), phi::make_ddim({3, row_numel})); EXPECT_EQ(output->value().dims(), phi::make_ddim({3, row_numel}));
......
...@@ -69,7 +69,7 @@ class MatMulKernel : public framework::OpKernel<T> { ...@@ -69,7 +69,7 @@ class MatMulKernel : public framework::OpKernel<T> {
auto &dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T)); dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context); auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor( auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(
RowMatrixFromVector(x.dims()), 0, context.Attr<bool>("transpose_X")); RowMatrixFromVector(x.dims()), 0, context.Attr<bool>("transpose_X"));
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor( auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(
...@@ -237,7 +237,8 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -237,7 +237,8 @@ class MatMulGradKernel : public framework::OpKernel<T> {
bool trans_b, bool trans_b,
phi::DenseTensor *out) const { phi::DenseTensor *out) const {
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context); auto &dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b);
...@@ -376,7 +377,8 @@ class MatMulDoubleGradKernel : public framework::OpKernel<T> { ...@@ -376,7 +377,8 @@ class MatMulDoubleGradKernel : public framework::OpKernel<T> {
bool flag, bool flag,
phi::DenseTensor *out) const { phi::DenseTensor *out) const {
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context); auto &dev_ctx = context.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b);
......
...@@ -61,7 +61,8 @@ void call_gemm(const framework::ExecutionContext& ctx, ...@@ -61,7 +61,8 @@ void call_gemm(const framework::ExecutionContext& ctx,
T* C) { T* C) {
int lda = (TransA == CblasNoTrans) ? K : M; int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K; int ldb = (TransB == CblasNoTrans) ? N : K;
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx); auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N); blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N);
} }
......
...@@ -698,7 +698,8 @@ struct DeviceIndependenceTensorOperations { ...@@ -698,7 +698,8 @@ struct DeviceIndependenceTensorOperations {
private: private:
const framework::ExecutionContext& context; const framework::ExecutionContext& context;
phi::funcs::BlasT<DeviceContext, T> GetBlas() { phi::funcs::BlasT<DeviceContext, T> GetBlas() {
return phi::funcs::GetBlas<DeviceContext, T>(context); auto& dev_ctx = context.template device_context<DeviceContext>();
return phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
} }
platform::ForRange<DeviceContext> GetForRange(int numel) { platform::ForRange<DeviceContext> GetForRange(int numel) {
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
......
...@@ -326,7 +326,8 @@ class CPUVarConv2dOPKernel : public framework::OpKernel<T> { ...@@ -326,7 +326,8 @@ class CPUVarConv2dOPKernel : public framework::OpKernel<T> {
auto* w_data = w->data<T>(); auto* w_data = w->data<T>();
auto* col_data = col->data<T>(); auto* col_data = col->data<T>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx); auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
for (int b = 0; b < batch; ++b) { for (int b = 0; b < batch; ++b) {
int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel; int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel;
if (top_im_size == 0) { if (top_im_size == 0) {
...@@ -484,7 +485,8 @@ class CPUVarConv2dOPGradKernel : public framework::OpKernel<T> { ...@@ -484,7 +485,8 @@ class CPUVarConv2dOPGradKernel : public framework::OpKernel<T> {
int batch = x->lod()[0].size() - 1; int batch = x->lod()[0].size() - 1;
const auto& top_offset = out->lod()[0]; const auto& top_offset = out->lod()[0];
const auto& col_offset = col->lod()[0]; const auto& col_offset = col->lod()[0];
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx); auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
for (int b = 0; b < batch; ++b) { for (int b = 0; b < batch; ++b) {
int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel; int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel;
if (top_im_size == 0) { if (top_im_size == 0) {
......
...@@ -92,8 +92,7 @@ cudaDataType_t ToCudaDataType() { ...@@ -92,8 +92,7 @@ cudaDataType_t ToCudaDataType() {
} else { } else {
PADDLE_THROW(phi::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"DataType %s is unsupported for CUDA.", "DataType %s is unsupported for CUDA.",
paddle::experimental::DataTypeToString( DataTypeToString(paddle::experimental::CppTypeToDataType<T>::Type())));
paddle::experimental::CppTypeToDataType<T>::Type())));
} }
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#if defined(PADDLE_WITH_XPU) #if defined(PADDLE_WITH_XPU)
#include "paddle/phi/backends/xpu/xpu_op_list.h" #include "paddle/phi/backends/xpu/xpu_op_list.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
#endif #endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE) #if defined(PADDLE_WITH_CUSTOM_DEVICE)
...@@ -134,8 +135,7 @@ bool KernelFactory::HasKernel(const std::string& kernel_name, ...@@ -134,8 +135,7 @@ bool KernelFactory::HasKernel(const std::string& kernel_name,
} }
void KernelFactory::AddToLowPrecisionKernelList( void KernelFactory::AddToLowPrecisionKernelList(
const std::string& name, const std::string& name, const phi::DataType& kernel_key_type) {
const paddle::experimental::DataType& kernel_key_type) {
if (FLAGS_low_precision_op_list >= 1) { if (FLAGS_low_precision_op_list >= 1) {
auto op_name = phi::TransToFluidOpName(name); auto op_name = phi::TransToFluidOpName(name);
if (op_name.find("_grad") != std::string::npos) { if (op_name.find("_grad") != std::string::npos) {
...@@ -469,14 +469,13 @@ std::string KernelSelectionErrorMessage(const std::string& kernel_name, ...@@ -469,14 +469,13 @@ std::string KernelSelectionErrorMessage(const std::string& kernel_name,
if (kernel_key.dtype() == target_key.dtype()) { if (kernel_key.dtype() == target_key.dtype()) {
support_dtype = true; support_dtype = true;
} }
dtype_set.insert( dtype_set.insert(DataTypeToString(kernel_key.dtype()));
paddle::experimental::DataTypeToString(kernel_key.dtype()));
} }
backend_set.insert( backend_set.insert(
paddle::experimental::BackendToString(kernel_key.backend())); paddle::experimental::BackendToString(kernel_key.backend()));
all_kernel_key[paddle::experimental::BackendToString(kernel_key.backend()) + all_kernel_key[paddle::experimental::BackendToString(kernel_key.backend()) +
", " + phi::DataLayoutToString(kernel_key.layout())] ", " + phi::DataLayoutToString(kernel_key.layout())]
.push_back(paddle::experimental::DataTypeToString(kernel_key.dtype())); .push_back(DataTypeToString(kernel_key.dtype()));
} }
// 1. If target_key not supports target backend, output "Selected wrong // 1. If target_key not supports target backend, output "Selected wrong
// Backend ..." // Backend ..."
...@@ -490,8 +489,7 @@ std::string KernelSelectionErrorMessage(const std::string& kernel_name, ...@@ -490,8 +489,7 @@ std::string KernelSelectionErrorMessage(const std::string& kernel_name,
// DataType ..." // DataType ..."
if (!support_dtype) { if (!support_dtype) {
std::string error_message = paddle::string::join_strings(dtype_set, ", "); std::string error_message = paddle::string::join_strings(dtype_set, ", ");
return "Selected wrong DataType `" + return "Selected wrong DataType `" + DataTypeToString(target_key.dtype()) +
paddle::experimental::DataTypeToString(target_key.dtype()) +
"`. Paddle support following DataTypes: " + error_message + "."; "`. Paddle support following DataTypes: " + error_message + ".";
} }
// 3. `target_key` is still not supported, output all kernel keys of // 3. `target_key` is still not supported, output all kernel keys of
......
...@@ -32,8 +32,6 @@ ...@@ -32,8 +32,6 @@
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
namespace phi { namespace phi {
using DataType = paddle::experimental::DataType;
struct OpCount { struct OpCount {
OpCount() { OpCount() {
fp16_called_ = 0; fp16_called_ = 0;
...@@ -337,9 +335,8 @@ class KernelFactory { ...@@ -337,9 +335,8 @@ class KernelFactory {
const KernelArgsDef& GetFirstKernelArgsDef( const KernelArgsDef& GetFirstKernelArgsDef(
const std::string& kernel_name) const; const std::string& kernel_name) const;
void AddToLowPrecisionKernelList( void AddToLowPrecisionKernelList(const std::string& name,
const std::string& name, const DataType& kernel_key_type);
const paddle::experimental::DataType& kernel_key_type);
std::map<const std::string, OpCount> GetLowPrecisionKernelList(); std::map<const std::string, OpCount> GetLowPrecisionKernelList();
......
...@@ -141,9 +141,9 @@ void ArgMinMaxInferMeta(const MetaTensor& x, ...@@ -141,9 +141,9 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but " "The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"received [%s]", "received [%s]",
phi::DataTypeToString(DataType::INT32), DataTypeToString(DataType::INT32),
phi::DataTypeToString(DataType::INT64), DataTypeToString(DataType::INT64),
phi::DataTypeToString(phi::TransToPhiDataType(dtype)))); DataTypeToString(phi::TransToPhiDataType(dtype))));
if (!config.is_runtime && axis.FromTensor()) { if (!config.is_runtime && axis.FromTensor()) {
std::vector<int64_t> vec; std::vector<int64_t> vec;
......
...@@ -81,9 +81,9 @@ void IndexSampleGradKernel(const Context& ctx, ...@@ -81,9 +81,9 @@ void IndexSampleGradKernel(const Context& ctx,
errors::InvalidArgument( errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but " "Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s", "desires to be %s or %s",
phi::DataTypeToString(index_type), DataTypeToString(index_type),
phi::DataTypeToString(DataType::INT32), DataTypeToString(DataType::INT32),
phi::DataTypeToString(DataType::INT64))); DataTypeToString(DataType::INT64)));
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
IndexSampleGradInner<T, Context, int>(ctx, out_grad, index, x_grad); IndexSampleGradInner<T, Context, int>(ctx, out_grad, index, x_grad);
} else if (index_type == DataType::INT64) { } else if (index_type == DataType::INT64) {
......
...@@ -94,9 +94,9 @@ void IndexSampleKernel(const Context &ctx, ...@@ -94,9 +94,9 @@ void IndexSampleKernel(const Context &ctx,
errors::InvalidArgument( errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but " "Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s", "desires to be %s or %s",
phi::DataTypeToString(index_type), DataTypeToString(index_type),
phi::DataTypeToString(DataType::INT32), DataTypeToString(DataType::INT32),
phi::DataTypeToString(DataType::INT64))); DataTypeToString(DataType::INT64)));
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
IndexSampleInner<T, Context, int>(ctx, x, index, out); IndexSampleInner<T, Context, int>(ctx, x, index, out);
} else if (index_type == DataType::INT64) { } else if (index_type == DataType::INT64) {
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "paddle/phi/kernels/matrix_nms_kernel.h" #include "paddle/phi/kernels/matrix_nms_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/repeat_interleave_grad_kernel.h" #include "paddle/phi/kernels/repeat_interleave_grad_kernel.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/core/utils/data_type.h"
...@@ -54,9 +55,9 @@ void RepeatInterleaveWithTensorIndexGradKernel( ...@@ -54,9 +55,9 @@ void RepeatInterleaveWithTensorIndexGradKernel(
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Input(Repeats) holds the wrong type, it holds %s, but " "Input(Repeats) holds the wrong type, it holds %s, but "
"desires to be %s or %s", "desires to be %s or %s",
phi::DataTypeToString(index_type), DataTypeToString(index_type),
phi::DataTypeToString(phi::DataType::INT32), DataTypeToString(phi::DataType::INT32),
phi::DataTypeToString(phi::DataType::INT64))); DataTypeToString(phi::DataType::INT64)));
phi::DeviceContextPool::Instance().Get(repeats_tensor.place()); phi::DeviceContextPool::Instance().Get(repeats_tensor.place());
if (index_type == phi::DataType::INT32) { if (index_type == phi::DataType::INT32) {
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#pragma once #pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
...@@ -579,13 +578,6 @@ class BlasT : private Blas<DeviceContext> { ...@@ -579,13 +578,6 @@ class BlasT : private Blas<DeviceContext> {
} }
}; };
template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(
const paddle::framework::ExecutionContext& exe_ctx) {
return BlasT<DeviceContext, T>(
exe_ctx.template device_context<DeviceContext>());
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) { inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
return BlasT<DeviceContext, T>(dev_ctx); return BlasT<DeviceContext, T>(dev_ctx);
......
...@@ -14,6 +14,12 @@ limitations under the License. */ ...@@ -14,6 +14,12 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/selected_rows_functor.h" #include "paddle/phi/kernels/funcs/selected_rows_functor.h"
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/mixed_vector.h" #include "paddle/phi/core/mixed_vector.h"
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
......
...@@ -17,6 +17,8 @@ limitations under the License. */ ...@@ -17,6 +17,8 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/phi/backends/all_context.h" #include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/mixed_vector.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
......
...@@ -84,9 +84,9 @@ struct UniqueOpFunctor { ...@@ -84,9 +84,9 @@ struct UniqueOpFunctor {
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Index holds the wrong type, it holds %s, " "Index holds the wrong type, it holds %s, "
"but desires to be %s or %s", "but desires to be %s or %s",
phi::DataTypeToString(index_type), DataTypeToString(index_type),
phi::DataTypeToString(DataType::INT32), DataTypeToString(DataType::INT32),
phi::DataTypeToString(DataType::INT64))); DataTypeToString(DataType::INT64)));
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
for (auto i = 0; i < in_->numel(); ++i) { for (auto i = 0; i < in_->numel(); ++i) {
......
...@@ -75,9 +75,9 @@ void IndexSampleGradKernel(const Context& ctx, ...@@ -75,9 +75,9 @@ void IndexSampleGradKernel(const Context& ctx,
errors::InvalidArgument( errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but " "Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s", "desires to be %s or %s",
phi::DataTypeToString(index_type), DataTypeToString(index_type),
phi::DataTypeToString(DataType::INT32), DataTypeToString(DataType::INT32),
phi::DataTypeToString(DataType::INT64))); DataTypeToString(DataType::INT64)));
auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream(); auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream();
auto input_num = x.numel(); auto input_num = x.numel();
......
...@@ -64,9 +64,9 @@ void IndexSampleKernel(const Context& ctx, ...@@ -64,9 +64,9 @@ void IndexSampleKernel(const Context& ctx,
errors::InvalidArgument( errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but " "Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s", "desires to be %s or %s",
phi::DataTypeToString(index_type), DataTypeToString(index_type),
phi::DataTypeToString(DataType::INT32), DataTypeToString(DataType::INT32),
phi::DataTypeToString(DataType::INT64))); DataTypeToString(DataType::INT64)));
const T* in_data = x.data<T>(); const T* in_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out); T* out_data = ctx.template Alloc<T>(out);
auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream(); auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream();
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/lamb_functors.h" #include "paddle/phi/kernels/funcs/lamb_functors.h"
...@@ -255,8 +256,8 @@ void ComputeImpl(const Context& dev_ctx, ...@@ -255,8 +256,8 @@ void ComputeImpl(const Context& dev_ctx,
auto pn = phi::funcs::ToVector(p_norm_ptr, 1, dev_ctx.GetPlace()); auto pn = phi::funcs::ToVector(p_norm_ptr, 1, dev_ctx.GetPlace());
auto tn = auto tn =
phi::funcs::ToVector(trust_ratio_div_norm_ptr, 1, dev_ctx.GetPlace()); phi::funcs::ToVector(trust_ratio_div_norm_ptr, 1, dev_ctx.GetPlace());
auto dtype = paddle::framework::DataTypeToString( auto dtype =
paddle::framework::DataTypeTrait<T>::DataType()); DataTypeToString(paddle::experimental::CppTypeToDataType<T>::Type());
VLOG(1) << "Param " << dtype << " " << name << " pn = " << pn[0] VLOG(1) << "Param " << dtype << " " << name << " pn = " << pn[0]
<< " , tn = " << tn[0]; << " , tn = " << tn[0];
} }
......
...@@ -224,7 +224,6 @@ void MultiDotKernel(const Context& ctx, ...@@ -224,7 +224,6 @@ void MultiDotKernel(const Context& ctx,
phi::DDim tmp_dim = phi::make_ddim({Ka, Nc}); phi::DDim tmp_dim = phi::make_ddim({Ka, Nc});
tmp_out.Resize(tmp_dim); tmp_out.Resize(tmp_dim);
ctx.template Alloc<T>(&tmp_out); ctx.template Alloc<T>(&tmp_out);
std::cout << tmp_out << std::endl;
blas.MatMul( blas.MatMul(
*ins[1], mat_dim_b, *ins[2], mat_dim_c, scale, &tmp_out, T(0)); *ins[1], mat_dim_b, *ins[2], mat_dim_c, scale, &tmp_out, T(0));
auto mat_dim_tmp = phi::funcs::CreateMatrixDescriptor(tmp_dim, 0, false); auto mat_dim_tmp = phi::funcs::CreateMatrixDescriptor(tmp_dim, 0, false);
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/cpu/index_select_impl.h" #include "paddle/phi/kernels/cpu/index_select_impl.h"
#include "paddle/phi/kernels/repeat_interleave_grad_kernel.h" #include "paddle/phi/kernels/repeat_interleave_grad_kernel.h"
...@@ -91,22 +92,18 @@ void RepeatInterleaveWithTensorIndexGradKernel( ...@@ -91,22 +92,18 @@ void RepeatInterleaveWithTensorIndexGradKernel(
repeats_tensor.dims()[0], repeats_tensor.dims()[0],
x_grad->dims()[dim])); x_grad->dims()[dim]));
const auto& index_type = const auto& index_type = repeats_tensor.dtype();
paddle::framework::TransToProtoVarType(repeats_tensor.dtype());
bool index_type_match = bool index_type_match =
index_type == paddle::framework::proto::VarType::INT32 || index_type == DataType::INT32 || index_type == DataType::INT64;
index_type == paddle::framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, PADDLE_ENFORCE_EQ(index_type_match,
true, true,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Input(Repeats) holds the wrong type, it holds %s, but " "Input(Repeats) holds the wrong type, it holds %s, but "
"desires to be %s or %s", "desires to be %s or %s",
paddle::framework::DataTypeToString(index_type), DataTypeToString(index_type),
paddle::framework::DataTypeToString( DataTypeToString(DataType::INT32),
paddle::framework::proto::VarType::INT32), DataTypeToString(DataType::INT64)));
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT64)));
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
auto output_dim = out_grad.dims(); auto output_dim = out_grad.dims();
...@@ -126,7 +123,7 @@ void RepeatInterleaveWithTensorIndexGradKernel( ...@@ -126,7 +123,7 @@ void RepeatInterleaveWithTensorIndexGradKernel(
0, 0,
stream>>>(in_grad_data, numel); stream>>>(in_grad_data, numel);
if (index_type == paddle::framework::proto::VarType::INT64) { if (index_type == DataType::INT64) {
phi::funcs::RepeatsTensor2IndexTensor<Context, int64_t>( phi::funcs::RepeatsTensor2IndexTensor<Context, int64_t>(
ctx, repeats_tensor, &index); ctx, repeats_tensor, &index);
int64_t index_nums = index.numel(); int64_t index_nums = index.numel();
......
...@@ -140,9 +140,9 @@ void RepeatInterleaveWithTensorIndexKernel(const Context& ctx, ...@@ -140,9 +140,9 @@ void RepeatInterleaveWithTensorIndexKernel(const Context& ctx,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Input(RepeatsTensor) holds the wrong type, it holds %s, but " "Input(RepeatsTensor) holds the wrong type, it holds %s, but "
"desires to be %s or %s", "desires to be %s or %s",
phi::DataTypeToString(index_type), DataTypeToString(index_type),
phi::DataTypeToString(phi::DataType::INT32), DataTypeToString(phi::DataType::INT32),
phi::DataTypeToString(phi::DataType::INT64))); DataTypeToString(phi::DataType::INT64)));
if (place == cpu_place) { if (place == cpu_place) {
auto x_copy = x; auto x_copy = x;
if (index_type == phi::DataType::INT32) { if (index_type == phi::DataType::INT32) {
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include "paddle/phi/kernels/funcs/adam_functors.h" #include "paddle/phi/kernels/funcs/adam_functors.h"
#include "paddle/phi/kernels/funcs/selected_rows_functor.h" #include "paddle/phi/kernels/funcs/selected_rows_functor.h"
DECLARE_int32(inner_op_parallelism);
namespace phi { namespace phi {
namespace sr { namespace sr {
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "paddle/phi/kernels/selected_rows/hsigmoid_loss_grad_kernel.h" #include "paddle/phi/kernels/selected_rows/hsigmoid_loss_grad_kernel.h"
#include <set>
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/mixed_vector.h" #include "paddle/phi/core/mixed_vector.h"
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/kernels/funcs/lamb_functors.h" #include "paddle/phi/kernels/funcs/lamb_functors.h"
...@@ -309,8 +310,8 @@ void ComputeRowImpl(const Context& dev_ctx, ...@@ -309,8 +310,8 @@ void ComputeRowImpl(const Context& dev_ctx,
auto pn = phi::funcs::ToVector(p_norm_ptr, 1, dev_ctx.GetPlace()); auto pn = phi::funcs::ToVector(p_norm_ptr, 1, dev_ctx.GetPlace());
auto tn = auto tn =
phi::funcs::ToVector(trust_ratio_div_norm_ptr, 1, dev_ctx.GetPlace()); phi::funcs::ToVector(trust_ratio_div_norm_ptr, 1, dev_ctx.GetPlace());
auto dtype = paddle::framework::DataTypeToString( auto dtype =
paddle::framework::DataTypeTrait<T>::DataType()); DataTypeToString(paddle::experimental::CppTypeToDataType<T>::Type());
VLOG(1) << "Param " << dtype << " " << name << " pn = " << pn[0] VLOG(1) << "Param " << dtype << " " << name << " pn = " << pn[0]
<< " , tn = " << tn[0]; << " , tn = " << tn[0];
} }
......
...@@ -32,9 +32,9 @@ void IndexSampleKernel(const Context& ctx, ...@@ -32,9 +32,9 @@ void IndexSampleKernel(const Context& ctx,
errors::InvalidArgument( errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but " "Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s", "desires to be %s or %s",
phi::DataTypeToString(index_type), DataTypeToString(index_type),
phi::DataTypeToString(DataType::INT32), DataTypeToString(DataType::INT32),
phi::DataTypeToString(DataType::INT64))); DataTypeToString(DataType::INT64)));
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
......
...@@ -12,7 +12,10 @@ ...@@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <set>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/phi/backends/context_pool.h" #include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册