diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index 0a89b2e416b09475b3f19785306199e554b9bf1d..5c69ad94b36c92a9ac8ddc32de56fe2ca1b37730 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -41,6 +41,14 @@ class AssignOp : public framework::OperatorWithKernel { } protected: + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + return framework::OpKernelType(expected_kernel_type.data_type_, + expected_kernel_type.place_, + tensor.layout()); + } + framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( diff --git a/paddle/fluid/operators/assign_op.h b/paddle/fluid/operators/assign_op.h index 6718999d7f70d1182de958c1b7f574284c7b449f..6ce04d19fc4376e4263712e2904e480e26590553 100644 --- a/paddle/fluid/operators/assign_op.h +++ b/paddle/fluid/operators/assign_op.h @@ -47,7 +47,7 @@ class AssignFunctor { out_rows.set_height(rows.height()); auto &t = rows.value(); auto *m = out_rows.mutable_value(); - framework::TensorCopy(t, t.place(), dev_ctx_, m); + framework::TensorCopy(t, dev_ctx_.GetPlace(), dev_ctx_, m); } template @@ -60,7 +60,7 @@ class AssignFunctor { framework::LoDTensor *out) const { if (lod_tensor.numel() == 0) return; auto &out_tensor = *out; - TensorCopy(lod_tensor, lod_tensor.place(), dev_ctx_, &out_tensor); + TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_, &out_tensor); out_tensor.set_lod(lod_tensor.lod()); }