diff --git a/paddle/fluid/operators/gaussian_random_op_mlu.cc b/paddle/fluid/operators/gaussian_random_op_mlu.cc index 90b29892a3ed68747c83e9df1d2444300f5db577..4b5229b9e63ea18b8d586f6b25ed4217939a0939 100644 --- a/paddle/fluid/operators/gaussian_random_op_mlu.cc +++ b/paddle/fluid/operators/gaussian_random_op_mlu.cc @@ -45,6 +45,7 @@ class MLUGaussianRandomKernel : public framework::OpKernel { auto& dev_ctx = context.template device_context(); framework::TensorCopy(cpu_tensor, context.GetPlace(), dev_ctx, tensor); + dev_ctx.Wait(); } }; diff --git a/paddle/fluid/operators/optimizers/adam_op_mlu.cc b/paddle/fluid/operators/optimizers/adam_op_mlu.cc index aa62e2412b68e033f20a103d33f520ccd953ffc7..ecc527d5c72bf0757b71b382febe29e4c594a175 100644 --- a/paddle/fluid/operators/optimizers/adam_op_mlu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_mlu.cc @@ -68,6 +68,7 @@ class AdamMLUKernel : public framework::OpKernel { std::vector skip_update_vec; paddle::framework::TensorToVector( *skip_update_tensor, ctx.device_context(), &skip_update_vec); + ctx.device_context().Wait(); skip_update = skip_update_vec[0]; } // skip_update=true, just copy input to output, and TensorCopy will call @@ -286,6 +287,7 @@ class AdamWMLUKernel : public AdamMLUKernel { std::vector skip_update_vec; paddle::framework::TensorToVector( *skip_update_tensor, ctx.device_context(), &skip_update_vec); + ctx.device_context().Wait(); skip_update = skip_update_vec[0]; } bool with_decay = ctx.Attr("with_decay"); diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 3b0a9f8fb0ce295c76752e7b0a44c9b50f147f4e..4a8ef8795e089267f586766aba6dd11b1146e9bc 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -439,7 +439,11 @@ void SetTensorFromPyArrayT( platform::Place tmp_place = place; platform::MLUDeviceGuard guard(tmp_place.device); auto dst = self->mutable_data(place); - paddle::platform::MLUMemcpyH2DSync(dst, array.data(), array.nbytes()); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto dev_ctx = static_cast(pool.Get(place)); + paddle::platform::MLUMemcpyH2DAsync( + dst, array.data(), array.nbytes(), dev_ctx->stream()); + dev_ctx->Wait(); #else PADDLE_THROW(platform::errors::PermissionDenied( "Cannot use MLUPlace in CPU/GPU version, "