未验证 提交 ef53e1b4 编写于 作者: T TeFeng Chen 提交者: GitHub

cinn_launch op: fix dtype of tensor is always mutable_data<float> (#45835)

上级 2b0857be
......@@ -270,8 +270,9 @@ void CinnLaunchContext::AssignExternalVariable(const std::string& var_name) {
[this, var_name](void* ctx, cinn_buffer_t* buffer) {
auto* tensor = cached_scope_->GetVar(var_name)->GetMutable<LoDTensor>();
tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
buffer->memory = reinterpret_cast<uint8_t*>(
tensor->mutable_data<float>(*cached_place_));
buffer->memory = reinterpret_cast<uint8_t*>(tensor->mutable_data(
*cached_place_,
framework::paddle2cinn::TransToPaddleDataType(buffer->type)));
return 0;
});
......@@ -295,8 +296,9 @@ void CinnLaunchContext::AssignInternalVariable(const std::string& var_name) {
auto* tensor =
cached_temp_scope_->Var(var_name)->GetMutable<LoDTensor>();
tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
buffer->memory = reinterpret_cast<uint8_t*>(
tensor->mutable_data<float>(*cached_place_));
buffer->memory = reinterpret_cast<uint8_t*>(tensor->mutable_data(
*cached_place_,
framework::paddle2cinn::TransToPaddleDataType(buffer->type)));
return 0;
});
......@@ -437,7 +439,8 @@ ParallelExecutor* CinnLaunchContext::InitializePE(const platform::Place& place,
auto* buffer = GetCinnBufferOfVar(var_name);
auto dim = framework::DDim(buffer->dims, buffer->dimensions);
var->GetMutable<LoDTensor>()->Resize(dim);
var->GetMutable<LoDTensor>()->mutable_data<float>(place);
var->GetMutable<LoDTensor>()->mutable_data(
place, framework::paddle2cinn::TransToPaddleDataType(buffer->type));
}
return parallel_executor_.get();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册