diff --git a/lite/kernels/mlu/bridges/graph.h b/lite/kernels/mlu/bridges/graph.h index 140900a2dde004281945e50fb1c72d09b58befa1..bc4da057bfd39e120a48aa7b0525710c2282358b 100644 --- a/lite/kernels/mlu/bridges/graph.h +++ b/lite/kernels/mlu/bridges/graph.h @@ -35,6 +35,7 @@ class Graph { Graph() { CNML_CALL(cnmlCreateFusionOp(&fusion_op_)); } ~Graph() { + FreeConstData(); CNML_CALL(cnmlDestroyFusionOp(&fusion_op_)); for (auto op : ops_) { CNML_CALL(cnmlDestroyBaseOp(&op)); @@ -99,6 +100,49 @@ class Graph { CNRT_CALL(cnrtSyncQueue(que)); } + template + void* RegisterConstData(size_t len) { + void* addr = malloc(len * sizeof(T)); + const_data_storage_.push_back(addr); + return addr; + } + + void FreeConstData() { + for (auto& addr : const_data_storage_) { + free(addr); + } + } + + void BindConstRawData(std::string tensor_name, + const float* data, + size_t len, + bool alloc = true) { + void* alloc_data; + if (fp_type_ == CNML_DATA_FLOAT32) { + if (alloc) { + alloc_data = RegisterConstData(len); + memcpy(alloc_data, data, len * sizeof(float)); + } else { + alloc_data = const_cast(static_cast(data)); + } + CNML_CALL(cnmlBindConstData_V2( + nodes_[tensor_name]->mlu_tensor(), alloc_data, false)); + } else if (fp_type_ == CNML_DATA_FLOAT16) { + void* data_fp16 = RegisterConstData<::paddle::lite::fluid::float16>(len); + CNRT_CALL( + cnrtCastDataType(const_cast(static_cast(data)), + CNRT_FLOAT32, + data_fp16, + CNRT_FLOAT16, + len, + nullptr)); + CNML_CALL(cnmlBindConstData_V2( + nodes_[tensor_name]->mlu_tensor(), data_fp16, false)); + } else { + CHECK(0); + } + } + void BindConstData(std::string tensor_name, ::paddle::lite::Tensor* tensor) { const float* data = tensor->data(); size_t len = tensor->data_size(); @@ -158,6 +202,7 @@ class Graph { std::vector> output_tensors_; std::vector ops_; cnmlFusionOp_t fusion_op_; + std::vector const_data_storage_; }; } // namespace mlu