未验证 提交 8760817a 编写于 作者: Z zyfncg 提交者: GitHub

fix tensor copy bug (#43299) (#43728)

上级 a4c898cf
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/api/lib/tensor_copy.h"
#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/utils/storage.h"
......@@ -24,18 +26,21 @@ limitations under the License. */
namespace paddle {
namespace experimental {
void copy(const Tensor& src, Place place, bool blocking, Tensor* dst) {
void copy(const Tensor& src, const Place& place, bool blocking, Tensor* dst) {
auto kernel_key_set = ParseKernelKeyByInputArgs(src);
kernel_key_set.backend_set =
kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place));
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"copy", kernel_key);
VLOG(6) << "copy API kernel key: " << kernel_key;
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"copy", kernel_key);
VLOG(6) << "copy API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto target_place = phi::TransToPhiPlace(kernel_key.backend());
auto& pool = paddle::experimental::DeviceContextPool::Instance();
auto* dev_ctx = pool.GetMutable(
target_place.GetType() == place.GetType() ? place : target_place);
auto dense_x = TensorToDenseTensor(src);
......
......@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle {
namespace experimental {
void copy(const Tensor& src, Place place, bool blocking, Tensor* dst);
void copy(const Tensor& src, const Place& place, bool blocking, Tensor* dst);
} // namespace experimental
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册