提交 b8cefff4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!337 add null byte check in the api of get_plugins

Merge pull request !337 from wangshuide/wsd_fuzz_fix
......@@ -107,6 +107,15 @@ class TrainJobNotExistError(MindInsightException):
http_code=400)
class QueryStringContainsNullByteError(MindInsightException):
"""Query string contains null byte error."""
def __init__(self, error_detail):
error_msg = f"Query string contains null byte error. Detail: {error_detail}"
super(QueryStringContainsNullByteError, self).__init__(DataVisualErrors.QUERY_STRING_CONTAINS_NULL_BYTE,
error_msg,
http_code=400)
class PluginNotAvailableError(MindInsightException):
"""The given plugin is not available."""
def __init__(self, error_detail):
......
......@@ -22,6 +22,7 @@ from pathlib import Path
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.utils.tools import Counter
from mindinsight.datavisual.utils.utils import contains_null_byte
from mindinsight.datavisual.common.exceptions import MaxCountExceededError
from mindinsight.utils.exceptions import FileSystemPermissionError
......@@ -61,7 +62,7 @@ class SummaryWatcher:
>>> summary_watcher = SummaryWatcher()
>>> directories = summary_watcher.list_summary_directories('/summary/base/dir')
"""
if self._contains_null_byte(summary_base_dir=summary_base_dir):
if contains_null_byte(summary_base_dir=summary_base_dir):
return []
relative_path = os.path.join('.', '')
......@@ -148,25 +149,6 @@ class SummaryWatcher:
pass
self._update_summary_dict(summary_dict, summary_base_dir, subdir_relative_path, subdir_entry)
def _contains_null_byte(self, **kwargs):
"""
Check if arg contains null byte.
Args:
kwargs (Any): Check if arg contains null byte.
Returns:
bool, indicates if any arg contains null byte.
"""
for key, value in kwargs.items():
if not isinstance(value, str):
continue
if '\x00' in value:
logger.warning('%s contains null byte \\x00.', key)
return True
return False
def _is_valid_summary_directory(self, summary_base_dir, relative_path):
"""
Check if the given summary directory is valid.
......@@ -276,7 +258,7 @@ class SummaryWatcher:
>>> summary_watcher = SummaryWatcher()
>>> summaries = summary_watcher.is_summary_directory('/summary/base/dir', './job-01')
"""
if self._contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path):
if contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path):
return False
if not self._is_valid_summary_directory(summary_base_dir, relative_path):
......@@ -371,7 +353,7 @@ class SummaryWatcher:
>>> summary_watcher = SummaryWatcher()
>>> summaries = summary_watcher.list_summaries('/summary/base/dir', './job-01')
"""
if self._contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path):
if contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path):
return []
if not self._is_valid_summary_directory(summary_base_dir, relative_path):
......
......@@ -19,7 +19,9 @@ from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.enums import CacheStatus
from mindinsight.datavisual.common.exceptions import QueryStringContainsNullByteError
from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.utils.utils import contains_null_byte
from mindinsight.datavisual.processors.base_processor import BaseProcessor
from mindinsight.datavisual.data_transform.data_manager import DATAVISUAL_PLUGIN_KEY, DATAVISUAL_CACHE_KEY
......@@ -57,6 +59,8 @@ class TrainTaskManager(BaseProcessor):
dict, refer to restful api.
"""
Validation.check_param_empty(train_id=train_id)
if contains_null_byte(train_id=train_id):
raise QueryStringContainsNullByteError("train job id: {} contains null byte.".format(train_id))
if manual_update:
self._data_manager.cache_train_job(train_id)
......
......@@ -14,6 +14,7 @@
# ============================================================================
"""Utils."""
import math
from mindinsight.datavisual.common.log import logger
def calc_histogram_bins(count):
......@@ -45,3 +46,23 @@ def calc_histogram_bins(count):
return math.ceil(count / number_per_bucket) + 1
return max_bins
def contains_null_byte(**kwargs):
"""
Check if arg contains null byte.
Args:
kwargs (Any): Check if arg contains null byte.
Returns:
bool, indicates if any arg contains null byte.
"""
for key, value in kwargs.items():
if not isinstance(value, str):
continue
if '\x00' in value:
logger.warning('%s contains null byte \\x00.', key)
return True
return False
......@@ -70,6 +70,7 @@ class DataVisualErrors(Enum):
SCALAR_NOT_EXIST = 14
HISTOGRAM_NOT_EXIST = 15
TRAIN_JOB_DETAIL_NOT_IN_CACHE = 16
QUERY_STRING_CONTAINS_NULL_BYTE = 17
class ScriptConverterErrors(Enum):
......
......@@ -79,9 +79,9 @@ class TestPlugins:
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.usefixtures("init_summary_logs")
@pytest.mark.parametrize("train_id", ["@#$", "./\x00home", "././/not_exist_id", dict()])
@pytest.mark.parametrize("train_id", ["@#$", "././/not_exist_id", dict()])
def test_plugins_with_special_train_id(self, client, train_id):
"""Test passing train_id with special character, null_byte, invalid id, and wrong type."""
"""Test passing train_id with special character, invalid id, and wrong type."""
params = dict(train_id=train_id)
url = get_url(BASE_URL, params)
......@@ -92,6 +92,26 @@ class TestPlugins:
assert response['error_code'] == '50545005'
assert response['error_msg'] == "Train job is not exist. Detail: Can not find the train job in data manager."
@pytest.mark.level1
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.usefixtures("init_summary_logs")
@pytest.mark.parametrize("train_id", ["./\x00home"])
def test_plugins_with_null_byte_train_id(self, client, train_id):
"""Test passing train_id with null_byte."""
params = dict(train_id=train_id, manual_update=True)
url = get_url(BASE_URL, params)
response = client.get(url)
assert response.status_code == 400
response = response.get_json()
assert response['error_code'] == '50545011'
assert "Query string contains null byte error. " in response['error_msg']
@pytest.mark.level1
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册