提交 b8cefff4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!337 add null byte check in the api of get_plugins

Merge pull request !337 from wangshuide/wsd_fuzz_fix
...@@ -107,6 +107,15 @@ class TrainJobNotExistError(MindInsightException): ...@@ -107,6 +107,15 @@ class TrainJobNotExistError(MindInsightException):
http_code=400) http_code=400)
class QueryStringContainsNullByteError(MindInsightException):
"""Query string contains null byte error."""
def __init__(self, error_detail):
error_msg = f"Query string contains null byte error. Detail: {error_detail}"
super(QueryStringContainsNullByteError, self).__init__(DataVisualErrors.QUERY_STRING_CONTAINS_NULL_BYTE,
error_msg,
http_code=400)
class PluginNotAvailableError(MindInsightException): class PluginNotAvailableError(MindInsightException):
"""The given plugin is not available.""" """The given plugin is not available."""
def __init__(self, error_detail): def __init__(self, error_detail):
......
...@@ -22,6 +22,7 @@ from pathlib import Path ...@@ -22,6 +22,7 @@ from pathlib import Path
from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common.validation import Validation from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.utils.tools import Counter from mindinsight.datavisual.utils.tools import Counter
from mindinsight.datavisual.utils.utils import contains_null_byte
from mindinsight.datavisual.common.exceptions import MaxCountExceededError from mindinsight.datavisual.common.exceptions import MaxCountExceededError
from mindinsight.utils.exceptions import FileSystemPermissionError from mindinsight.utils.exceptions import FileSystemPermissionError
...@@ -61,7 +62,7 @@ class SummaryWatcher: ...@@ -61,7 +62,7 @@ class SummaryWatcher:
>>> summary_watcher = SummaryWatcher() >>> summary_watcher = SummaryWatcher()
>>> directories = summary_watcher.list_summary_directories('/summary/base/dir') >>> directories = summary_watcher.list_summary_directories('/summary/base/dir')
""" """
if self._contains_null_byte(summary_base_dir=summary_base_dir): if contains_null_byte(summary_base_dir=summary_base_dir):
return [] return []
relative_path = os.path.join('.', '') relative_path = os.path.join('.', '')
...@@ -148,25 +149,6 @@ class SummaryWatcher: ...@@ -148,25 +149,6 @@ class SummaryWatcher:
pass pass
self._update_summary_dict(summary_dict, summary_base_dir, subdir_relative_path, subdir_entry) self._update_summary_dict(summary_dict, summary_base_dir, subdir_relative_path, subdir_entry)
def _contains_null_byte(self, **kwargs):
"""
Check if arg contains null byte.
Args:
kwargs (Any): Check if arg contains null byte.
Returns:
bool, indicates if any arg contains null byte.
"""
for key, value in kwargs.items():
if not isinstance(value, str):
continue
if '\x00' in value:
logger.warning('%s contains null byte \\x00.', key)
return True
return False
def _is_valid_summary_directory(self, summary_base_dir, relative_path): def _is_valid_summary_directory(self, summary_base_dir, relative_path):
""" """
Check if the given summary directory is valid. Check if the given summary directory is valid.
...@@ -276,7 +258,7 @@ class SummaryWatcher: ...@@ -276,7 +258,7 @@ class SummaryWatcher:
>>> summary_watcher = SummaryWatcher() >>> summary_watcher = SummaryWatcher()
>>> summaries = summary_watcher.is_summary_directory('/summary/base/dir', './job-01') >>> summaries = summary_watcher.is_summary_directory('/summary/base/dir', './job-01')
""" """
if self._contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path): if contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path):
return False return False
if not self._is_valid_summary_directory(summary_base_dir, relative_path): if not self._is_valid_summary_directory(summary_base_dir, relative_path):
...@@ -371,7 +353,7 @@ class SummaryWatcher: ...@@ -371,7 +353,7 @@ class SummaryWatcher:
>>> summary_watcher = SummaryWatcher() >>> summary_watcher = SummaryWatcher()
>>> summaries = summary_watcher.list_summaries('/summary/base/dir', './job-01') >>> summaries = summary_watcher.list_summaries('/summary/base/dir', './job-01')
""" """
if self._contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path): if contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path):
return [] return []
if not self._is_valid_summary_directory(summary_base_dir, relative_path): if not self._is_valid_summary_directory(summary_base_dir, relative_path):
......
...@@ -19,7 +19,9 @@ from mindinsight.datavisual.common.log import logger ...@@ -19,7 +19,9 @@ from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.enums import CacheStatus from mindinsight.datavisual.common.enums import CacheStatus
from mindinsight.datavisual.common.exceptions import QueryStringContainsNullByteError
from mindinsight.datavisual.common.validation import Validation from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.utils.utils import contains_null_byte
from mindinsight.datavisual.processors.base_processor import BaseProcessor from mindinsight.datavisual.processors.base_processor import BaseProcessor
from mindinsight.datavisual.data_transform.data_manager import DATAVISUAL_PLUGIN_KEY, DATAVISUAL_CACHE_KEY from mindinsight.datavisual.data_transform.data_manager import DATAVISUAL_PLUGIN_KEY, DATAVISUAL_CACHE_KEY
...@@ -57,6 +59,8 @@ class TrainTaskManager(BaseProcessor): ...@@ -57,6 +59,8 @@ class TrainTaskManager(BaseProcessor):
dict, refer to restful api. dict, refer to restful api.
""" """
Validation.check_param_empty(train_id=train_id) Validation.check_param_empty(train_id=train_id)
if contains_null_byte(train_id=train_id):
raise QueryStringContainsNullByteError("train job id: {} contains null byte.".format(train_id))
if manual_update: if manual_update:
self._data_manager.cache_train_job(train_id) self._data_manager.cache_train_job(train_id)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""Utils.""" """Utils."""
import math import math
from mindinsight.datavisual.common.log import logger
def calc_histogram_bins(count): def calc_histogram_bins(count):
...@@ -45,3 +46,23 @@ def calc_histogram_bins(count): ...@@ -45,3 +46,23 @@ def calc_histogram_bins(count):
return math.ceil(count / number_per_bucket) + 1 return math.ceil(count / number_per_bucket) + 1
return max_bins return max_bins
def contains_null_byte(**kwargs):
"""
Check if arg contains null byte.
Args:
kwargs (Any): Check if arg contains null byte.
Returns:
bool, indicates if any arg contains null byte.
"""
for key, value in kwargs.items():
if not isinstance(value, str):
continue
if '\x00' in value:
logger.warning('%s contains null byte \\x00.', key)
return True
return False
...@@ -70,6 +70,7 @@ class DataVisualErrors(Enum): ...@@ -70,6 +70,7 @@ class DataVisualErrors(Enum):
SCALAR_NOT_EXIST = 14 SCALAR_NOT_EXIST = 14
HISTOGRAM_NOT_EXIST = 15 HISTOGRAM_NOT_EXIST = 15
TRAIN_JOB_DETAIL_NOT_IN_CACHE = 16 TRAIN_JOB_DETAIL_NOT_IN_CACHE = 16
QUERY_STRING_CONTAINS_NULL_BYTE = 17
class ScriptConverterErrors(Enum): class ScriptConverterErrors(Enum):
......
...@@ -79,9 +79,9 @@ class TestPlugins: ...@@ -79,9 +79,9 @@ class TestPlugins:
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.usefixtures("init_summary_logs") @pytest.mark.usefixtures("init_summary_logs")
@pytest.mark.parametrize("train_id", ["@#$", "./\x00home", "././/not_exist_id", dict()]) @pytest.mark.parametrize("train_id", ["@#$", "././/not_exist_id", dict()])
def test_plugins_with_special_train_id(self, client, train_id): def test_plugins_with_special_train_id(self, client, train_id):
"""Test passing train_id with special character, null_byte, invalid id, and wrong type.""" """Test passing train_id with special character, invalid id, and wrong type."""
params = dict(train_id=train_id) params = dict(train_id=train_id)
url = get_url(BASE_URL, params) url = get_url(BASE_URL, params)
...@@ -92,6 +92,26 @@ class TestPlugins: ...@@ -92,6 +92,26 @@ class TestPlugins:
assert response['error_code'] == '50545005' assert response['error_code'] == '50545005'
assert response['error_msg'] == "Train job is not exist. Detail: Can not find the train job in data manager." assert response['error_msg'] == "Train job is not exist. Detail: Can not find the train job in data manager."
@pytest.mark.level1
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.usefixtures("init_summary_logs")
@pytest.mark.parametrize("train_id", ["./\x00home"])
def test_plugins_with_null_byte_train_id(self, client, train_id):
"""Test passing train_id with null_byte."""
params = dict(train_id=train_id, manual_update=True)
url = get_url(BASE_URL, params)
response = client.get(url)
assert response.status_code == 400
response = response.get_json()
assert response['error_code'] == '50545011'
assert "Query string contains null byte error. " in response['error_msg']
@pytest.mark.level1 @pytest.mark.level1
@pytest.mark.env_single @pytest.mark.env_single
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册