diff --git a/paddle/fluid/operators/cinn/cinn_launch_context.cc b/paddle/fluid/operators/cinn/cinn_launch_context.cc index 3c983994925144fd7210dd152602388c29a3587a..a796050c7ec2ac8d9d12ccddd887e2a069f00382 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_context.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_context.cc @@ -270,8 +270,9 @@ void CinnLaunchContext::AssignExternalVariable(const std::string& var_name) { [this, var_name](void* ctx, cinn_buffer_t* buffer) { auto* tensor = cached_scope_->GetVar(var_name)->GetMutable(); tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions)); - buffer->memory = reinterpret_cast( - tensor->mutable_data(*cached_place_)); + buffer->memory = reinterpret_cast(tensor->mutable_data( + *cached_place_, + framework::paddle2cinn::TransToPaddleDataType(buffer->type))); return 0; }); @@ -295,8 +296,9 @@ void CinnLaunchContext::AssignInternalVariable(const std::string& var_name) { auto* tensor = cached_temp_scope_->Var(var_name)->GetMutable(); tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions)); - buffer->memory = reinterpret_cast( - tensor->mutable_data(*cached_place_)); + buffer->memory = reinterpret_cast(tensor->mutable_data( + *cached_place_, + framework::paddle2cinn::TransToPaddleDataType(buffer->type))); return 0; }); @@ -437,7 +439,8 @@ ParallelExecutor* CinnLaunchContext::InitializePE(const platform::Place& place, auto* buffer = GetCinnBufferOfVar(var_name); auto dim = framework::DDim(buffer->dims, buffer->dimensions); var->GetMutable()->Resize(dim); - var->GetMutable()->mutable_data(place); + var->GetMutable()->mutable_data( + place, framework::paddle2cinn::TransToPaddleDataType(buffer->type)); } return parallel_executor_.get(); }