未验证 提交 88216f63 编写于 作者: Z zyfncg 提交者: GitHub

fix tensor copy bug (#43299)

上级 99c6497b
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/phi/api/lib/tensor_copy.h" #include "paddle/phi/api/lib/tensor_copy.h"
#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
...@@ -24,18 +25,21 @@ limitations under the License. */ ...@@ -24,18 +25,21 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
void copy(const Tensor& src, Place place, bool blocking, Tensor* dst) { void copy(const Tensor& src, const Place& place, bool blocking, Tensor* dst) {
auto kernel_key_set = ParseKernelKeyByInputArgs(src); auto kernel_key_set = ParseKernelKeyByInputArgs(src);
kernel_key_set.backend_set = kernel_key_set.backend_set =
kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place)); kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place));
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"copy", kernel_key);
VLOG(6) << "copy API kernel key: " << kernel_key; VLOG(6) << "copy API kernel key: " << kernel_key;
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"copy", kernel_key);
VLOG(6) << "copy API kernel: " << kernel; VLOG(6) << "copy API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); auto target_place = phi::TransToPhiPlace(kernel_key.backend());
auto& pool = paddle::experimental::DeviceContextPool::Instance();
auto* dev_ctx = pool.GetMutable(
target_place.GetType() == place.GetType() ? place : target_place);
auto dense_x = TensorToDenseTensor(src); auto dense_x = TensorToDenseTensor(src);
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
void copy(const Tensor& src, Place place, bool blocking, Tensor* dst); void copy(const Tensor& src, const Place& place, bool blocking, Tensor* dst);
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册