提交 4afaaa4b 编写于 作者: Y Yu Yang

Autoformat all files

上级 94538798
# External dependency to Google protobuf.
http_archive(
name = "protobuf",
url = "http://github.com/google/protobuf/archive/v3.1.0.tar.gz",
sha256 = "0a0ae63cbffc274efb573bdde9a253e3f32e458c41261df51c5dbc5ad541e8f7",
strip_prefix = "protobuf-3.1.0",
)
name="protobuf",
url="http://github.com/google/protobuf/archive/v3.1.0.tar.gz",
sha256="0a0ae63cbffc274efb573bdde9a253e3f32e458c41261df51c5dbc5ad541e8f7",
strip_prefix="protobuf-3.1.0", )
# External dependency to gtest 1.7.0. This method comes from
# https://www.bazel.io/versions/master/docs/tutorial/cpp.html.
new_http_archive(
name = "gtest",
url = "https://github.com/google/googletest/archive/release-1.7.0.zip",
sha256 = "b58cb7547a28b2c718d1e38aee18a3659c9e3ff52440297e965f5edffe34b6d0",
build_file = "third_party/gtest.BUILD",
strip_prefix = "googletest-release-1.7.0",
)
name="gtest",
url="https://github.com/google/googletest/archive/release-1.7.0.zip",
sha256="b58cb7547a28b2c718d1e38aee18a3659c9e3ff52440297e965f5edffe34b6d0",
build_file="third_party/gtest.BUILD",
strip_prefix="googletest-release-1.7.0", )
......@@ -71,9 +71,7 @@ class SentimentPrediction():
transform word into integer index according to the dictionary.
"""
words = data.strip().split()
word_slot = [
self.word_dict[w] for w in words if w in self.word_dict
]
word_slot = [self.word_dict[w] for w in words if w in self.word_dict]
return word_slot
def batch_predict(self, data_batch):
......@@ -85,8 +83,8 @@ class SentimentPrediction():
if self.label is None:
print("predicting label is %d" % (lab[0]))
else:
print("predicting label is %s" %
(self.label[lab[0]]))
print("predicting label is %s" % (self.label[lab[0]]))
def option_parser():
usage = "python predict.py -n config -w model_dir -d dictionary -i input_file "
......@@ -143,9 +141,10 @@ def main():
batch.append([predict.get_index(line)])
if len(batch) == batch_size:
predict.batch_predict(batch)
batch=[]
batch = []
if len(batch) > 0:
predict.batch_predict(batch)
if __name__ == '__main__':
main()
......@@ -3364,7 +3364,10 @@ def my_fatal(s):
logger.critical(s)
raise Exception()
_parse_config_hooks = set()
def register_parse_config_hook(f):
"""
Register a hook function for parse_config. parse_config will invoke the hook
......@@ -3373,6 +3376,7 @@ def register_parse_config_hook(f):
"""
_parse_config_hooks.add(f)
def parse_config(config_file, config_arg_str):
'''
@param config_arg_str: a string of the form var1=val1,var2=val2. It will be
......
......@@ -84,12 +84,15 @@ class DefaultNameFactory(object):
_name_factories = []
def reset_hook():
for factory in _name_factories:
factory.reset()
register_parse_config_hook(reset_hook)
def wrap_name_default(name_prefix=None):
"""
Decorator to set "name" arguments default to "{name_prefix}_{invoke_count}".
......
......@@ -17,6 +17,7 @@ import sys
import re
import getopt
def main(print_whole_config, globals, locals):
'''
this test will all test_config.py
......@@ -38,7 +39,8 @@ def main(print_whole_config, globals, locals):
else:
cmdstr = cmdstr + """print parse_config(configs, "").model_config"""
exec(cmdstr, globals, locals)
exec (cmdstr, globals, locals)
if __name__ == '__main__':
whole = False
......
......@@ -14,13 +14,13 @@
import unittest
from paddle.trainer.config_parser import parse_config
class TestParse(unittest.TestCase):
class TestParse(unittest.TestCase):
def test_parse(self):
a = parse_config(
'trainer_config_helpers/tests/layers_test_config.py', '')
b = parse_config(
'trainer_config_helpers/tests/layers_test_config.py', '')
a = parse_config('trainer_config_helpers/tests/layers_test_config.py',
'')
b = parse_config('trainer_config_helpers/tests/layers_test_config.py',
'')
self.assertEqual(a, b)
......
cc_library(
name = "main",
srcs = glob(
["src/*.cc"],
exclude = ["src/gtest-all.cc"]
),
hdrs = glob([
"include/**/*.h",
"src/*.h"
]),
copts = ["-Iexternal/gtest/include"],
linkopts = ["-pthread"],
visibility = ["//visibility:public"],
)
name="main",
srcs=glob(
["src/*.cc"], exclude=["src/gtest-all.cc"]),
hdrs=glob(["include/**/*.h", "src/*.h"]),
copts=["-Iexternal/gtest/include"],
linkopts=["-pthread"],
visibility=["//visibility:public"], )
......@@ -3,25 +3,22 @@ licenses(["notice"]) # Apache 2.0
load("@protobuf//:protobuf.bzl", "cc_proto_library")
cc_proto_library(
name = "example_proto",
srcs = ["example.proto"],
protoc = "@protobuf//:protoc",
default_runtime = "@protobuf//:protobuf",
)
name="example_proto",
srcs=["example.proto"],
protoc="@protobuf//:protoc",
default_runtime="@protobuf//:protobuf", )
cc_library(
name = "example_lib",
srcs = ["example_lib.cc"],
hdrs = ["example_lib.h"],
deps = [":example_proto"],
)
name="example_lib",
srcs=["example_lib.cc"],
hdrs=["example_lib.h"],
deps=[":example_proto"], )
cc_test(
name = "example_lib_test",
srcs = ["example_lib_test.cc"],
copts = ["-Iexternal/gtest/include"],
deps =[
name="example_lib_test",
srcs=["example_lib_test.cc"],
copts=["-Iexternal/gtest/include"],
deps=[
"@gtest//:main",
":example_lib",
],
)
], )
......@@ -3,9 +3,7 @@
namespace third_party {
namespace protobuf_test {
std::string get_greet(const Greeting& who) {
return "Hello " + who.name();
}
std::string get_greet(const Greeting& who) { return "Hello " + who.name(); }
} // namespace protobuf_test
} // namespace thrid_party
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册