提交 ada6790e 编写于 作者: J jackzhang235

add BindConstRawData()

上级 922f1193
...@@ -35,6 +35,7 @@ class Graph { ...@@ -35,6 +35,7 @@ class Graph {
Graph() { CNML_CALL(cnmlCreateFusionOp(&fusion_op_)); } Graph() { CNML_CALL(cnmlCreateFusionOp(&fusion_op_)); }
~Graph() { ~Graph() {
FreeConstData();
CNML_CALL(cnmlDestroyFusionOp(&fusion_op_)); CNML_CALL(cnmlDestroyFusionOp(&fusion_op_));
for (auto op : ops_) { for (auto op : ops_) {
CNML_CALL(cnmlDestroyBaseOp(&op)); CNML_CALL(cnmlDestroyBaseOp(&op));
...@@ -99,6 +100,49 @@ class Graph { ...@@ -99,6 +100,49 @@ class Graph {
CNRT_CALL(cnrtSyncQueue(que)); CNRT_CALL(cnrtSyncQueue(que));
} }
template <typename T>
void* RegisterConstData(size_t len) {
void* addr = malloc(len * sizeof(T));
const_data_storage_.push_back(addr);
return addr;
}
void FreeConstData() {
for (auto& addr : const_data_storage_) {
free(addr);
}
}
void BindConstRawData(std::string tensor_name,
const float* data,
size_t len,
bool alloc = true) {
void* alloc_data;
if (fp_type_ == CNML_DATA_FLOAT32) {
if (alloc) {
alloc_data = RegisterConstData<float>(len);
memcpy(alloc_data, data, len * sizeof(float));
} else {
alloc_data = const_cast<void*>(static_cast<const void*>(data));
}
CNML_CALL(cnmlBindConstData_V2(
nodes_[tensor_name]->mlu_tensor(), alloc_data, false));
} else if (fp_type_ == CNML_DATA_FLOAT16) {
void* data_fp16 = RegisterConstData<::paddle::lite::fluid::float16>(len);
CNRT_CALL(
cnrtCastDataType(const_cast<void*>(static_cast<const void*>(data)),
CNRT_FLOAT32,
data_fp16,
CNRT_FLOAT16,
len,
nullptr));
CNML_CALL(cnmlBindConstData_V2(
nodes_[tensor_name]->mlu_tensor(), data_fp16, false));
} else {
CHECK(0);
}
}
void BindConstData(std::string tensor_name, ::paddle::lite::Tensor* tensor) { void BindConstData(std::string tensor_name, ::paddle::lite::Tensor* tensor) {
const float* data = tensor->data<float>(); const float* data = tensor->data<float>();
size_t len = tensor->data_size(); size_t len = tensor->data_size();
...@@ -158,6 +202,7 @@ class Graph { ...@@ -158,6 +202,7 @@ class Graph {
std::vector<std::shared_ptr<MLUTensor>> output_tensors_; std::vector<std::shared_ptr<MLUTensor>> output_tensors_;
std::vector<cnmlBaseOp_t> ops_; std::vector<cnmlBaseOp_t> ops_;
cnmlFusionOp_t fusion_op_; cnmlFusionOp_t fusion_op_;
std::vector<void*> const_data_storage_;
}; };
} // namespace mlu } // namespace mlu
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册