未验证 提交 75690584 编写于 作者: F fwenguang 提交者: GitHub

[MLU] fix copy error (#45194)

上级 216d25ac
......@@ -45,6 +45,7 @@ class MLUGaussianRandomKernel : public framework::OpKernel<T> {
auto& dev_ctx =
context.template device_context<paddle::platform::MLUDeviceContext>();
framework::TensorCopy(cpu_tensor, context.GetPlace(), dev_ctx, tensor);
dev_ctx.Wait();
}
};
......
......@@ -68,6 +68,7 @@ class AdamMLUKernel : public framework::OpKernel<T> {
std::vector<bool> skip_update_vec;
paddle::framework::TensorToVector(
*skip_update_tensor, ctx.device_context(), &skip_update_vec);
ctx.device_context().Wait();
skip_update = skip_update_vec[0];
}
// skip_update=true, just copy input to output, and TensorCopy will call
......@@ -286,6 +287,7 @@ class AdamWMLUKernel : public AdamMLUKernel<T> {
std::vector<bool> skip_update_vec;
paddle::framework::TensorToVector(
*skip_update_tensor, ctx.device_context(), &skip_update_vec);
ctx.device_context().Wait();
skip_update = skip_update_vec[0];
}
bool with_decay = ctx.Attr<bool>("with_decay");
......
......@@ -439,7 +439,11 @@ void SetTensorFromPyArrayT(
platform::Place tmp_place = place;
platform::MLUDeviceGuard guard(tmp_place.device);
auto dst = self->mutable_data<T>(place);
paddle::platform::MLUMemcpyH2DSync(dst, array.data(), array.nbytes());
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto dev_ctx = static_cast<platform::MLUDeviceContext *>(pool.Get(place));
paddle::platform::MLUMemcpyH2DAsync(
dst, array.data(), array.nbytes(), dev_ctx->stream());
dev_ctx->Wait();
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Cannot use MLUPlace in CPU/GPU version, "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册