From 01a9646323e306c505efde60aa1c67669052de8f Mon Sep 17 00:00:00 2001 From: Zhang Ting <709968123@qq.com> Date: Wed, 20 Nov 2019 11:07:25 +0800 Subject: [PATCH] optimize assign op to avoid copy data from GPU to GPU (#21181) * optimize assign op to avoid copy data from GPU to GPU, test=develop * modified GetkernelTypeForVar and just avoid device transform, test=develop --- paddle/fluid/operators/assign_op.cc | 8 ++++++++ paddle/fluid/operators/assign_op.h | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index 0a89b2e416b..5c69ad94b36 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -41,6 +41,14 @@ class AssignOp : public framework::OperatorWithKernel { } protected: + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + return framework::OpKernelType(expected_kernel_type.data_type_, + expected_kernel_type.place_, + tensor.layout()); + } + framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( diff --git a/paddle/fluid/operators/assign_op.h b/paddle/fluid/operators/assign_op.h index 6718999d7f7..6ce04d19fc4 100644 --- a/paddle/fluid/operators/assign_op.h +++ b/paddle/fluid/operators/assign_op.h @@ -47,7 +47,7 @@ class AssignFunctor { out_rows.set_height(rows.height()); auto &t = rows.value(); auto *m = out_rows.mutable_value(); - framework::TensorCopy(t, t.place(), dev_ctx_, m); + framework::TensorCopy(t, dev_ctx_.GetPlace(), dev_ctx_, m); } template @@ -60,7 +60,7 @@ class AssignFunctor { framework::LoDTensor *out) const { if (lod_tensor.numel() == 0) return; auto &out_tensor = *out; - TensorCopy(lod_tensor, lod_tensor.place(), dev_ctx_, &out_tensor); + TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_, &out_tensor); out_tensor.set_lod(lod_tensor.lod()); } -- GitLab