未验证 提交 642bfc47 编写于 作者: T Thuan Nguyen 提交者: GitHub

Cpp python style check (#215)

* Add C++ and python style check to travis.  
* Update all C++/python code that violate coding standards.
上级 13f89d70
#!/bin/bash #!/bin/bash
set -e set -e
readonly VERSION="3.8" readonly SUPPORTED_VERSION="3.8"
version=$(clang-format -version) version=$(clang-format -version)
if ! [[ $version == *"$VERSION"* ]]; then if ! [[ $version == *"$SUPPORTED_VERSION"* ]]; then
echo "clang-format version check failed." echo "clang-format version check failed."
echo "a version contains '$VERSION' is needed, but get '$version'" echo "a version contains '$SUPPORTED_VERSION' is needed, but get '$version'"
echo "you can install the right version, and make an soft-link to '\$PATH' env" echo "you can install the right version, and make an soft-link to '\$PATH' env"
exit -1 exit -1
fi fi
......
[flake8]
max-line-length = 120
\ No newline at end of file
...@@ -22,4 +22,14 @@ ...@@ -22,4 +22,14 @@
description: Format files with ClangFormat. description: Format files with ClangFormat.
entry: bash ./.clang_format.hook -i entry: bash ./.clang_format.hook -i
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx)$
- repo: local
hooks:
- id: python-format-checker
name: python-format-checker
description: Format python files using PEP8 standard
entry: flake8
language: system
files: \.(py)$
...@@ -13,6 +13,10 @@ os: ...@@ -13,6 +13,10 @@ os:
# TODO(ChunweiYan) support osx in the future # TODO(ChunweiYan) support osx in the future
#- osx #- osx
env:
- JOB=check_style
- JOB=test
addons: addons:
apt: apt:
packages: packages:
...@@ -29,12 +33,14 @@ addons: ...@@ -29,12 +33,14 @@ addons:
- nodejs - nodejs
before_install: before_install:
- if [[ "$JOB" == "check_style" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; sudo pip install pre-commit flake8; fi
- if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then brew update; fi - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then brew update; fi
- if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then brew upgrade python; fi - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then brew upgrade python; fi
- if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then brew install brew-pip; fi - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then brew install brew-pip; fi
script: script:
/bin/bash ./tests.sh all - if [[ "$JOB" == "check_style" ]]; then ./travis/check_style.sh; fi
- if [[ "$JOB" == "test" ]]; then /bin/bash ./tests.sh all; fi
notifications: notifications:
email: email:
......
import numpy as np
import mxnet as mx
import logging import logging
import mxnet as mx import mxnet as mx
...@@ -10,7 +8,6 @@ from visualdl import LogWriter ...@@ -10,7 +8,6 @@ from visualdl import LogWriter
mnist = mx.test_utils.get_mnist() mnist = mx.test_utils.get_mnist()
batch_size = 100 batch_size = 100
# Provide a folder to store data for log, model, image, etc. VisualDL's visualization will be # Provide a folder to store data for log, model, image, etc. VisualDL's visualization will be
# based on this folder. # based on this folder.
logdir = "./tmp" logdir = "./tmp"
...@@ -44,8 +41,10 @@ def add_scalar(): ...@@ -44,8 +41,10 @@ def add_scalar():
for name, value in name_value: for name, value in name_value:
scalar0.add_record(cnt_step, value) scalar0.add_record(cnt_step, value)
cnt_step += 1 cnt_step += 1
return _callback return _callback
def add_image_histogram(): def add_image_histogram():
def _callback(iter_no, sym, arg, aux): def _callback(iter_no, sym, arg, aux):
image0.start_sampling() image0.start_sampling()
...@@ -57,6 +56,7 @@ def add_image_histogram(): ...@@ -57,6 +56,7 @@ def add_image_histogram():
histogram0.add_record(iter_no, list(data)) histogram0.add_record(iter_no, list(data))
image0.finish_sampling() image0.finish_sampling()
return _callback return _callback
...@@ -65,18 +65,22 @@ def add_image_histogram(): ...@@ -65,18 +65,22 @@ def add_image_histogram():
logging.getLogger().setLevel(logging.DEBUG) # logging to stdout logging.getLogger().setLevel(logging.DEBUG) # logging to stdout
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True) train_iter = mx.io.NDArrayIter(
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'],
batch_size)
data = mx.sym.var('data') data = mx.sym.var('data')
# first conv layer # first conv layer
conv1 = mx.sym.Convolution(data=data, kernel=(5, 5), num_filter=20) conv1 = mx.sym.Convolution(data=data, kernel=(5, 5), num_filter=20)
tanh1 = mx.sym.Activation(data=conv1, act_type="tanh") tanh1 = mx.sym.Activation(data=conv1, act_type="tanh")
pool1 = mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2, 2), stride=(2, 2)) pool1 = mx.sym.Pooling(
data=tanh1, pool_type="max", kernel=(2, 2), stride=(2, 2))
# second conv layer # second conv layer
conv2 = mx.sym.Convolution(data=pool1, kernel=(5, 5), num_filter=50) conv2 = mx.sym.Convolution(data=pool1, kernel=(5, 5), num_filter=50)
tanh2 = mx.sym.Activation(data=conv2, act_type="tanh") tanh2 = mx.sym.Activation(data=conv2, act_type="tanh")
pool2 = mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2, 2), stride=(2, 2)) pool2 = mx.sym.Pooling(
data=tanh2, pool_type="max", kernel=(2, 2), stride=(2, 2))
# first fullc layer # first fullc layer
flatten = mx.sym.flatten(data=pool2) flatten = mx.sym.flatten(data=pool2)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500) fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)
...@@ -89,9 +93,9 @@ lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax') ...@@ -89,9 +93,9 @@ lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax')
# create a trainable module on CPU # create a trainable module on CPU
lenet_model = mx.mod.Module(symbol=lenet, context=mx.cpu()) lenet_model = mx.mod.Module(symbol=lenet, context=mx.cpu())
# train with the same # train with the same
lenet_model.fit(train_iter, lenet_model.fit(
train_iter,
eval_data=val_iter, eval_data=val_iter,
optimizer='sgd', optimizer='sgd',
optimizer_params={'learning_rate': 0.1}, optimizer_params={'learning_rate': 0.1},
...@@ -103,7 +107,8 @@ lenet_model.fit(train_iter, ...@@ -103,7 +107,8 @@ lenet_model.fit(train_iter,
test_iter = mx.io.NDArrayIter(mnist['test_data'], None, batch_size) test_iter = mx.io.NDArrayIter(mnist['test_data'], None, batch_size)
prob = lenet_model.predict(test_iter) prob = lenet_model.predict(test_iter)
test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'],
batch_size)
# predict accuracy for lenet # predict accuracy for lenet
acc = mx.metric.Accuracy() acc = mx.metric.Accuracy()
......
...@@ -117,7 +117,10 @@ elif net_type == "resnet": ...@@ -117,7 +117,10 @@ elif net_type == "resnet":
else: else:
raise ValueError("%s network is not supported" % net_type) raise ValueError("%s network is not supported" % net_type)
predict = fluid.layers.fc(input=net, size=classdim, act='softmax', predict = fluid.layers.fc(
input=net,
size=classdim,
act='softmax',
param_attr=ParamAttr(name="param1", initializer=NormalInitializer())) param_attr=ParamAttr(name="param1", initializer=NormalInitializer()))
cost = fluid.layers.cross_entropy(input=predict, label=label) cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
...@@ -131,8 +134,7 @@ BATCH_SIZE = 16 ...@@ -131,8 +134,7 @@ BATCH_SIZE = 16
PASS_NUM = 1 PASS_NUM = 1
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(paddle.dataset.cifar.train10(), buf_size=128 * 10),
paddle.dataset.cifar.train10(), buf_size=128 * 10),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
place = fluid.CPUPlace() place = fluid.CPUPlace()
...@@ -150,7 +152,8 @@ param1_var = start_up_program.global_block().var("param1") ...@@ -150,7 +152,8 @@ param1_var = start_up_program.global_block().var("param1")
for pass_id in range(PASS_NUM): for pass_id in range(PASS_NUM):
accuracy.reset(exe) accuracy.reset(exe)
for data in train_reader(): for data in train_reader():
loss, conv1_out, param1, acc = exe.run(fluid.default_main_program(), loss, conv1_out, param1, acc = exe.run(
fluid.default_main_program(),
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[avg_cost, conv1, param1_var] + accuracy.metrics) fetch_list=[avg_cost, conv1, param1_var] + accuracy.metrics)
pass_acc = accuracy.eval(exe) pass_acc = accuracy.eval(exe)
...@@ -165,11 +168,14 @@ for pass_id in range(PASS_NUM): ...@@ -165,11 +168,14 @@ for pass_id in range(PASS_NUM):
idx = idx1 idx = idx1
if idx != -1: if idx != -1:
image_data = data[0][0] image_data = data[0][0]
input_image_data = np.transpose(image_data.reshape(data_shape), axes=[1, 2, 0]) input_image_data = np.transpose(
input_image.set_sample(idx, input_image_data.shape, input_image_data.flatten()) image_data.reshape(data_shape), axes=[1, 2, 0])
input_image.set_sample(idx, input_image_data.shape,
input_image_data.flatten())
conv_image_data = conv1_out[0][0] conv_image_data = conv1_out[0][0]
conv_image.set_sample(idx, conv_image_data.shape, conv_image_data.flatten()) conv_image.set_sample(idx, conv_image_data.shape,
conv_image_data.flatten())
sample_num += 1 sample_num += 1
if sample_num % num_samples == 0: if sample_num % num_samples == 0:
......
#!/user/bin/env python #!/user/bin/env python
import math
import os import os
import random import random
import subprocess
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from scipy.stats import norm
from visualdl import ROOT, LogWriter from visualdl import ROOT, LogWriter
from visualdl.server.log import logger as log from visualdl.server.log import logger as log
...@@ -73,8 +69,7 @@ with logw.mode("train") as logger: ...@@ -73,8 +69,7 @@ with logw.mode("train") as logger:
idx = image.is_sample_taken() idx = image.is_sample_taken()
if idx >= 0: if idx >= 0:
data = np.array( data = np.array(
dog_jpg.crop((left_x, left_y, right_x, dog_jpg.crop((left_x, left_y, right_x, right_y))).flatten()
right_y))).flatten()
# add this image to log # add this image to log
image.set_sample(idx, target_shape, data) image.set_sample(idx, target_shape, data)
# you can also just write followig codes, it is more clear, but need to # you can also just write followig codes, it is more clear, but need to
...@@ -95,6 +90,7 @@ with logw.mode("train") as logger: ...@@ -95,6 +90,7 @@ with logw.mode("train") as logger:
image0.add_sample(shape, list(data)) image0.add_sample(shape, list(data))
image0.finish_sampling() image0.finish_sampling()
def download_graph_image(): def download_graph_image():
''' '''
This is a scratch demo, it do not generate a ONNX proto, but just download an image This is a scratch demo, it do not generate a ONNX proto, but just download an image
...@@ -110,4 +106,5 @@ def download_graph_image(): ...@@ -110,4 +106,5 @@ def download_graph_image():
f.write(graph_image) f.write(graph_image)
log.warning('graph ready!') log.warning('graph ready!')
download_graph_image() download_graph_image()
...@@ -112,6 +112,3 @@ Frontend uses rest API to get data from the server. The data format will be JSON ...@@ -112,6 +112,3 @@ Frontend uses rest API to get data from the server. The data format will be JSON
] ]
} }
``` ```
...@@ -96,5 +96,3 @@ For example, for the MNIST dataset, Graph component can render model graph as be ...@@ -96,5 +96,3 @@ For example, for the MNIST dataset, Graph component can render model graph as be
<p align=center> <p align=center>
<img width="70%" src="https://github.com/PaddlePaddle/VisualDL/blob/develop/demo/mxnet/mxnet_graph.gif" /> <img width="70%" src="https://github.com/PaddlePaddle/VisualDL/blob/develop/demo/mxnet/mxnet_graph.gif" />
</p> </p>
...@@ -37,5 +37,3 @@ module.exports = function (path, queryParam, postParam) { ...@@ -37,5 +37,3 @@ module.exports = function (path, queryParam, postParam) {
} }
}; };
}; };
...@@ -174,6 +174,3 @@ export default { ...@@ -174,6 +174,3 @@ export default {
| onClose | Function | | on-close callback | | onClose | Function | | on-close callback |
| duration | Number | 3000 | duration | | duration | Number | 3000 | duration |
| type | String | | include success,error,warning,info, others not use | | type | String | | include success,error,warning,info, others not use |
...@@ -13,5 +13,3 @@ router.add({ ...@@ -13,5 +13,3 @@ router.add({
rule: '/scalars', rule: '/scalars',
Component: Scalar Component: Scalar
}); });
...@@ -15,4 +15,3 @@ export const getPluginHistogramsTags = makeService('/data/plugin/histograms/tags ...@@ -15,4 +15,3 @@ export const getPluginHistogramsTags = makeService('/data/plugin/histograms/tags
export const getPluginHistogramsHistograms = makeService('/data/plugin/histograms/histograms'); export const getPluginHistogramsHistograms = makeService('/data/plugin/histograms/histograms');
export const getPluginGraphsGraph = makeService('/data/plugin/graphs/graph'); export const getPluginGraphsGraph = makeService('/data/plugin/graphs/graph');
...@@ -3,10 +3,10 @@ from __future__ import absolute_import ...@@ -3,10 +3,10 @@ from __future__ import absolute_import
import os import os
import sys import sys
from distutils.spawn import find_executable from distutils.spawn import find_executable
from distutils import sysconfig, dep_util, log from distutils import log
import setuptools.command.build_py import setuptools.command.build_py
import setuptools import setuptools
from setuptools import setup, find_packages, Distribution, Extension from setuptools import setup, Extension
import subprocess import subprocess
TOP_DIR = os.path.realpath(os.path.dirname(__file__)) TOP_DIR = os.path.realpath(os.path.dirname(__file__))
......
#!/bin/bash
function abort(){
echo "Your change doesn't follow VisualDL's code style." 1>&2
echo "Please use pre-commit to check what is wrong." 1>&2
exit 1
}
trap 'abort' 0
set -e
cd $TRAVIS_BUILD_DIR
export PATH=/usr/bin:$PATH
pre-commit install
clang-format --version
flake8 --version
if ! pre-commit run -a ; then
git diff
exit 1
fi
trap : 0
...@@ -2,6 +2,6 @@ from __future__ import absolute_import ...@@ -2,6 +2,6 @@ from __future__ import absolute_import
import os import os
from .python.storage import * from .python.storage import * # noqa: F401,F403
ROOT = os.path.dirname(__file__) ROOT = os.path.dirname(__file__)
...@@ -15,10 +15,10 @@ limitations under the License. */ ...@@ -15,10 +15,10 @@ limitations under the License. */
#ifndef VISUALDL_LOGIC_HISTOGRAM_H #ifndef VISUALDL_LOGIC_HISTOGRAM_H
#define VISUALDL_LOGIC_HISTOGRAM_H #define VISUALDL_LOGIC_HISTOGRAM_H
#include "visualdl/utils/logging.h"
#include <cstdlib> #include <cstdlib>
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "visualdl/utils/logging.h"
namespace visualdl { namespace visualdl {
......
...@@ -32,7 +32,8 @@ std::string g_log_dir; ...@@ -32,7 +32,8 @@ std::string g_log_dir;
LogWriter LogWriter::AsMode(const std::string& mode) { LogWriter LogWriter::AsMode(const std::string& mode) {
for (auto ch : "%/") { for (auto ch : "%/") {
CHECK(mode.find(ch) == std::string::npos) CHECK(mode.find(ch) == std::string::npos)
<< "character "<< ch << " is a reserved word, it is not allowed in mode."; << "character " << ch
<< " is a reserved word, it is not allowed in mode.";
} }
LogWriter writer = *this; LogWriter writer = *this;
......
...@@ -12,9 +12,11 @@ class MemCache(object): ...@@ -12,9 +12,11 @@ class MemCache(object):
def expired(self, timeout): def expired(self, timeout):
return timeout > 0 and time.time() - self.time >= timeout return timeout > 0 and time.time() - self.time >= timeout
''' '''
A global dict to help cache some temporary data. A global dict to help cache some temporary data.
''' '''
def __init__(self, timeout=-1): def __init__(self, timeout=-1):
self._timeout = timeout self._timeout = timeout
self._data = {} self._data = {}
...@@ -24,13 +26,15 @@ class MemCache(object): ...@@ -24,13 +26,15 @@ class MemCache(object):
def get(self, key): def get(self, key):
rcd = self._data.get(key, None) rcd = self._data.get(key, None)
if not rcd: return None if not rcd:
return None
# do not delete the key to accelerate speed # do not delete the key to accelerate speed
if rcd.expired(self._timeout): if rcd.expired(self._timeout):
rcd.clear() rcd.clear()
return None return None
return rcd.value return rcd.value
if __name__ == '__main__': if __name__ == '__main__':
import unittest import unittest
......
...@@ -4,13 +4,16 @@ from visualdl import core ...@@ -4,13 +4,16 @@ from visualdl import core
dtypes = ("float", "double", "int32", "int64") dtypes = ("float", "double", "int32", "int64")
def check_tag_name_valid(tag): def check_tag_name_valid(tag):
assert '%' not in tag, "character % is a reserved word, it is not allowed in tag." assert '%' not in tag, "character % is a reserved word, it is not allowed in tag."
def check_mode_name_valid(tag): def check_mode_name_valid(tag):
for char in ['%', '/']: for char in ['%', '/']:
assert char not in tag, "character %s is a reserved word, it is not allowed in mode." % char assert char not in tag, "character %s is a reserved word, it is not allowed in mode." % char
class LogReader(object): class LogReader(object):
"""LogReader is a Python wrapper to read and analysis the data that """LogReader is a Python wrapper to read and analysis the data that
saved with data format defined in storage.proto. user can get saved with data format defined in storage.proto. user can get
...@@ -125,7 +128,8 @@ class LogWriter(object): ...@@ -125,7 +128,8 @@ class LogWriter(object):
create a new LogWriter with mode and return it. create a new LogWriter with mode and return it.
""" """
check_mode_name_valid(mode) check_mode_name_valid(mode)
LogWriter.cur_mode = LogWriter(self.dir, self.sync_cycle, self.writer.as_mode(mode)) LogWriter.cur_mode = LogWriter(self.dir, self.sync_cycle,
self.writer.as_mode(mode))
return LogWriter.cur_mode return LogWriter.cur_mode
def scalar(self, tag, type='float'): def scalar(self, tag, type='float'):
......
...@@ -9,7 +9,6 @@ from visualdl import LogReader, LogWriter ...@@ -9,7 +9,6 @@ from visualdl import LogReader, LogWriter
pprint.pprint(sys.path) pprint.pprint(sys.path)
class StorageTest(unittest.TestCase): class StorageTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.dir = "./tmp/storage_test" self.dir = "./tmp/storage_test"
......
...@@ -9,8 +9,8 @@ import onnx ...@@ -9,8 +9,8 @@ import onnx
def debug_print(json_obj): def debug_print(json_obj):
print(json.dumps( print(
json_obj, sort_keys=True, indent=4, separators=(',', ': '))) json.dumps(json_obj, sort_keys=True, indent=4, separators=(',', ': ')))
def reorganize_inout(json_obj, key): def reorganize_inout(json_obj, key):
...@@ -54,8 +54,8 @@ def rename_model(model_json): ...@@ -54,8 +54,8 @@ def rename_model(model_json):
for variable in variables: for variable in variables:
old_name = variable['name'] old_name = variable['name']
new_shape = [int(dim) for dim in variable['shape']] new_shape = [int(dim) for dim in variable['shape']]
new_name = old_name + '\ndata_type=' + str(variable['data_type']) \ new_name = old_name + '\ndata_type=' + str(
+ '\nshape=' + str(new_shape) variable['data_type']) + '\nshape=' + str(new_shape)
variable['name'] = new_name variable['name'] = new_name
rename_edge(model, old_name, new_name) rename_edge(model, old_name, new_name)
...@@ -234,8 +234,8 @@ def get_level_to_all(node_links, model_json): ...@@ -234,8 +234,8 @@ def get_level_to_all(node_links, model_json):
if out_level not in output_to_level: if out_level not in output_to_level:
output_to_level[out_idx] = out_level output_to_level[out_idx] = out_level
else: else:
raise Exception( raise Exception("output " + out_name +
"output " + out_name + "have multiple source") "have multiple source")
level_to_outputs = dict() level_to_outputs = dict()
for out_idx in output_to_level: for out_idx in output_to_level:
level = output_to_level[out_idx] level = output_to_level[out_idx]
...@@ -353,6 +353,7 @@ class GraphPreviewGenerator(object): ...@@ -353,6 +353,7 @@ class GraphPreviewGenerator(object):
''' '''
Generate a graph image for ONNX proto. Generate a graph image for ONNX proto.
''' '''
def __init__(self, model_json): def __init__(self, model_json):
self.model = model_json self.model = model_json
# init graphviz graph # init graphviz graph
...@@ -360,8 +361,7 @@ class GraphPreviewGenerator(object): ...@@ -360,8 +361,7 @@ class GraphPreviewGenerator(object):
self.model['name'], self.model['name'],
layout="dot", layout="dot",
concentrate="true", concentrate="true",
rankdir="TB", rankdir="TB", )
)
self.op_rank = self.graph.rank_group('same', 2) self.op_rank = self.graph.rank_group('same', 2)
self.param_rank = self.graph.rank_group('same', 1) self.param_rank = self.graph.rank_group('same', 1)
...@@ -396,10 +396,9 @@ class GraphPreviewGenerator(object): ...@@ -396,10 +396,9 @@ class GraphPreviewGenerator(object):
self.args.add(target) self.args.add(target)
if source in self.args or target in self.args: if source in self.args or target in self.args:
edge = self.add_edge( self.add_edge(style="dashed,bold", color="#aaaaaa", **item)
style="dashed,bold", color="#aaaaaa", **item)
else: else:
edge = self.add_edge(style="bold", color="#aaaaaa", **item) self.add_edge(style="bold", color="#aaaaaa", **item)
if not show: if not show:
self.graph.display(path) self.graph.display(path)
...@@ -448,8 +447,7 @@ class GraphPreviewGenerator(object): ...@@ -448,8 +447,7 @@ class GraphPreviewGenerator(object):
fontname="Arial", fontname="Arial",
fontcolor="#ffffff", fontcolor="#ffffff",
width="1.3", width="1.3",
height="0.84", height="0.84", )
)
def add_arg(self, name): def add_arg(self, name):
return self.graph.node( return self.graph.node(
...@@ -483,17 +481,16 @@ def draw_graph(model_pb_path, image_dir): ...@@ -483,17 +481,16 @@ def draw_graph(model_pb_path, image_dir):
if min_width is None or im.size[0] < min_width: if min_width is None or im.size[0] < min_width:
min_width = im.size min_width = im.size
best_image = image_path best_image = image_path
except: except Exception:
pass pass
return best_image return best_image
if __name__ == '__main__': if __name__ == '__main__':
import os
import sys import sys
current_path = os.path.abspath(os.path.dirname(sys.argv[0])) current_path = os.path.abspath(os.path.dirname(sys.argv[0]))
json_str = load_model(current_path + "/mock/inception_v1_model.pb") json_str = load_model(current_path + "/mock/inception_v1_model.pb")
#json_str = load_model(current_path + "/mock/squeezenet_model.pb") # json_str = load_model(current_path + "/mock/squeezenet_model.pb")
# json_str = load_model('./mock/shufflenet/model.pb') # json_str = load_model('./mock/shufflenet/model.pb')
debug_print(json_str) debug_print(json_str)
assert json_str assert json_str
......
...@@ -28,7 +28,8 @@ class GraphTest(unittest.TestCase): ...@@ -28,7 +28,8 @@ class GraphTest(unittest.TestCase):
# label_100: (in-edge) # label_100: (in-edge)
# {u'source': u'fire6/squeeze1x1_1', u'target': u'node_34', u'label': u'label_100'} # {u'source': u'fire6/squeeze1x1_1', u'target': u'node_34', u'label': u'label_100'}
self.assertEqual(json_obj['edges'][100]['source'], 'fire6/squeeze1x1_1') self.assertEqual(json_obj['edges'][100]['source'],
'fire6/squeeze1x1_1')
self.assertEqual(json_obj['edges'][100]['target'], 'node_34') self.assertEqual(json_obj['edges'][100]['target'], 'node_34')
self.assertEqual(json_obj['edges'][100]['label'], 'label_100') self.assertEqual(json_obj['edges'][100]['label'], 'label_100')
......
...@@ -9,4 +9,3 @@ cd .. ...@@ -9,4 +9,3 @@ cd ..
python graph_test.py python graph_test.py
rm ./mock/*.pb rm ./mock/*.pb
import os
import random import random
import subprocess import subprocess
import sys
import tempfile
def crepr(v): def crepr(v):
......
...@@ -106,7 +106,8 @@ def get_image_tag_steps(storage, mode, tag): ...@@ -106,7 +106,8 @@ def get_image_tag_steps(storage, mode, tag):
record = image.record(step_index, sample_index) record = image.record(step_index, sample_index)
shape = record.shape() shape = record.shape()
# TODO(ChunweiYan) remove this trick, some shape will be empty # TODO(ChunweiYan) remove this trick, some shape will be empty
if not shape: continue if not shape:
continue
try: try:
query = urllib.urlencode({ query = urllib.urlencode({
'sample': 0, 'sample': 0,
...@@ -121,7 +122,7 @@ def get_image_tag_steps(storage, mode, tag): ...@@ -121,7 +122,7 @@ def get_image_tag_steps(storage, mode, tag):
'wall_time': image.timestamp(step_index), 'wall_time': image.timestamp(step_index),
'query': query, 'query': query,
}) })
except: except Exception:
logger.error("image sample out of range") logger.error("image sample out of range")
return res return res
...@@ -164,7 +165,7 @@ def get_histogram(storage, mode, tag, num_samples=100): ...@@ -164,7 +165,7 @@ def get_histogram(storage, mode, tag, num_samples=100):
try: try:
# some bug with protobuf, some times may overflow # some bug with protobuf, some times may overflow
record = histogram.record(i) record = histogram.record(i)
except: except Exception:
continue continue
res.append([]) res.append([])
...@@ -177,9 +178,7 @@ def get_histogram(storage, mode, tag, num_samples=100): ...@@ -177,9 +178,7 @@ def get_histogram(storage, mode, tag, num_samples=100):
for j in xrange(record.num_instances()): for j in xrange(record.num_instances()):
instance = record.instance(j) instance = record.instance(j)
data.append( data.append(
[instance.left(), [instance.left(), instance.right(), instance.frequency()])
instance.right(),
instance.frequency()])
if len(res) < num_samples: if len(res) < num_samples:
return res return res
...@@ -206,11 +205,12 @@ def retry(ntimes, function, time2sleep, *args, **kwargs): ...@@ -206,11 +205,12 @@ def retry(ntimes, function, time2sleep, *args, **kwargs):
for i in xrange(ntimes): for i in xrange(ntimes):
try: try:
return function(*args, **kwargs) return function(*args, **kwargs)
except: except Exception:
error_info = '\n'.join(map(str, sys.exc_info())) error_info = '\n'.join(map(str, sys.exc_info()))
logger.error("Unexpected error: %s" % error_info) logger.error("Unexpected error: %s" % error_info)
time.sleep(time2sleep) time.sleep(time2sleep)
def cache_get(cache): def cache_get(cache):
def _handler(key, func, *args, **kwargs): def _handler(key, func, *args, **kwargs):
data = cache.get(key) data = cache.get(key)
...@@ -220,4 +220,5 @@ def cache_get(cache): ...@@ -220,4 +220,5 @@ def cache_get(cache):
cache.set(key, data) cache.set(key, data)
return data return data
return data return data
return _handler return _handler
...@@ -8,6 +8,7 @@ from storage_mock import add_histogram, add_image, add_scalar ...@@ -8,6 +8,7 @@ from storage_mock import add_histogram, add_image, add_scalar
_retry_counter = 0 _retry_counter = 0
class LibTest(unittest.TestCase): class LibTest(unittest.TestCase):
def setUp(self): def setUp(self):
dir = "./tmp/mock" dir = "./tmp/mock"
......
...@@ -6,4 +6,3 @@ cp squeezenet/model.pb squeezenet_model.pb ...@@ -6,4 +6,3 @@ cp squeezenet/model.pb squeezenet_model.pb
rm -rf squeezenet rm -rf squeezenet
rm squeezenet.tar.gz rm squeezenet.tar.gz
from setuptools import setup from setuptools import setup
packages = ['visualdl', packages = [
'visualdl.onnx', 'visualdl', 'visualdl.onnx', 'visualdl.mock', 'visualdl.frontend.dist'
'visualdl.mock', ]
'visualdl.frontend.dist']
setup( setup(
name="visualdl", name="visualdl",
......
import random import random
import time
import unittest
import numpy as np import numpy as np
...@@ -36,4 +34,5 @@ def add_histogram(writer, mode, tag, num_buckets): ...@@ -36,4 +34,5 @@ def add_histogram(writer, mode, tag, num_buckets):
with writer.mode(mode) as writer: with writer.mode(mode) as writer:
histogram = writer.histogram(tag, num_buckets) histogram = writer.histogram(tag, num_buckets)
for i in range(10): for i in range(10):
histogram.add_record(i, np.random.normal(0.1 + i * 0.01, size=1000)) histogram.add_record(i, np.random.normal(
0.1 + i * 0.01, size=1000))
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#ifndef VISUALDL_STORAGE_STORAGE_H #ifndef VISUALDL_STORAGE_STORAGE_H
#define VISUALDL_STORAGE_STORAGE_H #define VISUALDL_STORAGE_STORAGE_H
#include "visualdl/utils/logging.h"
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include <vector> #include <vector>
...@@ -25,6 +24,7 @@ limitations under the License. */ ...@@ -25,6 +24,7 @@ limitations under the License. */
#include "visualdl/storage/tablet.h" #include "visualdl/storage/tablet.h"
#include "visualdl/utils/filesystem.h" #include "visualdl/utils/filesystem.h"
#include "visualdl/utils/guard.h" #include "visualdl/utils/guard.h"
#include "visualdl/utils/logging.h"
namespace visualdl { namespace visualdl {
static const std::string meta_file_name = "storage.meta"; static const std::string meta_file_name = "storage.meta";
......
...@@ -15,11 +15,10 @@ limitations under the License. */ ...@@ -15,11 +15,10 @@ limitations under the License. */
#ifndef VISUALDL_TABLET_H #ifndef VISUALDL_TABLET_H
#define VISUALDL_TABLET_H #define VISUALDL_TABLET_H
#include "visualdl/utils/logging.h"
#include "visualdl/logic/im.h" #include "visualdl/logic/im.h"
#include "visualdl/storage/record.h" #include "visualdl/storage/record.h"
#include "visualdl/storage/storage.pb.h" #include "visualdl/storage/storage.pb.h"
#include "visualdl/utils/logging.h"
#include "visualdl/utils/string.h" #include "visualdl/utils/string.h"
namespace visualdl { namespace visualdl {
......
import sys import sys
import unittest import unittest
import numpy as np import numpy as np
sys.path.append('../../build') sys.path.append('../../build') # noqa: E402
import core import core
im = core.im() im = core.im()
...@@ -62,7 +63,7 @@ class TabletTester(unittest.TestCase): ...@@ -62,7 +63,7 @@ class TabletTester(unittest.TestCase):
class ImTester(unittest.TestCase): class ImTester(unittest.TestCase):
def test_persist(self): def test_persist(self):
im.clear_tablets() im.clear_tablets()
tablet = im.add_tablet("tab0", 111) im.add_tablet("tab0", 111)
self.assertEqual(im.storage().tablets_size(), 1) self.assertEqual(im.storage().tablets_size(), 1)
im.storage().set_dir("./1") im.storage().set_dir("./1")
im.persist_to_disk() im.persist_to_disk()
......
...@@ -15,11 +15,11 @@ limitations under the License. */ ...@@ -15,11 +15,11 @@ limitations under the License. */
#ifndef VISUALDL_UTILS_CONCURRENCY_H #ifndef VISUALDL_UTILS_CONCURRENCY_H
#define VISUALDL_UTILS_CONCURRENCY_H #define VISUALDL_UTILS_CONCURRENCY_H
#include "visualdl/utils/logging.h"
#include <chrono> #include <chrono>
#include <memory> #include <memory>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include "visualdl/utils/logging.h"
namespace visualdl { namespace visualdl {
namespace cc { namespace cc {
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#include "visualdl/utils/concurrency.h" #include "visualdl/utils/concurrency.h"
#include "visualdl/utils/logging.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "visualdl/utils/logging.h"
namespace visualdl { namespace visualdl {
......
...@@ -30,4 +30,3 @@ TEST(image, NormalizeImage) { ...@@ -30,4 +30,3 @@ TEST(image, NormalizeImage) {
NormalizeImage(&image, arr, 3, 128); NormalizeImage(&image, arr, 3, 128);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册