From 0fad2218fdd7ec814b53c242ff2150f70ad556fc Mon Sep 17 00:00:00 2001 From: wangshuide2020 <7511764+wangshuide2020@user.noreply.gitee.com> Date: Mon, 22 Jun 2020 14:27:39 +0800 Subject: [PATCH] add null byte check in the api of get_plugins --- mindinsight/datavisual/common/exceptions.py | 9 +++++++ .../data_transform/summary_watcher.py | 26 +++---------------- .../processors/train_task_manager.py | 4 +++ mindinsight/datavisual/utils/utils.py | 21 +++++++++++++++ mindinsight/utils/constant.py | 1 + .../taskmanager/test_plugins_restful_api.py | 24 +++++++++++++++-- 6 files changed, 61 insertions(+), 24 deletions(-) diff --git a/mindinsight/datavisual/common/exceptions.py b/mindinsight/datavisual/common/exceptions.py index e6bb667..35e0f92 100644 --- a/mindinsight/datavisual/common/exceptions.py +++ b/mindinsight/datavisual/common/exceptions.py @@ -107,6 +107,15 @@ class TrainJobNotExistError(MindInsightException): http_code=400) +class QueryStringContainsNullByteError(MindInsightException): + """Query string contains null byte error.""" + def __init__(self, error_detail): + error_msg = f"Query string contains null byte error. Detail: {error_detail}" + super(QueryStringContainsNullByteError, self).__init__(DataVisualErrors.QUERY_STRING_CONTAINS_NULL_BYTE, + error_msg, + http_code=400) + + class PluginNotAvailableError(MindInsightException): """The given plugin is not available.""" def __init__(self, error_detail): diff --git a/mindinsight/datavisual/data_transform/summary_watcher.py b/mindinsight/datavisual/data_transform/summary_watcher.py index e87d537..14f49ca 100644 --- a/mindinsight/datavisual/data_transform/summary_watcher.py +++ b/mindinsight/datavisual/data_transform/summary_watcher.py @@ -22,6 +22,7 @@ from pathlib import Path from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common.validation import Validation from mindinsight.datavisual.utils.tools import Counter +from mindinsight.datavisual.utils.utils import contains_null_byte from mindinsight.datavisual.common.exceptions import MaxCountExceededError from mindinsight.utils.exceptions import FileSystemPermissionError @@ -61,7 +62,7 @@ class SummaryWatcher: >>> summary_watcher = SummaryWatcher() >>> directories = summary_watcher.list_summary_directories('/summary/base/dir') """ - if self._contains_null_byte(summary_base_dir=summary_base_dir): + if contains_null_byte(summary_base_dir=summary_base_dir): return [] relative_path = os.path.join('.', '') @@ -148,25 +149,6 @@ class SummaryWatcher: pass self._update_summary_dict(summary_dict, summary_base_dir, subdir_relative_path, subdir_entry) - def _contains_null_byte(self, **kwargs): - """ - Check if arg contains null byte. - - Args: - kwargs (Any): Check if arg contains null byte. - - Returns: - bool, indicates if any arg contains null byte. - """ - for key, value in kwargs.items(): - if not isinstance(value, str): - continue - if '\x00' in value: - logger.warning('%s contains null byte \\x00.', key) - return True - - return False - def _is_valid_summary_directory(self, summary_base_dir, relative_path): """ Check if the given summary directory is valid. @@ -276,7 +258,7 @@ class SummaryWatcher: >>> summary_watcher = SummaryWatcher() >>> summaries = summary_watcher.is_summary_directory('/summary/base/dir', './job-01') """ - if self._contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path): + if contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path): return False if not self._is_valid_summary_directory(summary_base_dir, relative_path): @@ -371,7 +353,7 @@ class SummaryWatcher: >>> summary_watcher = SummaryWatcher() >>> summaries = summary_watcher.list_summaries('/summary/base/dir', './job-01') """ - if self._contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path): + if contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path): return [] if not self._is_valid_summary_directory(summary_base_dir, relative_path): diff --git a/mindinsight/datavisual/processors/train_task_manager.py b/mindinsight/datavisual/processors/train_task_manager.py index b1aeb87..540ea39 100644 --- a/mindinsight/datavisual/processors/train_task_manager.py +++ b/mindinsight/datavisual/processors/train_task_manager.py @@ -19,7 +19,9 @@ from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import CacheStatus +from mindinsight.datavisual.common.exceptions import QueryStringContainsNullByteError from mindinsight.datavisual.common.validation import Validation +from mindinsight.datavisual.utils.utils import contains_null_byte from mindinsight.datavisual.processors.base_processor import BaseProcessor from mindinsight.datavisual.data_transform.data_manager import DATAVISUAL_PLUGIN_KEY, DATAVISUAL_CACHE_KEY @@ -57,6 +59,8 @@ class TrainTaskManager(BaseProcessor): dict, refer to restful api. """ Validation.check_param_empty(train_id=train_id) + if contains_null_byte(train_id=train_id): + raise QueryStringContainsNullByteError("train job id: {} contains null byte.".format(train_id)) if manual_update: self._data_manager.cache_train_job(train_id) diff --git a/mindinsight/datavisual/utils/utils.py b/mindinsight/datavisual/utils/utils.py index 2dac02b..6ce0488 100644 --- a/mindinsight/datavisual/utils/utils.py +++ b/mindinsight/datavisual/utils/utils.py @@ -14,6 +14,7 @@ # ============================================================================ """Utils.""" import math +from mindinsight.datavisual.common.log import logger def calc_histogram_bins(count): @@ -45,3 +46,23 @@ def calc_histogram_bins(count): return math.ceil(count / number_per_bucket) + 1 return max_bins + + +def contains_null_byte(**kwargs): + """ + Check if arg contains null byte. + + Args: + kwargs (Any): Check if arg contains null byte. + + Returns: + bool, indicates if any arg contains null byte. + """ + for key, value in kwargs.items(): + if not isinstance(value, str): + continue + if '\x00' in value: + logger.warning('%s contains null byte \\x00.', key) + return True + + return False diff --git a/mindinsight/utils/constant.py b/mindinsight/utils/constant.py index 87f39d6..0bac03c 100644 --- a/mindinsight/utils/constant.py +++ b/mindinsight/utils/constant.py @@ -70,6 +70,7 @@ class DataVisualErrors(Enum): SCALAR_NOT_EXIST = 14 HISTOGRAM_NOT_EXIST = 15 TRAIN_JOB_DETAIL_NOT_IN_CACHE = 16 + QUERY_STRING_CONTAINS_NULL_BYTE = 17 class ScriptConverterErrors(Enum): diff --git a/tests/st/func/datavisual/taskmanager/test_plugins_restful_api.py b/tests/st/func/datavisual/taskmanager/test_plugins_restful_api.py index ccd2a9a..e028216 100644 --- a/tests/st/func/datavisual/taskmanager/test_plugins_restful_api.py +++ b/tests/st/func/datavisual/taskmanager/test_plugins_restful_api.py @@ -79,9 +79,9 @@ class TestPlugins: @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_ascend_training @pytest.mark.usefixtures("init_summary_logs") - @pytest.mark.parametrize("train_id", ["@#$", "./\x00home", "././/not_exist_id", dict()]) + @pytest.mark.parametrize("train_id", ["@#$", "././/not_exist_id", dict()]) def test_plugins_with_special_train_id(self, client, train_id): - """Test passing train_id with special character, null_byte, invalid id, and wrong type.""" + """Test passing train_id with special character, invalid id, and wrong type.""" params = dict(train_id=train_id) url = get_url(BASE_URL, params) @@ -92,6 +92,26 @@ class TestPlugins: assert response['error_code'] == '50545005' assert response['error_msg'] == "Train job is not exist. Detail: Can not find the train job in data manager." + @pytest.mark.level1 + @pytest.mark.env_single + @pytest.mark.platform_x86_cpu + @pytest.mark.platform_arm_ascend_training + @pytest.mark.platform_x86_gpu_training + @pytest.mark.platform_x86_ascend_training + @pytest.mark.usefixtures("init_summary_logs") + @pytest.mark.parametrize("train_id", ["./\x00home"]) + def test_plugins_with_null_byte_train_id(self, client, train_id): + """Test passing train_id with null_byte.""" + params = dict(train_id=train_id, manual_update=True) + url = get_url(BASE_URL, params) + + response = client.get(url) + assert response.status_code == 400 + + response = response.get_json() + assert response['error_code'] == '50545011' + assert "Query string contains null byte error. " in response['error_msg'] + @pytest.mark.level1 @pytest.mark.env_single @pytest.mark.platform_x86_cpu -- GitLab