未验证 提交 b5da73c5 编写于 作者: H Huang Jiyi 提交者: GitHub

[phi decoupling] clean TensorCopy usage in phi (#50538)

* rm framework::tensor_util in phi

* clean TensoCopy

* fix bugs

* fix bugs

* fix bugs

* repalce mutable_data

* revert custom_device_test.cc
上级 3027c58a
...@@ -120,8 +120,7 @@ void BroadcastTensorsGradKernel(const Context& ctx, ...@@ -120,8 +120,7 @@ void BroadcastTensorsGradKernel(const Context& ctx,
ctx.template Alloc<T>(output_tensor); ctx.template Alloc<T>(output_tensor);
if (just_copy) { if (just_copy) {
// If this turns out to be a No-Op, simply perform a tensor copy // If this turns out to be a No-Op, simply perform a tensor copy
paddle::framework::TensorCopy( phi::Copy(ctx, *input_tensor, ctx.GetPlace(), false, output_tensor);
*input_tensor, ctx.GetPlace(), ctx, output_tensor);
} else { } else {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
reduce_dims_vec.size(), reduce_dims_vec.size(),
......
...@@ -429,7 +429,7 @@ static void Interpolate1DCPUBwd( ...@@ -429,7 +429,7 @@ static void Interpolate1DCPUBwd(
zero(dev_ctx, input_grad, static_cast<T>(0.0)); zero(dev_ctx, input_grad, static_cast<T>(0.0));
if (in_w == out_w) { if (in_w == out_w) {
paddle::framework::TensorCopy(output_grad, dev_ctx.GetPlace(), input_grad); phi::Copy(dev_ctx, output_grad, dev_ctx.GetPlace(), false, input_grad);
return; return;
} }
...@@ -552,7 +552,7 @@ static void Interpolate2DCPUBwd( ...@@ -552,7 +552,7 @@ static void Interpolate2DCPUBwd(
zero(dev_ctx, input_grad, static_cast<T>(0.0)); zero(dev_ctx, input_grad, static_cast<T>(0.0));
if (in_h == out_h && in_w == out_w) { if (in_h == out_h && in_w == out_w) {
paddle::framework::TensorCopy(output_grad, dev_ctx.GetPlace(), input_grad); phi::Copy(dev_ctx, output_grad, dev_ctx.GetPlace(), false, input_grad);
return; return;
} }
...@@ -732,7 +732,7 @@ static void Interpolate3DCPUBwd( ...@@ -732,7 +732,7 @@ static void Interpolate3DCPUBwd(
zero(dev_ctx, input_grad, static_cast<T>(0.0)); zero(dev_ctx, input_grad, static_cast<T>(0.0));
if (in_d == out_d && in_h == out_h && in_w == out_w) { if (in_d == out_d && in_h == out_h && in_w == out_w) {
paddle::framework::TensorCopy(output_grad, dev_ctx.GetPlace(), input_grad); phi::Copy(dev_ctx, output_grad, dev_ctx.GetPlace(), false, input_grad);
return; return;
} }
......
...@@ -57,7 +57,7 @@ void MatrixSolveFunctor<Context, T>::operator()(const Context& context, ...@@ -57,7 +57,7 @@ void MatrixSolveFunctor<Context, T>::operator()(const Context& context,
tmp_a.Resize(a.dims()); tmp_a.Resize(a.dims());
context.template Alloc<T>(&tmp_a); context.template Alloc<T>(&tmp_a);
paddle::framework::TensorCopy(a, context.GetPlace(), &tmp_a); phi::Copy(context, a, context.GetPlace(), false, &tmp_a);
// copy input B to a temporary tensor tmp_b, and transpose tmp_b, // copy input B to a temporary tensor tmp_b, and transpose tmp_b,
// because cuBlas assumes column-major while Paddle uses row-majar. // because cuBlas assumes column-major while Paddle uses row-majar.
......
...@@ -117,7 +117,7 @@ void HandleLargeDimGrad(const Context& dev_ctx, ...@@ -117,7 +117,7 @@ void HandleLargeDimGrad(const Context& dev_ctx,
std::vector<int> origin_axis(x_dim.size()); std::vector<int> origin_axis(x_dim.size());
GetOriginDimFromShuffled(x_dim, dims, &origin_axis); GetOriginDimFromShuffled(x_dim, dims, &origin_axis);
DenseTensor dx_tmp; DenseTensor dx_tmp;
paddle::framework::TensorCopy(*dx, dev_ctx.GetPlace(), &dx_tmp); phi::Copy(dev_ctx, *dx, dev_ctx.GetPlace(), false, &dx_tmp);
dx_tmp.Resize(shuffled_dim); dx_tmp.Resize(shuffled_dim);
dx->Resize(x_dim); dx->Resize(x_dim);
phi::funcs::TransposeNormal<Context, T> trans; phi::funcs::TransposeNormal<Context, T> trans;
......
...@@ -417,8 +417,7 @@ class SegmentPoolGradFunctor<phi::GPUContext, T, IndexT> { ...@@ -417,8 +417,7 @@ class SegmentPoolGradFunctor<phi::GPUContext, T, IndexT> {
DenseTensor mean_grad; DenseTensor mean_grad;
mean_grad.Resize(input.dims()); mean_grad.Resize(input.dims());
dev_ctx.template Alloc<T>(&mean_grad); dev_ctx.template Alloc<T>(&mean_grad);
paddle::framework::TensorCopy( phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, &mean_grad);
out_grad, dev_ctx.GetPlace(), dev_ctx, &mean_grad);
int len = output.dims()[0]; int len = output.dims()[0];
int dim = output.numel() / len; int dim = output.numel() / len;
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, len); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, len);
......
...@@ -78,10 +78,8 @@ void BincountCUDAInner(const Context& dev_ctx, ...@@ -78,10 +78,8 @@ void BincountCUDAInner(const Context& dev_ctx,
input_min_scala.device(*place) = input_x.minimum(); input_min_scala.device(*place) = input_x.minimum();
DenseTensor input_min_cpu, input_max_cpu; DenseTensor input_min_cpu, input_max_cpu;
paddle::framework::TensorCopySync( phi::Copy(dev_ctx, input_min_t, phi::CPUPlace(), true, &input_min_cpu);
input_max_t, phi::CPUPlace(), &input_max_cpu); phi::Copy(dev_ctx, input_max_t, phi::CPUPlace(), true, &input_max_cpu);
paddle::framework::TensorCopySync(
input_min_t, phi::CPUPlace(), &input_min_cpu);
InputT input_min = input_min_cpu.data<InputT>()[0]; InputT input_min = input_min_cpu.data<InputT>()[0];
......
...@@ -110,10 +110,8 @@ void HistogramKernel(const Context& dev_ctx, ...@@ -110,10 +110,8 @@ void HistogramKernel(const Context& dev_ctx,
input_max_scala.device(*place) = input_x.maximum(); input_max_scala.device(*place) = input_x.maximum();
DenseTensor input_min_cpu, input_max_cpu; DenseTensor input_min_cpu, input_max_cpu;
paddle::framework::TensorCopySync( phi::Copy(dev_ctx, input_min_t, phi::CPUPlace(), true, &input_min_cpu);
input_min_t, phi::CPUPlace(), &input_min_cpu); phi::Copy(dev_ctx, input_max_t, phi::CPUPlace(), true, &input_max_cpu);
paddle::framework::TensorCopySync(
input_max_t, phi::CPUPlace(), &input_max_cpu);
output_min = input_min_cpu.data<T>()[0]; output_min = input_min_cpu.data<T>()[0];
output_max = input_max_cpu.data<T>()[0]; output_max = input_max_cpu.data<T>()[0];
......
...@@ -790,8 +790,8 @@ static void Interpolate1DCUDABwd( ...@@ -790,8 +790,8 @@ static void Interpolate1DCUDABwd(
if (out_size) { if (out_size) {
DenseTensor sizes; DenseTensor sizes;
paddle::framework::TensorCopySync( phi::Copy(dev_ctx, *out_size, phi::CPUPlace(), true, &sizes);
*out_size, paddle::platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>(); auto size_data = sizes.data<int>();
out_w = size_data[0]; out_w = size_data[0];
} }
...@@ -815,7 +815,7 @@ static void Interpolate1DCUDABwd( ...@@ -815,7 +815,7 @@ static void Interpolate1DCUDABwd(
zero(dev_ctx, input_grad, static_cast<T>(0.0)); zero(dev_ctx, input_grad, static_cast<T>(0.0));
if (in_w == out_w) { if (in_w == out_w) {
paddle::framework::TensorCopy(output_grad, dev_ctx.GetPlace(), input_grad); phi::Copy(dev_ctx, output_grad, dev_ctx.GetPlace(), false, input_grad);
return; return;
} }
...@@ -928,8 +928,7 @@ static void Interpolate2DCUDABwd( ...@@ -928,8 +928,7 @@ static void Interpolate2DCUDABwd(
if (out_size) { if (out_size) {
DenseTensor sizes; DenseTensor sizes;
paddle::framework::TensorCopySync( phi::Copy(dev_ctx, *out_size, phi::CPUPlace(), true, &sizes);
*out_size, paddle::platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>(); auto size_data = sizes.data<int>();
out_h = size_data[0]; out_h = size_data[0];
out_w = size_data[1]; out_w = size_data[1];
...@@ -954,7 +953,7 @@ static void Interpolate2DCUDABwd( ...@@ -954,7 +953,7 @@ static void Interpolate2DCUDABwd(
zero(dev_ctx, input_grad, static_cast<T>(0.0)); zero(dev_ctx, input_grad, static_cast<T>(0.0));
if (in_h == out_h && in_w == out_w) { if (in_h == out_h && in_w == out_w) {
paddle::framework::TensorCopy(output_grad, dev_ctx.GetPlace(), input_grad); phi::Copy(dev_ctx, output_grad, dev_ctx.GetPlace(), false, input_grad);
return; return;
} }
...@@ -1210,8 +1209,7 @@ static void Interpolate3DCUDABwd( ...@@ -1210,8 +1209,7 @@ static void Interpolate3DCUDABwd(
if (out_size) { if (out_size) {
DenseTensor sizes; DenseTensor sizes;
paddle::framework::TensorCopySync( phi::Copy(dev_ctx, *out_size, phi::CPUPlace(), true, &sizes);
*out_size, paddle::platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>(); auto size_data = sizes.data<int>();
out_d = size_data[0]; out_d = size_data[0];
out_h = size_data[1]; out_h = size_data[1];
...@@ -1238,7 +1236,7 @@ static void Interpolate3DCUDABwd( ...@@ -1238,7 +1236,7 @@ static void Interpolate3DCUDABwd(
zero(dev_ctx, input_grad, static_cast<T>(0.0)); zero(dev_ctx, input_grad, static_cast<T>(0.0));
if (in_d == out_d && in_h == out_h && in_w == out_w) { if (in_d == out_d && in_h == out_h && in_w == out_w) {
paddle::framework::TensorCopy(output_grad, dev_ctx.GetPlace(), input_grad); phi::Copy(dev_ctx, output_grad, dev_ctx.GetPlace(), false, input_grad);
return; return;
} }
......
...@@ -351,7 +351,7 @@ void MatrixRankTolKernel(const Context& dev_ctx, ...@@ -351,7 +351,7 @@ void MatrixRankTolKernel(const Context& dev_ctx,
// Must Copy X once, because the gesvdj will destory the content when exit. // Must Copy X once, because the gesvdj will destory the content when exit.
DenseTensor x_tmp; DenseTensor x_tmp;
paddle::framework::TensorCopy(x, dev_ctx.GetPlace(), &x_tmp); phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, &x_tmp);
auto info = paddle::memory::Alloc( auto info = paddle::memory::Alloc(
dev_ctx.GetPlace(), dev_ctx.GetPlace(),
sizeof(int) * batches, sizeof(int) * batches,
......
...@@ -39,7 +39,7 @@ void MatrixPowerGradFunction(const DenseTensor* X, ...@@ -39,7 +39,7 @@ void MatrixPowerGradFunction(const DenseTensor* X,
return; return;
} else if (n == 1) { } else if (n == 1) {
// \nabla X = \nabla Out // \nabla X = \nabla Out
paddle::framework::TensorCopy(*dOut, ctx.GetPlace(), ctx, dX); phi::Copy(ctx, *dOut, ctx.GetPlace(), false, dX);
return; return;
} }
...@@ -74,7 +74,7 @@ void MatrixPowerGradFunction(const DenseTensor* X, ...@@ -74,7 +74,7 @@ void MatrixPowerGradFunction(const DenseTensor* X,
int new_n = n; int new_n = n;
if (n > 0) { if (n > 0) {
// newX = X // newX = X
paddle::framework::TensorCopy(*X, ctx.GetPlace(), ctx, &new_x); phi::Copy(ctx, *X, ctx.GetPlace(), false, &new_x);
} else { } else {
// newX = X^{-1}, n = -n // newX = X^{-1}, n = -n
phi::funcs::MatrixInverseFunctor<Context, T> mat_inv; phi::funcs::MatrixInverseFunctor<Context, T> mat_inv;
...@@ -158,7 +158,7 @@ void MatrixPowerGradFunction(const DenseTensor* X, ...@@ -158,7 +158,7 @@ void MatrixPowerGradFunction(const DenseTensor* X,
if (n > 0) { if (n > 0) {
// \nabla X = \nabla newX // \nabla X = \nabla newX
paddle::framework::TensorCopy(dx_new, ctx.GetPlace(), ctx, dX); phi::Copy(ctx, dx_new, ctx.GetPlace(), false, dX);
} else { } else {
// \nabla X = newX^{T} * \nabla newX * newX^{T} // \nabla X = newX^{T} * \nabla newX * newX^{T}
DenseTensor temp_dx; DenseTensor temp_dx;
......
...@@ -61,7 +61,7 @@ void MatrixPowerFunction(const DenseTensor* X, ...@@ -61,7 +61,7 @@ void MatrixPowerFunction(const DenseTensor* X,
int new_n = n; int new_n = n;
if (n > 0) { if (n > 0) {
// newX = X // newX = X
paddle::framework::TensorCopy(*X, ctx.GetPlace(), ctx, &new_x); phi::Copy(ctx, *X, ctx.GetPlace(), false, &new_x);
} else { } else {
// newX = X^{-1}, n = -n // newX = X^{-1}, n = -n
phi::funcs::MatrixInverseFunctor<Context, T> mat_inv; phi::funcs::MatrixInverseFunctor<Context, T> mat_inv;
...@@ -70,7 +70,7 @@ void MatrixPowerFunction(const DenseTensor* X, ...@@ -70,7 +70,7 @@ void MatrixPowerFunction(const DenseTensor* X,
} }
if (new_n == 1) { if (new_n == 1) {
paddle::framework::TensorCopy(new_x, ctx.GetPlace(), ctx, Out); phi::Copy(ctx, new_x, ctx.GetPlace(), false, Out);
return; return;
} }
...@@ -153,11 +153,11 @@ void MatrixPowerFunction(const DenseTensor* X, ...@@ -153,11 +153,11 @@ void MatrixPowerFunction(const DenseTensor* X,
static_cast<T>(1), static_cast<T>(1),
&temp_z, &temp_z,
static_cast<T>(0)); static_cast<T>(0));
paddle::framework::TensorCopy(temp_z, ctx.GetPlace(), ctx, &z); phi::Copy(ctx, temp_z, ctx.GetPlace(), false, &z);
} else { } else {
z.Resize(X->dims()); z.Resize(X->dims());
ctx.template Alloc<T>(&z); ctx.template Alloc<T>(&z);
paddle::framework::TensorCopy(new_x, ctx.GetPlace(), ctx, &z); phi::Copy(ctx, new_x, ctx.GetPlace(), false, &z);
} }
if (bit == 1) { if (bit == 1) {
if (out_inited == true) { if (out_inited == true) {
...@@ -168,9 +168,9 @@ void MatrixPowerFunction(const DenseTensor* X, ...@@ -168,9 +168,9 @@ void MatrixPowerFunction(const DenseTensor* X,
static_cast<T>(1), static_cast<T>(1),
&temp_out, &temp_out,
static_cast<T>(0)); static_cast<T>(0));
paddle::framework::TensorCopy(temp_out, ctx.GetPlace(), ctx, Out); phi::Copy(ctx, temp_out, ctx.GetPlace(), false, Out);
} else { } else {
paddle::framework::TensorCopy(z, ctx.GetPlace(), ctx, Out); phi::Copy(ctx, z, ctx.GetPlace(), false, Out);
out_inited = true; out_inited = true;
} }
} }
......
...@@ -59,14 +59,14 @@ TEST(math_function, notrans_mul_trans_fp32) { ...@@ -59,14 +59,14 @@ TEST(math_function, notrans_mul_trans_fp32) {
float arr[6] = {0, 1, 2, 3, 4, 5}; float arr[6] = {0, 1, 2, 3, 4, 5};
memcpy(input1_ptr, arr, 6 * sizeof(float)); memcpy(input1_ptr, arr, 6 * sizeof(float));
paddle::framework::TensorCopySync(input1, gpu_place, &input1_gpu); phi::Copy(*context, input1, gpu_place, true, &input1_gpu);
paddle::framework::TensorCopySync(input1, gpu_place, &input2_gpu); phi::Copy(*context, input1, gpu_place, true, &input2_gpu);
out_gpu.mutable_data<float>({2, 2}, gpu_place); out_gpu.mutable_data<float>({2, 2}, gpu_place);
GetBlas<float>(*context).MatMul( GetBlas<float>(*context).MatMul(
input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0); input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0);
paddle::framework::TensorCopySync(out_gpu, cpu_place, &out); phi::Copy(*context, out_gpu, cpu_place, true, &out);
float* out_ptr = out.data<float>(); float* out_ptr = out.data<float>();
context->Wait(); context->Wait();
...@@ -97,8 +97,8 @@ TEST(math_function, notrans_mul_trans_fp16) { ...@@ -97,8 +97,8 @@ TEST(math_function, notrans_mul_trans_fp16) {
input1.mutable_data<phi::dtype::float16>({2, 3}, cpu_place); input1.mutable_data<phi::dtype::float16>({2, 3}, cpu_place);
fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5}); fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5});
paddle::framework::TensorCopySync(input1, gpu_place, &input1_gpu); phi::Copy(*context, input1, gpu_place, true, &input1_gpu);
paddle::framework::TensorCopySync(input1, gpu_place, &input2_gpu); phi::Copy(*context, input1, gpu_place, true, &input2_gpu);
out_gpu.mutable_data<phi::dtype::float16>({2, 2}, gpu_place); out_gpu.mutable_data<phi::dtype::float16>({2, 2}, gpu_place);
...@@ -110,7 +110,7 @@ TEST(math_function, notrans_mul_trans_fp16) { ...@@ -110,7 +110,7 @@ TEST(math_function, notrans_mul_trans_fp16) {
&out_gpu, &out_gpu,
phi::dtype::float16(0)); phi::dtype::float16(0));
paddle::framework::TensorCopySync(out_gpu, cpu_place, &out); phi::Copy(*context, out_gpu, cpu_place, true, &out);
phi::dtype::float16* out_ptr = out.data<phi::dtype::float16>(); phi::dtype::float16* out_ptr = out.data<phi::dtype::float16>();
context->Wait(); context->Wait();
...@@ -136,15 +136,15 @@ TEST(math_function, trans_mul_notrans_fp32) { ...@@ -136,15 +136,15 @@ TEST(math_function, trans_mul_notrans_fp32) {
float arr[6] = {0, 1, 2, 3, 4, 5}; float arr[6] = {0, 1, 2, 3, 4, 5};
memcpy(input1_ptr, arr, 6 * sizeof(float)); memcpy(input1_ptr, arr, 6 * sizeof(float));
paddle::framework::TensorCopySync(input1, gpu_place, &input1_gpu); phi::Copy(*context, input1, gpu_place, true, &input1_gpu);
paddle::framework::TensorCopySync(input1, gpu_place, &input2_gpu); phi::Copy(*context, input1, gpu_place, true, &input2_gpu);
out_gpu.mutable_data<float>({3, 3}, gpu_place); out_gpu.mutable_data<float>({3, 3}, gpu_place);
GetBlas<float>(*context).MatMul( GetBlas<float>(*context).MatMul(
input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0); input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0);
paddle::framework::TensorCopySync(out_gpu, cpu_place, &out); phi::Copy(*context, out_gpu, cpu_place, true, &out);
float* out_ptr = out.data<float>(); float* out_ptr = out.data<float>();
context->Wait(); context->Wait();
...@@ -180,8 +180,8 @@ TEST(math_function, trans_mul_notrans_fp16) { ...@@ -180,8 +180,8 @@ TEST(math_function, trans_mul_notrans_fp16) {
input1.mutable_data<phi::dtype::float16>({2, 3}, cpu_place); input1.mutable_data<phi::dtype::float16>({2, 3}, cpu_place);
fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5}); fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5});
paddle::framework::TensorCopySync(input1, gpu_place, &input1_gpu); phi::Copy(*context, input1, gpu_place, true, &input1_gpu);
paddle::framework::TensorCopySync(input1, gpu_place, &input2_gpu); phi::Copy(*context, input1, gpu_place, true, &input2_gpu);
out_gpu.mutable_data<phi::dtype::float16>({3, 3}, gpu_place); out_gpu.mutable_data<phi::dtype::float16>({3, 3}, gpu_place);
...@@ -193,7 +193,7 @@ TEST(math_function, trans_mul_notrans_fp16) { ...@@ -193,7 +193,7 @@ TEST(math_function, trans_mul_notrans_fp16) {
&out_gpu, &out_gpu,
phi::dtype::float16(0)); phi::dtype::float16(0));
paddle::framework::TensorCopySync(out_gpu, cpu_place, &out); phi::Copy(*context, out_gpu, cpu_place, true, &out);
phi::dtype::float16* out_ptr = out.data<phi::dtype::float16>(); phi::dtype::float16* out_ptr = out.data<phi::dtype::float16>();
context->Wait(); context->Wait();
...@@ -234,9 +234,9 @@ TEST(math_function, gemm_notrans_cublas_fp32) { ...@@ -234,9 +234,9 @@ TEST(math_function, gemm_notrans_cublas_fp32) {
float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7};
memcpy(input3_ptr, arr3, 8 * sizeof(float)); memcpy(input3_ptr, arr3, 8 * sizeof(float));
paddle::framework::TensorCopySync(input1, gpu_place, &input1_gpu); phi::Copy(*context, input1, gpu_place, true, &input1_gpu);
paddle::framework::TensorCopySync(input2, gpu_place, &input2_gpu); phi::Copy(*context, input2, gpu_place, true, &input2_gpu);
paddle::framework::TensorCopySync(input3, gpu_place, &input3_gpu); phi::Copy(*context, input3, gpu_place, true, &input3_gpu);
float* a = input1_gpu.data<float>(); float* a = input1_gpu.data<float>();
float* b = input2_gpu.data<float>(); float* b = input2_gpu.data<float>();
float* c = input3_gpu.mutable_data<float>(gpu_place); float* c = input3_gpu.mutable_data<float>(gpu_place);
...@@ -244,7 +244,7 @@ TEST(math_function, gemm_notrans_cublas_fp32) { ...@@ -244,7 +244,7 @@ TEST(math_function, gemm_notrans_cublas_fp32) {
GetBlas<float>(*context).GEMM( GetBlas<float>(*context).GEMM(
false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4); false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4);
paddle::framework::TensorCopySync(input3_gpu, cpu_place, &input3); phi::Copy(*context, input3_gpu, cpu_place, true, &input3);
// numpy code: // numpy code:
// a = np.arange(6).reshape(2, 3) // a = np.arange(6).reshape(2, 3)
...@@ -295,9 +295,9 @@ TEST(math_function, gemm_notrans_cublas_fp16) { ...@@ -295,9 +295,9 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
input3.mutable_data<phi::dtype::float16>({2, 4}, cpu_place); input3.mutable_data<phi::dtype::float16>({2, 4}, cpu_place);
fill_fp16_data(input3_ptr, input3.numel(), {0, 1, 2, 3, 4, 5, 6, 7}); fill_fp16_data(input3_ptr, input3.numel(), {0, 1, 2, 3, 4, 5, 6, 7});
paddle::framework::TensorCopySync(input1, gpu_place, &input1_gpu); phi::Copy(*context, input1, gpu_place, true, &input1_gpu);
paddle::framework::TensorCopySync(input2, gpu_place, &input2_gpu); phi::Copy(*context, input2, gpu_place, true, &input2_gpu);
paddle::framework::TensorCopySync(input3, gpu_place, &input3_gpu); phi::Copy(*context, input3, gpu_place, true, &input3_gpu);
phi::dtype::float16* a = input1_gpu.data<phi::dtype::float16>(); phi::dtype::float16* a = input1_gpu.data<phi::dtype::float16>();
phi::dtype::float16* b = input2_gpu.data<phi::dtype::float16>(); phi::dtype::float16* b = input2_gpu.data<phi::dtype::float16>();
phi::dtype::float16* c = phi::dtype::float16* c =
...@@ -318,7 +318,7 @@ TEST(math_function, gemm_notrans_cublas_fp16) { ...@@ -318,7 +318,7 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
c + 1, c + 1,
4); 4);
paddle::framework::TensorCopySync(input3_gpu, cpu_place, &input3); phi::Copy(*context, input3_gpu, cpu_place, true, &input3);
// numpy code: // numpy code:
// a = np.arange(6).reshape(2, 3) // a = np.arange(6).reshape(2, 3)
...@@ -363,9 +363,9 @@ TEST(math_function, gemm_trans_cublas_fp32) { ...@@ -363,9 +363,9 @@ TEST(math_function, gemm_trans_cublas_fp32) {
float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7};
memcpy(input3_ptr, arr3, 8 * sizeof(float)); memcpy(input3_ptr, arr3, 8 * sizeof(float));
paddle::framework::TensorCopySync(input1, gpu_place, &input1_gpu); phi::Copy(*context, input1, gpu_place, true, &input1_gpu);
paddle::framework::TensorCopySync(input2, gpu_place, &input2_gpu); phi::Copy(*context, input2, gpu_place, true, &input2_gpu);
paddle::framework::TensorCopySync(input3, gpu_place, &input3_gpu); phi::Copy(*context, input3, gpu_place, true, &input3_gpu);
float* a = input1_gpu.data<float>(); float* a = input1_gpu.data<float>();
float* b = input2_gpu.data<float>(); float* b = input2_gpu.data<float>();
float* c = input3_gpu.mutable_data<float>(gpu_place); float* c = input3_gpu.mutable_data<float>(gpu_place);
...@@ -373,7 +373,7 @@ TEST(math_function, gemm_trans_cublas_fp32) { ...@@ -373,7 +373,7 @@ TEST(math_function, gemm_trans_cublas_fp32) {
GetBlas<float>(*context).GEMM( GetBlas<float>(*context).GEMM(
false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4); false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4);
paddle::framework::TensorCopySync(input3_gpu, cpu_place, &input3); phi::Copy(*context, input3_gpu, cpu_place, true, &input3);
context->Wait(); context->Wait();
EXPECT_EQ(input3_ptr[0], 0); EXPECT_EQ(input3_ptr[0], 0);
...@@ -418,9 +418,9 @@ TEST(math_function, gemm_trans_cublas_fp16) { ...@@ -418,9 +418,9 @@ TEST(math_function, gemm_trans_cublas_fp16) {
input3.mutable_data<phi::dtype::float16>({2, 4}, cpu_place); input3.mutable_data<phi::dtype::float16>({2, 4}, cpu_place);
fill_fp16_data(input3_ptr, input3.numel(), {0, 1, 2, 3, 4, 5, 6, 7}); fill_fp16_data(input3_ptr, input3.numel(), {0, 1, 2, 3, 4, 5, 6, 7});
paddle::framework::TensorCopySync(input1, gpu_place, &input1_gpu); phi::Copy(*context, input1, gpu_place, true, &input1_gpu);
paddle::framework::TensorCopySync(input2, gpu_place, &input2_gpu); phi::Copy(*context, input2, gpu_place, true, &input2_gpu);
paddle::framework::TensorCopySync(input3, gpu_place, &input3_gpu); phi::Copy(*context, input3, gpu_place, true, &input3_gpu);
phi::dtype::float16* a = input1_gpu.data<phi::dtype::float16>(); phi::dtype::float16* a = input1_gpu.data<phi::dtype::float16>();
phi::dtype::float16* b = input2_gpu.data<phi::dtype::float16>(); phi::dtype::float16* b = input2_gpu.data<phi::dtype::float16>();
phi::dtype::float16* c = phi::dtype::float16* c =
...@@ -441,7 +441,7 @@ TEST(math_function, gemm_trans_cublas_fp16) { ...@@ -441,7 +441,7 @@ TEST(math_function, gemm_trans_cublas_fp16) {
c + 1, c + 1,
4); 4);
paddle::framework::TensorCopySync(input3_gpu, cpu_place, &input3); phi::Copy(*context, input3_gpu, cpu_place, true, &input3);
context->Wait(); context->Wait();
EXPECT_EQ(static_cast<float>(input3_ptr[0]), 0); EXPECT_EQ(static_cast<float>(input3_ptr[0]), 0);
...@@ -483,8 +483,8 @@ void GemvTest(int m, int n, bool trans) { ...@@ -483,8 +483,8 @@ void GemvTest(int m, int n, bool trans) {
data_b[i] = static_cast<T>(i); data_b[i] = static_cast<T>(i);
} }
paddle::framework::TensorCopySync(mat_a, gpu_place, &g_mat_a); phi::Copy(*context, mat_a, gpu_place, true, &g_mat_a);
paddle::framework::TensorCopySync(vec_b, gpu_place, &g_vec_b); phi::Copy(*context, vec_b, gpu_place, true, &g_vec_b);
GetBlas<T>(*context).GEMV(trans, GetBlas<T>(*context).GEMV(trans,
static_cast<int>(m), static_cast<int>(m),
...@@ -495,7 +495,7 @@ void GemvTest(int m, int n, bool trans) { ...@@ -495,7 +495,7 @@ void GemvTest(int m, int n, bool trans) {
0., 0.,
g_data_c); g_data_c);
paddle::framework::TensorCopySync(g_vec_c, cpu_place, &vec_c); phi::Copy(*context, g_vec_c, cpu_place, true, &vec_c);
if (!trans) { if (!trans) {
for (int i = 0; i < m; ++i) { for (int i = 0; i < m; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册