未验证 提交 ef53e1b4 编写于 作者: T TeFeng Chen 提交者: GitHub

cinn_launch op: fix dtype of tensor is always mutable_data<float> (#45835)

上级 2b0857be
...@@ -270,8 +270,9 @@ void CinnLaunchContext::AssignExternalVariable(const std::string& var_name) { ...@@ -270,8 +270,9 @@ void CinnLaunchContext::AssignExternalVariable(const std::string& var_name) {
[this, var_name](void* ctx, cinn_buffer_t* buffer) { [this, var_name](void* ctx, cinn_buffer_t* buffer) {
auto* tensor = cached_scope_->GetVar(var_name)->GetMutable<LoDTensor>(); auto* tensor = cached_scope_->GetVar(var_name)->GetMutable<LoDTensor>();
tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions)); tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
buffer->memory = reinterpret_cast<uint8_t*>( buffer->memory = reinterpret_cast<uint8_t*>(tensor->mutable_data(
tensor->mutable_data<float>(*cached_place_)); *cached_place_,
framework::paddle2cinn::TransToPaddleDataType(buffer->type)));
return 0; return 0;
}); });
...@@ -295,8 +296,9 @@ void CinnLaunchContext::AssignInternalVariable(const std::string& var_name) { ...@@ -295,8 +296,9 @@ void CinnLaunchContext::AssignInternalVariable(const std::string& var_name) {
auto* tensor = auto* tensor =
cached_temp_scope_->Var(var_name)->GetMutable<LoDTensor>(); cached_temp_scope_->Var(var_name)->GetMutable<LoDTensor>();
tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions)); tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
buffer->memory = reinterpret_cast<uint8_t*>( buffer->memory = reinterpret_cast<uint8_t*>(tensor->mutable_data(
tensor->mutable_data<float>(*cached_place_)); *cached_place_,
framework::paddle2cinn::TransToPaddleDataType(buffer->type)));
return 0; return 0;
}); });
...@@ -437,7 +439,8 @@ ParallelExecutor* CinnLaunchContext::InitializePE(const platform::Place& place, ...@@ -437,7 +439,8 @@ ParallelExecutor* CinnLaunchContext::InitializePE(const platform::Place& place,
auto* buffer = GetCinnBufferOfVar(var_name); auto* buffer = GetCinnBufferOfVar(var_name);
auto dim = framework::DDim(buffer->dims, buffer->dimensions); auto dim = framework::DDim(buffer->dims, buffer->dimensions);
var->GetMutable<LoDTensor>()->Resize(dim); var->GetMutable<LoDTensor>()->Resize(dim);
var->GetMutable<LoDTensor>()->mutable_data<float>(place); var->GetMutable<LoDTensor>()->mutable_data(
place, framework::paddle2cinn::TransToPaddleDataType(buffer->type));
} }
return parallel_executor_.get(); return parallel_executor_.get();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册