提交 4bdf80be 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!158 Validate train_id and profile_dir

Merge pull request !158 from yuximiao/master
......@@ -49,10 +49,12 @@ def get_profile_op_info():
ParamValueError: If the search condition contains some errors.
Examples:
>>> POST http://xxxx/v1/mindinsight/profile/op
>>> POST http://xxxx/v1/mindinsight/profile/ops/search
"""
profiler_dir = get_profiler_dir(request)
train_id = get_train_id(request)
if not profiler_dir or not train_id:
raise ParamValueError("No profiler_dir or train_id.")
search_condition = request.stream.read()
try:
......@@ -90,10 +92,13 @@ def get_profile_device_list():
ParamValueError: If the search condition contains some errors.
Examples:
>>> POST http://xxxx/v1/mindinsight/profile/device_list
>>> POST http://xxxx/v1/mindinsight/profile/devices
"""
profiler_dir = get_profiler_dir(request)
train_id = get_train_id(request)
if not profiler_dir or not train_id:
raise ParamValueError("No profiler_dir or train_id.")
profiler_dir_abs = os.path.join(settings.SUMMARY_BASE_DIR, train_id, profiler_dir)
try:
profiler_dir_abs = validate_and_normalize_path(profiler_dir_abs, "profiler")
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test profiler restful api."""
import json
from unittest import TestCase, mock
from flask import Response
from mindinsight.backend.application import APP
class TestProfilerRestfulApi(TestCase):
"""Test the restful api of profiler."""
def setUp(self):
"""Test init."""
APP.response_class = Response
self.app_client = APP.test_client()
self.url = '/v1/mindinsight/profile/ops/search?train_id=run1&profile=profiler'
@mock.patch('mindinsight.backend.lineagemgr.lineage_api.settings')
@mock.patch('mindinsight.profiler.analyser.base_analyser.BaseAnalyser.query')
def test_ops_search_success(self, *args):
"""Test the success of ops/search."""
base_dir = '/path/to/test_profiler_base'
expect_result = {
'object': ["test"],
'count': 1
}
args[0].return_value = expect_result
args[1].SUMMARY_BASE_DIR = base_dir
body_data = {
"op_type": "aicore_type"
}
response = self.app_client.post(self.url, data=json.dumps(body_data))
self.assertEqual(200, response.status_code)
self.assertDictEqual(expect_result, response.get_json())
@mock.patch('mindinsight.backend.lineagemgr.lineage_api.settings')
@mock.patch('mindinsight.profiler.analyser.base_analyser.BaseAnalyser.query')
def test_ops_search_failed(self, *args):
"""Test the failed of ops/search."""
base_dir = '/path/to/test_profiler_base'
expect_result = {
'object': ["test"],
'count': 1
}
args[0].return_value = expect_result
args[1].SUMMARY_BASE_DIR = base_dir
response = self.app_client.post(self.url, data=json.dumps(1))
self.assertEqual(400, response.status_code)
expect_result = {
'error_code': '50546082',
'error_msg': "Param type error. Invalid search_condition type, it should be dict."
}
self.assertDictEqual(expect_result, response.get_json())
body_data = {"op_type": "1"}
response = self.app_client.post(self.url, data=json.dumps(body_data))
self.assertEqual(400, response.status_code)
expect_result = {
'error_code': '50546183',
}
result = response.get_json()
del result["error_msg"]
self.assertDictEqual(expect_result, result)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册