From 2781740ba778f7e3ba719e04f9b6da4489ba156d Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Thu, 28 Jul 2022 15:06:07 +0800 Subject: [PATCH] fix bugs of lstsq (#44689) --- paddle/fluid/operators/lstsq_op.cu | 2 +- paddle/fluid/operators/lstsq_op.h | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/lstsq_op.cu b/paddle/fluid/operators/lstsq_op.cu index 82a56af7eb..f063716b20 100644 --- a/paddle/fluid/operators/lstsq_op.cu +++ b/paddle/fluid/operators/lstsq_op.cu @@ -157,7 +157,7 @@ class LstsqCUDAKernel : public framework::OpKernel { Tensor trans_q = dito.Transpose(new_x); Tensor slice_q = dito.Slice(trans_q, {-1}, {0}, {m}); Tensor solu_tensor = dito.Matmul(slice_q, *solution, false, false); - framework::TensorCopy(solu_tensor, solution->place(), solution); + framework::TensorCopy(solu_tensor, context.GetPlace(), solution); } } }; diff --git a/paddle/fluid/operators/lstsq_op.h b/paddle/fluid/operators/lstsq_op.h index b3e5894a94..7e71d17364 100644 --- a/paddle/fluid/operators/lstsq_op.h +++ b/paddle/fluid/operators/lstsq_op.h @@ -112,8 +112,8 @@ class LstsqCPUKernel : public framework::OpKernel { Tensor input_x_trans = dito.Transpose(new_x); Tensor input_y_trans = dito.Transpose(*solution); - framework::TensorCopy(input_x_trans, new_x.place(), &new_x); - framework::TensorCopy(input_y_trans, solution->place(), solution); + framework::TensorCopy(input_x_trans, context.GetPlace(), &new_x); + framework::TensorCopy(input_y_trans, context.GetPlace(), solution); auto* x_vector = new_x.data(); auto* y_vector = solution->data(); @@ -310,7 +310,7 @@ class LstsqCPUKernel : public framework::OpKernel { } Tensor tmp_s = dito.Transpose(*solution); - framework::TensorCopy(tmp_s, solution->place(), solution); + framework::TensorCopy(tmp_s, context.GetPlace(), solution); if (m > n) { auto* solu_data = solution->data(); -- GitLab