提交 17966d76 编写于 作者: S shippingwang

refine code

上级 cd1a1963
......@@ -432,10 +432,49 @@ class TallMetrics(Metrics):
self.calculator = tall_metrics.MetricsCalculator(cfg=cfg, name=self.name, mode=self.mode)
def calculator_and_log_out(self, fetch_list, info=""):
loss = np.array(fetch_list[0])
logger.info(info +'\tLoss = {}'.format('%.6f' % np.mean(loss)))
if self.mode == "train":
loss = np.array(fetch_list[0])
logger.info(info +'\tLoss = {}'.format('%.6f' % np.mean(loss)))
elif self.mode == "valid":
outs = fetch_list[0]
outputs = np.squeeze(outs)
start = fetch_list[1]
end = fetch_list[2]
k = fetch_list[3]
t = fetch_list[4]
movie_clip_sentences = fetch_list[5]
movie_clip_featmaps = fetch_lkist[6]
sentence_image_mat = np.zeros([len(movie_clip_sentences), len(movie_clip_featmaps)])
sentence_image_reg_mat = np.zeros([len(movie_clip_sentences), len(movie_clip_featmaps ), 2])
sentence_image_mat[k, t] = outputs[0]
# sentence_image_mat[k, t] = expit(outputs[0]) * conf_score
reg_end = end + outputs[2]
reg_start = start + outputs[1]
sentence_image_reg_mat[k, t, 0] = reg_start
sentence_image_reg_mat[k, t, 1] = reg_end
clips = [b[0] for b in movie_clip_featmaps]
sclips = [b[0] for b in movie_clip_sentences]
for i in range(len(IoU_thresh)):
IoU = IoU_thresh[i]
correct_num_10 = compute_IoU_recall_top_n_forreg(10, IoU, sentence_image_mat, sentence_image_reg_mat, sclips, iclips)
correct_num_5 = compute_IoU_recall_top_n_forreg(5, IoU, sentence_image_mat, sentence_image_reg_mat, sclips, iclips)
correct_num_1 = compute_IoU_recall_top_n_forreg(1, IoU, sentence_image_mat, sentence_image_reg_mat, sclips, iclips)
logger.info(info + " IoU=" + str(IoU) + ", R@10: " + str(correct_num_10 / len(sclips)) + "; IoU=" + str(IoU) + ", R@5: " + str(correct_num_5 / len(sclips)) + "; IoU=" + str(IoU) + ", R@1: " + str(correct_num_1 / len(sclips)))
all_correct_num_10[i] += correct_num_10
all_correct_num_5[i] += correct_num_5
all_correct_num_1[i] += correct_num_1
all_retrievd += len(sclips)
else:
pass
def accumalate()
def accumalate():
def finalize_and_log_out(self, info="", savedir="/"):
......
......@@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
import numpy as np
from six.moves import xrange
import time
import pickle
import operator
import datetime
import logging
class MetricsCalculator():
......@@ -27,7 +31,7 @@ class MetricsCalculator():
def reset(self):
logger.info("Resetting {} metrics...".format(self.mode))
return
def finalize_metrics(self):
return
def calculate_metrics(self,):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册