From 756905848d0cfdce6adcfec35ba57657a4fb842b Mon Sep 17 00:00:00 2001 From: fwenguang <95677191+fwenguang@users.noreply.github.com> Date: Wed, 17 Aug 2022 10:33:20 +0800 Subject: [PATCH] [MLU] fix copy error (#45194) --- paddle/fluid/operators/gaussian_random_op_mlu.cc | 1 + paddle/fluid/operators/optimizers/adam_op_mlu.cc | 2 ++ paddle/fluid/pybind/tensor_py.h | 6 +++++- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/gaussian_random_op_mlu.cc b/paddle/fluid/operators/gaussian_random_op_mlu.cc index 90b29892a3..4b5229b9e6 100644 --- a/paddle/fluid/operators/gaussian_random_op_mlu.cc +++ b/paddle/fluid/operators/gaussian_random_op_mlu.cc @@ -45,6 +45,7 @@ class MLUGaussianRandomKernel : public framework::OpKernel { auto& dev_ctx = context.template device_context(); framework::TensorCopy(cpu_tensor, context.GetPlace(), dev_ctx, tensor); + dev_ctx.Wait(); } }; diff --git a/paddle/fluid/operators/optimizers/adam_op_mlu.cc b/paddle/fluid/operators/optimizers/adam_op_mlu.cc index aa62e2412b..ecc527d5c7 100644 --- a/paddle/fluid/operators/optimizers/adam_op_mlu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_mlu.cc @@ -68,6 +68,7 @@ class AdamMLUKernel : public framework::OpKernel { std::vector skip_update_vec; paddle::framework::TensorToVector( *skip_update_tensor, ctx.device_context(), &skip_update_vec); + ctx.device_context().Wait(); skip_update = skip_update_vec[0]; } // skip_update=true, just copy input to output, and TensorCopy will call @@ -286,6 +287,7 @@ class AdamWMLUKernel : public AdamMLUKernel { std::vector skip_update_vec; paddle::framework::TensorToVector( *skip_update_tensor, ctx.device_context(), &skip_update_vec); + ctx.device_context().Wait(); skip_update = skip_update_vec[0]; } bool with_decay = ctx.Attr("with_decay"); diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 3b0a9f8fb0..4a8ef8795e 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -439,7 +439,11 @@ void SetTensorFromPyArrayT( platform::Place tmp_place = place; platform::MLUDeviceGuard guard(tmp_place.device); auto dst = self->mutable_data(place); - paddle::platform::MLUMemcpyH2DSync(dst, array.data(), array.nbytes()); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto dev_ctx = static_cast(pool.Get(place)); + paddle::platform::MLUMemcpyH2DAsync( + dst, array.data(), array.nbytes(), dev_ctx->stream()); + dev_ctx->Wait(); #else PADDLE_THROW(platform::errors::PermissionDenied( "Cannot use MLUPlace in CPU/GPU version, " -- GitLab