未验证 提交 346b64e1 编写于 作者: Y Yan Chunwei 提交者: GitHub

fix mode match check (#134)

上级 19275c3d
......@@ -78,7 +78,8 @@ std::string LogReader::GenReadableTag(const std::string& mode,
bool LogReader::TagMatchMode(const std::string& tag, const std::string& mode) {
if (tag.size() <= mode.size()) return false;
return tag.substr(0, mode.size()) == mode;
return tag.substr(0, mode.size()) == mode &&
(tag[mode.size()] == '/' || tag[mode.size()] == '%');
}
namespace components {
......
......@@ -4,6 +4,12 @@ from visualdl import core
dtypes = ("float", "double", "int32", "int64")
def check_tag_name_valid(tag):
assert '%' not in tag, "character % is a reserved word, it is not allowed in tag."
def check_mode_name_valid(tag):
for char in ['%', '/']:
assert char not in tag, "character %s is a reserved word, it is not allowed in mode." % char
class LogReader(object):
"""LogReader is a Python wrapper to read and analysis the data that
......@@ -31,6 +37,7 @@ class LogReader(object):
generated during testing can be marked test.
:return: the reader itself
"""
check_mode_name_valid(mode)
self.reader.set_mode(mode)
return self
......@@ -38,6 +45,7 @@ class LogReader(object):
"""
create a new LogReader with mode and return it to user.
"""
check_mode_name_valid(mode)
tmp = LogReader(dir, self.reader.as_mode(mode))
return tmp
......@@ -60,6 +68,7 @@ class LogReader(object):
"""
Get a scalar reader with tag and data type
"""
check_tag_name_valid(tag)
type2scalar = {
'float': self.reader.get_scalar_float,
'double': self.reader.get_scalar_double,
......@@ -71,6 +80,7 @@ class LogReader(object):
"""
Get a image reader with tag
"""
check_tag_name_valid(tag)
return self.reader.get_image(tag)
def histogram(self, tag, type='float'):
......@@ -82,6 +92,7 @@ class LogReader(object):
'double': self.reader.get_histogram_double,
'int': self.reader.get_histogram_int,
}
check_tag_name_valid(tag)
return type2scalar[type](tag)
def __enter__(self):
......@@ -105,6 +116,7 @@ class LogWriter(object):
self.writer = writer if writer else core.LogWriter(dir, sync_cycle)
def mode(self, mode):
check_mode_name_valid(mode)
self.writer.set_mode(mode)
return self
......@@ -112,6 +124,7 @@ class LogWriter(object):
"""
create a new LogWriter with mode and return it.
"""
check_mode_name_valid(mode)
LogWriter.cur_mode = LogWriter(self.dir, self.sync_cycle, self.writer.as_mode(mode))
return LogWriter.cur_mode
......@@ -119,6 +132,7 @@ class LogWriter(object):
"""
Create a scalar writer with tag and type to write scalar data.
"""
check_tag_name_valid(tag)
type2scalar = {
'float': self.writer.new_scalar_float,
'double': self.writer.new_scalar_double,
......@@ -130,6 +144,7 @@ class LogWriter(object):
"""
Create an image writer that used to write image data.
"""
check_tag_name_valid(tag)
return self.writer.new_image(tag, num_samples, step_cycle)
def histogram(self, tag, num_buckets, type='float'):
......@@ -137,6 +152,7 @@ class LogWriter(object):
Create a histogram writer that used to write
histogram related data.
"""
check_tag_name_valid(tag)
types = {
'float': self.writer.new_histogram_float,
'double': self.writer.new_histogram_double,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册