提交 e7a0496e 编写于 作者: G gaocongli

initial version

上级
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
htmlcov
.trash
.pytest_cache/
# Distribution / packaging
bin/
develop-eggs/
dist/
eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Editors/IDEs
.idea/
*.sublime-*
*.swp
*.save
# test file
.coverage
.cache
# project wide git ignore
# Compiled artifacts
*.so
*.whl
# Python backup files
*.pyc
# Emacs backup files
*~
*#
.#*
# Vim file artifacts
.*.sw*
# Makefile dummy artifacts
.*-dummy
# log files
*.log
# code coverage
*.cov
# Test result xml files
report.xml
*.pprof
results.xml
TESTS*.xml
# local project settings
.settings
.project
.gradle
.idea
# tox
.tox/
# vscode settings
.vscode
package-lock.json
build/lib
build/bdist.*
output/
!output/README.md
third_party/securec/build
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
recursive-include mindinsight *
recursive-exclude * .git
recursive-exclude * .gitignore
recursive-exclude * __pycache__
recursive-exclude * *.py[co] *.swp
recursive-exclude mindinsight/ui *
recursive-include mindinsight/ui/dist *
MindSpore MindInsight
Copyright 2019-2020 Huawei Technologies Co., Ltd
MindInsight provides MindSpore with easy-to-use debugging and tuning capabilities. It
enables users to visualize the experiments. The features of MindInsight are as follows.
- Visualization of training process:
Provide visualization of training process information,
such as computation graph, training process metrics, etc.
- Traceability of training result:
Provide visualization of model parameters information,
such as training data, model accuracy, etc.
# Index
- [More about MindInsight](#more-about-mindinsight)
- [Installation](#installation)
- [QuickStart](#quickstart)
- [Docs](#docs)
- [Community](#community)
- [Contributing](#contributing)
- [Release Notes](#release-notes)
- [License](#license)
# More about MindInsight
The architecture diagram of MindInsight is illustrated as follows:
![MindInsight Architecture](docs/arch.png)
## Summary log file
The summary log file consists of a series of operation events. Each event contains
the necessary data for visualization.
MindSpore uses the Callback mechanism to record graph, scalar, image and model
information into summary log file.
- The scalar and image is recorded by Summary operator.
- The computation graph is recorded by SummaryRecord after it was compiled.
- The model parameters is recorded by TrainLineage or EvalLineage.
MindInsight provides the capability to analyze summary log files and visualize
relative information.
## Visualization
MindInsight provides users with a full-process visualized GUI during
AI development, in order to help model developers to improve the model
precision efficiently.
MindInsight has the following visualization capabilities:
### Graph visualization
The GUI of MindInsight displays the structure of neural network, the data flow and control
flow of each operator during the entire training process.
### Scalar visualization
The GUI of MindInsight displays the change tendency of a specific scalar during the entire
training process, such as loss value and accuracy rate of each iteration.
Two scalar curves can be combined and displayed in one chart.
### Image visualization
The GUI of MindInsight displays both original images and enhanced images during the entire
training process.
### Model lineage visualization
The GUI of MindInsight displays the parameters and metrics of all models, such as the
learning rate, the number of samples and the loss function of each model.
### Dataset Graph visualization
The GUI of MindInsight displays the pipeline of dataset processing and augmentation.
### Dataset Lineage visualization
The GUI of MindInsight displays the parameters and operations of the dataset processing and augmentation.
# Installation
See [Install MindInsight](https://www.mindspore.cn/install/en).
# QuickStart
See [guidance](https://www.mindspore.cn/tutorial/en/0.1.0-alpha/advanced_use/visualization_tutorials.html)
# Docs
See [API Reference](https://www.mindspore.cn/api/en/master/index.html)
# Community
- [MindSpore Slack](https://join.slack.com/t/mindspore/shared_invite/enQtOTcwMTIxMDI3NjM0LTNkMWM2MzI5NjIyZWU5ZWQ5M2EwMTQ5MWNiYzMxOGM4OWFhZjI4M2E5OGI2YTg3ODU1ODE2Njg1MThiNWI3YmQ) - Communication platform for developers.
# Contributing
Welcome contributions. See our [Contributor Wiki](https://gitee.com/mindspore/mindspore/blob/master/CONTRIBUTING.md) for more details.
# Release Notes
The release notes, see our [RELEASE](RELEASE.md).
# License
[Apache License 2.0](LICENSE)
## MindInsight
# Release 0.1.0-alpha
* Training process observation
* Provides and displays training process information, including computational graphs and training process indicators.
* Training result tracing
* Provides functions of tracing and visualizing model training parameter information, including filtering and sorting of training data, model accuracy and training hyperparameters.
# MindInsight Application Scenarios and Security Risks
1. MindInsight is a local tool developed using the HTTP protocol, which is insecure. You are not advised to use it in cloud services or scenarios with security requirements. Otherwise, data may be stolen.
2. The MindInsight source code restricts access from a localhost. If you modify the source code to cancel the localhost binding restriction, data leakage may occur.
# MindInsight Security Usage Suggestions
- You are advised to create an independent OS user to install and run the MindInsight service. Permissions among OS users are isolated to prevent data theft. In addition, you are advised to set a proper log directory size to prevent log recording exceptions due to insufficient disk space.
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
SCRIPT_BASEDIR=$(
cd "$(dirname "$0")" || exit
pwd
)
rename_wheel() {
VERSION="$1"
PACKAGE_LIST=$(ls mindinsight-*-any.whl) || exit
for PACKAGE_ORIG in ${PACKAGE_LIST}; do
MINDINSIGHT_VERSION=$(echo "${PACKAGE_ORIG}" | awk -F"-" '{print $2}')
PYTHON_VERSION_NUM=$(echo "${VERSION}" | awk -F"." '{print $1$2}')
PYTHON_VERSION_TAG="cp${PYTHON_VERSION_NUM}"
PYTHON_ABI_TAG="cp${PYTHON_VERSION_NUM}m"
OS_NAME=$(uname | tr '[:upper:]' '[:lower:]')
MACHINE_TAG="${OS_NAME}_$(uname -i)"
PACKAGE_NEW="mindinsight-${MINDINSIGHT_VERSION}-${PYTHON_VERSION_TAG}-${PYTHON_ABI_TAG}-${MACHINE_TAG}.whl"
mv "${PACKAGE_ORIG}" "${PACKAGE_NEW}"
done
}
build_wheel() {
PROJECT_BASEDIR=$(cd "$(dirname "$SCRIPT_BASEDIR")" || exit; pwd)
cd "${PROJECT_BASEDIR}" || exit
if [ $# -gt 0 ]; then
if [ "$1" = "clean" ]; then
echo "start cleaning mindinsight"
clean_files
echo "clean mindinsight done"
else
echo "unknown command: $1"
fi
exit
fi
echo "start building mindinsight"
clean_files
PYTHON=$(command -v python3 || command -v python)
if [ -z "${PYTHON}" ]; then
echo "Could not find python3 or python command"
exit 1
fi
PYTHON_VERSION=$(${PYTHON} -c "import platform; print(platform.python_version())" | grep '^3.*')
if [ -z "${PYTHON_VERSION}" ]; then
echo "Could not find Python 3"
exit 1
fi
rm -f output
mkdir output
${PYTHON} setup.py bdist_wheel
if [ ! -x "dist" ]; then
echo "Build failed"
exit 1
fi
mv dist/mindinsight-*-any.whl output/
cd output || exit
rename_wheel "${PYTHON_VERSION}"
cd - >/dev/null 2>&1 || exit
clean_files
echo "Build success, output directory is: ${PROJECT_BASEDIR}/output"
}
clean_files() {
rm -rf third_party/build
rm -rf build/lib
rm -rf build/bdist.*
rm -rf mindinsight.egg-info
rm -rf dist
}
show_usage() {
echo "Build mindinsight"
echo ""
echo "usage: build.sh [-h] [clean]"
echo ""
echo "options:"
echo " -h show usage info"
echo " clean clean build files"
}
check_opts() {
while getopts ':h' OPT; do
case "$OPT" in
h)
show_usage
exit 0
;;
\?)
show_usage
exit 1
;;
esac
done
}
check_opts "$@"
cd "${SCRIPT_BASEDIR}" || exit
build_wheel "$@"
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
SCRIPT_BASEDIR=$(
cd "$(dirname "$0")" || exit
pwd
)
THIRD_PARTY_DIR=$(realpath "${SCRIPT_BASEDIR}/../../third_party")
SECUREC_SOURCE_DIR="${THIRD_PARTY_DIR}/securec"
build_securec() {
CMAKE=$(command -v cmake)
if [ -z "${CMAKE}" ]; then
echo "Could not find cmake command"
exit 1
fi
cd "${SECUREC_SOURCE_DIR}" || exit
rm -rf build
mkdir build
cd build || exit
${CMAKE} ..
make
cd - >/dev/null 2>&1 || exit
}
build_crc32() {
CPP=$(command -v c++)
if [ -z "${CPP}" ]; then
echo "Could not find c++ command"
exit 1
fi
PYTHON=$(command -v python3 || command -v python)
if [ -z "${PYTHON}" ]; then
echo "Could not find python3 or python command"
exit 1
fi
PYTHON_VERSION=$(${PYTHON} -c "import platform; print(platform.python_version())" | grep '^3.*')
if [ -z "${PYTHON_VERSION}" ]; then
echo "Could not find Python 3"
exit 1
fi
DATAVISUAL_DIR=$(realpath "${SCRIPT_BASEDIR}/../../mindinsight/datavisual")
CRC32_SOURCE_DIR="${DATAVISUAL_DIR}/utils/crc32"
CRC32_OUTPUT_DIR="${DATAVISUAL_DIR}/utils"
CRC32_SO_FILE="crc32$(python3-config --extension-suffix)"
rm -f "${CRC32_SOURCE_DIR}/${CRC32_SO_FILE}"
rm -f "${CRC32_OUTPUT_DIR}/${CRC32_SO_FILE}"
cd "${CRC32_SOURCE_DIR}" || exit
PYBIND11_INCLUDES=$(${PYTHON} -m pybind11 --includes)
if [ -z "${PYBIND11_INCLUDES}" ]; then
echo "Could not find pybind11 module"
exit 1
fi
PYTHON_INCLUDE=$(echo "${PYBIND11_INCLUDES}" | awk '{print $1}' | sed "s/^-I//g")
PYTHON_HEADERS=$(echo "${PYBIND11_INCLUDES}" | awk '{print $2}' | sed "s/^-I//g")
${CPP} -O2 -O3 -shared -std=c++11 -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 \
-Wno-maybe-uninitialized -Wno-unused-parameter -Wall -Wl,-z,relro,-z,now,-z,noexecstack \
-I"${THIRD_PARTY_DIR}" -I"${DATAVISUAL_DIR}/utils" -I"${PYTHON_INCLUDE}" -I"${PYTHON_HEADERS}" \
-o "${CRC32_SO_FILE}" crc32.cc "${SECUREC_SOURCE_DIR}/build/src/libsecurec.a"
if [ ! -f "${CRC32_SO_FILE}" ]; then
echo "crc so file does not exist, build failed"
exit 1
fi
mv "${CRC32_SO_FILE}" "${CRC32_OUTPUT_DIR}"
}
cd "${SCRIPT_BASEDIR}" || exit
build_securec
cd "${SCRIPT_BASEDIR}" || exit
build_crc32
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
SCRIPT_BASEDIR=$(
cd "$(dirname "$0")" || exit
pwd
)
build_ui() {
NPM=$(command -v npm)
if [ -z "${NPM}" ]; then
echo "Could not find npm command"
exit 1
fi
UI_SOURCE_DIR=$(realpath "${SCRIPT_BASEDIR}/../../mindinsight/ui")
cd "${UI_SOURCE_DIR}" || exit
rm -rf dist
${NPM} config set strict-ssl false
${NPM} config set unsafe-perm true
${NPM} config set user 0
${NPM} install
${NPM} run build
if [ ! -f "dist/index.html" ]; then
echo "dist does not have file index.html, build failed"
exit 1
fi
rm -rf node_modules
}
cd "${SCRIPT_BASEDIR}" || exit
build_ui
# MindInsight Documentation
The MindInsight documentation is in the [MindSpore Docs](https://gitee.com/mindspore/docs) repository.
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mindinsight init module."""
from mindinsight._version import VERSION
__version__ = VERSION
__version_info__ = tuple(VERSION.split('.'))
__all__ = [
'__version__',
'__version_info__'
]
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mindinsight main module."""
from mindinsight.utils.command import main
main()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mindinsight version module."""
VERSION = '0.1.0'
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Web application module."""
import os
from importlib import import_module
from werkzeug.datastructures import Headers
from werkzeug.exceptions import HTTPException
from flask import Flask
from flask import request
from flask import Response
from flask_cors import CORS
from mindinsight.conf import settings
from mindinsight.utils.hook import HookUtils
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common.exceptions import RequestMethodNotAllowed
from mindinsight.datavisual.common import error_handler
from mindinsight.datavisual.utils.tools import find_app_package
from mindinsight.datavisual.utils.tools import get_img_mimetype
from mindinsight.utils.exceptions import MindInsightException
def get_security_headers():
"""Get security headers."""
domain_white_list = []
for hook in HookUtils.instance().hooks():
domain_white_list += hook.register_secure_domains()
content_security_policy = {
'img-src': ["'self'", 'data:'],
'style-src': ["'self'", "'unsafe-inline'"],
'frame-src': ["'self'"] + domain_white_list,
'frame-ancestors': ["'self'"] + domain_white_list,
'default-src': ["'self'"],
}
headers = {
'X-Frame-Options': 'SAMEORIGIN',
'X-XSS-Protection': '1; mode=block',
'X-Content-Type-Options': 'nosniff',
'Access-Control-Allow-Methods': ', '.join(settings.SUPPORT_REQUEST_METHODS),
'Content-Security-Policy': '; '.join([
f"{k} {' '.join(v)}" for k, v in content_security_policy.items()
]),
'X-Download-Options': 'noopen',
'Cache-Control': 'no-store',
'Pragma': 'no-cache'
}
return list(headers.items())
SECURITY_HEADERS = get_security_headers()
class CustomResponse(Response):
"""Define custom response."""
def __init__(self, response=None, **kwargs):
headers = kwargs.get("headers")
if isinstance(response, bytes):
mimetype = get_img_mimetype(response)
SECURITY_HEADERS.append(('Content-Type', mimetype))
if headers is None:
headers = Headers(SECURITY_HEADERS)
else:
for header in SECURITY_HEADERS:
headers.add(*header)
kwargs['headers'] = headers
super(CustomResponse, self).__init__(response, **kwargs)
def _init_app_module(app):
"""
Init app module.
Args:
app (Flask): An instance of Flask.
"""
packages = find_app_package()
for package in packages:
try:
app_module = import_module(package)
app_module.init_module(app)
except AttributeError:
logger.debug('[%s].init_module not exists.', package)
def before_request():
"""A function to run before each request."""
if request.method not in settings.SUPPORT_REQUEST_METHODS:
raise RequestMethodNotAllowed()
def create_app():
"""Set flask APP config, and start the data manager."""
static_url_path = "/static"
static_folder_path = os.path.realpath(os.path.join(os.path.dirname(__file__), os.pardir, 'ui', 'dist', 'static'))
app = Flask(__name__, static_url_path=static_url_path, static_folder=static_folder_path)
if settings.ENABLE_CORS:
CORS(app, supports_credentials=True)
app.before_request(before_request)
app.register_error_handler(HTTPException, error_handler.handle_http_exception_error)
app.register_error_handler(MindInsightException, error_handler.handle_mindinsight_error)
app.register_error_handler(Exception, error_handler.handle_unknown_error)
app.response_class = CustomResponse
_init_app_module(app)
return app
APP = create_app()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Config."""
import os
WEB_CONFIG_DIR = os.path.dirname(__file__)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Config file for gunicorn."""
import os
import threading
from importlib import import_module
import gunicorn
gunicorn.SERVER_SOFTWARE = 'unknown'
worker_class = 'sync'
workers = 1
threads = min(30, os.cpu_count() * 2 + 1)
worker_connections = 1000
timeout = 30
graceful_timeout = 30
daemon = True
captureoutput = True
# write gunicorn default log to stream, and using mindinsight logger write gunicorn log to file.
accesslog = '-'
def on_starting(server):
"""Hook function on starting gunicorn process."""
hook_module = import_module('mindinsight.utils.hook')
for hook in hook_module.HookUtils.instance().hooks():
threading.Thread(target=hook.on_startup, args=(server.log,)).start()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Datavisual."""
from mindinsight.backend.datavisual.static_resource_api import init_module as static_init_module
from mindinsight.backend.datavisual.task_manager_api import init_module as task_init_module
from mindinsight.backend.datavisual.train_visual_api import init_module as train_init_module
from mindinsight.conf import settings
from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER
def init_module(app):
"""
Interface to init module.
Args:
app (Flask): An instance of Flask.
"""
static_init_module(app)
task_init_module(app)
train_init_module(app)
DATA_MANAGER.start_load_data(reload_interval=int(settings.RELOAD_INTERVAL),
max_threads_count=int(settings.MAX_THREADS_COUNT))
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Static resource api."""
import os
import sys
from flask import current_app
from flask import send_from_directory
from flask import Blueprint
APP_PATH = os.path.realpath(os.path.dirname(sys.argv[0]))
BLUEPRINT = Blueprint("static_resource", __name__)
@BLUEPRINT.route("/", methods=["GET"])
def index():
"""Interface to return static index.html."""
return send_from_directory(get_index_resource_dir(), "index.html")
def get_index_resource_dir():
"""Interface to return index.html resource directory."""
return os.path.realpath(os.path.join(APP_PATH, current_app.static_folder, os.pardir))
def init_module(app):
"""
Init module entry.
Args:
app: the application obj.
"""
app.register_blueprint(BLUEPRINT)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Task manager api.
This module provides the interfaces to task manage functions.
"""
import os
from flask import Blueprint
from flask import request
from flask import jsonify
from mindinsight.conf import settings
from mindinsight.datavisual.utils.tools import str_to_bool
from mindinsight.datavisual.utils.tools import get_train_id
from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER
BLUEPRINT = Blueprint("task_manager", __name__, url_prefix=settings.URL_PREFIX)
@BLUEPRINT.route("/datavisual/single-job", methods=["GET"])
def query_single_train_task():
"""Query single train task"""
plugin_name = request.args.get('plugin_name')
train_id = get_train_id(request)
processor = TrainTaskManager(DATA_MANAGER)
tasks = processor.get_single_train_task(train_id=train_id, plugin_name=plugin_name)
return jsonify(tasks)
@BLUEPRINT.route("/datavisual/plugins", methods=["GET"])
def query_plugins():
"""Query plugins."""
train_id = get_train_id(request)
manual_update = request.args.get('manual_update', default='false')
manual_update = str_to_bool(manual_update, "manual_update")
processor = TrainTaskManager(DATA_MANAGER)
plugins = processor.get_plugins(train_id, manual_update)
return jsonify(plugins)
@BLUEPRINT.route("/datavisual/train-jobs", methods=["GET"])
def query_train_jobs():
"""Query train jobs."""
offset = request.args.get("offset", default=0)
limit = request.args.get("limit", default=10)
summary_watcher = SummaryWatcher()
total, directories = summary_watcher.list_summary_directories_by_pagination(
settings.SUMMARY_BASE_DIR, offset, limit)
train_jobs = [{
'train_id': directory['relative_path'],
'relative_path': directory['relative_path'],
'create_time': directory['create_time'].strftime('%Y-%m-%d %H:%M:%S'),
'update_time': directory['update_time'].strftime('%Y-%m-%d %H:%M:%S'),
} for directory in directories]
return jsonify({
'name': os.path.basename(os.path.realpath(settings.SUMMARY_BASE_DIR)),
'total': total,
'train_jobs': train_jobs,
})
def init_module(app):
"""
Init module entry.
Args:
app: the application obj.
"""
app.register_blueprint(BLUEPRINT)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Backend interface module.
This module provides the interfaces to train processors functions.
"""
from flask import Blueprint
from flask import request
from flask import jsonify
from mindinsight.conf import settings
from mindinsight.datavisual.utils.tools import get_train_id
from mindinsight.datavisual.utils.tools import if_nan_inf_to_none
from mindinsight.datavisual.processors.images_processor import ImageProcessor
from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor
from mindinsight.datavisual.processors.graph_processor import GraphProcessor
from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER
BLUEPRINT = Blueprint("train_visual", __name__, url_prefix=settings.URL_PREFIX)
@BLUEPRINT.route("/datavisual/image/metadata", methods=["GET"])
def image_metadata():
"""
Interface to fetch metadata about the images for the particular run,tag, and zero-indexed sample.
Returns:
Response, which contains a list in JSON containing image events, each
one of which is an object containing items wall_time, step, width,
height, and query.
"""
tag = request.args.get("tag")
train_id = get_train_id(request)
processor = ImageProcessor(DATA_MANAGER)
response = processor.get_metadata_list(train_id, tag)
return jsonify(response)
@BLUEPRINT.route("/datavisual/image/single-image", methods=["GET"])
def single_image():
"""
Interface to fetch raw image data for a particular image.
Returns:
Response, which contains a byte string of image.
"""
tag = request.args.get("tag")
step = request.args.get("step")
train_id = get_train_id(request)
processor = ImageProcessor(DATA_MANAGER)
img_data = processor.get_single_image(train_id, tag, step)
return img_data
@BLUEPRINT.route("/datavisual/scalar/metadata", methods=["GET"])
def scalar_metadata():
"""
Interface to fetch metadata about the scalars for the particular run and tag.
Returns:
Response, which contains a list in JSON containing scalar events, each
one of which is an object containing items' wall_time, step and value.
"""
tag = request.args.get("tag")
train_id = request.args.get("train_id")
processor = ScalarsProcessor(DATA_MANAGER)
response = processor.get_metadata_list(train_id, tag)
metadatas = response['metadatas']
for metadata in metadatas:
value = metadata.get("value")
metadata["value"] = if_nan_inf_to_none('scalar_value', value)
return jsonify(response)
@BLUEPRINT.route("/datavisual/graphs/nodes", methods=["GET"])
def graph_nodes():
"""
Interface to get graph nodes.
Returns:
Response, which contains a JSON object.
"""
name = request.args.get('name', default=None)
node_type = request.args.get('type', default='name_scope')
tag = request.args.get("tag", default=None)
train_id = get_train_id(request)
graph_process = GraphProcessor(train_id, DATA_MANAGER, tag)
response = graph_process.get_nodes(name=name, node_type=node_type)
return jsonify(response)
@BLUEPRINT.route("/datavisual/graphs/nodes/names", methods=["GET"])
def graph_node_names():
"""
Interface to query node names.
Returns:
Response, which contains a JSON object.
"""
search_content = request.args.get("search")
offset = request.args.get("offset", default=0)
limit = request.args.get("limit", default=100)
tag = request.args.get("tag", default=None)
train_id = get_train_id(request)
graph_process = GraphProcessor(train_id, DATA_MANAGER, tag)
resp = graph_process.search_node_names(search_content, offset, limit)
return jsonify(resp)
@BLUEPRINT.route("/datavisual/graphs/single-node", methods=["GET"])
def graph_search_single_node():
"""
Interface to search single node.
Returns:
Response, which contains a JSON object.
"""
name = request.args.get("name")
tag = request.args.get("tag", default=None)
train_id = get_train_id(request)
graph_process = GraphProcessor(train_id, DATA_MANAGER, tag)
resp = graph_process.search_single_node(name)
return jsonify(resp)
def init_module(app):
"""
Init module entry.
Args:
app (Flask): The application obj.
"""
app.register_blueprint(BLUEPRINT)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
module init file.
"""
from mindinsight.backend.lineagemgr.lineage_api import init_module as init_query_module
def init_module(app):
"""
Init module entry.
Args:
app: Flask. A Flask instance.
Returns:
"""
init_query_module(app)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Lineage restful api."""
import json
import os
from flask import Blueprint, jsonify, request
from mindinsight.conf import settings
from mindinsight.datavisual.utils.tools import get_train_id
from mindinsight.lineagemgr import filter_summary_lineage, get_summary_lineage
from mindinsight.lineagemgr.common.validator.validate import validate_path
from mindinsight.utils.exceptions import MindInsightException, ParamValueError
BLUEPRINT = Blueprint("lineage", __name__, url_prefix=settings.URL_PREFIX.rstrip("/"))
@BLUEPRINT.route("/models/model_lineage", methods=["POST"])
def search_model():
"""
Get model lineage info.
Get model info by summary base dir return a model lineage information list of dict
contains model's all kinds of param and count of summary log.
Returns:
str, the model lineage information.
Raises:
MindInsightException: If method fails to be called.
ParamValueError: If parsing json data search_condition fails.
Examples:
>>> POST http://xxxx/v1/mindinsight/models/model_lineage
"""
search_condition = request.stream.read()
try:
search_condition = json.loads(search_condition if search_condition else "{}")
except Exception:
raise ParamValueError("Json data parse failed.")
model_lineage_info = _get_lineage_info(
lineage_type="model",
search_condition=search_condition
)
return jsonify(model_lineage_info)
@BLUEPRINT.route("/datasets/dataset_lineage", methods=["POST"])
def get_datasets_lineage():
"""
Get dataset lineage.
Returns:
str, the dataset lineage information.
Raises:
MindInsightException: If method fails to be called.
ParamValueError: If parsing json data search_condition fails.
Examples:
>>> POST http://xxxx/v1/minddata/datasets/dataset_lineage
"""
search_condition = request.stream.read()
try:
search_condition = json.loads(search_condition if search_condition else "{}")
except Exception:
raise ParamValueError("Json data parse failed.")
dataset_lineage_info = _get_lineage_info(
lineage_type="dataset",
search_condition=search_condition
)
return jsonify(dataset_lineage_info)
def _get_lineage_info(lineage_type, search_condition):
"""
Get lineage info for dataset or model.
Args:
lineage_type (str): Lineage type, 'dataset' or 'model'.
search_condition (dict): Search condition.
Returns:
dict, lineage info.
Raises:
MindInsightException: If method fails to be called.
"""
if 'lineage_type' in search_condition:
raise ParamValueError("Lineage type does not need to be assigned in a specific interface.")
if lineage_type == 'dataset':
search_condition.update({'lineage_type': 'dataset'})
summary_base_dir = str(settings.SUMMARY_BASE_DIR)
try:
lineage_info = filter_summary_lineage(
summary_base_dir, search_condition)
lineages = lineage_info['object']
summary_base_dir = os.path.realpath(summary_base_dir)
length = len(summary_base_dir)
for lineage in lineages:
summary_dir = lineage['summary_dir']
summary_dir = os.path.realpath(summary_dir)
if summary_base_dir == summary_dir:
relative_dir = './'
else:
relative_dir = os.path.join(os.curdir, summary_dir[length+1:])
lineage['summary_dir'] = relative_dir
except MindInsightException as exception:
raise MindInsightException(exception.error, exception.message, http_code=400)
return lineage_info
@BLUEPRINT.route("/datasets/dataset_graph", methods=["GET"])
def get_dataset_graph():
"""
Get dataset graph.
Returns:
str, the dataset graph information.
Raises:
MindInsightException: If method fails to be called.
ParamValueError: If summary_dir is invalid.
Examples:
>>> GET http://xxxx/v1/mindinsight/datasets/dataset_graph?train_id=xxx
"""
summary_base_dir = str(settings.SUMMARY_BASE_DIR)
summary_dir = get_train_id(request)
if summary_dir.startswith('/'):
validate_path(summary_dir)
elif summary_dir.startswith('./'):
summary_dir = os.path.join(summary_base_dir, summary_dir[2:])
summary_dir = validate_path(summary_dir)
else:
raise ParamValueError(
"Summary dir should be absolute path or "
"relative path that relate to summary base dir."
)
try:
dataset_graph = get_summary_lineage(
summary_dir=summary_dir,
keys=['dataset_graph']
)
except MindInsightException as exception:
raise MindInsightException(exception.error, exception.message, http_code=400)
if dataset_graph:
summary_dir_result = dataset_graph.get('summary_dir')
base_dir_len = len(summary_base_dir)
if summary_base_dir == summary_dir_result:
relative_dir = './'
else:
relative_dir = os.path.join(
os.curdir, summary_dir[base_dir_len + 1:]
)
dataset_graph['summary_dir'] = relative_dir
return jsonify(dataset_graph)
def init_module(app):
"""
Init module entry.
Args:
app (Flask): The application obj.
"""
app.register_blueprint(BLUEPRINT)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Web service entrance."""
import os
import stat
import re
import subprocess
import time
import shlex
from gunicorn.glogging import Logger
from mindinsight.backend.config import gunicorn_conf
from mindinsight.backend.config import WEB_CONFIG_DIR
from mindinsight.conf import settings
from mindinsight.utils.log import setup_logger
MINDBOARD_APP_MODULE = "mindinsight.backend.application:APP"
GUNICORN_LOGGER = "mindinsight.backend.run.GunicornLogger"
_MIN_PORT = 1
_MAX_PORT = 65535
def _get_file_size(file_path):
"""
Get the file size.
Args:
file_path (str): The file path.
Returns:
int, the file size. If file is not existed, then return 0.
"""
try:
file_size = os.path.getsize(file_path)
except FileNotFoundError:
file_size = 0
return file_size
def _is_match_one(sub_string_list, src_string):
"""
Whether the sub-string in the list can match the source string.
Args:
sub_string_list (list): The sub-string list.
src_string (str): The source string.
Returns:
bool, if matched return True, else return False.
"""
for match_info in sub_string_list:
if match_info in src_string:
return True
return False
def _check_stat_from_log(log_info):
"""
Determine the service startup status based on the log information.
Args:
log_info (str): The output log of service startup.
Returns:
str, the state value that is one of the follows: "unknown", "failed" and "success".
"""
server_state = "unknown"
match_success_info = "Listening at: http://%s:%d" % \
(settings.HOST, int(settings.PORT))
common_failed_info_list = [
"[ERROR] Retrying in 1 second",
"[INFO] Reason: App failed to load",
"[ERROR] Exception in worker process"
]
re_pattern = "\\[ERROR\\].+%s.+%d" % \
(settings.HOST, int(settings.PORT))
# matched failed output log by fuzzy match
if re.search(re_pattern, log_info) or \
_is_match_one(common_failed_info_list, log_info):
server_state = "failed"
if match_success_info in log_info:
server_state = "success"
return server_state
def _get_error_log_path():
"""
Get gunicorn error log path.
Returns:
str, the path of error log.
"""
path = os.path.join(settings.WORKSPACE, 'log/gunicorn/error.log')
errorlog_abspath = os.path.realpath(path)
return errorlog_abspath
def _get_access_log_path():
"""Get gunicorn access log path."""
access_log_path = os.path.join(settings.WORKSPACE, 'log/gunicorn/access.log')
access_log_path = os.path.realpath(access_log_path)
return access_log_path
def _check_state_from_log(log_abspath, start_pos=0):
"""
Check the service startup status based on the log file.
Args:
log_abspath (str): Absolute path of the log file.
start_pos (int): Offset position of the log file.
Returns:
dict, a dict with "state" and "prompt_message" key.
The value of the "state" key is as follows:"unknown", "failed" and "success".
The value of the "prompt_message" key is a list of prompt messages.
"""
server_is_start = False
state_result = {"state": "unknown", "prompt_message": []}
prompt_messages = []
match_start_log = "Starting gunicorn"
with open(log_abspath) as f_log:
f_log.seek(start_pos)
for line in f_log.readlines():
if match_start_log in line:
if server_is_start:
break
server_is_start = True
continue
if server_is_start:
log_result = _check_stat_from_log(line)
# ignore "unknown" result
if log_result != "unknown":
state_result["state"] = log_result
if log_result == "failed":
prompt_messages.append(line.strip())
prompt_messages.append(
"more failed details in log: %s" % log_abspath)
break
state_result["prompt_message"].append(
"service start state: %s" % state_result["state"])
for prompt_message in prompt_messages:
state_result["prompt_message"].append(prompt_message)
return state_result
def _check_server_start_stat(log_abspath, start_pos=None):
"""
Checking the Server Startup Status.
Args:
log_abspath (str): The log file path.
start_pos (int): The log file start position.
Returns:
dict, an dict object that contains the state and prompt_message fields.
The state values are as follows: "unknown", "failed" and "success".
"""
state_result = {"state": "unknown", "prompt_message": []}
# return unknown when not config gunicorn error log file
if not log_abspath:
return state_result
log_pos = _get_file_size(log_abspath) if start_pos is None else start_pos
try_cnt = 0
try_cnt_max = 2
while try_cnt < try_cnt_max:
try_cnt += 1
time.sleep(1)
if _get_file_size(log_abspath) > log_pos:
state_result.update(_check_state_from_log(log_abspath, log_pos))
break
return state_result
class GunicornLogger(Logger):
"""Rewrite gunicorn default logger."""
def __init__(self, cfg):
self.access_log = setup_logger('gunicorn', 'access')
self.error_log = setup_logger('gunicorn', 'error')
super(GunicornLogger, self).__init__(cfg)
access_log_path = _get_access_log_path()
error_log_path = _get_error_log_path()
os.chmod(access_log_path, stat.S_IREAD | stat.S_IWRITE)
os.chmod(error_log_path, stat.S_IREAD | stat.S_IWRITE)
def start():
"""Start web service."""
errorlog_abspath = _get_error_log_path()
gunicorn_conf_file = os.path.join(WEB_CONFIG_DIR, "gunicorn_conf.py")
cmd = "gunicorn " \
"-b {host}:{port} {app_module} " \
"-c {conf_file} " \
"--logger-class {logger_class} " \
"--access-logformat {log_format}"\
.format(host=settings.HOST,
port=settings.PORT,
conf_file=gunicorn_conf_file,
app_module=MINDBOARD_APP_MODULE,
logger_class=GUNICORN_LOGGER,
log_format=settings.GUNICORN_ACCESS_FORMAT
)
log_size = _get_file_size(errorlog_abspath)
# start server
process = subprocess.Popen(
shlex.split(cmd),
shell=False,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
_, stderr = process.communicate()
if stderr:
print(stderr.decode())
# wait command success to end when gunicorn running in daemon.
if gunicorn_conf.daemon and process.wait() == 0:
state_result = _check_server_start_stat(errorlog_abspath, log_size)
# print gunicorn start state to stdout
print('Web address: http://{}:{}'.format(settings.HOST, settings.PORT))
for line in state_result["prompt_message"]:
print(line)
if __name__ == '__main__':
start()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Datavisual hook."""
import argparse
import os
from mindinsight.conf import settings
from mindinsight.utils.hook import BaseHook
class ReloadIntervalAction(argparse.Action):
"""Reload interval action class definition."""
def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.
Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Option string for specific argument name.
"""
reload_interval = values
if reload_interval < 0:
parser.error(f'{option_string} should be greater than or equal to 0')
setattr(namespace, self.dest, reload_interval)
class SummaryBaseDirAction(argparse.Action):
"""Summary base dir action class definition."""
def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.
Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Option string for specific argument name.
"""
summary_base_dir = os.path.realpath(values)
setattr(namespace, self.dest, summary_base_dir)
class Hook(BaseHook):
"""Hook class definition."""
def register_startup_arguments(self, parser):
"""
Hook function to register startup arguments.
Args:
parser (ArgumentParser): Specify parser to which arguments are added.
"""
parser.add_argument(
'--reload-interval',
type=int,
action=ReloadIntervalAction,
help="""
data reload time(Seconds). It should be greater than 0 or equal to 0.
If it equals 0, load data only once. Default value is %s seconds.
""" % settings.RELOAD_INTERVAL)
parser.add_argument(
'--summary-base-dir',
type=str,
action=SummaryBaseDirAction,
help="""
directory where MindInsight will walk through its direct subdirectories
and look for summary files naming with regex 'summary.\\d+' or '\\.pb$'. Any direct
subdirectory containing summary files will turn out to be the summary
file directory. Summary file existing in summary-base-dir indicates that
sumamry-base-dir is one of the summary file directories as well. Default
value is current directory.""")
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Conf module."""
import os
import json
import types
from importlib import import_module
class Settings:
"""
Definition of Settings class.
Examples:
>>> from mindinsight.conf import settings
>>> print(settings.PORT)
"""
_prefix = 'MINDINSIGHT_'
_explicit_settings = set()
_default_settings = set()
def __init__(self):
"""Initialization of Settings."""
self.load_from_defaults()
self.load_from_constants()
self.refresh()
def refresh(self):
"""Refresh settings from config file and environment variables."""
self.update_from_file()
self.update_from_env()
def load_from_defaults(self):
"""Update settings from defaults module."""
default_settings = import_module('mindinsight.conf.defaults')
for setting in dir(default_settings):
if setting.isupper():
setattr(self, setting, getattr(default_settings, setting))
self._default_settings.add(setting)
def load_from_constants(self):
"""Update settings from constants module"""
constant_settings = import_module('mindinsight.conf.constants')
for setting in dir(constant_settings):
if setting.isupper():
setattr(self, setting, getattr(constant_settings, setting))
def update_from_file(self):
"""Update settings from config file."""
config_path = os.environ.get('MINDINSIGHT_CONFIG', '')
if not config_path:
return
config_module = None
# python:full.path.for.config.module
if config_path.startswith('python:'):
config_module = import_module(config_path[len('python:'):])
# file:full/path/for/config.py
elif config_path.startswith('file:'):
config_path = config_path[len('file:'):]
module_name = '__mindinsightconfig__'
config_module = types.ModuleType(module_name)
machinery = import_module('importlib.machinery')
loader = machinery.SourceFileLoader(module_name, config_path)
loader.exec_module(config_module)
if config_module is None:
return
for setting in dir(config_module):
if setting.isupper() and setting in self._default_settings:
setting_value = getattr(config_module, setting)
setattr(self, setting, setting_value)
self._explicit_settings.add(setting)
def update_from_env(self):
"""Update settings from environment variables."""
for key, value in os.environ.items():
if not key.startswith(self._prefix):
continue
setting = key[len(self._prefix):]
if setting not in self._default_settings:
continue
setting_value = getattr(self, setting)
if isinstance(setting_value, bool):
value = (value == 'True')
elif isinstance(setting_value, (int, float)):
value = type(setting_value)(value)
elif isinstance(setting_value, (list, dict)):
value = json.loads(value)
setattr(self, setting, value)
self._explicit_settings.add(setting)
def config_workspace(self, workspace):
"""
Config workspace value.
Args:
workspace (str): Path of workspace.
"""
setattr(self, 'WORKSPACE', workspace)
self._explicit_settings.add('WORKSPACE')
def is_overridden(self, setting_name):
"""
Check if specified setting is overridden.
Args:
setting_name (str): Setting name to be checked.
Returns:
bool, indicate whether given setting name is overridden.
"""
return setting_name in self._explicit_settings
def dump(self):
"""
Dump settings data.
Returns:
dict, json formatted data of settings.
"""
config = {}
for setting in dir(self):
if setting.isupper():
config[setting] = getattr(self, setting)
return config
settings = Settings()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Constants module for mindinsight settings."""
import logging
####################################
# Global default settings.
####################################
LOG_FORMAT = '[%(levelname)s] MI(%(process)d:%(thread)d,%(processName)s):%(asctime)s ' \
'[%(filepath)s:%(lineno)d][%(sub_module)s] %(message)s'
GUNICORN_ACCESS_FORMAT = "'%(h)s <%(r)s> %(s)s %(b)s <%(f)s> <%(a)s> %(D)s'"
LOG_LEVEL = logging.INFO
# rotating max bytes, default is 50M
LOG_ROTATING_MAXBYTES = 52428800
# rotating backup count, default is 30
LOG_ROTATING_BACKUPCOUNT = 30
####################################
# Web default settings.
####################################
HOST = '127.0.0.1'
# Allow to support cross origin resource sharing(CORS) enable. Default is disable.
# If enable CORS, `SUPPORT_REQUEST_METHODS` should enable 'OPTIONS' method.
ENABLE_CORS = False
SUPPORT_REQUEST_METHODS = {'POST', 'GET', 'PUT', 'DELETE'}
# url prefix should not end with slash, correct format is /v1/url
URL_PREFIX = '/v1/mindinsight'
####################################
# Datavisual default settings.
####################################
MAX_THREADS_COUNT = 15
MAX_TAG_SIZE_PER_EVENTS_DATA = 300
DEFAULT_STEP_SIZES_PER_TAG = 500
MAX_GRAPH_TAG_SIZE = 10
MAX_IMAGE_STEP_SIZE_PER_TAG = 10
MAX_SCALAR_STEP_SIZE_PER_TAG = 1000
MAX_GRAPH_STEP_SIZE_PER_TAG = 1
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Defaults module for mindinsight settings."""
import os
####################################
# Global default settings.
####################################
WORKSPACE = os.path.join(os.environ['HOME'], 'mindinsight')
####################################
# Web default settings.
####################################
PORT = 8080
####################################
# Datavisual default settings.
####################################
RELOAD_INTERVAL = 3 # Seconds
SUMMARY_BASE_DIR = os.getcwd()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Enums."""
from enum import Enum
class BaseEnum(Enum):
@classmethod
def list_members(cls):
"""List all members."""
return [member.value for member in cls]
class DataManagerStatus(BaseEnum):
"""Data manager status."""
INIT = 'INIT'
LOADING = 'LOADING'
DONE = 'DONE'
INVALID = 'INVALID'
class PluginNameEnum(BaseEnum):
"""Plugin Name Enum."""
IMAGE = 'image'
SCALAR = 'scalar'
GRAPH = 'graph'
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Handle custom error."""
from urllib.parse import quote
from werkzeug.exceptions import NotFound
from werkzeug.exceptions import MethodNotAllowed
from flask import request, jsonify
from mindinsight.datavisual.common.exceptions import RequestMethodNotAllowed
from mindinsight.datavisual.common.exceptions import RestfulApiNotExist
from mindinsight.datavisual.common.log import restful_logger as logger
from mindinsight.utils.exceptions import UnknownError
from mindinsight.utils.exceptions import FileSystemPermissionError
def handle_http_exception_error(ex):
"""Handle http exception error."""
logger.warning('%r %r, detail: %r', request.method, quote(request.path), str(ex))
if isinstance(ex, NotFound):
error = RestfulApiNotExist()
elif isinstance(ex, MethodNotAllowed):
error = RequestMethodNotAllowed()
else:
logger.exception(ex)
error = UnknownError('System error or http error.')
res_body = {"error_code": error.error_code, "error_msg": error.message}
return jsonify(res_body), error.http_code
def handle_mindinsight_error(ex):
"""Handle mindinsight error."""
if int(ex.http_code) < 500:
logger.warning('%r %r detail: %r', request.method, quote(request.path), ex.message)
else:
logger.error('%r %r detail: %r', request.method, quote(request.path), ex.message)
logger.exception(ex)
res_body = dict(error_code=ex.error_code, error_msg=ex.message)
return jsonify(res_body), ex.http_code
def handle_unknown_error(ex):
"""Handle unknown error."""
logger.error('%r %r detail: %r', request.method, quote(request.path), str(ex))
logger.exception(ex)
if isinstance(ex, PermissionError):
error = FileSystemPermissionError('File System Permission Error')
else:
error = UnknownError('System error.')
res_body = dict(error_code=error.error_code, error_msg=error.message)
return jsonify(res_body), error.http_code
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define custom exception."""
from mindinsight.utils.constant import DataVisualErrors
from mindinsight.utils.exceptions import MindInsightException
class RestfulApiNotExist(MindInsightException):
"""404 not found."""
def __init__(self):
error_msg = '404 Not Found.'
super(RestfulApiNotExist, self).__init__(DataVisualErrors.RESTFUL_API_NOT_EXIST,
error_msg,
http_code=404)
class RequestMethodNotAllowed(MindInsightException):
"""Request method not allowed."""
def __init__(self):
error_msg = '405 Method Not Allowed.'
super(RequestMethodNotAllowed, self).__init__(DataVisualErrors.REQUEST_METHOD_NOT_ALLOWED,
error_msg,
http_code=405)
class PathNotDirectoryError(MindInsightException):
"""Raised when specified path do not exist."""
def __init__(self, error_detail):
"""Initialize PathNotExistError"""
error_msg = 'Specified path is not a directory. Detail: {}'.format(error_detail)
super(PathNotDirectoryError, self).__init__(DataVisualErrors.PATH_NOT_DIRECTORY_ERROR,
error_msg,
http_code=400)
class SummaryLogPathInvalid(MindInsightException):
"""No valid log file in the path."""
def __init__(self):
error_msg = 'No valid summary log file in path'
super(SummaryLogPathInvalid, self).__init__(DataVisualErrors.SUMMARY_LOG_PATH_INVALID,
error_msg,
http_code=400)
class CRCFailedError(MindInsightException):
"""CRC fail, record corrupted."""
def __init__(self):
error_msg = 'CRC Failed.'
super(CRCFailedError, self).__init__(DataVisualErrors.CRC_FAILED,
error_msg,
http_code=400)
class SummaryLogIsLoading(MindInsightException):
"""Data is loading."""
def __init__(self, error_detail):
error_msg = "Data is loading. Detail: %s" % error_detail
super(SummaryLogIsLoading, self).__init__(DataVisualErrors.SUMMARY_LOG_IS_LOADING,
error_msg,
http_code=400)
class NodeNotInGraphError(MindInsightException):
"""Can not find node in graph error."""
def __init__(self):
error_msg = "Can not find node in graph by given node name."
super(NodeNotInGraphError, self).__init__(DataVisualErrors.NODE_NOT_IN_GRAPH_ERROR,
error_msg,
http_code=400)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Create a logger."""
from mindinsight.utils.log import setup_logger
logger = setup_logger("datavisual", "datavisual")
restful_logger = setup_logger("restful_api", "restful_api")
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define a validation class which contain all check methods of datavisual module."""
from numbers import Number
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.utils.exceptions import ParamMissError
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.utils.tools import to_int
class Validation:
"""Validation class, define all check methods."""
@classmethod
def check_offset(cls, offset, default_value=0):
"""
Check offset parameter, it must be greater or equal 0.
Args:
offset (Union[str, int]): Value can be string number or int.
default_value (int): Default value for checked offset. Default: 0.
Returns:
int, offset.
"""
if offset is None:
return default_value
offset = to_int(offset, 'offset')
if offset < 0:
raise ParamValueError("'offset' should be greater than or equal to 0.")
return offset
@classmethod
def check_limit(cls, limit, min_value=1, max_value=1000, default_value=100):
"""
Check limit parameter, it should between min_value and max_value.
Args:
limit (Union[str, int]): Value can be string number or int.
min_value (int): Limit should greater or equal this value. Default: 1.
max_value (int): Limit should less or equal this value. Default: 1000.
default_value (int): Default value for limit. Default: 100.
Returns:
int, limit.
"""
if limit is None:
return default_value
limit = to_int(limit, 'limit')
if limit < min_value or limit > max_value:
raise ParamValueError("'limit' should in [{}, {}].".format(min_value, max_value))
return limit
@classmethod
def check_param_empty(cls, **kwargs):
"""
Check param.
Args:
kwargs (Any): Check if arg is truthy.
Raises:
ParamMissError: When param missing.
"""
for key, value in kwargs.items():
# When value is 0, 0.0 or False, it is not empty.
if isinstance(value, Number):
continue
if not value:
raise ParamMissError(key)
@classmethod
def check_plugin_name(cls, plugin_name):
"""
Check plugin name.
Args:
plugin_name (str): The plugin name.
Raises:
ParamValueError: When plugin name is not valid.
"""
plugin_name_list = PluginNameEnum.list_members()
if plugin_name not in plugin_name_list:
raise ParamValueError("'plugin_name' only can be one of {}"
"".format(plugin_name_list))
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Base file system."""
from abc import ABC, abstractmethod
from collections import namedtuple
StatInfo = namedtuple("Info", ["size", "mtime"])
class BaseFileSystem(ABC):
"""Base class for file systems."""
@abstractmethod
def list_dir(self, path):
"""
Abstract method for listing directories by path.
Args:
path (str): Directory path or file path.
"""
@abstractmethod
def is_dir(self, path):
"""
Abstract method for determining if it is a directory.
Args:
path (str): Directory path or file path.
"""
@abstractmethod
def exists(self, path):
"""
Abstract method for determining if it exists.
Args:
path (str): Directory path or file path.
"""
@abstractmethod
def file_stat(self, file_path):
"""
Abstract method for getting file stat information.
Args:
file_path (str): File path.
"""
@abstractmethod
def join(self, path, *paths):
"""
Abstract method for combining paths.
Args:
path (str): Directory path.
*paths (str): Path or paths.
"""
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""File handler for file operations."""
from mindinsight.utils.exceptions import PathNotExistError
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.utils.tools import to_str
from mindinsight.datavisual.data_access.local_file_system import LocalFileSystem
_DEFAULT_BUFFER_SIZE = 24 * 1024 * 1024
# _FILE_SYSTEMS, key: FileProtocolHead, value: FileSystem
_FILE_SYSTEMS = dict()
_FILE_SYSTEMS[""] = LocalFileSystem()
class FileHandler:
"""File handler."""
def __init__(self, file_path, mode='rb'):
"""
Init FileHandler.
Args:
file_path (str): File path.
mode (Literal['r', 'rb', 'br', 'w', 'wb', 'bw']): It must be
in ['r', 'rb', 'br', 'w', 'wb', 'bw'].
"""
logger.debug("The __init__ method enter, param: file_path=%s"
"mode=%s", file_path, mode)
if mode not in ('r', 'rb', 'br', 'w', 'wb', 'bw'):
raise ValueError("mode %s is not supported by FileHandler." % mode)
self._file_path = to_str(file_path)
self._file_system = self.get_file_system(self._file_path)
self._buff_chunk_size = _DEFAULT_BUFFER_SIZE
self._buff = None
self._buff_offset = 0
self._offset = 0
self._binary_mode = 'b' in mode
@staticmethod
def get_file_system(path):
"""
Get file system object from path.
Args:
path (str): Directory path or file path.
Returns:
BaseFileSystem, a file system object.
"""
path = to_str(path)
prefix_index = path.find("://")
prefix = path[:prefix_index] if prefix_index >= 0 else ""
file_system = _FILE_SYSTEMS.get(prefix, None)
if file_system is None:
raise ValueError("No filesystem can be found for prefix %s" % prefix)
return file_system
@staticmethod
def walk(node, forward=True, onerror=None):
"""
Traverse path for directory and file tree.
Read from the buffer first.If there is not enough data in the buffer,
data will be read from the file system.
Args:
node (str): Current path.
forward (bool): If True, it will return the sub-directories and files in the top-level
directory first and then iterate the files in the sub-directories. Default: True.
onerror (Optional[Callable]): If None, it indicates that errors during file traversal
will be ignored. Default: None.
Yields:
Tuple, (node, sub_dirs, files).
"""
logger.debug("The walk method enter, param: node=%s, "
"forward=%s, onerror=%s.", node, forward, type(onerror))
file_system = FileHandler.get_file_system(node)
node = to_str(node)
dirs = []
try:
dirs = file_system.list_dir(node)
except PathNotExistError as err:
if onerror:
onerror(err)
else:
logger.warning("Get dir list error, dir_path=%s error=%s.", node, str(err))
return
sub_dirs, files = [], []
for item in dirs:
full_path = file_system.join(node, to_str(item))
if file_system.is_dir(full_path):
sub_dirs.append(item)
else:
files.append(item)
result = (node, sub_dirs, files)
if forward:
logger.debug("The walk method return, pre result=%s.", result)
yield result
for subdir in sub_dirs:
joined_subdir = file_system.join(node, to_str(subdir))
for sub_results in FileHandler.walk(joined_subdir, forward, onerror):
yield sub_results
if not forward:
logger.debug("The walk method return, post result=%s.", result)
yield result
def read(self, size=None):
"""
Read bytes from buffer or file by size.
Args:
size (Union[None, int]): Number of bytes to read, If set None, read the whole file. Default: None.
Returns:
str, a certain number of bytes.
"""
if size is None:
result = self._file_system.read(self._file_path, self._binary_mode)
self._offset = len(result)
return result
result = None
if self._buff and len(self._buff) > self._buff_offset:
read_offset = self._buff_offset + size if size is not None else len(self._buff)
result = self._read_buffer_by_offset(read_offset)
if size is not None:
if len(result) == size:
return result
size -= len(result)
read_size = max(self._buff_chunk_size, size) if size is not None else None
self._buff = self._file_system.read(self._file_path, self._binary_mode,
read_size, self._offset)
self._buff_offset = 0
read_offset = size if size is not None else len(self._buff)
chunk = self._read_buffer_by_offset(read_offset)
result = result + chunk if result else chunk
return result
def _read_buffer_by_offset(self, new_buff_offset):
"""
Read buffer by offset.
Args:
new_buff_offset (int): Ending offset to read.
Returns:
str, bytes from old offset to new offset.
"""
old_buff_offset = self._buff_offset
read_size = min(len(self._buff), new_buff_offset) - old_buff_offset
self._offset += read_size
self._buff_offset += read_size
return self._buff[old_buff_offset:old_buff_offset + read_size]
def reset_offset(self, offset):
"""
Reset offset and buff_offset, clean buff.
Args:
offset (int): Offset.
"""
self._offset = offset
self._buff = None
self._buff_offset = 0
@staticmethod
def list_dir(path):
"""
List directories by path.
Args:
path (str): Directory path or file path.
Returns:
list[str], directories.
"""
file_system = FileHandler.get_file_system(path)
return file_system.list_dir(path)
@staticmethod
def is_dir(path):
"""
Determine if it is a directory.
Args:
path (str): Directory path or file path.
Returns:
bool, if it is a directory path, return True.
"""
file_system = FileHandler.get_file_system(path)
return file_system.is_dir(path)
@staticmethod
def is_file(path):
"""
Determine if it is a file.
Args:
path (str): Directory path or file path.
Returns:
bool, if it is a file path, return True.
"""
file_system = FileHandler.get_file_system(path)
return file_system.is_file(path)
@staticmethod
def exists(path):
"""
Determine if it exists.
Args:
path (str): Directory path or file path.
Returns:
bool, if it exists, return True.
"""
file_system = FileHandler.get_file_system(path)
return file_system.exists(path)
@staticmethod
def file_stat(file_path):
"""
Get file stat information.
Args:
file_path (str): File path.
Returns:
Nametuple, the (size, mtime) of file.
"""
file_system = FileHandler.get_file_system(file_path)
return file_system.file_stat(file_path)
@staticmethod
def join(path, *paths):
"""
Join paths.
Args:
path (str): Directory path.
paths (str): Path or paths.
Returns:
str, the joined path.
"""
file_system = FileHandler.get_file_system(path)
return file_system.join(path, *paths)
@property
def offset(self):
"""Get offset."""
return self._offset
@property
def file_path(self):
"""Get file path."""
return self._file_path
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local File System."""
import io
import os
from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.utils.tools import to_str
from mindinsight.datavisual.data_access.base_file_system import BaseFileSystem
from mindinsight.datavisual.data_access.base_file_system import StatInfo
from mindinsight.utils.exceptions import PathNotExistError
class LocalFileSystem(BaseFileSystem):
"""Local file system."""
def list_dir(self, path):
"""
List directories by path.
Args:
path (str): Directory path or file path.
Returns:
list[str], directories.
"""
path = to_str(path)
if not self.is_dir(path):
raise exceptions.PathNotDirectoryError("Path is %s." % path)
return os.listdir(path)
def is_dir(self, path):
"""
Determine if it is a directory.
Args:
path (str): Directory path or file path.
Returns:
bool, if it is a directory path, return True.
"""
return os.path.isdir(to_str(path))
def is_file(self, path):
"""
Determine if it is a file.
Args:
path (str): Directory path or file path.
Returns:
bool, if it is a file path, return True.
"""
return os.path.isfile(to_str(path))
def exists(self, path):
"""
Determine if it exists.
Args:
path (str): Directory path or file path.
Returns:
bool, if it exists, return True.
"""
return os.path.exists(to_str(path))
def file_stat(self, file_path):
"""
Get file stat information.
Args:
file_path (str): File path.
Returns:
Nametuple, the (size, mtime) of file.
"""
try:
file_info = os.stat(to_str(file_path))
except OSError:
raise PathNotExistError("File %s is not exist." % file_path)
return StatInfo(size=file_info.st_size, mtime=file_info.st_mtime)
@staticmethod
def read_access(file_path):
"""
Determine if it has read permission.
Args:
file_path (str): File path.
Returns:
bool, if it has read permission, return True.
"""
return os.access(to_str(file_path), os.R_OK)
def join(self, path, *paths):
"""
Join paths.
Args:
path (str): Directory path.
paths (str): Path or paths.
Returns:
str, the joined path.
"""
return os.path.join(path, *paths)
@staticmethod
def read(file_path, binary_mode=False, size=None, offset=None):
"""
Read file.
Args:
file_path (str): File path.
binary_mode (bool): If true, mode will be 'rb'. Else, 'r'.
size (int): Size of bytes to read.
offset (int): Offset of file to read.
Returns:
bytes, the content read.
"""
mode = "rb" if binary_mode else "r"
encoding = None if binary_mode else "utf8"
with io.open(file_path, mode, encoding=encoding) as file:
if offset is not None:
file.seek(offset)
if size is not None:
return file.read(size)
return file.read()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
DataLoader is an adapter for all other loaders.
This module can identify what loader should be used to load data.
"""
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader
from mindinsight.datavisual.common import exceptions
class DataLoader:
"""
The adapter of all kinds of loaders.
Args:
summary_dir (str): A directory path.
"""
def __init__(self, summary_dir):
self._summary_dir = summary_dir
self._loader = None
def load(self):
"""Load the data when loader is exist."""
if self._loader is None:
ms_dataloader = MSDataLoader(self._summary_dir)
loaders = [ms_dataloader]
for loader in loaders:
if loader.filter_valid_files():
self._loader = loader
break
if self._loader is None:
logger.warning("No valid files can be loaded, summary_dir: %s.", self._summary_dir)
raise exceptions.SummaryLogPathInvalid()
self._loader.load()
def get_events_data(self):
"""
Get events data from log file.
Returns:
Optional[EventsData], None or events data.
"""
return self._loader.get_events_data()
def has_valid_files(self):
"""
Check the directory for valid files.
Returns:
bool, if the directory has valid files, return True.
"""
ms_dataloader = MSDataLoader(self._summary_dir)
return bool(ms_dataloader.filter_valid_files())
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Management of all events data.
This module exists to all loaders.
It can read events data through the DataLoader.
This module also acts as a thread pool manager.
"""
import threading
import time
from concurrent.futures import ThreadPoolExecutor, wait, ALL_COMPLETED
from mindinsight.conf import settings
from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common.enums import DataManagerStatus
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.utils.exceptions import MindInsightException
from mindinsight.utils.exceptions import ParamValueError
class DataManager:
"""
DataManager manages a pool of loader which help access events data.
Each loader helps deal the data of the events.
A loader corresponds to an events_data.
The DataManager build a pool including all the data_loader.
The data_loader provides extracting
method to get the information of events.
"""
def __init__(self, loader_generators):
"""
Initialize the pool of loader and the dict of name-to-path.
Args:
loader_generators (list[LoaderGenerator]): Loader generators help generate loaders.
self._status: Refer `datavisual.common.enums.DataManagerStatus`.
self._loader_pool: {'loader_id': <LoaderStruct>}.
"""
self._loader_pool = {}
self._deleted_id_list = []
self._status = DataManagerStatus.INIT.value
self._status_mutex = threading.Lock()
self._loader_pool_mutex = threading.Lock()
self._max_threads_count = 30
self._reload_interval = 3
self._loader_generators = loader_generators
def _add_loader(self, loader):
"""
Add a loader to load data.
Args:
loader (LoaderStruct): A object of `Loader`.
"""
if len(self._loader_pool) >= MAX_DATA_LOADER_SIZE:
delete_number = len(self._loader_pool) - MAX_DATA_LOADER_SIZE + 1
sorted_loaders = sorted(self._loader_pool.items(),
key=lambda loader: loader[1].latest_update_time)
for index in range(delete_number):
delete_loader_id = sorted_loaders[index][0]
self._delete_loader(delete_loader_id)
self._loader_pool.update({loader.loader_id: loader})
def _delete_loader(self, loader_id):
"""
Delete loader from loader pool by loader id.
Args:
loader_id (str): ID of loader.
"""
if self._loader_pool.get(loader_id) is not None:
logger.debug("delete loader %s", loader_id)
self._loader_pool.pop(loader_id)
def _execute_loader(self, loader_id):
"""
Load data form data_loader.
If there is something wrong by loading, add logs and delete the loader.
Args:
loader_id (str): An ID for `Loader`.
"""
try:
with self._loader_pool_mutex:
loader = self._loader_pool.get(loader_id, None)
if loader is None:
logger.debug("Loader %r has been deleted, will not load data.", loader_id)
return
loader.data_loader.load()
except MindInsightException as ex:
logger.warning("Data loader %r load data failed. "
"Delete data_loader. Detail: %s", loader_id, ex)
with self._loader_pool_mutex:
self._delete_loader(loader_id)
def start_load_data(self,
reload_interval=settings.RELOAD_INTERVAL,
max_threads_count=MAX_DATA_LOADER_SIZE):
"""
Start threads for loading data.
Args:
reload_interval (int): Time to reload data once.
max_threads_count (int): Max number of threads of execution.
"""
logger.info("Start to load data, reload_interval: %s, "
"max_threads_count: %s.", reload_interval, max_threads_count)
DataManager.check_reload_interval(reload_interval)
DataManager.check_max_threads_count(max_threads_count)
self._reload_interval = reload_interval
self._max_threads_count = max_threads_count
thread = threading.Thread(target=self._reload_data,
name='start_load_data_thread')
thread.daemon = True
thread.start()
def _reload_data(self):
"""This function periodically loads the data."""
# Let gunicorn load other modules first.
time.sleep(1)
while True:
self._load_data()
if not self._reload_interval:
break
time.sleep(self._reload_interval)
def reload_data(self):
"""
Reload the data once.
This function needs to be used after `start_load_data` function.
"""
logger.debug("start to reload data")
thread = threading.Thread(target=self._load_data,
name='reload_data_thread')
thread.daemon = False
thread.start()
def _load_data(self):
"""This function will load data once and ignore it if the status is loading."""
logger.info("Start to load data, reload interval: %r.", self._reload_interval)
with self._status_mutex:
if self.status == DataManagerStatus.LOADING.value:
logger.debug("Current status is %s , will ignore to load data.", self.status)
return
self.status = DataManagerStatus.LOADING.value
self._generate_loaders()
self._execute_load_data()
if not self._loader_pool:
self.status = DataManagerStatus.INVALID.value
else:
self.status = DataManagerStatus.DONE.value
logger.info("Load event data end, status: %r, and loader pool size is %r.",
self.status, len(self._loader_pool))
def _generate_loaders(self):
"""This function generates the loader from given path."""
loader_dict = {}
for generator in self._loader_generators:
loader_dict.update(generator.generate_loaders(self._loader_pool))
sorted_loaders = sorted(loader_dict.items(), key=lambda loader: loader[1].latest_update_time)
latest_loaders = sorted_loaders[-MAX_DATA_LOADER_SIZE:]
self._deal_loaders(latest_loaders)
def _deal_loaders(self, latest_loaders):
"""
This function determines which loaders to keep or remove or added.
It is based on the given dict of loaders.
Args:
latest_loaders (list[dict]): A list of <loader_id: LoaderStruct>.
"""
with self._loader_pool_mutex:
for loader_id, loader in latest_loaders:
if self._loader_pool.get(loader_id, None) is None:
self._add_loader(loader)
continue
# If this loader was updated manually before,
# its latest_update_time may bigger than update_time in summary.
if self._loader_pool[loader_id].latest_update_time < loader.latest_update_time:
self._update_loader_latest_update_time(loader_id, loader.latest_update_time)
def _execute_load_data(self):
"""Load data through multiple threads."""
threads_count = self._get_threads_count()
if not threads_count:
logger.info("Can not find any valid train log path to load, loader pool is empty.")
return
logger.info("Start to execute load data. threads_count: %s.", threads_count)
with ThreadPoolExecutor(max_workers=threads_count) as executor:
futures = []
loader_pool = self._get_snapshot_loader_pool()
for loader_id in loader_pool:
future = executor.submit(self._execute_loader, loader_id)
futures.append(future)
wait(futures, return_when=ALL_COMPLETED)
@staticmethod
def check_reload_interval(reload_interval):
"""
Check reload interval is valid.
Args:
reload_interval (int): Reload interval >= 0.
"""
if not isinstance(reload_interval, int):
raise ParamValueError("The value of reload interval should be integer.")
if reload_interval < 0:
raise ParamValueError("The value of reload interval should be >= 0.")
@staticmethod
def check_max_threads_count(max_threads_count):
"""
Threads count should be a integer, and should > 0.
Args:
max_threads_count (int), should > 0.
"""
if not isinstance(max_threads_count, int):
raise ParamValueError("The value of max threads count should be integer.")
if max_threads_count <= 0:
raise ParamValueError("The value of max threads count should be > 0.")
def _get_threads_count(self):
"""
Use the maximum number of threads available.
Returns:
int, number of threads.
"""
threads_count = min(self._max_threads_count, len(self._loader_pool))
return threads_count
def get_train_job_by_plugin(self, train_id, plugin_name):
"""
Get a train job by train job id.
If the given train job does not has the given plugin data, the tag list will be empty.
Args:
train_id (str): Get train job info by the given id.
plugin_name (str): Get tags by given plugin.
Returns:
TypedDict('TrainJobEntity', {'id': str, 'name': str, 'tags': List[str]}),
a train job object.
"""
self._check_status_valid()
self._check_train_job_exist(train_id, self._loader_pool)
loader = self._get_loader(train_id)
if loader is None:
logger.warning("No valid summary log in train job %s, "
"or it is not in the cache.", train_id)
return None
name = loader.name
data_loader = loader.data_loader
tags = []
try:
events_data = data_loader.get_events_data()
tags = events_data.list_tags_by_plugin(plugin_name)
except KeyError:
logger.debug("Plugin name %r does not exist "
"in train job %r, and set tags to empty list.", plugin_name, name)
except AttributeError:
logger.debug("Train job %r has been deleted or it has not loaded data, "
"and set tags to empty list.", name)
result = dict(id=train_id, name=name, tags=tags)
return result
def delete_train_job(self, train_id):
"""
Delete train job with a train id.
Args:
train_id (str): ID for train job.
"""
with self._loader_pool_mutex:
self._delete_loader(train_id)
def list_tensors(self, train_id, tag):
"""
List tensors of the given train job and tag.
If the tensor can not find by the given tag, will raise exception.
Args:
train_id (str): ID for train job.
tag (str): The tag name.
Returns:
NamedTuple, the tuple format is `collections.namedtuple('_Tensor', ['wall_time', 'event_step', 'value'])`.
the value will contain the given tag data.
"""
self._check_status_valid()
loader_pool = self._get_snapshot_loader_pool()
if not self._is_loader_in_loader_pool(train_id, loader_pool):
raise ParamValueError("Can not find any data in loader pool about the train job.")
data_loader = loader_pool[train_id].data_loader
events_data = data_loader.get_events_data()
try:
tensors = events_data.tensors(tag)
except KeyError:
error_msg = "Can not find any data in this train job by given tag."
raise ParamValueError(error_msg)
return tensors
def _check_train_job_exist(self, train_id, loader_pool):
"""
Check train job exist, if not exist, will raise exception.
Args:
train_id (str): The given train job id.
loader_pool (dict[str, LoaderStruct]): Refer to self._loader_pool.
Raises:
ParamValueError: Can not found train job in data manager.
"""
is_exist = False
if train_id in loader_pool:
return
for generator in self._loader_generators:
if generator.check_train_job_exist(train_id):
is_exist = True
break
if not is_exist:
raise ParamValueError("Can not find the train job in data manager.")
def _is_loader_in_loader_pool(self, train_id, loader_pool):
"""
Check train job exist, if not exist, return False. Else, return True.
Args:
train_id (str): The given train job id.
loader_pool (dict): See self._loader_pool.
Returns:
bool, if loader in loader pool, return True.
"""
if train_id in loader_pool:
return True
return False
def _get_snapshot_loader_pool(self):
"""
Create a snapshot of data loader pool to avoid concurrent mutation and iteration issues.
Returns:
dict, a copy of `self._loader_pool`.
"""
with self._loader_pool_mutex:
return dict(self._loader_pool)
def _check_status_valid(self):
"""Check if the status is valid to load data."""
if self.status == DataManagerStatus.INIT.value:
raise exceptions.SummaryLogIsLoading("Data is being loaded, "
"current status: %s." % self._status)
def get_single_train_job(self, train_id, manual_update=False):
"""
Get train job by train ID.
Args:
train_id (str): Train ID for train job.
manual_update (bool): If manual update, True.
Returns:
dict, single train job, if can not find any data, will return None.
"""
self._check_status_valid()
self._check_train_job_exist(train_id, self._loader_pool)
loader = self._get_loader(train_id, manual_update)
if loader is None:
logger.warning("No valid summary log in train job %s, "
"or it is not in the cache.", train_id)
return None
train_job = loader.to_dict()
train_job.pop('data_loader')
plugin_data = {}
for plugin_name in PluginNameEnum.list_members():
job = self.get_train_job_by_plugin(train_id, plugin_name=plugin_name)
if job is None:
plugin_data[plugin_name] = []
else:
plugin_data[plugin_name] = job['tags']
train_job.update({'tag_mapping': plugin_data})
return train_job
def _get_loader(self, train_id, manual_update=False):
"""
Get loader by train id.
Args:
train_id (str): Train Id.
manual_update (bool): If manual, True. Else False.
Returns:
LoaderStruct, the loader.
"""
loader = None
is_reload = False
with self._loader_pool_mutex:
if self._is_loader_in_loader_pool(train_id, self._loader_pool):
loader = self._loader_pool.get(train_id)
if manual_update and loader is None:
for generator in self._loader_generators:
tmp_loader = generator.generate_loader_by_train_id(train_id)
if loader and loader.latest_update_time > tmp_loader.latest_update_time:
continue
loader = tmp_loader
if loader is None:
return None
self._add_loader(loader)
is_reload = True
if manual_update:
self._update_loader_latest_update_time(loader.loader_id)
if is_reload:
self.reload_data()
return loader
def _update_loader_latest_update_time(self, loader_id, latest_update_time=None):
"""
Update loader with latest_update_time.
Args:
loader_id (str): ID of loader.
latest_update_time (float): Timestamp.
"""
if latest_update_time is None:
latest_update_time = time.time()
self._loader_pool[loader_id].latest_update_time = latest_update_time
@property
def status(self):
"""
Get the status of data manager.
Returns:
DataManagerStatus, the status of data manager.
"""
return self._status
@status.setter
def status(self, status):
"""Set data manger status."""
self._status = status
_loader_generators = [DataLoaderGenerator(settings.SUMMARY_BASE_DIR)]
DATA_MANAGER = DataManager(_loader_generators)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Takes a generator of values, and collects them for a frontend."""
import collections
import threading
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import reservoir
from mindinsight.conf import settings
# Type of the tensor event from external component
_Tensor = collections.namedtuple('_Tensor', ['wall_time', 'step', 'value'])
TensorEvent = collections.namedtuple(
'TensorEvent', ['wall_time', 'step', 'tag', 'plugin_name', 'value'])
# config for `EventsData`
_DEFAULT_STEP_SIZES_PER_TAG = settings.DEFAULT_STEP_SIZES_PER_TAG
CONFIG = {
'max_total_tag_sizes': settings.MAX_TAG_SIZE_PER_EVENTS_DATA,
'max_tag_sizes_per_plugin':
{
PluginNameEnum.GRAPH.value: settings.MAX_GRAPH_TAG_SIZE,
},
'max_step_sizes_per_tag':
{
PluginNameEnum.SCALAR.value: settings.MAX_SCALAR_STEP_SIZE_PER_TAG,
PluginNameEnum.IMAGE.value: settings.MAX_IMAGE_STEP_SIZE_PER_TAG,
PluginNameEnum.GRAPH.value: settings.MAX_GRAPH_STEP_SIZE_PER_TAG,
}
}
class EventsData:
"""
EventsData is an event data manager.
It manages the log events generated during a training process.
The log event records information such as graph, tag, and tensor.
Data such as tensor can be retrieved based on its tag.
"""
def __init__(self):
self._config = CONFIG
self._max_step_sizes_per_tag = self._config['max_step_sizes_per_tag']
self._tags = list()
self._reservoir_by_tag = {}
self._reservoir_mutex_lock = threading.Lock()
self._tags_by_plugin = collections.defaultdict(list)
self._tags_by_plugin_mutex_lock = collections.defaultdict(threading.Lock)
def add_tensor_event(self, tensor_event):
"""
Add a new tensor event to the tensors_data.
Args:
tensor_event (TensorEvent): Refer to `TensorEvent` object.
"""
if not isinstance(tensor_event, TensorEvent):
raise TypeError('Expect to get data of type `TensorEvent`.')
tag = tensor_event.tag
plugin_name = tensor_event.plugin_name
if tag not in set(self._tags):
deleted_tag = self._check_tag_out_of_spec(plugin_name)
if deleted_tag is not None:
self.delete_tensor_event(deleted_tag)
self._tags.append(tag)
with self._tags_by_plugin_mutex_lock[plugin_name]:
if tag not in self._tags_by_plugin[plugin_name]:
self._tags_by_plugin[plugin_name].append(tag)
with self._reservoir_mutex_lock:
if tag not in self._reservoir_by_tag:
reservoir_size = self._get_reservoir_size(tensor_event.plugin_name)
self._reservoir_by_tag[tag] = reservoir.Reservoir(reservoir_size)
tensor = _Tensor(wall_time=tensor_event.wall_time,
step=tensor_event.step,
value=tensor_event.value)
if self._is_out_of_order_step(tensor_event.step, tensor_event.tag):
self.purge_reservoir_data(tensor_event.step, self._reservoir_by_tag[tag])
self._reservoir_by_tag[tag].add_sample(tensor)
def delete_tensor_event(self, tag):
"""
This function will delete tensor event by the given tag in memory record.
Args:
tag (str): The tag name.
"""
self._tags.remove(tag)
for plugin_name, lock in self._tags_by_plugin_mutex_lock.items():
with lock:
if tag in self._tags_by_plugin[plugin_name]:
self._tags_by_plugin[plugin_name].remove(tag)
break
with self._reservoir_mutex_lock:
if tag in self._reservoir_by_tag:
self._reservoir_by_tag.pop(tag)
def list_tags_by_plugin(self, plugin_name):
"""
Return all the tag names of the plugin.
Args:
plugin_name (str): The Plugin name.
Returns:
list[str], tags of the plugin.
Raises:
KeyError: when plugin name could not be found.
"""
if plugin_name not in self._tags_by_plugin:
raise KeyError('Plugin %r could not be found.' % plugin_name)
with self._tags_by_plugin_mutex_lock[plugin_name]:
# Return a snapshot to avoid concurrent mutation and iteration issues.
return list(self._tags_by_plugin[plugin_name])
def tensors(self, tag):
"""
Return all tensors of the tag.
Args:
tag (str): The tag name.
Returns:
list[_Tensor], the list of tensors to the tag.
"""
if tag not in self._reservoir_by_tag:
raise KeyError('TAG %r could not be found.' % tag)
return self._reservoir_by_tag[tag].samples()
def _is_out_of_order_step(self, step, tag):
"""
If the current step is smaller than the latest one, it is out-of-order step.
Args:
step (int): Check if the given step out of order.
tag (str): The checked tensor of the given tag.
Returns:
bool, boolean value.
"""
if self.tensors(tag):
tensors = self.tensors(tag)
last_step = tensors[-1].step
if step <= last_step:
return True
return False
@staticmethod
def purge_reservoir_data(start_step, tensor_reservoir):
"""
Purge all tensor event that are out-of-order step after the given start step.
Args:
start_step (int): Urge start step. All previously seen events with
a greater or equal to step will be purged.
tensor_reservoir (Reservoir): A `Reservoir` object.
Returns:
int, the number of items removed.
"""
cnt_out_of_order = tensor_reservoir.remove_sample(lambda x: x.step < start_step)
return cnt_out_of_order
def _get_reservoir_size(self, plugin_name):
max_step_sizes_per_tag = self._config['max_step_sizes_per_tag']
return max_step_sizes_per_tag.get(plugin_name, _DEFAULT_STEP_SIZES_PER_TAG)
def _check_tag_out_of_spec(self, plugin_name):
"""
Check whether the tag is out of specification.
Args:
plugin_name (str): The given plugin name.
Returns:
Union[str, None], if out of specification, will return the first tag, else return None.
"""
tag_specifications = self._config['max_tag_sizes_per_plugin'].get(plugin_name)
if tag_specifications is not None and len(self._tags_by_plugin[plugin_name]) >= tag_specifications:
deleted_tag = self._tags_by_plugin[plugin_name][0]
return deleted_tag
if len(self._tags) >= self._config['max_total_tag_sizes']:
deleted_tag = self._tags[0]
return deleted_tag
return None
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""This file is used to define the graph."""
from .msgraph import MSGraph
from .node import NodeTypeEnum
__all__ = ['MSGraph', 'NodeTypeEnum']
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
This file is used to define the basic graph.
"""
import copy
import time
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common import exceptions
from .node import NodeTypeEnum
from .node import Node
class EdgeTypeEnum:
"""Node edge type enum."""
control = 'control'
data = 'data'
class DataTypeEnum:
"""Data type enum."""
DT_TENSOR = 13
class Graph:
"""The `Graph` object is used to describe a graph file."""
MIN_POLYMERIC_NODE_COUNT = 5
def __init__(self):
# Store nodes contain leaf nodes, name scope node, except polymeric nodes
self._normal_nodes = {}
# Store polymeric nodes.
self._polymeric_nodes = {}
# Store all nodes resolved from the file.
self._leaf_nodes = {}
# The format of node groups is {'group_name': {'node_name': <Node>}}
self._node_groups = {}
def exist_node(self, name):
"""
Check node exist in graph.
Args:
name (str): The node name.
Returns:
bool, if node is exist will return True.
"""
if self._normal_nodes.get(name) is None:
return False
return True
def get_normal_nodes(self, namescope=None):
"""
Get nodes by namescope.
Args:
namescope (str): A namescope of nodes.
Returns:
list[dict], a list object contain `Node` object.
"""
nodes = []
if namescope is None:
for name, node in self._normal_nodes.items():
if '/' not in name:
# Get first layer nodes
nodes.append(node.to_dict())
return nodes
namescope = namescope + '/'
for name, node in self._normal_nodes.items():
if name.startswith(namescope) and '/' not in name.split(namescope)[1]:
nodes.append(node.to_dict())
return nodes
def get_polymeric_nodes(self, polymeric_scope):
"""
Get polymeric nodes by polymeric scope.
Args:
polymeric_scope (str): The polymeric scope name of nodes.
Returns:
list[dict], a list object contain `Node` object.
"""
nodes = []
for node in self._polymeric_nodes.values():
if node.polymeric_scope_name == polymeric_scope:
nodes.append(node.to_dict())
return nodes
def search_node_names(self, content, offset, limit):
"""
Search node names by content.
Args:
content (Union[str, None]): This content can be the key content of the node to search,
if None, will get all node names.
offset (int): An offset for page. Ex, offset is 0, mean current page is 1.
limit (int): An offset for page. Ex, offset is 0, mean current page is 1.
Returns:
list[str], a list of node names.
"""
all_names = []
all_names.extend(list(self._normal_nodes.keys()))
all_names.extend(list(self._polymeric_nodes.keys()))
if content is not None:
content = content.lower()
catch_names = [name for name in all_names if content in name.lower()]
else:
catch_names = all_names
catch_names = sorted(catch_names)
real_offset = offset * limit
return catch_names[real_offset:real_offset+limit]
def search_single_node(self, node_name):
"""
Search node, and return every layer nodes until this node.
Args:
node_name (str): The name of node.
Returns:
dict, a dict object, format is :
item_object = {'nodes': [<Node object>],
'scope_name': '<Node scope>',
'children': {<item_object>}}
"""
if node_name and self._polymeric_nodes.get(node_name) is None \
and self._normal_nodes.get(node_name) is None:
raise exceptions.NodeNotInGraphError()
response = {}
nodes = self.get_normal_nodes()
response.update({
'nodes': nodes,
'scope_name': '',
'children': {}
})
names = node_name.split('/')
children = response['children']
for i in range(1, len(names)+1):
if i == len(names):
polymeric_node = self._polymeric_nodes.get(node_name)
if polymeric_node:
polymeric_scope = polymeric_node.polymeric_scope_name
nodes = self.get_polymeric_nodes(polymeric_scope)
children.update({'nodes': nodes,
'scope_name': polymeric_scope,
'children': {}})
break
name_scope = '/'.join(names[:i])
nodes = self.get_normal_nodes(name_scope)
children.update({
'nodes': nodes,
'scope_name': name_scope,
'children': {}
})
children = children['children']
return response
def _build_polymeric_nodes(self):
"""Build polymeric node."""
logger.debug("Start to build polymeric nodes")
self._find_polymeric_nodes()
group_count_map = {}
for group_name, group in self._node_groups.items():
name = group_name.split('/')[-1]
count = group_count_map.get(name, 0)
count += 1
group_count_map[name] = count
polymeric_node_name = group_name + '_{}_[{}]'.format(count, len(group))
polymeric_node = Node(polymeric_node_name, node_id=polymeric_node_name)
polymeric_node.node_type = NodeTypeEnum.POLYMERIC_SCOPE.value
polymeric_node.name_scope = '/'.join(group_name.split('/')[:-1])
polymeric_node.subnode_count = len(group)
for name_tmp, node_tmp in group.items():
node_tmp.polymeric_scope_name = polymeric_node_name
self._polymeric_nodes.update({name_tmp: node_tmp})
polymeric_node.update_input(node_tmp.input)
polymeric_node.update_output(node_tmp.output)
self._normal_nodes.update({polymeric_node_name: polymeric_node})
self._update_input_output()
def _find_polymeric_nodes(self):
"""Find polymeric nodes from node groups."""
node_groups = copy.deepcopy(self._node_groups)
for group_name, group in node_groups.items():
if len(group) < self.MIN_POLYMERIC_NODE_COUNT:
self._normal_nodes.update(group)
self._node_groups.pop(group_name)
continue
move_node_names = []
is_move_group = False
for node_name, group_node in group.items():
node_list = []
is_in_group = False
for dst_name in group_node.output:
node_tmp = self._leaf_nodes[dst_name]
node_list.append(node_tmp)
start = time.time()
run_count = 0
visit_nodes = {}
while node_list:
# Iterate to find if the output of the node in the group causes a loop
# example: there is a group A, and node_a is a Node in group.
# if there is a loop in node_a, like A/node_a -> B/node_b -> A/node_b
# we will remove the node_a from group A.
node_tmp = node_list[0]
node_list = node_list[1:]
visit_nodes.update({node_tmp.name: True})
if node_tmp in group.values():
is_in_group = True
break
for dst_name_tmp in node_tmp.output:
run_count += 1
node_tmp = self._leaf_nodes[dst_name_tmp]
if visit_nodes.get(dst_name_tmp):
continue
node_list.append(node_tmp)
logger.debug("Find group %s node end, is_in_group: %s, use time: %s, "
"run count: %s.", group_name, is_in_group,
time.time() - start, run_count)
if is_in_group:
move_node_names.append(node_name)
if (len(group) - len(move_node_names)) < self.MIN_POLYMERIC_NODE_COUNT:
is_move_group = True
break
if is_move_group:
self._normal_nodes.update(group)
self._node_groups.pop(group_name)
else:
for name_tmp in move_node_names:
node_tmp = self._node_groups[group_name].pop(name_tmp)
self._normal_nodes.update({name_tmp: node_tmp})
def _update_input_output(self):
"""We need to update input and output attribute after build polymeric node."""
for node in self._normal_nodes.values():
for src_name, input_attr in node.input.items():
if self._polymeric_nodes.get(src_name):
input_attr['scope'] = NodeTypeEnum.POLYMERIC_SCOPE.value
node.update_input({src_name: input_attr})
for dst_name, output_attr in node.output.items():
if self._polymeric_nodes.get(dst_name):
output_attr['scope'] = NodeTypeEnum.POLYMERIC_SCOPE.value
node.update_output({dst_name: output_attr})
for node in self._polymeric_nodes.values():
for src_name, input_attr in node.input.items():
if self._polymeric_nodes.get(src_name):
input_attr['scope'] = NodeTypeEnum.POLYMERIC_SCOPE.value
node.update_input({src_name: input_attr})
for dst_name, output_attr in node.output.items():
if self._polymeric_nodes.get(dst_name):
output_attr['scope'] = NodeTypeEnum.POLYMERIC_SCOPE.value
node.update_output({dst_name: output_attr})
def _calc_polymeric_input_output(self):
"""Calc polymeric input and output after build polymeric node."""
for name, node in self._normal_nodes.items():
polymeric_input = {}
for src_name in node.input:
src_node = self._polymeric_nodes.get(src_name)
if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value:
src_name = src_name if not src_node else src_node.polymeric_scope_name
output_name = self._calc_dummy_node_name(name, src_name)
polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.data}})
continue
if not src_node:
continue
if not node.name_scope and src_node.name_scope:
# if current node is in first layer, and the src node is not in
# the first layer, the src node will not be the polymeric input of current node.
continue
if node.name_scope == src_node.name_scope \
or node.name_scope.startswith(src_node.name_scope):
polymeric_input.update(
{src_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.data}})
node.update_polymeric_input(polymeric_input)
polymeric_output = {}
for dst_name in node.output:
dst_node = self._polymeric_nodes.get(dst_name)
if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value:
dst_name = dst_name if not dst_node else dst_node.polymeric_scope_name
output_name = self._calc_dummy_node_name(name, dst_name)
polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.data}})
continue
if not dst_node:
continue
if not node.name_scope and dst_node.name_scope:
continue
if node.name_scope == dst_node.name_scope \
or node.name_scope.startswith(dst_node.name_scope):
polymeric_output.update(
{dst_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.data}})
node.update_polymeric_output(polymeric_output)
for name, node in self._polymeric_nodes.items():
polymeric_input = {}
for src_name in node.input:
output_name = self._calc_dummy_node_name(name, src_name)
polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.data}})
node.update_polymeric_input(polymeric_input)
polymeric_output = {}
for dst_name in node.output:
polymeric_output = {}
output_name = self._calc_dummy_node_name(name, dst_name)
polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.data}})
node.update_polymeric_output(polymeric_output)
def _calc_dummy_node_name(self, current_node_name, other_node_name):
"""
Calc dummy node name.
Args:
current_node_name (str): The name of current node.
other_node_name (str): The target dummy node name.
Returns:
str, the dummy node name.
"""
name_tmp = other_node_name
if self._polymeric_nodes.get(other_node_name):
name_tmp = self._polymeric_nodes[other_node_name].polymeric_scope_name
name_tmp_list = name_tmp.split('/')
current_name_list = current_node_name.split('/')
index = 0
min_len = min(len(name_tmp_list), len(current_name_list))
for i in range(min_len):
index = i
if name_tmp_list[index] != current_name_list[index]:
break
dummy_node_name = '/'.join(name_tmp_list[:index+1])
return dummy_node_name
def _build_name_scope_nodes(self):
"""Build name scope node by every node name."""
normal_nodes = dict(self._normal_nodes)
rename_node_names = {}
for name, node in normal_nodes.items():
name_list = name.split('/')
for i in range(1, len(name_list)):
name_scope = '/'.join(name_list[:i])
name_scope_node = self._normal_nodes.get(name_scope)
if name_scope_node is None:
name_scope_node = Node(name_scope, node_id=name_scope)
name_scope_node.node_type = NodeTypeEnum.NAME_SCOPE.value
name_scope_node.name_scope = '/'.join(name_list[:i-1])
elif name_scope_node.node_type != NodeTypeEnum.NAME_SCOPE.value:
# The name of this node conflicts with namescope, so rename this node
old_name = name_scope_node.name
old_names = name_scope_node.name.split('/')
old_names[-1] = f'({old_names[-1]})'
new_name = '/'.join(old_names)
name_scope_node.name = new_name
self._normal_nodes.pop(old_name)
self._normal_nodes.update({new_name: name_scope_node})
rename_node_names.update({old_name: new_name})
# create new namescope
name_scope_node = Node(name_scope, node_id=name_scope)
name_scope_node.node_type = NodeTypeEnum.NAME_SCOPE.value
name_scope_node.name_scope = '/'.join(name_list[:i-1])
# update the input and output of this to namescope node
name_scope_with_slash = name_scope + '/'
for src_name, input_attr in node.input.items():
if src_name.startswith(name_scope_with_slash):
continue
name_scope_node.update_input({src_name: input_attr})
for dst_name, output_attr in node.output.items():
if dst_name.startswith(name_scope_with_slash):
continue
name_scope_node.update_output({dst_name: output_attr})
self._normal_nodes.update({name_scope: name_scope_node})
if rename_node_names:
# If existing nodes are renamed, the inputs and outputs of all nodes need to be refreshed
nodes = []
nodes.extend(self._normal_nodes.values())
nodes.extend(self._polymeric_nodes.values())
for node in nodes:
attrs = ['input', 'output', 'polymeric_input', 'polymeric_output']
for item in attrs:
tmp_dict = dict(getattr(node, item))
for name, value in tmp_dict.items():
new_name = rename_node_names.get(name, False)
if new_name:
getattr(node, item).pop(name)
getattr(node, f'update_{item}')({new_name: value})
self._calc_subnode_count()
def _calc_subnode_count(self):
"""Calc the sub node count of scope node."""
name_scope_mapping = {}
for node in self._normal_nodes.values():
if node.name_scope:
count = name_scope_mapping.get(node.name_scope, 0)
name_scope_mapping[node.name_scope] = count + 1
for name_scope, count in name_scope_mapping.items():
node = self._normal_nodes[name_scope]
node.subnode_count = count
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""This file is used to define the MindSpore graph."""
import re
import copy
from mindinsight.datavisual.common.log import logger
from .node import Node
from .node import NodeTypeEnum
from .graph import Graph
from .graph import EdgeTypeEnum
from .graph import DataTypeEnum
class MSGraph(Graph):
"""The object describes the MindSpore graph, and it is defined in the anf_if proto file."""
def build_graph(self, graph_proto):
"""
Build graph by graph proto which refer to `anf_ir_pb2.GraphProto`, and set status to loading.
Args:
graph_proto (anf_ir_pb2.GraphProto): Refer to `anf_ir_pb2.GraphProto`.
"""
logger.info("Start to build graph.")
self._build_leaf_nodes(graph_proto)
self._build_polymeric_nodes()
self._build_name_scope_nodes()
self._calc_polymeric_input_output()
logger.info("Build graph end, normal node count: %s, polymeric node "
"count: %s.", len(self._normal_nodes), len(self._polymeric_nodes))
def _build_leaf_nodes(self, graph_proto):
"""
Build leaf node from graph proto.
Left node will contain operation node, parameter node, const node.
Args:
graph_proto (anf_ir_pb2.model_proto.graph): Refer to anf_ir_pb2.model_proto.graph.
"""
logger.info("Start to build leaf nodes.")
leaf_node_id_map_name = {}
const_nodes_map = {}
for node_def in graph_proto.node:
node = self._parse_graph_proto_node(node_def)
leaf_node_id_map_name.update({node.node_id: node.name})
for parameter in graph_proto.parameters:
node = self._parse_graph_proto_parameter(parameter)
const_nodes_map.update({node.name: node})
for i, const in enumerate(graph_proto.const_vals):
node_id = 'const_{}'.format(i)
node = self._parse_graph_proto_const(const, node_id)
const_nodes_map.update({const.key: node})
self._calc_input(leaf_node_id_map_name, graph_proto, const_nodes_map)
self._calc_output()
logger.info("Build leaf nodes end, normal nodes count: %s, group count: %s, "
"left node count: %s.", len(self._normal_nodes), len(self._node_groups),
len(self._leaf_nodes))
def _calc_input(self, leaf_node_id_map_name, graph_proto, const_nodes_map):
"""
Calc input for every leaf node.
Args:
leaf_node_id_map_name (dict[str, str]): Format is {'node_id': 'node_name'}.
graph_proto (anf_ir_pb2.model_proto.graph): See anf_ir_pb2.model_proto.graph.
const_nodes_map (dict[str, Node]): Format is {'node name': <Const node>}.
"""
logger.debug("Start to calc input.")
for node_def in graph_proto.node:
node_name = leaf_node_id_map_name[node_def.name]
node = self._leaf_nodes[node_name]
for input_def in node_def.input:
edge_type = EdgeTypeEnum.data
if input_def.type == "CONTROL_EDGE":
edge_type = EdgeTypeEnum.control
if const_nodes_map.get(input_def.name):
const_node = copy.deepcopy(const_nodes_map[input_def.name])
src_name = '{}/{}'.format(node.name_scope, input_def.name)
if not self._normal_nodes.get(src_name):
const_node.name = src_name
const_node.name_scope = node.name_scope
self._normal_nodes.update({src_name: const_node})
self._leaf_nodes.update({src_name: const_node})
src_node = self._leaf_nodes.get(src_name)
else:
src_name = leaf_node_id_map_name.get(input_def.name)
if not src_name:
logger.warning("The input_def name '%s' in node '%s' is invalid, "
"will be ignore.", input_def.name, node_name)
continue
src_node = self._leaf_nodes.get(src_name)
if src_node is None:
logger.warning("The input '%s' in node '%s' is not in "
"leaf nodes.", src_name, node_name)
continue
input_item = {
src_name: {
"shape": src_node.shape,
"edge_type": edge_type,
"scope": NodeTypeEnum.NAME_SCOPE.value
}
}
node.update_input(input_item)
if self._normal_nodes.get(node_name):
self._normal_nodes[node_name] = node
else:
group_name = self._create_group_name(node.name_scope, node.node_type, node.name)
self._node_groups[group_name][node.name] = node
def _calc_output(self):
"""Calc output of every node."""
logger.debug("Start to calc output.")
for name, node in self._leaf_nodes.items():
if node.node_type == NodeTypeEnum.CONST.value:
continue
for src_name, input_attr in node.input.items():
src_node = self._leaf_nodes[src_name]
if src_node.node_type == NodeTypeEnum.CONST.value:
continue
if self._normal_nodes.get(src_name):
self._normal_nodes[src_name].update_output({name: input_attr})
else:
group_name = self._create_group_name(src_node.name_scope,
src_node.node_type, src_node.name)
self._node_groups[group_name][src_name].update_output({name: input_attr})
def _parse_graph_proto_node(self, node_def):
"""
Parse `anf_ir_pb2.model_proto.graph.node_def`, and create a a node.
Args:
node_def (anf_ir_pb2.model_proto.graph.node_def): Refer to anf_ir_pb2.model_proto.graph.node_def.
Returns:
Node, a `Node` object.
"""
node_name = '/'.join([node_def.scope, node_def.op_type])+node_def.name
node = Node(name=node_name, node_id=node_def.name)
node.node_type = node_def.op_type
logger.debug("Foreach graph proto nodes, node id: %s, node name: %s, node def name: %s, "
"input count: %s", node.node_id, node.name, node_def.name, len(node_def.input))
for attr in node_def.attribute:
node.update_attr({attr.name: str(attr.value)})
node.output_i = node_def.output_i
node.name_scope = node_def.scope
output_type = node_def.output_type
shape = self._parse_type_proto(output_type)
node.shape = shape
self._leaf_nodes.update({node.name: node})
group_name = self._create_group_name(node.name_scope, node.node_type, node.name)
if group_name is not None:
node_dict = self._node_groups.get(group_name, {})
node_dict.update({node.name: node})
self._node_groups.update({group_name: node_dict})
else:
self._normal_nodes.update({node.name: node})
return node
def _parse_graph_proto_parameter(self, parameter):
"""
Parse anf_ir_pb2.model_proto.graph.parameter, and create a parameter node.
Args:
parameter (anf_ir_pb2.model_proto.graph.parameter): Refer to anf_ir_pb2.model_proto.graph.parameter.
Returns:
Node, a `Node` object.
"""
node = Node(name=parameter.name, node_id=parameter.name)
node.node_type = NodeTypeEnum.PARAMETER.value
node.shape = self._parse_type_proto(parameter.type)
logger.debug("Foreach graph proto parameters, node id: %s, node name: %s, "
"node def name: %s", node.node_id, node.name, parameter.name)
return node
def _parse_graph_proto_const(self, const, const_node_id):
"""
Parse anf_ir_pb2.model_proto.graph.const, and create a const node.
Args:
const (anf_ir_pb2.model_proto.graph.const): Refer to anf_ir_pb2.model_proto.graph.const
const_node_id (str): The id of the new const node, it should be unique in graph.
Returns:
Node, a `Node` object.
"""
node = Node(name=const.key, node_id=const_node_id)
node.node_type = NodeTypeEnum.CONST.value
node.update_attr({const.key: str(const.value)})
if const.value.dtype == DataTypeEnum.DT_TENSOR:
shape = []
for dim in const.value.tensor_val.dims:
shape.append(dim)
node.shape = shape
return node
def _parse_type_proto(self, type_proto):
"""
Parse proto's `message TypeProto` to get shape information.
Args:
type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
Returns:
list, a list of shape.
"""
shapes = []
if type_proto.HasField('tensor_type'):
tensor_type = type_proto.tensor_type
tensor_shape_proto = tensor_type.shape
for dim in tensor_shape_proto.dim:
shapes.append(dim.size)
if type_proto.HasField('sequence_type'):
for elem_type in type_proto.sequence_type.elem_types:
shapes.append(self._parse_type_proto(elem_type))
return shapes
def _create_group_name(self, name_scope, node_type, node_name):
"""
Create group name by node name, name scope, node type.
Only nodes that conform to the rules are aggregated.
Args:
name_scope (str): The node name scope.
node_type (str): The node type.
node_name (str): The node name.
Returns:
Optional[str], if match the rules will return a group name, else return None.
"""
group_types = ['Reshape', 'Variable']
pattern_names = r'.*?/Cast-op\d+'
if node_type in group_types:
group_name = name_scope + '/' + node_type if name_scope else node_type
return group_name
if node_type == 'FrameworkOp' and re.search(pattern_names, node_name):
group_name = name_scope + '/' + 'Cast-op' if name_scope else 'Cast-op'
return group_name
return None
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
This file is used to define the node of graph and associated base types.
"""
from enum import Enum
class NodeTypeEnum(Enum):
"""Node type enum. The following types are new to our custom."""
NAME_SCOPE = 'name_scope'
POLYMERIC_SCOPE = 'polymeric_scope'
PARAMETER = 'Parameter'
CONST = 'Const'
class Node:
"""
Define a node object.
Args:
name (str): Name of new node.
node_id (str): The id of this node, and node id is unique in graph.
"""
def __init__(self, name, node_id):
self._node_id = node_id
self._name = name
self._type = ""
self._attr = dict()
self._input = dict()
self._output_i = -1
self._output = {}
self._polymeric_input = {}
self._polymeric_output = {}
self._polymeric_scope_name = ""
self._subnode_count = 0
self._name_scope = ""
self.shape = []
def to_dict(self):
"""Converts the node object to dictionary format."""
return {
'name': self._name,
'type': self._type,
'attr': self._attr,
'input': self._input,
'output_i': self._output_i,
'output': self._output,
'polymeric_input': self._polymeric_input,
'polymeric_output': self._polymeric_output,
'subnode_count': self._subnode_count,
'polymeric_scope_name': self._polymeric_scope_name
}
@property
def node_id(self):
"""The id of this node, and id is unique in graph."""
return self._node_id
@property
def name(self):
"""Get node name."""
return self._name
@name.setter
def name(self, name):
"""Set node name."""
self._name = name
@property
def node_type(self):
"""Get node type."""
return self._type
@node_type.setter
def node_type(self, node_type):
"""Set node type."""
self._type = node_type
@property
def attr(self):
"""Get node attr."""
return self._attr
def update_attr(self, attr_dict):
"""
Update node attr.
Args:
attr_dict (dict[str, str]): Format is {'<key>': '<value>'}.
"""
self._attr.update(attr_dict)
@property
def input(self):
"""
Get all input of current node.
Returns:
dict[str, dict], format is {'<src_name>': {'shape': [], 'edge_type', 'scope'}}.
"""
return self._input
def update_input(self, input_dict):
"""
Update input.
Args:
input_dict (dict[str, dict]): Format is {'<src_name>': {'shape': [], 'edge_type', 'scope'}}.
"""
self._input.update(input_dict)
@property
def output_i(self):
"""The memory address of this node when it is in run time."""
return self._output_i
@output_i.setter
def output_i(self, output_i):
"""Set memory address."""
self._output_i = output_i
@property
def polymeric_input(self):
"""
The polymeric input is the input of the polymeric nodes.
Returns:
dict[str, dict], format is {'<src_name>': {'edge_type': '<value>'}}.
"""
return self._polymeric_input
def update_polymeric_input(self, polymeric_input):
"""The polymeric input is the input of the polymeric nodes."""
self._polymeric_input.update(polymeric_input)
@property
def output(self):
"""The output node of this node."""
return self._output
def update_output(self, output):
"""
Update output node.
Args:
output (dict[str, TypedDict('NodeType', {'type': str})]): Format
is {"<node_name>": {"type": "<node type>"}}.
"""
self._output.update(output)
@property
def polymeric_output(self):
"""Get polymeric output."""
return self._polymeric_output
def update_polymeric_output(self, polymeric_output):
"""
Update polymeric output.
Args:
polymeric_output (dict[str, dict): Format is {dst_node.polymeric_scope_name:
{'edge_type': EdgeTypeEnum.data}}).
"""
self._polymeric_output.update(polymeric_output)
@property
def polymeric_scope_name(self):
"""Get polymeric scope name."""
return self._polymeric_scope_name
@polymeric_scope_name.setter
def polymeric_scope_name(self, name):
"""Set polymeric scope name."""
self._polymeric_scope_name = name
@property
def subnode_count(self):
"""The sub node count of this node, if this node is a scope node, this count will not be zero."""
return self._subnode_count
@subnode_count.setter
def subnode_count(self, count):
"""Set sub node count."""
self._subnode_count = count
@property
def name_scope(self):
"""Get name scope of this node."""
return self._name_scope
@name_scope.setter
def name_scope(self, name_scope):
"""Set name scope."""
self._name_scope = name_scope
def __str__(self):
return f'<Node, name: {self._name}, type: {self._type}>'
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Base loader generator."""
from abc import abstractmethod
MAX_DATA_LOADER_SIZE = 15
class LoaderGenerator:
"""Base loader generator for loader generators."""
@abstractmethod
def generate_loaders(self, loader_pool):
"""
Abstract method for generating loaders.
Args:
loader_pool (dict[str, LoaderStruct]): Current loader pool in data_manager.
Returns:
dict[str, LoaderStruct], a dict of `Loader`.
"""
@abstractmethod
def check_train_job_exist(self, train_id):
"""
Abstract method for checking if train job exists.
Args:
train_id (str): Train ID.
Returns:
bool, if train job exists, return True.
"""
@abstractmethod
def generate_loader_by_train_id(self, train_id):
"""
Abstract method for generating loader by train id.
Args:
train_id (str): Train ID.
Returns:
dict[str, LoaderStruct], a dict of `Loader`.
"""
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Loader struct."""
class LoaderStruct:
"""
Loader to save summary info.
LoaderStruct contains: loader_id, name, path, latest_update_time, status, data_loader.
"""
def __init__(self, loader_id, name, path, latest_update_time, data_loader):
self._loader_id = loader_id
self._name = name
self._path = path
self._latest_update_time = latest_update_time
self._data_loader = data_loader
@property
def loader_id(self):
"""Get loader ID."""
return self._loader_id
@property
def name(self):
"""Get loader name."""
return self._name
@property
def latest_update_time(self):
"""Get the latest update time of loader."""
return self._latest_update_time
@property
def data_loader(self):
"""Get data loader."""
return self._data_loader
@latest_update_time.setter
def latest_update_time(self, latest_update_time):
"""Set the latest update time of loader."""
self._latest_update_time = latest_update_time
def to_dict(self):
"""Transform LoaderStruct to dict."""
return dict(
loader_id=self._loader_id,
name=self._name,
path=self._path,
latest_update_time=self._latest_update_time,
data_loader=self._data_loader
)
此差异已折叠。
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""A reservoir sampling on the values."""
import random
import threading
from mindinsight.utils.exceptions import ParamValueError
class Reservoir:
"""
A container based on Reservoir Sampling algorithm.
The newly added sample will be preserved. If the container is full, an old
sample will be replaced randomly. The probability of each sample being
replaced is the same.
"""
def __init__(self, size):
"""
A Container constructor which create a new Reservoir.
Args:
size (int): Container Size. If the size is 0, the container is not limited.
Raises:
ValueError: If size is negative integer.
"""
if not isinstance(size, (int,)) or size < 0:
raise ParamValueError('size must be nonnegative integer, was %s' % size)
self._samples_max_size = size
self._samples = []
self._sample_counter = 0
self._sample_selector = random.Random(0)
self._mutex = threading.Lock()
def samples(self):
"""Return all stored samples."""
with self._mutex:
return list(self._samples)
def add_sample(self, sample):
"""
Add a sample to Reservoir.
Replace the old sample when the capacity is full.
New added samples are guaranteed to be added to the reservoir.
Args:
sample (Any): The sample to add to the Reservoir.
"""
with self._mutex:
if len(self._samples) < self._samples_max_size or self._samples_max_size == 0:
self._samples.append(sample)
else:
# Use the Reservoir Sampling algorithm to replace the old sample.
rand_int = self._sample_selector.randint(
0, self._sample_counter)
if rand_int < self._samples_max_size:
self._samples.pop(rand_int)
self._samples.append(sample)
else:
self._samples[-1] = sample
self._sample_counter += 1
def remove_sample(self, filter_fun):
"""
Remove the samples from Reservoir that do not meet the filter criteria.
Args:
filter_fun (Callable[..., Any]): Determines whether a sample meets
the deletion condition.
Returns:
int, the number of samples removed.
"""
remove_size = 0
with self._mutex:
before_remove_size = len(self._samples)
if before_remove_size > 0:
# remove samples that meet the filter criteria.
self._samples = list(filter(filter_fun, self._samples))
after_remove_size = len(self._samples)
remove_size = before_remove_size - after_remove_size
if remove_size > 0:
# update _sample_counter when samples has been removed.
sample_remaining_rate = float(
after_remove_size) / before_remove_size
self._sample_counter = int(
round(self._sample_counter * sample_remaining_rate))
return remove_size
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册