提交 60a1408d 编写于 作者: 文幕地方's avatar 文幕地方

add eps

上级 7c04ff55
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
class ClsMetric(object): class ClsMetric(object):
def __init__(self, main_indicator='acc', **kwargs): def __init__(self, main_indicator='acc', **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.eps = 1e-5
self.reset() self.reset()
def __call__(self, pred_label, *args, **kwargs): def __call__(self, pred_label, *args, **kwargs):
...@@ -28,7 +29,7 @@ class ClsMetric(object): ...@@ -28,7 +29,7 @@ class ClsMetric(object):
all_num += 1 all_num += 1
self.correct_num += correct_num self.correct_num += correct_num
self.all_num += all_num self.all_num += all_num
return {'acc': correct_num / all_num, } return {'acc': correct_num / (all_num + self.eps), }
def get_metric(self): def get_metric(self):
""" """
...@@ -36,7 +37,7 @@ class ClsMetric(object): ...@@ -36,7 +37,7 @@ class ClsMetric(object):
'acc': 0 'acc': 0
} }
""" """
acc = self.correct_num / self.all_num acc = self.correct_num / (self.all_num + self.eps)
self.reset() self.reset()
return {'acc': acc} return {'acc': acc}
......
...@@ -20,6 +20,7 @@ class RecMetric(object): ...@@ -20,6 +20,7 @@ class RecMetric(object):
def __init__(self, main_indicator='acc', is_filter=False, **kwargs): def __init__(self, main_indicator='acc', is_filter=False, **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.is_filter = is_filter self.is_filter = is_filter
self.eps = 1e-5
self.reset() self.reset()
def _normalize_text(self, text): def _normalize_text(self, text):
...@@ -47,8 +48,8 @@ class RecMetric(object): ...@@ -47,8 +48,8 @@ class RecMetric(object):
self.all_num += all_num self.all_num += all_num
self.norm_edit_dis += norm_edit_dis self.norm_edit_dis += norm_edit_dis
return { return {
'acc': correct_num / all_num, 'acc': correct_num / (all_num + self.eps),
'norm_edit_dis': 1 - norm_edit_dis / (all_num + 1e-3) 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps)
} }
def get_metric(self): def get_metric(self):
...@@ -58,8 +59,8 @@ class RecMetric(object): ...@@ -58,8 +59,8 @@ class RecMetric(object):
'norm_edit_dis': 0, 'norm_edit_dis': 0,
} }
""" """
acc = 1.0 * self.correct_num / (self.all_num + 1e-3) acc = 1.0 * self.correct_num / (self.all_num + self.eps)
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + 1e-3) norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
self.reset() self.reset()
return {'acc': acc, 'norm_edit_dis': norm_edit_dis} return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
......
...@@ -12,9 +12,12 @@ ...@@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
class TableMetric(object): class TableMetric(object):
def __init__(self, main_indicator='acc', **kwargs): def __init__(self, main_indicator='acc', **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.eps = 1e-5
self.reset() self.reset()
def __call__(self, pred, batch, *args, **kwargs): def __call__(self, pred, batch, *args, **kwargs):
...@@ -31,9 +34,7 @@ class TableMetric(object): ...@@ -31,9 +34,7 @@ class TableMetric(object):
correct_num += 1 correct_num += 1
self.correct_num += correct_num self.correct_num += correct_num
self.all_num += all_num self.all_num += all_num
return { return {'acc': correct_num * 1.0 / (all_num + self.eps), }
'acc': correct_num * 1.0 / all_num,
}
def get_metric(self): def get_metric(self):
""" """
...@@ -41,7 +42,7 @@ class TableMetric(object): ...@@ -41,7 +42,7 @@ class TableMetric(object):
'acc': 0, 'acc': 0,
} }
""" """
acc = 1.0 * self.correct_num / self.all_num acc = 1.0 * self.correct_num / (self.all_num + self.eps)
self.reset() self.reset()
return {'acc': acc} return {'acc': acc}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册