From ef53e1b4f8a355a1d8452640ebc740e6d089aa23 Mon Sep 17 00:00:00 2001 From: TeFeng Chen Date: Thu, 8 Sep 2022 14:21:53 +0800 Subject: [PATCH] cinn_launch op: fix dtype of tensor is always mutable_data (#45835) --- paddle/fluid/operators/cinn/cinn_launch_context.cc | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/cinn/cinn_launch_context.cc b/paddle/fluid/operators/cinn/cinn_launch_context.cc index 3c98399492..a796050c7e 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_context.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_context.cc @@ -270,8 +270,9 @@ void CinnLaunchContext::AssignExternalVariable(const std::string& var_name) { [this, var_name](void* ctx, cinn_buffer_t* buffer) { auto* tensor = cached_scope_->GetVar(var_name)->GetMutable(); tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions)); - buffer->memory = reinterpret_cast( - tensor->mutable_data(*cached_place_)); + buffer->memory = reinterpret_cast(tensor->mutable_data( + *cached_place_, + framework::paddle2cinn::TransToPaddleDataType(buffer->type))); return 0; }); @@ -295,8 +296,9 @@ void CinnLaunchContext::AssignInternalVariable(const std::string& var_name) { auto* tensor = cached_temp_scope_->Var(var_name)->GetMutable(); tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions)); - buffer->memory = reinterpret_cast( - tensor->mutable_data(*cached_place_)); + buffer->memory = reinterpret_cast(tensor->mutable_data( + *cached_place_, + framework::paddle2cinn::TransToPaddleDataType(buffer->type))); return 0; }); @@ -437,7 +439,8 @@ ParallelExecutor* CinnLaunchContext::InitializePE(const platform::Place& place, auto* buffer = GetCinnBufferOfVar(var_name); auto dim = framework::DDim(buffer->dims, buffer->dimensions); var->GetMutable()->Resize(dim); - var->GetMutable()->mutable_data(place); + var->GetMutable()->mutable_data( + place, framework::paddle2cinn::TransToPaddleDataType(buffer->type)); } return parallel_executor_.get(); } -- GitLab