未验证 提交 642bfc47 编写于 作者: T Thuan Nguyen 提交者: GitHub

Cpp python style check (#215)

* Add C++ and python style check to travis.  
* Update all C++/python code that violate coding standards.
上级 13f89d70
#!/bin/bash #!/bin/bash
set -e set -e
readonly VERSION="3.8" readonly SUPPORTED_VERSION="3.8"
version=$(clang-format -version) version=$(clang-format -version)
if ! [[ $version == *"$VERSION"* ]]; then if ! [[ $version == *"$SUPPORTED_VERSION"* ]]; then
echo "clang-format version check failed." echo "clang-format version check failed."
echo "a version contains '$VERSION' is needed, but get '$version'" echo "a version contains '$SUPPORTED_VERSION' is needed, but get '$version'"
echo "you can install the right version, and make an soft-link to '\$PATH' env" echo "you can install the right version, and make an soft-link to '\$PATH' env"
exit -1 exit -1
fi fi
......
[flake8]
max-line-length = 120
\ No newline at end of file
...@@ -22,4 +22,14 @@ ...@@ -22,4 +22,14 @@
description: Format files with ClangFormat. description: Format files with ClangFormat.
entry: bash ./.clang_format.hook -i entry: bash ./.clang_format.hook -i
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx)$
- repo: local
hooks:
- id: python-format-checker
name: python-format-checker
description: Format python files using PEP8 standard
entry: flake8
language: system
files: \.(py)$
...@@ -13,6 +13,10 @@ os: ...@@ -13,6 +13,10 @@ os:
# TODO(ChunweiYan) support osx in the future # TODO(ChunweiYan) support osx in the future
#- osx #- osx
env:
- JOB=check_style
- JOB=test
addons: addons:
apt: apt:
packages: packages:
...@@ -29,12 +33,14 @@ addons: ...@@ -29,12 +33,14 @@ addons:
- nodejs - nodejs
before_install: before_install:
- if [[ "$JOB" == "check_style" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; sudo pip install pre-commit flake8; fi
- if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then brew update; fi - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then brew update; fi
- if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then brew upgrade python; fi - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then brew upgrade python; fi
- if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then brew install brew-pip; fi - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then brew install brew-pip; fi
script: script:
/bin/bash ./tests.sh all - if [[ "$JOB" == "check_style" ]]; then ./travis/check_style.sh; fi
- if [[ "$JOB" == "test" ]]; then /bin/bash ./tests.sh all; fi
notifications: notifications:
email: email:
......
# VisualDL demos # VisualDL demos
VisualDL supports Python and C++ based DL frameworks, VisualDL supports Python and C++ based DL frameworks,
there are several demos for different platforms. there are several demos for different platforms.
## PaddlePaddle ## PaddlePaddle
Locates in `./paddle`. Locates in `./paddle`.
This is a visualization for `resnet` on `cifar10` dataset, we visualize the CONV parameters, This is a visualization for `resnet` on `cifar10` dataset, we visualize the CONV parameters,
and there are some interesting patterns. and there are some interesting patterns.
## PyTorch GAN ## PyTorch GAN
Locates in `./pytorch-CycleGAN-and-pix2pix`. Locates in `./pytorch-CycleGAN-and-pix2pix`.
This submodule is forked from [pytorch-CycleGAN-and-pix2pix]( This submodule is forked from [pytorch-CycleGAN-and-pix2pix](
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix), https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix),
great model and the generated fake images are really funny. great model and the generated fake images are really funny.
This demo only works with CycleGAN mode, read [CycleGAN train doc](https://github.com/Superjomn/pytorch-CycleGAN-and-pix2pix#cyclegan-traintest) and [changes to the original code](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/compare/master...Superjomn:master) for more information. This demo only works with CycleGAN mode, read [CycleGAN train doc](https://github.com/Superjomn/pytorch-CycleGAN-and-pix2pix#cyclegan-traintest) and [changes to the original code](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/compare/master...Superjomn:master) for more information.
...@@ -21,7 +21,7 @@ This demo only works with CycleGAN mode, read [CycleGAN train doc](https://githu ...@@ -21,7 +21,7 @@ This demo only works with CycleGAN mode, read [CycleGAN train doc](https://githu
## MxNet Mnist ## MxNet Mnist
Locates in `./mxnet_demo`. Locates in `./mxnet_demo`.
By adding VisualDL as callbacks to `model.fit`, By adding VisualDL as callbacks to `model.fit`,
we can use the Python SDK in MxNet, we can use the Python SDK in MxNet,
but it seems that only the outside program can only retrieve parameters in epoch callbacks, but it seems that only the outside program can only retrieve parameters in epoch callbacks,
that limits the number of steps for visualization. that limits the number of steps for visualization.
import numpy as np
import mxnet as mx
import logging import logging
import mxnet as mx import mxnet as mx
...@@ -10,7 +8,6 @@ from visualdl import LogWriter ...@@ -10,7 +8,6 @@ from visualdl import LogWriter
mnist = mx.test_utils.get_mnist() mnist = mx.test_utils.get_mnist()
batch_size = 100 batch_size = 100
# Provide a folder to store data for log, model, image, etc. VisualDL's visualization will be # Provide a folder to store data for log, model, image, etc. VisualDL's visualization will be
# based on this folder. # based on this folder.
logdir = "./tmp" logdir = "./tmp"
...@@ -44,8 +41,10 @@ def add_scalar(): ...@@ -44,8 +41,10 @@ def add_scalar():
for name, value in name_value: for name, value in name_value:
scalar0.add_record(cnt_step, value) scalar0.add_record(cnt_step, value)
cnt_step += 1 cnt_step += 1
return _callback return _callback
def add_image_histogram(): def add_image_histogram():
def _callback(iter_no, sym, arg, aux): def _callback(iter_no, sym, arg, aux):
image0.start_sampling() image0.start_sampling()
...@@ -57,6 +56,7 @@ def add_image_histogram(): ...@@ -57,6 +56,7 @@ def add_image_histogram():
histogram0.add_record(iter_no, list(data)) histogram0.add_record(iter_no, list(data))
image0.finish_sampling() image0.finish_sampling()
return _callback return _callback
...@@ -65,18 +65,22 @@ def add_image_histogram(): ...@@ -65,18 +65,22 @@ def add_image_histogram():
logging.getLogger().setLevel(logging.DEBUG) # logging to stdout logging.getLogger().setLevel(logging.DEBUG) # logging to stdout
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True) train_iter = mx.io.NDArrayIter(
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'],
batch_size)
data = mx.sym.var('data') data = mx.sym.var('data')
# first conv layer # first conv layer
conv1 = mx.sym.Convolution(data=data, kernel=(5, 5), num_filter=20) conv1 = mx.sym.Convolution(data=data, kernel=(5, 5), num_filter=20)
tanh1 = mx.sym.Activation(data=conv1, act_type="tanh") tanh1 = mx.sym.Activation(data=conv1, act_type="tanh")
pool1 = mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2, 2), stride=(2, 2)) pool1 = mx.sym.Pooling(
data=tanh1, pool_type="max", kernel=(2, 2), stride=(2, 2))
# second conv layer # second conv layer
conv2 = mx.sym.Convolution(data=pool1, kernel=(5, 5), num_filter=50) conv2 = mx.sym.Convolution(data=pool1, kernel=(5, 5), num_filter=50)
tanh2 = mx.sym.Activation(data=conv2, act_type="tanh") tanh2 = mx.sym.Activation(data=conv2, act_type="tanh")
pool2 = mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2, 2), stride=(2, 2)) pool2 = mx.sym.Pooling(
data=tanh2, pool_type="max", kernel=(2, 2), stride=(2, 2))
# first fullc layer # first fullc layer
flatten = mx.sym.flatten(data=pool2) flatten = mx.sym.flatten(data=pool2)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500) fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)
...@@ -89,21 +93,22 @@ lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax') ...@@ -89,21 +93,22 @@ lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax')
# create a trainable module on CPU # create a trainable module on CPU
lenet_model = mx.mod.Module(symbol=lenet, context=mx.cpu()) lenet_model = mx.mod.Module(symbol=lenet, context=mx.cpu())
# train with the same # train with the same
lenet_model.fit(train_iter, lenet_model.fit(
eval_data=val_iter, train_iter,
optimizer='sgd', eval_data=val_iter,
optimizer_params={'learning_rate': 0.1}, optimizer='sgd',
eval_metric='acc', optimizer_params={'learning_rate': 0.1},
# integrate our customized callback method eval_metric='acc',
batch_end_callback=[add_scalar()], # integrate our customized callback method
epoch_end_callback=[add_image_histogram()], batch_end_callback=[add_scalar()],
num_epoch=5) epoch_end_callback=[add_image_histogram()],
num_epoch=5)
test_iter = mx.io.NDArrayIter(mnist['test_data'], None, batch_size) test_iter = mx.io.NDArrayIter(mnist['test_data'], None, batch_size)
prob = lenet_model.predict(test_iter) prob = lenet_model.predict(test_iter)
test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'],
batch_size)
# predict accuracy for lenet # predict accuracy for lenet
acc = mx.metric.Accuracy() acc = mx.metric.Accuracy()
......
...@@ -117,8 +117,11 @@ elif net_type == "resnet": ...@@ -117,8 +117,11 @@ elif net_type == "resnet":
else: else:
raise ValueError("%s network is not supported" % net_type) raise ValueError("%s network is not supported" % net_type)
predict = fluid.layers.fc(input=net, size=classdim, act='softmax', predict = fluid.layers.fc(
param_attr=ParamAttr(name="param1", initializer=NormalInitializer())) input=net,
size=classdim,
act='softmax',
param_attr=ParamAttr(name="param1", initializer=NormalInitializer()))
cost = fluid.layers.cross_entropy(input=predict, label=label) cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
...@@ -131,8 +134,7 @@ BATCH_SIZE = 16 ...@@ -131,8 +134,7 @@ BATCH_SIZE = 16
PASS_NUM = 1 PASS_NUM = 1
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(paddle.dataset.cifar.train10(), buf_size=128 * 10),
paddle.dataset.cifar.train10(), buf_size=128 * 10),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
place = fluid.CPUPlace() place = fluid.CPUPlace()
...@@ -150,9 +152,10 @@ param1_var = start_up_program.global_block().var("param1") ...@@ -150,9 +152,10 @@ param1_var = start_up_program.global_block().var("param1")
for pass_id in range(PASS_NUM): for pass_id in range(PASS_NUM):
accuracy.reset(exe) accuracy.reset(exe)
for data in train_reader(): for data in train_reader():
loss, conv1_out, param1, acc = exe.run(fluid.default_main_program(), loss, conv1_out, param1, acc = exe.run(
feed=feeder.feed(data), fluid.default_main_program(),
fetch_list=[avg_cost, conv1, param1_var] + accuracy.metrics) feed=feeder.feed(data),
fetch_list=[avg_cost, conv1, param1_var] + accuracy.metrics)
pass_acc = accuracy.eval(exe) pass_acc = accuracy.eval(exe)
if sample_num == 0: if sample_num == 0:
...@@ -165,11 +168,14 @@ for pass_id in range(PASS_NUM): ...@@ -165,11 +168,14 @@ for pass_id in range(PASS_NUM):
idx = idx1 idx = idx1
if idx != -1: if idx != -1:
image_data = data[0][0] image_data = data[0][0]
input_image_data = np.transpose(image_data.reshape(data_shape), axes=[1, 2, 0]) input_image_data = np.transpose(
input_image.set_sample(idx, input_image_data.shape, input_image_data.flatten()) image_data.reshape(data_shape), axes=[1, 2, 0])
input_image.set_sample(idx, input_image_data.shape,
input_image_data.flatten())
conv_image_data = conv1_out[0][0] conv_image_data = conv1_out[0][0]
conv_image.set_sample(idx, conv_image_data.shape, conv_image_data.flatten()) conv_image.set_sample(idx, conv_image_data.shape,
conv_image_data.flatten())
sample_num += 1 sample_num += 1
if sample_num % num_samples == 0: if sample_num % num_samples == 0:
......
#!/user/bin/env python #!/user/bin/env python
import math
import os import os
import random import random
import subprocess
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from scipy.stats import norm
from visualdl import ROOT, LogWriter from visualdl import ROOT, LogWriter
from visualdl.server.log import logger as log from visualdl.server.log import logger as log
...@@ -44,10 +40,10 @@ with logw.mode('train') as logger: ...@@ -44,10 +40,10 @@ with logw.mode('train') as logger:
for step in range(1, 50): for step in range(1, 50):
histogram0.add_record(step, histogram0.add_record(step,
np.random.normal( np.random.normal(
0.1 + step * 0.003, 0.1 + step * 0.003,
200. / (120 + step), 200. / (120 + step),
size=1000)) size=1000))
# create image # create image
with logw.mode("train") as logger: with logw.mode("train") as logger:
image = logger.image("scratch/dog", 4) # randomly sample 4 images one pass image = logger.image("scratch/dog", 4) # randomly sample 4 images one pass
...@@ -70,11 +66,10 @@ with logw.mode("train") as logger: ...@@ -70,11 +66,10 @@ with logw.mode("train") as logger:
# a more efficient way to sample images # a more efficient way to sample images
# check whether this image will be taken by reservoir sampling # check whether this image will be taken by reservoir sampling
idx = image.is_sample_taken() idx = image.is_sample_taken()
if idx >= 0: if idx >= 0:
data = np.array( data = np.array(
dog_jpg.crop((left_x, left_y, right_x, dog_jpg.crop((left_x, left_y, right_x, right_y))).flatten()
right_y))).flatten()
# add this image to log # add this image to log
image.set_sample(idx, target_shape, data) image.set_sample(idx, target_shape, data)
# you can also just write followig codes, it is more clear, but need to # you can also just write followig codes, it is more clear, but need to
...@@ -95,6 +90,7 @@ with logw.mode("train") as logger: ...@@ -95,6 +90,7 @@ with logw.mode("train") as logger:
image0.add_sample(shape, list(data)) image0.add_sample(shape, list(data))
image0.finish_sampling() image0.finish_sampling()
def download_graph_image(): def download_graph_image():
''' '''
This is a scratch demo, it do not generate a ONNX proto, but just download an image This is a scratch demo, it do not generate a ONNX proto, but just download an image
...@@ -110,4 +106,5 @@ def download_graph_image(): ...@@ -110,4 +106,5 @@ def download_graph_image():
f.write(graph_image) f.write(graph_image)
log.warning('graph ready!') log.warning('graph ready!')
download_graph_image() download_graph_image()
...@@ -9,7 +9,7 @@ Most of the DNN platforms are using Python. VisualDL supports Python out of the ...@@ -9,7 +9,7 @@ Most of the DNN platforms are using Python. VisualDL supports Python out of the
By just adding a few lines of configuration to the code, VisualDL can provide a rich visual support for the training process. By just adding a few lines of configuration to the code, VisualDL can provide a rich visual support for the training process.
In addition to Python SDK, the underlying VisualDL is written in C++, and its exposed C++ SDK can be integrated into other platforms. In addition to Python SDK, the underlying VisualDL is written in C++, and its exposed C++ SDK can be integrated into other platforms.
Users can access the original features and monitor customized matrix. Users can access the original features and monitor customized matrix.
## Components ## Components
VisualDL supports four componments: VisualDL supports four componments:
...@@ -27,7 +27,7 @@ Compatible with ONNX (Open Neural Network Exchange) [https://github.com/onnx/onn ...@@ -27,7 +27,7 @@ Compatible with ONNX (Open Neural Network Exchange) [https://github.com/onnx/onn
</p> </p>
### scalar ### scalar
Show the error trend throughout the training. Show the error trend throughout the training.
<p align="center"> <p align="center">
<img src="./introduction/scalar.png" width="60%"/> <img src="./introduction/scalar.png" width="60%"/>
...@@ -64,7 +64,7 @@ logger = LogWriter(dir, sync_cycle=10) ...@@ -64,7 +64,7 @@ logger = LogWriter(dir, sync_cycle=10)
with logger.mode("train"): with logger.mode("train"):
# create a scalar component called 'scalars/scalar0' # create a scalar component called 'scalars/scalar0'
scalar0 = logger.scalar("scalars/scalar0") scalar0 = logger.scalar("scalars/scalar0")
# add some records during DL model running, lets start from another block. # add some records during DL model running, lets start from another block.
with logger.mode("train"): with logger.mode("train"):
...@@ -86,12 +86,12 @@ namespace cp = visualdl::components; ...@@ -86,12 +86,12 @@ namespace cp = visualdl::components;
int main() { int main() {
const std::string dir = "./tmp"; const std::string dir = "./tmp";
vs::LogWriter logger(dir, 10); vs::LogWriter logger(dir, 10);
logger.SetMode("train"); logger.SetMode("train");
auto tablet = logger.AddTablet("scalars/scalar0"); auto tablet = logger.AddTablet("scalars/scalar0");
cp::Scalar<float> scalar0(tablet); cp::Scalar<float> scalar0(tablet);
for (int step = 0; step < 1000; step++) { for (int step = 0; step < 1000; step++) {
float v = (float)std::rand() / RAND_MAX; float v = (float)std::rand() / RAND_MAX;
scalar0.AddRecord(step, v); scalar0.AddRecord(step, v);
......
...@@ -8,7 +8,7 @@ Facebook has an open-source project called [ONNX](http://onnx.ai/)(Open Neural N ...@@ -8,7 +8,7 @@ Facebook has an open-source project called [ONNX](http://onnx.ai/)(Open Neural N
## IR of ONNX ## IR of ONNX
The description of ONNX IR can be found [here](https://github.com/onnx/onnx/blob/master/docs/IR.md). The most important part is the definition of [Graph](https://github.com/onnx/onnx/blob/master/docs/IR.md#graphs). The description of ONNX IR can be found [here](https://github.com/onnx/onnx/blob/master/docs/IR.md). The most important part is the definition of [Graph](https://github.com/onnx/onnx/blob/master/docs/IR.md#graphs).
Each computation data flow graph is structured as a list of nodes that form the graph. Each node is called an operator. Nodes have zero or more inputs, one or more outputs, and zero or more attribute-value pairs. Each computation data flow graph is structured as a list of nodes that form the graph. Each node is called an operator. Nodes have zero or more inputs, one or more outputs, and zero or more attribute-value pairs.
## Rest API data format ## Rest API data format
Frontend uses rest API to get data from the server. The data format will be JSON. The data structure of a Graph is as below. Each Graph has three vectors: Frontend uses rest API to get data from the server. The data format will be JSON. The data structure of a Graph is as below. Each Graph has three vectors:
...@@ -112,6 +112,3 @@ Frontend uses rest API to get data from the server. The data format will be JSON ...@@ -112,6 +112,3 @@ Frontend uses rest API to get data from the server. The data format will be JSON
] ]
} }
``` ```
## Visual DL ## Visual DL
`Visual DL`: makes your deep learning jobs more alive via visualization. `Visual DL`: makes your deep learning jobs more alive via visualization.
At present, most deep learning frameworks are using Python. The status of training process is recorded At present, most deep learning frameworks are using Python. The status of training process is recorded
by logs. A sample log is as follow: by logs. A sample log is as follow:
...@@ -22,7 +22,7 @@ Visual DL can help you visualize the whole training process and construct plots ...@@ -22,7 +22,7 @@ Visual DL can help you visualize the whole training process and construct plots
The above is just one of Visual DL's many features. Visual DL has the following advantages: The above is just one of Visual DL's many features. Visual DL has the following advantages:
### Comprehensive Usability ### Comprehensive Usability
1. Scalar: support scalar line/dot data visualization, like the figure above. 1. Scalar: support scalar line/dot data visualization, like the figure above.
- can show metrics such as loss, accuracy, etc via lines and dots and let user see trends easily - can show metrics such as loss, accuracy, etc via lines and dots and let user see trends easily
...@@ -40,22 +40,22 @@ The above is just one of Visual DL's many features. Visual DL has the following ...@@ -40,22 +40,22 @@ The above is just one of Visual DL's many features. Visual DL has the following
<img src="image-gan.png" height="300" width="300"/> <img src="image-gan.png" height="300" width="300"/>
</p> </p>
3. Histogram: display of parameter distribution, easy to check distribution curves in each tensor, 3. Histogram: display of parameter distribution, easy to check distribution curves in each tensor,
show the trend of parameter distribution. show the trend of parameter distribution.
- help users understand the training process and the underneath reason for the change from one parameter distribution to another - help users understand the training process and the underneath reason for the change from one parameter distribution to another
- help users judge if the training is on the track. For example, if parameter change rate becomes close to 0 or grows rapidly, - help users judge if the training is on the track. For example, if parameter change rate becomes close to 0 or grows rapidly,
then exploding and vanishing gradients might happen then exploding and vanishing gradients might happen
<p align="center"> <p align="center">
<img src="histogram.png" /> <img src="histogram.png" />
</p> </p>
4. Graph: visualize the model structure of deep learning networks. 4. Graph: visualize the model structure of deep learning networks.
- Graph supports the preview of [ONNX](http://onnx.ai/) model. Since models of MXNet, Caffe2, PyTorch and CNTK can be converted to ONNX models easily, - Graph supports the preview of [ONNX](http://onnx.ai/) model. Since models of MXNet, Caffe2, PyTorch and CNTK can be converted to ONNX models easily,
Visual DL can also support these models indirectly Visual DL can also support these models indirectly
- easy to see wrong configuration of a network - easy to see wrong configuration of a network
- help understand network structure - help understand network structure
<p align="center"> <p align="center">
<img src="graph.png" height="250" width="400"/> <img src="graph.png" height="250" width="400"/>
...@@ -63,8 +63,8 @@ show the trend of parameter distribution. ...@@ -63,8 +63,8 @@ show the trend of parameter distribution.
### Easy to Integrate ### Easy to Integrate
Visual DL provides independent Python SDK. If the training task is based on Python, user can simply Visual DL provides independent Python SDK. If the training task is based on Python, user can simply
use Visual DL by installing the Visual DL wheel package and importing it into her/his own project. use Visual DL by installing the Visual DL wheel package and importing it into her/his own project.
a. Install Visual DL package. a. Install Visual DL package.
...@@ -100,7 +100,7 @@ visualDL --logdir ./log --port 8080 ...@@ -100,7 +100,7 @@ visualDL --logdir ./log --port 8080
``` ```
### Purely Open Source ### Purely Open Source
As a deep learning visualization tool, Visual DL support most deep learning frameworks. On the SDK perspective, As a deep learning visualization tool, Visual DL support most deep learning frameworks. On the SDK perspective,
it is easy to integrate into Python and C++ projects. Through ONNX, Visual DL's Graph component can support it is easy to integrate into Python and C++ projects. Through ONNX, Visual DL's Graph component can support
many popular frameworks such as PaddlePaddle, MXNet, PyTorch and Caffe2. many popular frameworks such as PaddlePaddle, MXNet, PyTorch and Caffe2.
......
...@@ -10,7 +10,7 @@ VisualDL 是一个面向深度学习任务的可视化工具,可用于训练 ...@@ -10,7 +10,7 @@ VisualDL 是一个面向深度学习任务的可视化工具,可用于训练
VisualDL提供原生的Python和C++ SDK,可以支持多种深度学习平台。用户可以在特定深度学习平台上利用Python SDK进行简单配置来支持可视化,也可以利用 C++ SDK深入嵌入到平台底层。 VisualDL提供原生的Python和C++ SDK,可以支持多种深度学习平台。用户可以在特定深度学习平台上利用Python SDK进行简单配置来支持可视化,也可以利用 C++ SDK深入嵌入到平台底层。
## 一个简单的Scalar的Python使用示例 ## 一个简单的Scalar的Python使用示例
为了简单,我们先尝试使用Python SDK。 为了简单,我们先尝试使用Python SDK。
使用VisualDL的第一步是创建一个 `LogWriter` 来存储用于可视化的数据 使用VisualDL的第一步是创建一个 `LogWriter` 来存储用于可视化的数据
......
...@@ -8,8 +8,8 @@ Currently, VisualDL supports visualization features as follows: ...@@ -8,8 +8,8 @@ Currently, VisualDL supports visualization features as follows:
- Histogram: can be used to show parameter distribution and trend. - Histogram: can be used to show parameter distribution and trend.
- Graph: can be used to visualize model structure. - Graph: can be used to visualize model structure.
VisualDL provides both Python SDK and C++ SDK in nature. It can support various frameworks. VisualDL provides both Python SDK and C++ SDK in nature. It can support various frameworks.
Users can retrieve visualization data by simply adding a few lines of code using Pythong SDK. Users can retrieve visualization data by simply adding a few lines of code using Pythong SDK.
In addition, users can also have a deep integration by using the C++ SDK at a lower level. In addition, users can also have a deep integration by using the C++ SDK at a lower level.
## A Simple Python Demo on Scalar ## A Simple Python Demo on Scalar
...@@ -25,8 +25,8 @@ from random import random ...@@ -25,8 +25,8 @@ from random import random
logw = LogWriter("./random_log", sync_cycle=30) logw = LogWriter("./random_log", sync_cycle=30)
``` ```
The first parameter points to a folder; the second parameter `sync_cycle` specifies out of how memory operations should be The first parameter points to a folder; the second parameter `sync_cycle` specifies out of how memory operations should be
store the data into hard drive. store the data into hard drive.
There are different modes for model training, such as training, validating and testing. All these correspond to `mode' in VisualDL. There are different modes for model training, such as training, validating and testing. All these correspond to `mode' in VisualDL.
We can use the following pattern to specify mode: We can use the following pattern to specify mode:
...@@ -84,7 +84,7 @@ VisualDL's C++ SDK is very similar to its Python SDK. The Python demo above can ...@@ -84,7 +84,7 @@ VisualDL's C++ SDK is very similar to its Python SDK. The Python demo above can
``` ```
## Visualization Based on ONNX Model Structure ## Visualization Based on ONNX Model Structure
VisualDL supports the visualization for the format in [ONNX](https://github.com/onnx/onnx). VisualDL supports the visualization for the format in [ONNX](https://github.com/onnx/onnx).
Currently, ONNX supports format conversion among various deep learning frameworks such as `MXNet`, `PyTorch`, `Caffe2`, `Caffe`. Currently, ONNX supports format conversion among various deep learning frameworks such as `MXNet`, `PyTorch`, `Caffe2`, `Caffe`.
``` ```
...@@ -96,5 +96,3 @@ For example, for the MNIST dataset, Graph component can render model graph as be ...@@ -96,5 +96,3 @@ For example, for the MNIST dataset, Graph component can render model graph as be
<p align=center> <p align=center>
<img width="70%" src="https://github.com/PaddlePaddle/VisualDL/blob/develop/demo/mxnet/mxnet_graph.gif" /> <img width="70%" src="https://github.com/PaddlePaddle/VisualDL/blob/develop/demo/mxnet/mxnet_graph.gif" />
</p> </p>
...@@ -24,4 +24,4 @@ response: ...@@ -24,4 +24,4 @@ response:
## data/plugins_listing ## data/plugins_listing
url: data/plugins_listing url: data/plugins_listing
\ No newline at end of file
...@@ -37,5 +37,3 @@ module.exports = function (path, queryParam, postParam) { ...@@ -37,5 +37,3 @@ module.exports = function (path, queryParam, postParam) {
} }
}; };
}; };
...@@ -174,6 +174,3 @@ export default { ...@@ -174,6 +174,3 @@ export default {
| onClose | Function | | on-close callback | | onClose | Function | | on-close callback |
| duration | Number | 3000 | duration | | duration | Number | 3000 | duration |
| type | String | | include success,error,warning,info, others not use | | type | String | | include success,error,warning,info, others not use |
...@@ -13,5 +13,3 @@ router.add({ ...@@ -13,5 +13,3 @@ router.add({
rule: '/scalars', rule: '/scalars',
Component: Scalar Component: Scalar
}); });
...@@ -15,4 +15,3 @@ export const getPluginHistogramsTags = makeService('/data/plugin/histograms/tags ...@@ -15,4 +15,3 @@ export const getPluginHistogramsTags = makeService('/data/plugin/histograms/tags
export const getPluginHistogramsHistograms = makeService('/data/plugin/histograms/histograms'); export const getPluginHistogramsHistograms = makeService('/data/plugin/histograms/histograms');
export const getPluginGraphsGraph = makeService('/data/plugin/graphs/graph'); export const getPluginGraphsGraph = makeService('/data/plugin/graphs/graph');
...@@ -3,10 +3,10 @@ from __future__ import absolute_import ...@@ -3,10 +3,10 @@ from __future__ import absolute_import
import os import os
import sys import sys
from distutils.spawn import find_executable from distutils.spawn import find_executable
from distutils import sysconfig, dep_util, log from distutils import log
import setuptools.command.build_py import setuptools.command.build_py
import setuptools import setuptools
from setuptools import setup, find_packages, Distribution, Extension from setuptools import setup, Extension
import subprocess import subprocess
TOP_DIR = os.path.realpath(os.path.dirname(__file__)) TOP_DIR = os.path.realpath(os.path.dirname(__file__))
......
#!/bin/bash
function abort(){
echo "Your change doesn't follow VisualDL's code style." 1>&2
echo "Please use pre-commit to check what is wrong." 1>&2
exit 1
}
trap 'abort' 0
set -e
cd $TRAVIS_BUILD_DIR
export PATH=/usr/bin:$PATH
pre-commit install
clang-format --version
flake8 --version
if ! pre-commit run -a ; then
git diff
exit 1
fi
trap : 0
...@@ -2,6 +2,6 @@ from __future__ import absolute_import ...@@ -2,6 +2,6 @@ from __future__ import absolute_import
import os import os
from .python.storage import * from .python.storage import * # noqa: F401,F403
ROOT = os.path.dirname(__file__) ROOT = os.path.dirname(__file__)
...@@ -15,10 +15,10 @@ limitations under the License. */ ...@@ -15,10 +15,10 @@ limitations under the License. */
#ifndef VISUALDL_LOGIC_HISTOGRAM_H #ifndef VISUALDL_LOGIC_HISTOGRAM_H
#define VISUALDL_LOGIC_HISTOGRAM_H #define VISUALDL_LOGIC_HISTOGRAM_H
#include "visualdl/utils/logging.h"
#include <cstdlib> #include <cstdlib>
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "visualdl/utils/logging.h"
namespace visualdl { namespace visualdl {
......
...@@ -32,7 +32,8 @@ std::string g_log_dir; ...@@ -32,7 +32,8 @@ std::string g_log_dir;
LogWriter LogWriter::AsMode(const std::string& mode) { LogWriter LogWriter::AsMode(const std::string& mode) {
for (auto ch : "%/") { for (auto ch : "%/") {
CHECK(mode.find(ch) == std::string::npos) CHECK(mode.find(ch) == std::string::npos)
<< "character "<< ch << " is a reserved word, it is not allowed in mode."; << "character " << ch
<< " is a reserved word, it is not allowed in mode.";
} }
LogWriter writer = *this; LogWriter writer = *this;
......
...@@ -12,9 +12,11 @@ class MemCache(object): ...@@ -12,9 +12,11 @@ class MemCache(object):
def expired(self, timeout): def expired(self, timeout):
return timeout > 0 and time.time() - self.time >= timeout return timeout > 0 and time.time() - self.time >= timeout
''' '''
A global dict to help cache some temporary data. A global dict to help cache some temporary data.
''' '''
def __init__(self, timeout=-1): def __init__(self, timeout=-1):
self._timeout = timeout self._timeout = timeout
self._data = {} self._data = {}
...@@ -24,13 +26,15 @@ class MemCache(object): ...@@ -24,13 +26,15 @@ class MemCache(object):
def get(self, key): def get(self, key):
rcd = self._data.get(key, None) rcd = self._data.get(key, None)
if not rcd: return None if not rcd:
return None
# do not delete the key to accelerate speed # do not delete the key to accelerate speed
if rcd.expired(self._timeout): if rcd.expired(self._timeout):
rcd.clear() rcd.clear()
return None return None
return rcd.value return rcd.value
if __name__ == '__main__': if __name__ == '__main__':
import unittest import unittest
......
...@@ -4,13 +4,16 @@ from visualdl import core ...@@ -4,13 +4,16 @@ from visualdl import core
dtypes = ("float", "double", "int32", "int64") dtypes = ("float", "double", "int32", "int64")
def check_tag_name_valid(tag): def check_tag_name_valid(tag):
assert '%' not in tag, "character % is a reserved word, it is not allowed in tag." assert '%' not in tag, "character % is a reserved word, it is not allowed in tag."
def check_mode_name_valid(tag): def check_mode_name_valid(tag):
for char in ['%', '/']: for char in ['%', '/']:
assert char not in tag, "character %s is a reserved word, it is not allowed in mode." % char assert char not in tag, "character %s is a reserved word, it is not allowed in mode." % char
class LogReader(object): class LogReader(object):
"""LogReader is a Python wrapper to read and analysis the data that """LogReader is a Python wrapper to read and analysis the data that
saved with data format defined in storage.proto. user can get saved with data format defined in storage.proto. user can get
...@@ -125,7 +128,8 @@ class LogWriter(object): ...@@ -125,7 +128,8 @@ class LogWriter(object):
create a new LogWriter with mode and return it. create a new LogWriter with mode and return it.
""" """
check_mode_name_valid(mode) check_mode_name_valid(mode)
LogWriter.cur_mode = LogWriter(self.dir, self.sync_cycle, self.writer.as_mode(mode)) LogWriter.cur_mode = LogWriter(self.dir, self.sync_cycle,
self.writer.as_mode(mode))
return LogWriter.cur_mode return LogWriter.cur_mode
def scalar(self, tag, type='float'): def scalar(self, tag, type='float'):
......
...@@ -9,7 +9,6 @@ from visualdl import LogReader, LogWriter ...@@ -9,7 +9,6 @@ from visualdl import LogReader, LogWriter
pprint.pprint(sys.path) pprint.pprint(sys.path)
class StorageTest(unittest.TestCase): class StorageTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.dir = "./tmp/storage_test" self.dir = "./tmp/storage_test"
......
...@@ -9,8 +9,8 @@ import onnx ...@@ -9,8 +9,8 @@ import onnx
def debug_print(json_obj): def debug_print(json_obj):
print(json.dumps( print(
json_obj, sort_keys=True, indent=4, separators=(',', ': '))) json.dumps(json_obj, sort_keys=True, indent=4, separators=(',', ': ')))
def reorganize_inout(json_obj, key): def reorganize_inout(json_obj, key):
...@@ -54,8 +54,8 @@ def rename_model(model_json): ...@@ -54,8 +54,8 @@ def rename_model(model_json):
for variable in variables: for variable in variables:
old_name = variable['name'] old_name = variable['name']
new_shape = [int(dim) for dim in variable['shape']] new_shape = [int(dim) for dim in variable['shape']]
new_name = old_name + '\ndata_type=' + str(variable['data_type']) \ new_name = old_name + '\ndata_type=' + str(
+ '\nshape=' + str(new_shape) variable['data_type']) + '\nshape=' + str(new_shape)
variable['name'] = new_name variable['name'] = new_name
rename_edge(model, old_name, new_name) rename_edge(model, old_name, new_name)
...@@ -234,8 +234,8 @@ def get_level_to_all(node_links, model_json): ...@@ -234,8 +234,8 @@ def get_level_to_all(node_links, model_json):
if out_level not in output_to_level: if out_level not in output_to_level:
output_to_level[out_idx] = out_level output_to_level[out_idx] = out_level
else: else:
raise Exception( raise Exception("output " + out_name +
"output " + out_name + "have multiple source") "have multiple source")
level_to_outputs = dict() level_to_outputs = dict()
for out_idx in output_to_level: for out_idx in output_to_level:
level = output_to_level[out_idx] level = output_to_level[out_idx]
...@@ -353,6 +353,7 @@ class GraphPreviewGenerator(object): ...@@ -353,6 +353,7 @@ class GraphPreviewGenerator(object):
''' '''
Generate a graph image for ONNX proto. Generate a graph image for ONNX proto.
''' '''
def __init__(self, model_json): def __init__(self, model_json):
self.model = model_json self.model = model_json
# init graphviz graph # init graphviz graph
...@@ -360,8 +361,7 @@ class GraphPreviewGenerator(object): ...@@ -360,8 +361,7 @@ class GraphPreviewGenerator(object):
self.model['name'], self.model['name'],
layout="dot", layout="dot",
concentrate="true", concentrate="true",
rankdir="TB", rankdir="TB", )
)
self.op_rank = self.graph.rank_group('same', 2) self.op_rank = self.graph.rank_group('same', 2)
self.param_rank = self.graph.rank_group('same', 1) self.param_rank = self.graph.rank_group('same', 1)
...@@ -396,10 +396,9 @@ class GraphPreviewGenerator(object): ...@@ -396,10 +396,9 @@ class GraphPreviewGenerator(object):
self.args.add(target) self.args.add(target)
if source in self.args or target in self.args: if source in self.args or target in self.args:
edge = self.add_edge( self.add_edge(style="dashed,bold", color="#aaaaaa", **item)
style="dashed,bold", color="#aaaaaa", **item)
else: else:
edge = self.add_edge(style="bold", color="#aaaaaa", **item) self.add_edge(style="bold", color="#aaaaaa", **item)
if not show: if not show:
self.graph.display(path) self.graph.display(path)
...@@ -448,8 +447,7 @@ class GraphPreviewGenerator(object): ...@@ -448,8 +447,7 @@ class GraphPreviewGenerator(object):
fontname="Arial", fontname="Arial",
fontcolor="#ffffff", fontcolor="#ffffff",
width="1.3", width="1.3",
height="0.84", height="0.84", )
)
def add_arg(self, name): def add_arg(self, name):
return self.graph.node( return self.graph.node(
...@@ -483,17 +481,16 @@ def draw_graph(model_pb_path, image_dir): ...@@ -483,17 +481,16 @@ def draw_graph(model_pb_path, image_dir):
if min_width is None or im.size[0] < min_width: if min_width is None or im.size[0] < min_width:
min_width = im.size min_width = im.size
best_image = image_path best_image = image_path
except: except Exception:
pass pass
return best_image return best_image
if __name__ == '__main__': if __name__ == '__main__':
import os
import sys import sys
current_path = os.path.abspath(os.path.dirname(sys.argv[0])) current_path = os.path.abspath(os.path.dirname(sys.argv[0]))
json_str = load_model(current_path + "/mock/inception_v1_model.pb") json_str = load_model(current_path + "/mock/inception_v1_model.pb")
#json_str = load_model(current_path + "/mock/squeezenet_model.pb") # json_str = load_model(current_path + "/mock/squeezenet_model.pb")
# json_str = load_model('./mock/shufflenet/model.pb') # json_str = load_model('./mock/shufflenet/model.pb')
debug_print(json_str) debug_print(json_str)
assert json_str assert json_str
......
...@@ -28,7 +28,8 @@ class GraphTest(unittest.TestCase): ...@@ -28,7 +28,8 @@ class GraphTest(unittest.TestCase):
# label_100: (in-edge) # label_100: (in-edge)
# {u'source': u'fire6/squeeze1x1_1', u'target': u'node_34', u'label': u'label_100'} # {u'source': u'fire6/squeeze1x1_1', u'target': u'node_34', u'label': u'label_100'}
self.assertEqual(json_obj['edges'][100]['source'], 'fire6/squeeze1x1_1') self.assertEqual(json_obj['edges'][100]['source'],
'fire6/squeeze1x1_1')
self.assertEqual(json_obj['edges'][100]['target'], 'node_34') self.assertEqual(json_obj['edges'][100]['target'], 'node_34')
self.assertEqual(json_obj['edges'][100]['label'], 'label_100') self.assertEqual(json_obj['edges'][100]['label'], 'label_100')
......
...@@ -9,4 +9,3 @@ cd .. ...@@ -9,4 +9,3 @@ cd ..
python graph_test.py python graph_test.py
rm ./mock/*.pb rm ./mock/*.pb
import os
import random import random
import subprocess import subprocess
import sys
import tempfile
def crepr(v): def crepr(v):
...@@ -28,7 +25,7 @@ class Rank(object): ...@@ -28,7 +25,7 @@ class Rank(object):
return '' return ''
return '{' + 'rank={};'.format(self.kind) + \ return '{' + 'rank={};'.format(self.kind) + \
','.join([node.name for node in self.nodes]) + '}' ','.join([node.name for node in self.nodes]) + '}'
# the python package graphviz is too poor. # the python package graphviz is too poor.
......
...@@ -106,7 +106,8 @@ def get_image_tag_steps(storage, mode, tag): ...@@ -106,7 +106,8 @@ def get_image_tag_steps(storage, mode, tag):
record = image.record(step_index, sample_index) record = image.record(step_index, sample_index)
shape = record.shape() shape = record.shape()
# TODO(ChunweiYan) remove this trick, some shape will be empty # TODO(ChunweiYan) remove this trick, some shape will be empty
if not shape: continue if not shape:
continue
try: try:
query = urllib.urlencode({ query = urllib.urlencode({
'sample': 0, 'sample': 0,
...@@ -121,7 +122,7 @@ def get_image_tag_steps(storage, mode, tag): ...@@ -121,7 +122,7 @@ def get_image_tag_steps(storage, mode, tag):
'wall_time': image.timestamp(step_index), 'wall_time': image.timestamp(step_index),
'query': query, 'query': query,
}) })
except: except Exception:
logger.error("image sample out of range") logger.error("image sample out of range")
return res return res
...@@ -164,7 +165,7 @@ def get_histogram(storage, mode, tag, num_samples=100): ...@@ -164,7 +165,7 @@ def get_histogram(storage, mode, tag, num_samples=100):
try: try:
# some bug with protobuf, some times may overflow # some bug with protobuf, some times may overflow
record = histogram.record(i) record = histogram.record(i)
except: except Exception:
continue continue
res.append([]) res.append([])
...@@ -177,9 +178,7 @@ def get_histogram(storage, mode, tag, num_samples=100): ...@@ -177,9 +178,7 @@ def get_histogram(storage, mode, tag, num_samples=100):
for j in xrange(record.num_instances()): for j in xrange(record.num_instances()):
instance = record.instance(j) instance = record.instance(j)
data.append( data.append(
[instance.left(), [instance.left(), instance.right(), instance.frequency()])
instance.right(),
instance.frequency()])
if len(res) < num_samples: if len(res) < num_samples:
return res return res
...@@ -206,11 +205,12 @@ def retry(ntimes, function, time2sleep, *args, **kwargs): ...@@ -206,11 +205,12 @@ def retry(ntimes, function, time2sleep, *args, **kwargs):
for i in xrange(ntimes): for i in xrange(ntimes):
try: try:
return function(*args, **kwargs) return function(*args, **kwargs)
except: except Exception:
error_info = '\n'.join(map(str, sys.exc_info())) error_info = '\n'.join(map(str, sys.exc_info()))
logger.error("Unexpected error: %s" % error_info) logger.error("Unexpected error: %s" % error_info)
time.sleep(time2sleep) time.sleep(time2sleep)
def cache_get(cache): def cache_get(cache):
def _handler(key, func, *args, **kwargs): def _handler(key, func, *args, **kwargs):
data = cache.get(key) data = cache.get(key)
...@@ -220,4 +220,5 @@ def cache_get(cache): ...@@ -220,4 +220,5 @@ def cache_get(cache):
cache.set(key, data) cache.set(key, data)
return data return data
return data return data
return _handler return _handler
...@@ -8,6 +8,7 @@ from storage_mock import add_histogram, add_image, add_scalar ...@@ -8,6 +8,7 @@ from storage_mock import add_histogram, add_image, add_scalar
_retry_counter = 0 _retry_counter = 0
class LibTest(unittest.TestCase): class LibTest(unittest.TestCase):
def setUp(self): def setUp(self):
dir = "./tmp/mock" dir = "./tmp/mock"
......
...@@ -102,7 +102,7 @@ def sequence_data(): ...@@ -102,7 +102,7 @@ def sequence_data():
def graph_data(): def graph_data():
return """{ return """{
"title": { "title": {
"text": "MLP" "text": "MLP"
}, },
......
...@@ -6,4 +6,3 @@ cp squeezenet/model.pb squeezenet_model.pb ...@@ -6,4 +6,3 @@ cp squeezenet/model.pb squeezenet_model.pb
rm -rf squeezenet rm -rf squeezenet
rm squeezenet.tar.gz rm squeezenet.tar.gz
from setuptools import setup from setuptools import setup
packages = ['visualdl', packages = [
'visualdl.onnx', 'visualdl', 'visualdl.onnx', 'visualdl.mock', 'visualdl.frontend.dist'
'visualdl.mock', ]
'visualdl.frontend.dist']
setup( setup(
name="visualdl", name="visualdl",
......
import random import random
import time
import unittest
import numpy as np import numpy as np
...@@ -36,4 +34,5 @@ def add_histogram(writer, mode, tag, num_buckets): ...@@ -36,4 +34,5 @@ def add_histogram(writer, mode, tag, num_buckets):
with writer.mode(mode) as writer: with writer.mode(mode) as writer:
histogram = writer.histogram(tag, num_buckets) histogram = writer.histogram(tag, num_buckets)
for i in range(10): for i in range(10):
histogram.add_record(i, np.random.normal(0.1 + i * 0.01, size=1000)) histogram.add_record(i, np.random.normal(
0.1 + i * 0.01, size=1000))
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#ifndef VISUALDL_STORAGE_STORAGE_H #ifndef VISUALDL_STORAGE_STORAGE_H
#define VISUALDL_STORAGE_STORAGE_H #define VISUALDL_STORAGE_STORAGE_H
#include "visualdl/utils/logging.h"
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include <vector> #include <vector>
...@@ -25,6 +24,7 @@ limitations under the License. */ ...@@ -25,6 +24,7 @@ limitations under the License. */
#include "visualdl/storage/tablet.h" #include "visualdl/storage/tablet.h"
#include "visualdl/utils/filesystem.h" #include "visualdl/utils/filesystem.h"
#include "visualdl/utils/guard.h" #include "visualdl/utils/guard.h"
#include "visualdl/utils/logging.h"
namespace visualdl { namespace visualdl {
static const std::string meta_file_name = "storage.meta"; static const std::string meta_file_name = "storage.meta";
......
...@@ -15,11 +15,10 @@ limitations under the License. */ ...@@ -15,11 +15,10 @@ limitations under the License. */
#ifndef VISUALDL_TABLET_H #ifndef VISUALDL_TABLET_H
#define VISUALDL_TABLET_H #define VISUALDL_TABLET_H
#include "visualdl/utils/logging.h"
#include "visualdl/logic/im.h" #include "visualdl/logic/im.h"
#include "visualdl/storage/record.h" #include "visualdl/storage/record.h"
#include "visualdl/storage/storage.pb.h" #include "visualdl/storage/storage.pb.h"
#include "visualdl/utils/logging.h"
#include "visualdl/utils/string.h" #include "visualdl/utils/string.h"
namespace visualdl { namespace visualdl {
......
import sys import sys
import unittest import unittest
import numpy as np import numpy as np
sys.path.append('../../build') sys.path.append('../../build') # noqa: E402
import core import core
im = core.im() im = core.im()
...@@ -62,7 +63,7 @@ class TabletTester(unittest.TestCase): ...@@ -62,7 +63,7 @@ class TabletTester(unittest.TestCase):
class ImTester(unittest.TestCase): class ImTester(unittest.TestCase):
def test_persist(self): def test_persist(self):
im.clear_tablets() im.clear_tablets()
tablet = im.add_tablet("tab0", 111) im.add_tablet("tab0", 111)
self.assertEqual(im.storage().tablets_size(), 1) self.assertEqual(im.storage().tablets_size(), 1)
im.storage().set_dir("./1") im.storage().set_dir("./1")
im.persist_to_disk() im.persist_to_disk()
......
...@@ -15,11 +15,11 @@ limitations under the License. */ ...@@ -15,11 +15,11 @@ limitations under the License. */
#ifndef VISUALDL_UTILS_CONCURRENCY_H #ifndef VISUALDL_UTILS_CONCURRENCY_H
#define VISUALDL_UTILS_CONCURRENCY_H #define VISUALDL_UTILS_CONCURRENCY_H
#include "visualdl/utils/logging.h"
#include <chrono> #include <chrono>
#include <memory> #include <memory>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include "visualdl/utils/logging.h"
namespace visualdl { namespace visualdl {
namespace cc { namespace cc {
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#include "visualdl/utils/concurrency.h" #include "visualdl/utils/concurrency.h"
#include "visualdl/utils/logging.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "visualdl/utils/logging.h"
namespace visualdl { namespace visualdl {
......
...@@ -30,4 +30,3 @@ TEST(image, NormalizeImage) { ...@@ -30,4 +30,3 @@ TEST(image, NormalizeImage) {
NormalizeImage(&image, arr, 3, 128); NormalizeImage(&image, arr, 3, 128);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册