未验证 提交 75690584 编写于 作者: F fwenguang 提交者: GitHub

[MLU] fix copy error (#45194)

上级 216d25ac
...@@ -45,6 +45,7 @@ class MLUGaussianRandomKernel : public framework::OpKernel<T> { ...@@ -45,6 +45,7 @@ class MLUGaussianRandomKernel : public framework::OpKernel<T> {
auto& dev_ctx = auto& dev_ctx =
context.template device_context<paddle::platform::MLUDeviceContext>(); context.template device_context<paddle::platform::MLUDeviceContext>();
framework::TensorCopy(cpu_tensor, context.GetPlace(), dev_ctx, tensor); framework::TensorCopy(cpu_tensor, context.GetPlace(), dev_ctx, tensor);
dev_ctx.Wait();
} }
}; };
......
...@@ -68,6 +68,7 @@ class AdamMLUKernel : public framework::OpKernel<T> { ...@@ -68,6 +68,7 @@ class AdamMLUKernel : public framework::OpKernel<T> {
std::vector<bool> skip_update_vec; std::vector<bool> skip_update_vec;
paddle::framework::TensorToVector( paddle::framework::TensorToVector(
*skip_update_tensor, ctx.device_context(), &skip_update_vec); *skip_update_tensor, ctx.device_context(), &skip_update_vec);
ctx.device_context().Wait();
skip_update = skip_update_vec[0]; skip_update = skip_update_vec[0];
} }
// skip_update=true, just copy input to output, and TensorCopy will call // skip_update=true, just copy input to output, and TensorCopy will call
...@@ -286,6 +287,7 @@ class AdamWMLUKernel : public AdamMLUKernel<T> { ...@@ -286,6 +287,7 @@ class AdamWMLUKernel : public AdamMLUKernel<T> {
std::vector<bool> skip_update_vec; std::vector<bool> skip_update_vec;
paddle::framework::TensorToVector( paddle::framework::TensorToVector(
*skip_update_tensor, ctx.device_context(), &skip_update_vec); *skip_update_tensor, ctx.device_context(), &skip_update_vec);
ctx.device_context().Wait();
skip_update = skip_update_vec[0]; skip_update = skip_update_vec[0];
} }
bool with_decay = ctx.Attr<bool>("with_decay"); bool with_decay = ctx.Attr<bool>("with_decay");
......
...@@ -439,7 +439,11 @@ void SetTensorFromPyArrayT( ...@@ -439,7 +439,11 @@ void SetTensorFromPyArrayT(
platform::Place tmp_place = place; platform::Place tmp_place = place;
platform::MLUDeviceGuard guard(tmp_place.device); platform::MLUDeviceGuard guard(tmp_place.device);
auto dst = self->mutable_data<T>(place); auto dst = self->mutable_data<T>(place);
paddle::platform::MLUMemcpyH2DSync(dst, array.data(), array.nbytes()); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto dev_ctx = static_cast<platform::MLUDeviceContext *>(pool.Get(place));
paddle::platform::MLUMemcpyH2DAsync(
dst, array.data(), array.nbytes(), dev_ctx->stream());
dev_ctx->Wait();
#else #else
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Cannot use MLUPlace in CPU/GPU version, " "Cannot use MLUPlace in CPU/GPU version, "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册