未验证 提交 346b64e1 编写于 作者: Y Yan Chunwei 提交者: GitHub

fix mode match check (#134)

上级 19275c3d
...@@ -78,7 +78,8 @@ std::string LogReader::GenReadableTag(const std::string& mode, ...@@ -78,7 +78,8 @@ std::string LogReader::GenReadableTag(const std::string& mode,
bool LogReader::TagMatchMode(const std::string& tag, const std::string& mode) { bool LogReader::TagMatchMode(const std::string& tag, const std::string& mode) {
if (tag.size() <= mode.size()) return false; if (tag.size() <= mode.size()) return false;
return tag.substr(0, mode.size()) == mode; return tag.substr(0, mode.size()) == mode &&
(tag[mode.size()] == '/' || tag[mode.size()] == '%');
} }
namespace components { namespace components {
......
...@@ -4,6 +4,12 @@ from visualdl import core ...@@ -4,6 +4,12 @@ from visualdl import core
dtypes = ("float", "double", "int32", "int64") dtypes = ("float", "double", "int32", "int64")
def check_tag_name_valid(tag):
assert '%' not in tag, "character % is a reserved word, it is not allowed in tag."
def check_mode_name_valid(tag):
for char in ['%', '/']:
assert char not in tag, "character %s is a reserved word, it is not allowed in mode." % char
class LogReader(object): class LogReader(object):
"""LogReader is a Python wrapper to read and analysis the data that """LogReader is a Python wrapper to read and analysis the data that
...@@ -31,6 +37,7 @@ class LogReader(object): ...@@ -31,6 +37,7 @@ class LogReader(object):
generated during testing can be marked test. generated during testing can be marked test.
:return: the reader itself :return: the reader itself
""" """
check_mode_name_valid(mode)
self.reader.set_mode(mode) self.reader.set_mode(mode)
return self return self
...@@ -38,6 +45,7 @@ class LogReader(object): ...@@ -38,6 +45,7 @@ class LogReader(object):
""" """
create a new LogReader with mode and return it to user. create a new LogReader with mode and return it to user.
""" """
check_mode_name_valid(mode)
tmp = LogReader(dir, self.reader.as_mode(mode)) tmp = LogReader(dir, self.reader.as_mode(mode))
return tmp return tmp
...@@ -60,6 +68,7 @@ class LogReader(object): ...@@ -60,6 +68,7 @@ class LogReader(object):
""" """
Get a scalar reader with tag and data type Get a scalar reader with tag and data type
""" """
check_tag_name_valid(tag)
type2scalar = { type2scalar = {
'float': self.reader.get_scalar_float, 'float': self.reader.get_scalar_float,
'double': self.reader.get_scalar_double, 'double': self.reader.get_scalar_double,
...@@ -71,6 +80,7 @@ class LogReader(object): ...@@ -71,6 +80,7 @@ class LogReader(object):
""" """
Get a image reader with tag Get a image reader with tag
""" """
check_tag_name_valid(tag)
return self.reader.get_image(tag) return self.reader.get_image(tag)
def histogram(self, tag, type='float'): def histogram(self, tag, type='float'):
...@@ -82,6 +92,7 @@ class LogReader(object): ...@@ -82,6 +92,7 @@ class LogReader(object):
'double': self.reader.get_histogram_double, 'double': self.reader.get_histogram_double,
'int': self.reader.get_histogram_int, 'int': self.reader.get_histogram_int,
} }
check_tag_name_valid(tag)
return type2scalar[type](tag) return type2scalar[type](tag)
def __enter__(self): def __enter__(self):
...@@ -105,6 +116,7 @@ class LogWriter(object): ...@@ -105,6 +116,7 @@ class LogWriter(object):
self.writer = writer if writer else core.LogWriter(dir, sync_cycle) self.writer = writer if writer else core.LogWriter(dir, sync_cycle)
def mode(self, mode): def mode(self, mode):
check_mode_name_valid(mode)
self.writer.set_mode(mode) self.writer.set_mode(mode)
return self return self
...@@ -112,6 +124,7 @@ class LogWriter(object): ...@@ -112,6 +124,7 @@ class LogWriter(object):
""" """
create a new LogWriter with mode and return it. create a new LogWriter with mode and return it.
""" """
check_mode_name_valid(mode)
LogWriter.cur_mode = LogWriter(self.dir, self.sync_cycle, self.writer.as_mode(mode)) LogWriter.cur_mode = LogWriter(self.dir, self.sync_cycle, self.writer.as_mode(mode))
return LogWriter.cur_mode return LogWriter.cur_mode
...@@ -119,6 +132,7 @@ class LogWriter(object): ...@@ -119,6 +132,7 @@ class LogWriter(object):
""" """
Create a scalar writer with tag and type to write scalar data. Create a scalar writer with tag and type to write scalar data.
""" """
check_tag_name_valid(tag)
type2scalar = { type2scalar = {
'float': self.writer.new_scalar_float, 'float': self.writer.new_scalar_float,
'double': self.writer.new_scalar_double, 'double': self.writer.new_scalar_double,
...@@ -130,6 +144,7 @@ class LogWriter(object): ...@@ -130,6 +144,7 @@ class LogWriter(object):
""" """
Create an image writer that used to write image data. Create an image writer that used to write image data.
""" """
check_tag_name_valid(tag)
return self.writer.new_image(tag, num_samples, step_cycle) return self.writer.new_image(tag, num_samples, step_cycle)
def histogram(self, tag, num_buckets, type='float'): def histogram(self, tag, num_buckets, type='float'):
...@@ -137,6 +152,7 @@ class LogWriter(object): ...@@ -137,6 +152,7 @@ class LogWriter(object):
Create a histogram writer that used to write Create a histogram writer that used to write
histogram related data. histogram related data.
""" """
check_tag_name_valid(tag)
types = { types = {
'float': self.writer.new_histogram_float, 'float': self.writer.new_histogram_float,
'double': self.writer.new_histogram_double, 'double': self.writer.new_histogram_double,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册