未验证 提交 2781740b 编写于 作者: H Haohongxiang 提交者: GitHub

fix bugs of lstsq (#44689)

上级 55aaeb39
......@@ -157,7 +157,7 @@ class LstsqCUDAKernel : public framework::OpKernel<T> {
Tensor trans_q = dito.Transpose(new_x);
Tensor slice_q = dito.Slice(trans_q, {-1}, {0}, {m});
Tensor solu_tensor = dito.Matmul(slice_q, *solution, false, false);
framework::TensorCopy(solu_tensor, solution->place(), solution);
framework::TensorCopy(solu_tensor, context.GetPlace(), solution);
}
}
};
......
......@@ -112,8 +112,8 @@ class LstsqCPUKernel : public framework::OpKernel<T> {
Tensor input_x_trans = dito.Transpose(new_x);
Tensor input_y_trans = dito.Transpose(*solution);
framework::TensorCopy(input_x_trans, new_x.place(), &new_x);
framework::TensorCopy(input_y_trans, solution->place(), solution);
framework::TensorCopy(input_x_trans, context.GetPlace(), &new_x);
framework::TensorCopy(input_y_trans, context.GetPlace(), solution);
auto* x_vector = new_x.data<T>();
auto* y_vector = solution->data<T>();
......@@ -310,7 +310,7 @@ class LstsqCPUKernel : public framework::OpKernel<T> {
}
Tensor tmp_s = dito.Transpose(*solution);
framework::TensorCopy(tmp_s, solution->place(), solution);
framework::TensorCopy(tmp_s, context.GetPlace(), solution);
if (m > n) {
auto* solu_data = solution->data<T>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册