提交 7adfb613 编写于 作者: 郭叶军's avatar 郭叶军 提交者: Pedro Arthur

libavfilter/dnn: avoid memcpy for tensorflow dnn output

use TF_Tensor's cpu address to avoid extra memcpy.
Signed-off-by: 郭叶军's avatarGuo, Yejun <yejun.guo@intel.com>
Signed-off-by: NPedro Arthur <bygrandao@gmail.com>
上级 e2b92896
...@@ -35,6 +35,7 @@ typedef struct TFModel{ ...@@ -35,6 +35,7 @@ typedef struct TFModel{
TF_Status *status; TF_Status *status;
TF_Output input, output; TF_Output input, output;
TF_Tensor *input_tensor; TF_Tensor *input_tensor;
TF_Tensor *output_tensor;
} TFModel; } TFModel;
static void free_buffer(void *data, size_t length) static void free_buffer(void *data, size_t length)
...@@ -460,13 +461,11 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename) ...@@ -460,13 +461,11 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename)
return NULL; return NULL;
} }
tf_model = av_malloc(sizeof(TFModel)); tf_model = av_mallocz(sizeof(TFModel));
if (!tf_model){ if (!tf_model){
av_freep(&model); av_freep(&model);
return NULL; return NULL;
} }
tf_model->session = NULL;
tf_model->input_tensor = NULL;
if (load_tf_model(tf_model, model_filename) != DNN_SUCCESS){ if (load_tf_model(tf_model, model_filename) != DNN_SUCCESS){
if (load_native_model(tf_model, model_filename) != DNN_SUCCESS){ if (load_native_model(tf_model, model_filename) != DNN_SUCCESS){
...@@ -488,36 +487,22 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename) ...@@ -488,36 +487,22 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename)
DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *output) DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *output)
{ {
TFModel *tf_model = (TFModel *)model->model; TFModel *tf_model = (TFModel *)model->model;
TF_Tensor *output_tensor; if (tf_model->output_tensor)
uint64_t count; TF_DeleteTensor(tf_model->output_tensor);
uint64_t old_count = output->height * output->width * output->channels * sizeof(float);
TF_SessionRun(tf_model->session, NULL, TF_SessionRun(tf_model->session, NULL,
&tf_model->input, &tf_model->input_tensor, 1, &tf_model->input, &tf_model->input_tensor, 1,
&tf_model->output, &output_tensor, 1, &tf_model->output, &tf_model->output_tensor, 1,
NULL, 0, NULL, tf_model->status); NULL, 0, NULL, tf_model->status);
if (TF_GetCode(tf_model->status) != TF_OK){ if (TF_GetCode(tf_model->status) != TF_OK){
return DNN_ERROR; return DNN_ERROR;
} }
output->height = TF_Dim(output_tensor, 1); output->height = TF_Dim(tf_model->output_tensor, 1);
output->width = TF_Dim(output_tensor, 2); output->width = TF_Dim(tf_model->output_tensor, 2);
output->channels = TF_Dim(output_tensor, 3); output->channels = TF_Dim(tf_model->output_tensor, 3);
count = output->height * output->width * output->channels * sizeof(float); output->data = TF_TensorData(tf_model->output_tensor);
if (output->data) {
if (count > old_count) {
av_freep(&output->data);
}
}
if (!output->data) {
output->data = av_malloc(count);
if (!output->data){
return DNN_ERROR;
}
}
memcpy(output->data, TF_TensorData(output_tensor), count);
TF_DeleteTensor(output_tensor);
return DNN_SUCCESS; return DNN_SUCCESS;
} }
...@@ -541,6 +526,9 @@ void ff_dnn_free_model_tf(DNNModel **model) ...@@ -541,6 +526,9 @@ void ff_dnn_free_model_tf(DNNModel **model)
if (tf_model->input_tensor){ if (tf_model->input_tensor){
TF_DeleteTensor(tf_model->input_tensor); TF_DeleteTensor(tf_model->input_tensor);
} }
if (tf_model->output_tensor){
TF_DeleteTensor(tf_model->output_tensor);
}
av_freep(&tf_model); av_freep(&tf_model);
av_freep(model); av_freep(model);
} }
......
...@@ -274,9 +274,6 @@ static av_cold void uninit(AVFilterContext *context) ...@@ -274,9 +274,6 @@ static av_cold void uninit(AVFilterContext *context)
int i; int i;
SRContext *sr_context = context->priv; SRContext *sr_context = context->priv;
if (sr_context->backend_type == DNN_TF)
av_freep(&sr_context->output.data);
if (sr_context->dnn_module){ if (sr_context->dnn_module){
(sr_context->dnn_module->free_model)(&sr_context->model); (sr_context->dnn_module->free_model)(&sr_context->model);
av_freep(&sr_context->dnn_module); av_freep(&sr_context->dnn_module);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册