提交 4afaaa4b 编写于 作者: Y Yu Yang

Autoformat all files

上级 94538798
# External dependency to Google protobuf. # External dependency to Google protobuf.
http_archive( http_archive(
name = "protobuf", name="protobuf",
url = "http://github.com/google/protobuf/archive/v3.1.0.tar.gz", url="http://github.com/google/protobuf/archive/v3.1.0.tar.gz",
sha256 = "0a0ae63cbffc274efb573bdde9a253e3f32e458c41261df51c5dbc5ad541e8f7", sha256="0a0ae63cbffc274efb573bdde9a253e3f32e458c41261df51c5dbc5ad541e8f7",
strip_prefix = "protobuf-3.1.0", strip_prefix="protobuf-3.1.0", )
)
# External dependency to gtest 1.7.0. This method comes from # External dependency to gtest 1.7.0. This method comes from
# https://www.bazel.io/versions/master/docs/tutorial/cpp.html. # https://www.bazel.io/versions/master/docs/tutorial/cpp.html.
new_http_archive( new_http_archive(
name = "gtest", name="gtest",
url = "https://github.com/google/googletest/archive/release-1.7.0.zip", url="https://github.com/google/googletest/archive/release-1.7.0.zip",
sha256 = "b58cb7547a28b2c718d1e38aee18a3659c9e3ff52440297e965f5edffe34b6d0", sha256="b58cb7547a28b2c718d1e38aee18a3659c9e3ff52440297e965f5edffe34b6d0",
build_file = "third_party/gtest.BUILD", build_file="third_party/gtest.BUILD",
strip_prefix = "googletest-release-1.7.0", strip_prefix="googletest-release-1.7.0", )
)
...@@ -71,9 +71,7 @@ class SentimentPrediction(): ...@@ -71,9 +71,7 @@ class SentimentPrediction():
transform word into integer index according to the dictionary. transform word into integer index according to the dictionary.
""" """
words = data.strip().split() words = data.strip().split()
word_slot = [ word_slot = [self.word_dict[w] for w in words if w in self.word_dict]
self.word_dict[w] for w in words if w in self.word_dict
]
return word_slot return word_slot
def batch_predict(self, data_batch): def batch_predict(self, data_batch):
...@@ -85,8 +83,8 @@ class SentimentPrediction(): ...@@ -85,8 +83,8 @@ class SentimentPrediction():
if self.label is None: if self.label is None:
print("predicting label is %d" % (lab[0])) print("predicting label is %d" % (lab[0]))
else: else:
print("predicting label is %s" % print("predicting label is %s" % (self.label[lab[0]]))
(self.label[lab[0]]))
def option_parser(): def option_parser():
usage = "python predict.py -n config -w model_dir -d dictionary -i input_file " usage = "python predict.py -n config -w model_dir -d dictionary -i input_file "
...@@ -143,9 +141,10 @@ def main(): ...@@ -143,9 +141,10 @@ def main():
batch.append([predict.get_index(line)]) batch.append([predict.get_index(line)])
if len(batch) == batch_size: if len(batch) == batch_size:
predict.batch_predict(batch) predict.batch_predict(batch)
batch=[] batch = []
if len(batch) > 0: if len(batch) > 0:
predict.batch_predict(batch) predict.batch_predict(batch)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -3364,7 +3364,10 @@ def my_fatal(s): ...@@ -3364,7 +3364,10 @@ def my_fatal(s):
logger.critical(s) logger.critical(s)
raise Exception() raise Exception()
_parse_config_hooks = set() _parse_config_hooks = set()
def register_parse_config_hook(f): def register_parse_config_hook(f):
""" """
Register a hook function for parse_config. parse_config will invoke the hook Register a hook function for parse_config. parse_config will invoke the hook
...@@ -3373,6 +3376,7 @@ def register_parse_config_hook(f): ...@@ -3373,6 +3376,7 @@ def register_parse_config_hook(f):
""" """
_parse_config_hooks.add(f) _parse_config_hooks.add(f)
def parse_config(config_file, config_arg_str): def parse_config(config_file, config_arg_str):
''' '''
@param config_arg_str: a string of the form var1=val1,var2=val2. It will be @param config_arg_str: a string of the form var1=val1,var2=val2. It will be
......
...@@ -84,12 +84,15 @@ class DefaultNameFactory(object): ...@@ -84,12 +84,15 @@ class DefaultNameFactory(object):
_name_factories = [] _name_factories = []
def reset_hook(): def reset_hook():
for factory in _name_factories: for factory in _name_factories:
factory.reset() factory.reset()
register_parse_config_hook(reset_hook) register_parse_config_hook(reset_hook)
def wrap_name_default(name_prefix=None): def wrap_name_default(name_prefix=None):
""" """
Decorator to set "name" arguments default to "{name_prefix}_{invoke_count}". Decorator to set "name" arguments default to "{name_prefix}_{invoke_count}".
......
...@@ -17,33 +17,35 @@ import sys ...@@ -17,33 +17,35 @@ import sys
import re import re
import getopt import getopt
def main(print_whole_config, globals, locals): def main(print_whole_config, globals, locals):
''' '''
this test will all test_config.py this test will all test_config.py
''' '''
cmdstr = """from paddle.trainer.config_parser import parse_config\n""" cmdstr = """from paddle.trainer.config_parser import parse_config\n"""
importstr = "" importstr = ""
functionstr = "" functionstr = ""
for line in sys.stdin:
if re.match("^import", line) or re.match("^from.*import", line):
importstr = importstr + line
else:
functionstr = functionstr + " " + line
for line in sys.stdin: cmdstr = cmdstr + importstr + """def configs():\n""" + functionstr
if re.match("^import", line) or re.match("^from.*import", line): #cmdstr = cmdstr + """def configs():\n""" + importstr + functionstr
importstr = importstr + line if print_whole_config:
cmdstr = cmdstr + """print parse_config(configs, "")"""
else: else:
functionstr = functionstr + " " + line cmdstr = cmdstr + """print parse_config(configs, "").model_config"""
cmdstr = cmdstr + importstr + """def configs():\n""" + functionstr exec (cmdstr, globals, locals)
#cmdstr = cmdstr + """def configs():\n""" + importstr + functionstr
if print_whole_config:
cmdstr = cmdstr + """print parse_config(configs, "")"""
else:
cmdstr = cmdstr + """print parse_config(configs, "").model_config"""
exec(cmdstr, globals, locals)
if __name__ == '__main__': if __name__ == '__main__':
whole = False whole = False
opts, args = getopt.getopt(sys.argv[1:], "", ["whole"]) opts, args = getopt.getopt(sys.argv[1:], "", ["whole"])
for op, value in opts: for op, value in opts:
if op == "--whole": if op == "--whole":
whole = True whole = True
main(whole, globals(), locals()) main(whole, globals(), locals())
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
import unittest import unittest
from paddle.trainer.config_parser import parse_config from paddle.trainer.config_parser import parse_config
class TestParse(unittest.TestCase):
class TestParse(unittest.TestCase):
def test_parse(self): def test_parse(self):
a = parse_config( a = parse_config('trainer_config_helpers/tests/layers_test_config.py',
'trainer_config_helpers/tests/layers_test_config.py', '') '')
b = parse_config( b = parse_config('trainer_config_helpers/tests/layers_test_config.py',
'trainer_config_helpers/tests/layers_test_config.py', '') '')
self.assertEqual(a, b) self.assertEqual(a, b)
......
cc_library( cc_library(
name = "main", name="main",
srcs = glob( srcs=glob(
["src/*.cc"], ["src/*.cc"], exclude=["src/gtest-all.cc"]),
exclude = ["src/gtest-all.cc"] hdrs=glob(["include/**/*.h", "src/*.h"]),
), copts=["-Iexternal/gtest/include"],
hdrs = glob([ linkopts=["-pthread"],
"include/**/*.h", visibility=["//visibility:public"], )
"src/*.h"
]),
copts = ["-Iexternal/gtest/include"],
linkopts = ["-pthread"],
visibility = ["//visibility:public"],
)
...@@ -3,25 +3,22 @@ licenses(["notice"]) # Apache 2.0 ...@@ -3,25 +3,22 @@ licenses(["notice"]) # Apache 2.0
load("@protobuf//:protobuf.bzl", "cc_proto_library") load("@protobuf//:protobuf.bzl", "cc_proto_library")
cc_proto_library( cc_proto_library(
name = "example_proto", name="example_proto",
srcs = ["example.proto"], srcs=["example.proto"],
protoc = "@protobuf//:protoc", protoc="@protobuf//:protoc",
default_runtime = "@protobuf//:protobuf", default_runtime="@protobuf//:protobuf", )
)
cc_library( cc_library(
name = "example_lib", name="example_lib",
srcs = ["example_lib.cc"], srcs=["example_lib.cc"],
hdrs = ["example_lib.h"], hdrs=["example_lib.h"],
deps = [":example_proto"], deps=[":example_proto"], )
)
cc_test( cc_test(
name = "example_lib_test", name="example_lib_test",
srcs = ["example_lib_test.cc"], srcs=["example_lib_test.cc"],
copts = ["-Iexternal/gtest/include"], copts=["-Iexternal/gtest/include"],
deps =[ deps=[
"@gtest//:main", "@gtest//:main",
":example_lib", ":example_lib",
], ], )
)
...@@ -3,9 +3,7 @@ ...@@ -3,9 +3,7 @@
namespace third_party { namespace third_party {
namespace protobuf_test { namespace protobuf_test {
std::string get_greet(const Greeting& who) { std::string get_greet(const Greeting& who) { return "Hello " + who.name(); }
return "Hello " + who.name();
}
} // namespace protobuf_test } // namespace protobuf_test
} // namespace thrid_party } // namespace thrid_party
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册