diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index 253d436bcc04d8e0db78f6a4a2c67a050f456bba..45f44f617dcb46062355df4e35d537086215a46d 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -7,8 +7,17 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/eigen3) ExternalProject_Add( eigen3 ${EXTERNAL_PROJECT_LOG_ARGS} - URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz" - URL_MD5 "1a47e78efe365a97de0c022d127607c3" + # for latest version, please get from official website + # URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz" + # URL_MD5 "1a47e78efe365a97de0c022d127607c3" + + # for no-ssl http support, please get from bazel's mirror + # URL "http://mirror.bazel.build/bitbucket.org/eigen/eigen/get/f3a22f35b044.tar.gz" + # URL_MD5 "4645c66075982da6fa0bcf6b20f3e8f7" + + # get from github mirror + GIT_REPOSITORY "https://github.com/RLovelett/eigen.git" + GIT_TAG "a46d2e7337c4656f00abe54a8115f6d76153a048" PREFIX ${EIGEN_SOURCE_DIR} UPDATE_COMMAND "" CONFIGURE_COMMAND "" diff --git a/cmake/external/protobuf.cmake b/cmake/external/protobuf.cmake index d43badc1da50723d5d3dbd1f19f0bd4ef4d24737..2f267adc203f3da80615318f168de9798c537080 100644 --- a/cmake/external/protobuf.cmake +++ b/cmake/external/protobuf.cmake @@ -13,6 +13,10 @@ # limitations under the License. INCLUDE(ExternalProject) +# Always invoke `FIND_PACKAGE(Protobuf)` for importing function protobuf_generate_cpp +FIND_PACKAGE(Protobuf QUIET) +SET(PROTOBUF_FOUND "OFF") + # Print and set the protobuf library information, # finish this cmake process and exit from this file. @@ -39,12 +43,19 @@ macro(PROMPT_PROTOBUF_LIB) ADD_LIBRARY(protobuf_lite ${protobuf_LIBTYPE} IMPORTED GLOBAL) SET_PROPERTY(TARGET protobuf_lite PROPERTY IMPORTED_LOCATION ${PROTOBUF_LITE_LIBRARY}) - ADD_LIBRARY(protoc ${protobuf_LIBTYPE} IMPORTED GLOBAL) - SET_PROPERTY(TARGET protoc PROPERTY IMPORTED_LOCATION ${PROTOC_LIBRARY}) + ADD_LIBRARY(libprotoc ${protobuf_LIBTYPE} IMPORTED GLOBAL) + SET_PROPERTY(TARGET libprotoc PROPERTY IMPORTED_LOCATION ${PROTOC_LIBRARY}) + + ADD_EXECUTABLE(protoc IMPORTED GLOBAL) + SET_PROPERTY(TARGET protoc PROPERTY IMPORTED_LOCATION ${PROTOBUF_PROTOC_EXECUTABLE}) + # FIND_Protobuf.cmake uses `Protobuf_PROTOC_EXECUTABLE`. + # make `protobuf_generate_cpp` happy. + SET(Protobuf_PROTOC_EXECUTABLE ${PROTOBUF_PROTOC_EXECUTABLE}) FOREACH(dep ${protobuf_DEPS}) ADD_DEPENDENCIES(protobuf ${dep}) ADD_DEPENDENCIES(protobuf_lite ${dep}) + ADD_DEPENDENCIES(libprotoc ${dep}) ADD_DEPENDENCIES(protoc ${dep}) ENDFOREACH() diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 11c1f677ae5b308558b54bf49caf168cf6023444..61353a4a2622257eddb05578c5085c44c1719b98 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -87,6 +87,9 @@ # go_library(example SHARED) # +# including binary directory for generated headers. +include_directories(${CMAKE_BINARY_DIR}) + if(NOT APPLE) find_package(Threads REQUIRED) link_libraries(${CMAKE_THREAD_LIBS_INIT}) @@ -331,3 +334,13 @@ function(go_test TARGET_NAME) add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS}) add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}) endfunction(go_test) + +function(proto_library TARGET_NAME) + set(oneValueArgs "") + set(multiValueArgs SRCS) + cmake_parse_arguments(proto_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + set(proto_srcs) + set(proto_hdrs) + protobuf_generate_cpp(proto_srcs proto_hdrs ${proto_library_SRCS}) + cc_library(${TARGET_NAME} SRCS ${proto_srcs} DEPS protobuf) +endfunction() diff --git a/doc/CMakeLists.txt b/doc/CMakeLists.txt index 6fa42fd0c71e78cc2fa6b0fe2cb970baf4ac89ed..94dd3457fb5b513441c4c8e339e1862de9092517 100644 --- a/doc/CMakeLists.txt +++ b/doc/CMakeLists.txt @@ -27,10 +27,6 @@ sphinx_add_target(paddle_docs ${CMAKE_CURRENT_SOURCE_DIR} ${SPHINX_HTML_DIR_EN}) -add_dependencies(paddle_docs - gen_proto_py) - - # configured documentation tools and intermediate build results set(BINARY_BUILD_DIR_CN "${CMAKE_CURRENT_BINARY_DIR}/cn/_build") @@ -51,6 +47,3 @@ sphinx_add_target(paddle_docs_cn ${SPHINX_CACHE_DIR_CN} ${CMAKE_CURRENT_SOURCE_DIR} ${SPHINX_HTML_DIR_CN}) - -add_dependencies(paddle_docs_cn - gen_proto_py) diff --git a/doc/design/scope.md b/doc/design/scope.md index 2ff416f06e8ada48b1d4922f8869a106f35799e2..afe6bc028cafc5ee24b0041905857af58d3f5790 100644 --- a/doc/design/scope.md +++ b/doc/design/scope.md @@ -41,7 +41,7 @@ class Scope { const Variable* GetVariable(const std::string& name) const; private: - std::unordered_map> vars_; + std::unordered_map> vars_; }; ``` @@ -59,9 +59,9 @@ class Scope { Scope(const std::shared_ptr& scope): parent_(scope) {} Variable* GetVariable(const std::string& name) const { - Variable* var = GetVarLocally(name); - if (var != nullptr) { - return var; + auto it = vars_.find(name); + if (it != vars_.end()) { + return it->second.get(); } else if (parent_ != nullptr) { return parent_->GetVariable(name); } else { @@ -97,8 +97,8 @@ class Scope { // return nullptr if not found. Variable* GetVariable(const std::string& name) const; - // return Error if already contains same name variable. - Error CreateVariable(const std::string& name); + // return if already contains same name variable. + Variable* CreateVariable(const std::string& name); private: std::shared_ptr parent_; diff --git a/go/master/c/client.go b/go/master/c/client.go index b186474dc33138aeb02a2ffe34418b379b7a2db0..9e35e986002c0ae3b7593150ece96dba29a1521b 100644 --- a/go/master/c/client.go +++ b/go/master/c/client.go @@ -13,10 +13,13 @@ typedef int paddle_master_client; import "C" import ( + "strings" "sync" + "time" "unsafe" "github.com/PaddlePaddle/Paddle/go/master" + "github.com/coreos/etcd/clientv3" log "github.com/sirupsen/logrus" ) @@ -48,16 +51,33 @@ func remove(client C.paddle_master_client) *master.Client { return h } -type addresser string - -func (a addresser) Address() string { - return string(a) +//export paddle_new_etcd_master_client +func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client { + p := C.GoString(etcdEndpoints) + cli, err := clientv3.New(clientv3.Config{ + Endpoints: strings.Split(p, ","), + DialTimeout: time.Second * time.Duration(timeout), + }) + if err != nil { + panic(err) + } + ch := make(chan string, 1) + a, err := master.GetKey(cli, master.DefaultAddrPath, timeout) + if err != nil { + panic(err) + } + ch <- a + go master.WatchKey(cli, master.DefaultAddrPath, ch) + c := master.NewClient(ch, bufSize) + return add(c) } //export paddle_new_master_client func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client { a := C.GoString(addr) - c := master.NewClient(addresser(a), bufSize) + ch := make(chan string, 1) + ch <- a + c := master.NewClient(ch, bufSize) return add(c) } diff --git a/go/master/client.go b/go/master/client.go index 8451820c1963dd5a4eff0c3ab7763eb6a8e05ba4..d3bea49d0a8166420e83478076cc7bc81e48598d 100644 --- a/go/master/client.go +++ b/go/master/client.go @@ -2,18 +2,12 @@ package master import ( "os" - "time" "github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/recordio" log "github.com/sirupsen/logrus" ) -// Addresser provide the address of the master server. -type Addresser interface { - Address() string -} - // Client is the client of the master server. type Client struct { conn *connection.Conn @@ -24,11 +18,11 @@ type Client struct { // // bufSize is the record buffer size. NextRecord will read from this // buffer. -func NewClient(addr Addresser, bufSize int) *Client { +func NewClient(addrCh <-chan string, bufSize int) *Client { c := &Client{} c.conn = connection.New() c.ch = make(chan []byte, bufSize) - go c.monitorMaster(addr) + go c.monitorMaster(addrCh) go c.getRecords() return c } @@ -72,12 +66,10 @@ func (c *Client) getRecords() { } } -func (c *Client) monitorMaster(addr Addresser) { +func (c *Client) monitorMaster(addrCh <-chan string) { lastMaster := "" - monitor := func() { - // get the lastest address of the master server, + for curMaster := range addrCh { // connect to the new address once address changed. - curMaster := addr.Address() if curMaster != lastMaster { if curMaster == "" { err := c.conn.Close() @@ -94,18 +86,10 @@ func (c *Client) monitorMaster(addr Addresser) { // to retry next time. curMaster = lastMaster } - } } - lastMaster = curMaster } - - monitor() - ticker := time.NewTicker(10 * time.Second) - for _ = range ticker.C { - monitor() - } } // SetDataset set dataset for the master server to dispatch. diff --git a/go/master/client_internal_test.go b/go/master/client_internal_test.go index 251225780ae3077f90655b4e874d03b4f3794525..364dce7b58cf6366af711bde9107559a762563a4 100644 --- a/go/master/client_internal_test.go +++ b/go/master/client_internal_test.go @@ -26,12 +26,6 @@ func init() { log.SetLevel(log.ErrorLevel) } -type TestAddresser string - -func (a TestAddresser) Address() string { - return string(a) -} - func TestGetFinishTask(t *testing.T) { const path = "/tmp/master_client_test_0" @@ -45,7 +39,6 @@ func TestGetFinishTask(t *testing.T) { if err != nil { panic(err) } - go func(l net.Listener) { s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1) if err != nil { @@ -82,9 +75,11 @@ func TestGetFinishTask(t *testing.T) { // Manually intialize client to avoid calling c.getRecords() c := &Client{} c.conn = connection.New() - go c.monitorMaster(TestAddresser(fmt.Sprintf(":%d", p))) + addr := fmt.Sprintf(":%d", p) + ch := make(chan string, 1) + ch <- addr + go c.monitorMaster(ch) c.SetDataset([]string{path}) - checkOnePass := func(i int) { var tasks []Task for idx := 0; idx < totalTask; idx++ { diff --git a/go/master/client_test.go b/go/master/client_test.go index 85a86761c2e5897e3e89cbebfd32f7666c4a9f7f..c00aeebfd5d1fef6de4a8c67bf7f998a42ee863b 100644 --- a/go/master/client_test.go +++ b/go/master/client_test.go @@ -20,7 +20,6 @@ func TestNextRecord(t *testing.T) { path = "/tmp/master_client_TestFull" total = 50 ) - l, err := net.Listen("tcp", ":0") if err != nil { panic(err) @@ -31,7 +30,6 @@ func TestNextRecord(t *testing.T) { if err != nil { panic(err) } - go func(l net.Listener) { s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1) if err != nil { @@ -63,10 +61,10 @@ func TestNextRecord(t *testing.T) { } w.Close() f.Close() - - c := master.NewClient(master.TestAddresser(fmt.Sprintf(":%d", p)), 10) + curAddr := make(chan string, 1) + curAddr <- fmt.Sprintf(":%d", p) + c := master.NewClient(curAddr, 10) c.SetDataset([]string{path}) - for pass := 0; pass < 50; pass++ { received := make(map[byte]bool) for i := 0; i < total; i++ { diff --git a/go/master/etcd_client.go b/go/master/etcd_client.go index f7b463857735070241611af98030c102d1907356..e27c014792f31ca27fe1a1636d69acccc4206ea3 100644 --- a/go/master/etcd_client.go +++ b/go/master/etcd_client.go @@ -142,3 +142,31 @@ func (e *EtcdClient) Load() ([]byte, error) { state := kvs[0].Value return state, nil } + +// GetKey gets the value by the specify key. +func GetKey(c *clientv3.Client, key string, timeout int) (string, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) + resp, err := c.Get(ctx, key) + cancel() + if err != nil { + return "", err + } + kvs := resp.Kvs + if len(kvs) == 0 { + return "", nil + } + v := kvs[0].Value + return string(v), nil +} + +// WatchKey watches the specify key and send to valChan if there is some event. +func WatchKey(c *clientv3.Client, key string, valChan chan<- string) { + rch := c.Watch(context.Background(), key) + for wresp := range rch { + for _, ev := range wresp.Events { + // if received event is DELETE, the value will be an empty string + log.Infof("received event %s, %q : %q\n", ev.Type, ev.Kv.Key, ev.Kv.Value) + valChan <- string(ev.Kv.Value) + } + } +} diff --git a/go/pserver/client.go b/go/pserver/client.go index dda915977282d4880ddcc8c18ef6fd80ede9e01b..6938b9d5ce6f6d73c05bd6e3154777023965c319 100644 --- a/go/pserver/client.go +++ b/go/pserver/client.go @@ -1,6 +1,7 @@ package pserver import ( + "errors" "hash/fnv" "sort" "time" @@ -123,6 +124,9 @@ func (c *Client) FinishInitParams() error { // SendGrads sends gradients to parameter servers for updating // parameters. func (c *Client) SendGrads(grads []Gradient) error { + if len(grads) == 0 { + return errors.New("no gradient received") + } errCh := make(chan error, len(grads)) for _, g := range grads { go func(g Gradient) { diff --git a/paddle/api/CMakeLists.txt b/paddle/api/CMakeLists.txt index f2315e31cc06d8b5fea7a9fd203a697bac603a90..39d8aa075bc072d37dc8df67746f0d2b503418a6 100644 --- a/paddle/api/CMakeLists.txt +++ b/paddle/api/CMakeLists.txt @@ -16,7 +16,7 @@ set(API_HEADER Internal.h) add_library(paddle_api STATIC ${API_SOURCES}) -add_dependencies(paddle_api gen_proto_cpp paddle_trainer_lib) +add_dependencies(paddle_api paddle_proto paddle_trainer_lib) INCLUDE(${SWIG_USE_FILE}) INCLUDE_DIRECTORIES(${PROJ_ROOT}/paddle) diff --git a/paddle/capi/CMakeLists.txt b/paddle/capi/CMakeLists.txt index 206f512563466d40e9ad1db0ddb4753ffb6bf55a..11022d17541476c97a2b29be8eb8fecce7e39435 100644 --- a/paddle/capi/CMakeLists.txt +++ b/paddle/capi/CMakeLists.txt @@ -26,7 +26,7 @@ target_include_directories(paddle_capi PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) add_style_check_target(paddle_capi ${CAPI_SOURCES} ${CAPI_HEADER} ${CAPI_PRIVATE_HEADER}) -add_dependencies(paddle_capi gen_proto_cpp) +add_dependencies(paddle_capi paddle_proto) # combine all paddle static libraries together, into libpaddle_capi_whole.a diff --git a/paddle/cuda/CMakeLists.txt b/paddle/cuda/CMakeLists.txt index f9061e96deb659dcf7bfb88b46e6509af0425199..73ffa690d9d91b673079fc0ecf91f17cbabfdb1e 100755 --- a/paddle/cuda/CMakeLists.txt +++ b/paddle/cuda/CMakeLists.txt @@ -83,7 +83,7 @@ else() ${CUDA_CXX_SOURCES}) endif() -add_dependencies(paddle_cuda ${external_project_dependencies}) +add_dependencies(paddle_cuda paddle_proto ${external_project_dependencies}) add_style_check_target(paddle_cuda ${CUDA_SOURCES} diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index e3c3155aa902c941058ea1b15488360df6c06175..6aa6b9bc2db6a223dd8562b76ba9d777206bfd40 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -1,6 +1,7 @@ +# ddim lib cc_library(ddim SRCS ddim.cc) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) - nv_test(dim_test SRCS dim_test.cu DEPS ddim) - cc_test(variable_test SRCS variable_test.cc) +cc_test(scope_test SRCS scope_test.cc) +cc_test(enforce_test SRCS enforce_test.cc) diff --git a/paddle/framework/enforce.h b/paddle/framework/enforce.h new file mode 100644 index 0000000000000000000000000000000000000000..56cb7f95647e81efef58b156002d0d378ee22820 --- /dev/null +++ b/paddle/framework/enforce.h @@ -0,0 +1,69 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +namespace paddle { +namespace framework { + +/** + * @brief Enforce exception. Inherits std::exception + * + * All enforce condition not met, will throw an EnforceNotMet exception. + */ +class EnforceNotMet : public std::exception { + public: + EnforceNotMet(const std::string& msg, const char* file, int fileline) { + std::ostringstream sout; + sout << msg << " at [" << file << ":" << fileline << "];"; + all_msg_ = sout.str(); + } + + const char* what() const noexcept override { return all_msg_.c_str(); } + + private: + std::string all_msg_; +}; + +// From https://stackoverflow.com/questions/30130930/ +// __buildin_expect is in C++ 11 standard. Since the condition which enforced +// should be true in most situation, it will make the compiler generate faster +// code by adding `UNLIKELY` macro. +#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) + +/** + * @brief Throw a EnforceNotMet exception, automatically filled __FILE__ & + * __LINE__ + * + * This macro take __VA_ARGS__, user can pass any type if that type can + * serialize to std::ostream + */ +#define PADDLE_THROW(...) \ + do { \ + throw ::paddle::framework::EnforceNotMet( \ + ::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \ + } while (0) + +/** + * @brief Enforce a condition, otherwise throw an EnforceNotMet + */ +#define PADDLE_ENFORCE(condition, ...) \ + do { \ + if (UNLIKELY(!(condition))) { \ + PADDLE_THROW(__VA_ARGS__); \ + } \ + } while (0) + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/enforce_test.cc b/paddle/framework/enforce_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f8da1a192f63a54324d80725c9d2f156fb11a481 --- /dev/null +++ b/paddle/framework/enforce_test.cc @@ -0,0 +1,35 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +TEST(ENFORCE, OK) { + PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345); + size_t val = 1; + const size_t limit = 10; + PADDLE_ENFORCE(val < limit, "Enforce is OK too"); +} + +TEST(ENFORCE, FAILED) { + bool in_catch = false; + try { + PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123); + } catch (paddle::framework::EnforceNotMet err) { + in_catch = true; + std::string msg = "Enforce is not ok 123 at all"; + const char* what = err.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + ASSERT_TRUE(in_catch); +} \ No newline at end of file diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h new file mode 100644 index 0000000000000000000000000000000000000000..a4470f726fb0d59a82db29b3239c111ea1569c55 --- /dev/null +++ b/paddle/framework/scope.h @@ -0,0 +1,95 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/framework/variable.h" + +namespace paddle { +namespace framework { + +/** + * @brief Scope that manage all variables. + * + * Scope is an association of a name to Variable. All variables belong to + * Scope. You need to specify a scope to run a Net, i.e., `net.Run(&scope)`. + * One net can run in different scopes and update different variable in the + * scope. + */ +class Scope { + public: + /** + * @brief Initialize s Scope without parent. + */ + Scope() {} + + /** + * @brief Initialize a Scope with parent. + */ + explicit Scope(const std::shared_ptr& parent) : parent_(parent) {} + + /** + * @brief Create Variable + * + * Create Variable in this Scope. Return the exist one if Variable already + * been created. + */ + Variable* CreateVariable(const std::string& name) { + auto var = GetVariable(name); + if (var) { + return var; + } else { + vars_[name] = std::unique_ptr(new Variable()); + return GetVariable(name); + } + } + + /** + * @brief Get Variable. + * + * Get Variable from this Scope, this function will recursive find Variable + * from it's parent scope. Return nullptr if not found. + */ + Variable* GetVariable(const std::string& name) const { + auto it = vars_.find(name); + if (it != vars_.end()) { + return it->second.get(); + } else if (parent_ != nullptr) { + return parent_->GetVariable(name); + } else { + return nullptr; + } + } + + /** + * @brief If this scope has a Var named name. + * + * Find if there is a Variable in this scope and it's parent scope + */ + bool HasVariable(const std::string& name) const { + return (vars_.find(name) != vars_.end() || + (parent_ && parent_->HasVariable(name))); + } + + private: + std::unordered_map> vars_; + std::shared_ptr parent_{nullptr}; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/scope_test.cc b/paddle/framework/scope_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..df1afb200ce9e75c5b1e40f2da56667487ae3576 --- /dev/null +++ b/paddle/framework/scope_test.cc @@ -0,0 +1,58 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/scope.h" +#include "gtest/gtest.h" + +TEST(Scope, Create) { + using paddle::framework::Scope; + using paddle::framework::Variable; + + auto scope = std::make_shared(); + + Variable* var0 = scope->CreateVariable(""); + EXPECT_NE(var0, nullptr); + + /// GetVariable will return nullptr if not exist. + Variable* var1 = scope->GetVariable("a"); + EXPECT_EQ(var1, nullptr); + + /// CreateVariable will return one. + Variable* var2 = scope->CreateVariable("a"); + EXPECT_NE(var2, nullptr); + + /// Get the created variable. + Variable* var3 = scope->GetVariable("a"); + EXPECT_EQ(var2, var3); + + /// CreateVariable will just return the variable if it's + /// already exist. + Variable* var4 = scope->CreateVariable("a"); + EXPECT_EQ(var4, var2); +} + +TEST(Scope, Parent) { + using paddle::framework::Scope; + using paddle::framework::Variable; + + auto parent_scope = std::make_shared(); + auto scope = std::make_shared(parent_scope); + + Variable* var0 = parent_scope->CreateVariable("a"); + EXPECT_NE(var0, nullptr); + + /// GetVariable will get Variable from parent scope if exist. + Variable* var1 = scope->GetVariable("a"); + EXPECT_EQ(var0, var1); +} diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index 5e170714cf5b183fcf6e76d34746333397e6b060..1c39ced3c9e3da4079a66e29c00be9cc18411b68 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -12,7 +12,7 @@ endif() add_library(paddle_function STATIC ${cpp_files} ${cu_objs}) add_dependencies(paddle_function ${external_project_dependencies}) -add_dependencies(paddle_function gen_proto_cpp) +add_dependencies(paddle_function paddle_proto) if(WITH_TESTING) if(WITH_GPU) diff --git a/paddle/gserver/CMakeLists.txt b/paddle/gserver/CMakeLists.txt index 93a6a99848aa13bb36c9c5c7091fbaa891fc9823..0012636b8f618a1b45cfc801c04781e67694956f 100644 --- a/paddle/gserver/CMakeLists.txt +++ b/paddle/gserver/CMakeLists.txt @@ -58,7 +58,7 @@ endif() add_style_check_target(paddle_gserver ${GSERVER_SOURCES}) add_style_check_target(paddle_gserver ${GSERVER_HEADER}) -add_dependencies(paddle_gserver gen_proto_cpp) +add_dependencies(paddle_gserver paddle_proto ${external_project_dependencies}) if(WITH_TESTING) add_subdirectory(tests) endif() diff --git a/paddle/math/CMakeLists.txt b/paddle/math/CMakeLists.txt index f5657c4690ca71200346efd4e2c5244c02c92eb1..9981de61606bda6baac103592125b929d4c12a3d 100644 --- a/paddle/math/CMakeLists.txt +++ b/paddle/math/CMakeLists.txt @@ -33,7 +33,7 @@ endif() add_style_check_target(paddle_math ${MATH_SOURCES}) add_style_check_target(paddle_math ${MATH_HEADERS}) -add_dependencies(paddle_math gen_proto_cpp) # depends +add_dependencies(paddle_math paddle_proto ${external_project_dependencies}) # depends if(WITH_TESTING) add_subdirectory(tests) endif() diff --git a/paddle/optimizer/CMakeLists.txt b/paddle/optimizer/CMakeLists.txt index 4536f62ec7c2c3423d91e309dee993d4212160fe..9996d01d18b1185e9b01f8b1e4aab325eb28c894 100644 --- a/paddle/optimizer/CMakeLists.txt +++ b/paddle/optimizer/CMakeLists.txt @@ -10,7 +10,7 @@ set(OPITMIZER_SRCS ) add_library(paddle_optimizer STATIC ${OPITMIZER_SRCS}) -add_dependencies(paddle_optimizer gen_proto_cpp) +add_dependencies(paddle_optimizer paddle_proto ${external_project_dependencies}) if(WITH_TESTING) add_simple_unittest(serialization_test) diff --git a/paddle/parameter/CMakeLists.txt b/paddle/parameter/CMakeLists.txt index a35e46997fb04e9378e106bf428a629b286c2e8c..d2ae1c16c6b7316f1a6facdef4b933693d6ba818 100644 --- a/paddle/parameter/CMakeLists.txt +++ b/paddle/parameter/CMakeLists.txt @@ -7,7 +7,7 @@ add_library(paddle_parameter STATIC ${PARAMETERS_SOURCES}) add_style_check_target(paddle_parameter ${PARAMETERS_SOURCES}) add_style_check_target(paddle_parameter ${PARAMETERS_HEADERS}) -add_dependencies(paddle_parameter gen_proto_cpp) +add_dependencies(paddle_parameter paddle_proto ${external_project_dependencies}) if(WITH_TESTING) add_subdirectory(tests) endif() diff --git a/paddle/pserver/CMakeLists.txt b/paddle/pserver/CMakeLists.txt index b7f85ea1a6dfda2a37c315ba15c6ca1979cf4131..2245c7d88ca74922f9919db91977dfa6cb3ca468 100644 --- a/paddle/pserver/CMakeLists.txt +++ b/paddle/pserver/CMakeLists.txt @@ -17,7 +17,7 @@ add_library(paddle_network STATIC add_style_check_target(paddle_network ${NETWORK_SOURCES}) add_style_check_target(paddle_network ${NETWORK_HEADERS}) -add_dependencies(paddle_network gen_proto_cpp) +add_dependencies(paddle_network paddle_proto ${external_project_dependencies}) ################### paddle_pserver ###################### set(PSERVER_SOURCES @@ -40,7 +40,7 @@ add_library(paddle_pserver STATIC add_style_check_target(paddle_pserver ${PSERVER_SOURCES}) add_style_check_target(paddle_pserver ${PSERVER_HEADERS}) -add_dependencies(paddle_pserver gen_proto_cpp) +add_dependencies(paddle_pserver paddle_proto ${external_project_dependencies}) set(PSERVER_MAIN_SOURCES ParameterServer2Main.cpp) diff --git a/paddle/py_paddle/dataprovider_converter.py b/paddle/py_paddle/dataprovider_converter.py index 218cb5ec560ed0717d96a50e9560492ee55b9f70..43614b9779d21795f1f274589ea93639e923ce75 100644 --- a/paddle/py_paddle/dataprovider_converter.py +++ b/paddle/py_paddle/dataprovider_converter.py @@ -144,7 +144,7 @@ class DenseScanner(IScanner): if len(self.__shape__) > 1: # The last-two dimenstions are the frame height and width. # For example, the layout is CHW for 3-D feature of image. - # The H and W are the fram height and width. + # The H and W are the frame height and width. h, w = self.__shape__[-2:] argument.setSlotFrameHeight(self.pos, h) argument.setSlotFrameWidth(self.pos, w) diff --git a/paddle/string/CMakeLists.txt b/paddle/string/CMakeLists.txt index 0f39660a90aa6d202badde31ae7a9210d0256aed..5becf62672d0c606c98ea1a1a4383df97088ab05 100644 --- a/paddle/string/CMakeLists.txt +++ b/paddle/string/CMakeLists.txt @@ -1,2 +1,4 @@ cc_library(stringpiece SRCS piece.cc) cc_test(stringpiece_test SRCS piece_test.cc DEPS stringpiece glog gflags) + +cc_test(stringprintf_test SRCS printf_test.cc DEPS glog gflags) diff --git a/paddle/string/printf.h b/paddle/string/printf.h new file mode 100644 index 0000000000000000000000000000000000000000..8b5ce63a8e8dfe77962ff1e7415911d381a28aac --- /dev/null +++ b/paddle/string/printf.h @@ -0,0 +1,99 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +// Compared with std::stringstream, there are primary purpose of +// string::Printf: +// +// 1. Type-safe printing, with why and how explained in +// http://www.drdobbs.com/stringprintf-a-typesafe-printf-family-fo/184401999. +// Implementation includes +// +// https://github.com/c42f/tinyformat +// boost::format +// std::stringstream +// +// std::stringstream is not convenient enough in many cases. For example: +// +// std::cout << std::setprecision(2) << std::fixed << 1.23456 << "\n"; +// +// boost::format is the most convenient one. We can have +// +// std::cout << format("%2% %1%") % 36 % 77; +// +// or +// +// format fmter("%2% %1%"); +// fmter % 36; fmter % 77; +// std::cout << fmter.c_str(); +// +// But the overloading of % might be overkilling and it would be +// more efficient if it can write to std::cout directly. +// +// tinyformat has an interface compatible with the C-printf style, +// and it can writes to a stream or returns a std::string: +// +// std::cout << tfm::printf( +// "%s, %s %d, %.2d:%.2d\n", +// weekday, month, day, hour, min); +// +// or +// +// tfm::format(std::cout, +// "%s, %s %d, %.2d:%.2d\n", +// weekday, month, day, hour, min); +// +// 2. High-performance -- most printed strings are not too long and +// doens't need dynamic memory allocation. Many StringPrintf +// implementations doesn't enforce type-safe, but are +// high-performance, including +// +// https://developers.google.com/optimization/reference/base/stringprintf/ +// https://github.com/adobe/chromium/blob/master/base/stringprintf.h +// https://github.com/google/protobuf/blob/master/src/google/protobuf/stubs/stringprintf.h +// +// According to +// https://github.com/c42f/tinyformat#compile-time-and-code-bloat, +// boost::format runs too slow and results in large executable binary +// files. So here we port tinyformat. + +#pragma once + +#include +#include +#include "paddle/string/tinyformat/tinyformat.h" // https://github.com/c42f/tinyformat + +namespace paddle { +namespace string { + +template +void Fprintf(std::ostream& out, const char* fmt, const Args&... args) { + tinyformat::vformat(out, fmt, tinyformat::makeFormatList(args...)); +} + +template +std::string Sprintf(const char* fmt, const Args&... args) { + std::ostringstream oss; + Fprintf(oss, fmt, args...); + return oss.str(); +} + +template +void Printf(const char* fmt, const Args&... args) { + Fprintf(std::cout, fmt, args...); +} + +} // namespace string +} // namespace paddle diff --git a/paddle/string/printf_test.cc b/paddle/string/printf_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d8f2454165d741b3937f908dcfd87501940750d5 --- /dev/null +++ b/paddle/string/printf_test.cc @@ -0,0 +1,16 @@ +#include "paddle/string/printf.h" + +#include + +#include "gtest/gtest.h" + +TEST(StringPrintf, StringPrintf) { + std::string weekday = "Wednesday"; + const char* month = "July"; + size_t day = 27; + long hour = 14; + int min = 44; + EXPECT_EQ(std::string("Wednesday, July 27, 14:44"), + paddle::string::Sprintf( + "%s, %s %d, %.2d:%.2d", weekday, month, day, hour, min)); +} diff --git a/paddle/string/tinyformat/tinyformat.h b/paddle/string/tinyformat/tinyformat.h new file mode 100644 index 0000000000000000000000000000000000000000..f0e5e0160fb018b813c1dade727da2861a295147 --- /dev/null +++ b/paddle/string/tinyformat/tinyformat.h @@ -0,0 +1,902 @@ +// tinyformat.h +// Copyright (C) 2011, Chris Foster [chris42f (at) gmail (d0t) com] +// +// Boost Software License - Version 1.0 +// +// Permission is hereby granted, free of charge, to any person or organization +// obtaining a copy of the software and accompanying documentation covered by +// this license (the "Software") to use, reproduce, display, distribute, +// execute, and transmit the Software, and to prepare derivative works of the +// Software, and to permit third-parties to whom the Software is furnished to +// do so, all subject to the following: +// +// The copyright notices in the Software and this entire statement, including +// the above license grant, this restriction and the following disclaimer, +// must be included in all copies of the Software, in whole or in part, and +// all derivative works of the Software, unless such copies or derivative +// works are solely in the form of machine-executable object code generated by +// a source language processor. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +// SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +// FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +//------------------------------------------------------------------------------ +// Tinyformat: A minimal type safe printf replacement +// +// tinyformat.h is a type safe printf replacement library in a single C++ +// header file. Design goals include: +// +// * Type safety and extensibility for user defined types. +// * C99 printf() compatibility, to the extent possible using std::ostream +// * Simplicity and minimalism. A single header file to include and distribute +// with your projects. +// * Augment rather than replace the standard stream formatting mechanism +// * C++98 support, with optional C++11 niceties +// +// +// Main interface example usage +// ---------------------------- +// +// To print a date to std::cout: +// +// std::string weekday = "Wednesday"; +// const char* month = "July"; +// size_t day = 27; +// long hour = 14; +// int min = 44; +// +// tfm::printf("%s, %s %d, %.2d:%.2d\n", weekday, month, day, hour, min); +// +// The strange types here emphasize the type safety of the interface; it is +// possible to print a std::string using the "%s" conversion, and a +// size_t using the "%d" conversion. A similar result could be achieved +// using either of the tfm::format() functions. One prints on a user provided +// stream: +// +// tfm::format(std::cerr, "%s, %s %d, %.2d:%.2d\n", +// weekday, month, day, hour, min); +// +// The other returns a std::string: +// +// std::string date = tfm::format("%s, %s %d, %.2d:%.2d\n", +// weekday, month, day, hour, min); +// std::cout << date; +// +// These are the three primary interface functions. There is also a +// convenience function printfln() which appends a newline to the usual result +// of printf() for super simple logging. +// +// +// User defined format functions +// ----------------------------- +// +// Simulating variadic templates in C++98 is pretty painful since it requires +// writing out the same function for each desired number of arguments. To make +// this bearable tinyformat comes with a set of macros which are used +// internally to generate the API, but which may also be used in user code. +// +// The three macros TINYFORMAT_ARGTYPES(n), TINYFORMAT_VARARGS(n) and +// TINYFORMAT_PASSARGS(n) will generate a list of n argument types, +// type/name pairs and argument names respectively when called with an integer +// n between 1 and 16. We can use these to define a macro which generates the +// desired user defined function with n arguments. To generate all 16 user +// defined function bodies, use the macro TINYFORMAT_FOREACH_ARGNUM. For an +// example, see the implementation of printf() at the end of the source file. +// +// Sometimes it's useful to be able to pass a list of format arguments through +// to a non-template function. The FormatList class is provided as a way to do +// this by storing the argument list in a type-opaque way. Continuing the +// example from above, we construct a FormatList using makeFormatList(): +// +// FormatListRef formatList = tfm::makeFormatList(weekday, month, day, hour, +// min); +// +// The format list can now be passed into any non-template function and used +// via a call to the vformat() function: +// +// tfm::vformat(std::cout, "%s, %s %d, %.2d:%.2d\n", formatList); +// +// +// Additional API information +// -------------------------- +// +// Error handling: Define TINYFORMAT_ERROR to customize the error handling for +// format strings which are unsupported or have the wrong number of format +// specifiers (calls assert() by default). +// +// User defined types: Uses operator<< for user defined types by default. +// Overload formatValue() for more control. + +#pragma once + +#include +#include +#include +#include + +namespace paddle { +namespace string { +namespace tinyformat { + +#ifndef TINYFORMAT_ERROR +#define TINYFORMAT_ERROR(reason) assert(0 && reason) +#endif + +//------------------------------------------------------------------------------ +namespace detail { + +// Test whether type T1 is convertible to type T2 +template +struct is_convertible { +private: + // two types of different size + struct fail { + char dummy[2]; + }; + struct succeed { + char dummy; + }; + // Try to convert a T1 to a T2 by plugging into tryConvert + static fail tryConvert(...); + static succeed tryConvert(const T2 &); + static const T1 &makeT1(); + +public: + // Standard trick: the (...) version of tryConvert will be chosen from + // the overload set only if the version taking a T2 doesn't match. + // Then we compare the sizes of the return types to check which + // function matched. Very neat, in a disgusting kind of way :) + static const bool value = sizeof(tryConvert(makeT1())) == sizeof(succeed); +}; + +// Format the value by casting to type fmtT. This default implementation +// should never be called. +template ::value> +struct formatValueAsType { + static void invoke(std::ostream & /*out*/, const T & /*value*/) { assert(0); } +}; +// Specialized version for types that can actually be converted to fmtT, as +// indicated by the "convertible" template parameter. +template +struct formatValueAsType { + static void invoke(std::ostream &out, const T &value) { + out << static_cast(value); + } +}; + +// Convert an arbitrary type to integer. The version with convertible=false +// throws an error. +template ::value> +struct convertToInt { + static int invoke(const T & /*value*/) { + TINYFORMAT_ERROR( + "tinyformat: Cannot convert from argument type to " + "integer for use as variable width or precision"); + return 0; + } +}; +// Specialization for convertToInt when conversion is possible +template +struct convertToInt { + static int invoke(const T &value) { return static_cast(value); } +}; + +// Format at most ntrunc characters to the given stream. +template +inline void formatTruncated(std::ostream &out, const T &value, int ntrunc) { + std::ostringstream tmp; + tmp << value; + std::string result = tmp.str(); + out.write(result.c_str(), + (std::min)(ntrunc, static_cast(result.size()))); +} +#define TINYFORMAT_DEFINE_FORMAT_TRUNCATED_CSTR(type) \ + inline void formatTruncated(std::ostream &out, type *value, int ntrunc) { \ + std::streamsize len = 0; \ + while (len < ntrunc && value[len] != 0) ++len; \ + out.write(value, len); \ + } +// Overload for const char* and char*. Could overload for signed & unsigned +// char too, but these are technically unneeded for printf compatibility. +TINYFORMAT_DEFINE_FORMAT_TRUNCATED_CSTR(const char) +TINYFORMAT_DEFINE_FORMAT_TRUNCATED_CSTR(char) +#undef TINYFORMAT_DEFINE_FORMAT_TRUNCATED_CSTR + +} // namespace detail + +//------------------------------------------------------------------------------ +// Variable formatting functions. May be overridden for user-defined types if +// desired. + +/// Format a value into a stream, delegating to operator<< by default. +/// +/// Users may override this for their own types. When this function is called, +/// the stream flags will have been modified according to the format string. +/// The format specification is provided in the range [fmtBegin, fmtEnd). For +/// truncating conversions, ntrunc is set to the desired maximum number of +/// characters, for example "%.7s" calls formatValue with ntrunc = 7. +/// +/// By default, formatValue() uses the usual stream insertion operator +/// operator<< to format the type T, with special cases for the %c and %p +/// conversions. +template +inline void formatValue(std::ostream &out, + const char * /*fmtBegin*/, + const char *fmtEnd, + int ntrunc, + const T &value) { + // The mess here is to support the %c and %p conversions: if these + // conversions are active we try to convert the type to a char or const + // void* respectively and format that instead of the value itself. For the + // %p conversion it's important to avoid dereferencing the pointer, which + // could otherwise lead to a crash when printing a dangling (const char*). + const bool canConvertToChar = detail::is_convertible::value; + const bool canConvertToVoidPtr = + detail::is_convertible::value; + if (canConvertToChar && *(fmtEnd - 1) == 'c') + detail::formatValueAsType::invoke(out, value); + else if (canConvertToVoidPtr && *(fmtEnd - 1) == 'p') + detail::formatValueAsType::invoke(out, value); + else if (ntrunc >= 0) { + // Take care not to overread C strings in truncating conversions like + // "%.4s" where at most 4 characters may be read. + detail::formatTruncated(out, value, ntrunc); + } else + out << value; +} + +// Overloaded version for char types to support printing as an integer +#define TINYFORMAT_DEFINE_FORMATVALUE_CHAR(charType) \ + inline void formatValue(std::ostream &out, \ + const char * /*fmtBegin*/, \ + const char *fmtEnd, \ + int /**/, \ + charType value) { \ + switch (*(fmtEnd - 1)) { \ + case 'u': \ + case 'd': \ + case 'i': \ + case 'o': \ + case 'X': \ + case 'x': \ + out << static_cast(value); \ + break; \ + default: \ + out << value; \ + break; \ + } \ + } +// per 3.9.1: char, signed char and unsigned char are all distinct types +TINYFORMAT_DEFINE_FORMATVALUE_CHAR(char) +TINYFORMAT_DEFINE_FORMATVALUE_CHAR(signed char) +TINYFORMAT_DEFINE_FORMATVALUE_CHAR(unsigned char) +#undef TINYFORMAT_DEFINE_FORMATVALUE_CHAR + +//------------------------------------------------------------------------------ +// Tools for emulating variadic templates in C++98. The basic idea here is +// stolen from the boost preprocessor metaprogramming library and cut down to +// be just general enough for what we need. + +#define TINYFORMAT_ARGTYPES(n) TINYFORMAT_ARGTYPES_##n +#define TINYFORMAT_VARARGS(n) TINYFORMAT_VARARGS_##n +#define TINYFORMAT_PASSARGS(n) TINYFORMAT_PASSARGS_##n +#define TINYFORMAT_PASSARGS_TAIL(n) TINYFORMAT_PASSARGS_TAIL_##n + +// To keep it as transparent as possible, the macros below have been generated +// using python via the excellent cog.py code generation script. This avoids +// the need for a bunch of complex (but more general) preprocessor tricks as +// used in boost.preprocessor. +// +// To rerun the code generation in place, use `cog.py -r tinyformat.h` +// (see http://nedbatchelder.com/code/cog). Alternatively you can just create +// extra versions by hand. + +/*[[[cog +maxParams = 16 + +def makeCommaSepLists(lineTemplate, elemTemplate, startInd=1): + for j in range(startInd,maxParams+1): + list = ', '.join([elemTemplate % {'i':i} for i in range(startInd,j+1)]) + cog.outl(lineTemplate % {'j':j, 'list':list}) + +makeCommaSepLists('#define TINYFORMAT_ARGTYPES_%(j)d %(list)s', + 'class T%(i)d') + +cog.outl() +makeCommaSepLists('#define TINYFORMAT_VARARGS_%(j)d %(list)s', + 'const T%(i)d& v%(i)d') + +cog.outl() +makeCommaSepLists('#define TINYFORMAT_PASSARGS_%(j)d %(list)s', 'v%(i)d') + +cog.outl() +cog.outl('#define TINYFORMAT_PASSARGS_TAIL_1') +makeCommaSepLists('#define TINYFORMAT_PASSARGS_TAIL_%(j)d , %(list)s', + 'v%(i)d', startInd = 2) + +cog.outl() +cog.outl('#define TINYFORMAT_FOREACH_ARGNUM(m) \\\n ' + + ' '.join(['m(%d)' % (j,) for j in range(1,maxParams+1)])) +]]]*/ +#define TINYFORMAT_ARGTYPES_1 class T1 +#define TINYFORMAT_ARGTYPES_2 class T1, class T2 +#define TINYFORMAT_ARGTYPES_3 class T1, class T2, class T3 +#define TINYFORMAT_ARGTYPES_4 class T1, class T2, class T3, class T4 +#define TINYFORMAT_ARGTYPES_5 class T1, class T2, class T3, class T4, class T5 +#define TINYFORMAT_ARGTYPES_6 \ + class T1, class T2, class T3, class T4, class T5, class T6 +#define TINYFORMAT_ARGTYPES_7 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7 +#define TINYFORMAT_ARGTYPES_8 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, class T8 +#define TINYFORMAT_ARGTYPES_9 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9 +#define TINYFORMAT_ARGTYPES_10 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10 +#define TINYFORMAT_ARGTYPES_11 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10, class T11 +#define TINYFORMAT_ARGTYPES_12 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10, class T11, class T12 +#define TINYFORMAT_ARGTYPES_13 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10, class T11, class T12, class T13 +#define TINYFORMAT_ARGTYPES_14 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10, class T11, class T12, class T13, \ + class T14 +#define TINYFORMAT_ARGTYPES_15 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10, class T11, class T12, class T13, \ + class T14, class T15 +#define TINYFORMAT_ARGTYPES_16 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10, class T11, class T12, class T13, \ + class T14, class T15, class T16 + +#define TINYFORMAT_VARARGS_1 const T1 &v1 +#define TINYFORMAT_VARARGS_2 const T1 &v1, const T2 &v2 +#define TINYFORMAT_VARARGS_3 const T1 &v1, const T2 &v2, const T3 &v3 +#define TINYFORMAT_VARARGS_4 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4 +#define TINYFORMAT_VARARGS_5 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5 +#define TINYFORMAT_VARARGS_6 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6 +#define TINYFORMAT_VARARGS_7 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7 +#define TINYFORMAT_VARARGS_8 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8 +#define TINYFORMAT_VARARGS_9 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9 +#define TINYFORMAT_VARARGS_10 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10 +#define TINYFORMAT_VARARGS_11 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \ + const T11 &v11 +#define TINYFORMAT_VARARGS_12 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \ + const T11 &v11, const T12 &v12 +#define TINYFORMAT_VARARGS_13 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \ + const T11 &v11, const T12 &v12, const T13 &v13 +#define TINYFORMAT_VARARGS_14 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \ + const T11 &v11, const T12 &v12, const T13 &v13, const T14 &v14 +#define TINYFORMAT_VARARGS_15 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \ + const T11 &v11, const T12 &v12, const T13 &v13, const T14 &v14, \ + const T15 &v15 +#define TINYFORMAT_VARARGS_16 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \ + const T11 &v11, const T12 &v12, const T13 &v13, const T14 &v14, \ + const T15 &v15, const T16 &v16 + +#define TINYFORMAT_PASSARGS_1 v1 +#define TINYFORMAT_PASSARGS_2 v1, v2 +#define TINYFORMAT_PASSARGS_3 v1, v2, v3 +#define TINYFORMAT_PASSARGS_4 v1, v2, v3, v4 +#define TINYFORMAT_PASSARGS_5 v1, v2, v3, v4, v5 +#define TINYFORMAT_PASSARGS_6 v1, v2, v3, v4, v5, v6 +#define TINYFORMAT_PASSARGS_7 v1, v2, v3, v4, v5, v6, v7 +#define TINYFORMAT_PASSARGS_8 v1, v2, v3, v4, v5, v6, v7, v8 +#define TINYFORMAT_PASSARGS_9 v1, v2, v3, v4, v5, v6, v7, v8, v9 +#define TINYFORMAT_PASSARGS_10 v1, v2, v3, v4, v5, v6, v7, v8, v9, v10 +#define TINYFORMAT_PASSARGS_11 v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11 +#define TINYFORMAT_PASSARGS_12 v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12 +#define TINYFORMAT_PASSARGS_13 \ + v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13 +#define TINYFORMAT_PASSARGS_14 \ + v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14 +#define TINYFORMAT_PASSARGS_15 \ + v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15 +#define TINYFORMAT_PASSARGS_16 \ + v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16 + +#define TINYFORMAT_PASSARGS_TAIL_1 +#define TINYFORMAT_PASSARGS_TAIL_2 , v2 +#define TINYFORMAT_PASSARGS_TAIL_3 , v2, v3 +#define TINYFORMAT_PASSARGS_TAIL_4 , v2, v3, v4 +#define TINYFORMAT_PASSARGS_TAIL_5 , v2, v3, v4, v5 +#define TINYFORMAT_PASSARGS_TAIL_6 , v2, v3, v4, v5, v6 +#define TINYFORMAT_PASSARGS_TAIL_7 , v2, v3, v4, v5, v6, v7 +#define TINYFORMAT_PASSARGS_TAIL_8 , v2, v3, v4, v5, v6, v7, v8 +#define TINYFORMAT_PASSARGS_TAIL_9 , v2, v3, v4, v5, v6, v7, v8, v9 +#define TINYFORMAT_PASSARGS_TAIL_10 , v2, v3, v4, v5, v6, v7, v8, v9, v10 +#define TINYFORMAT_PASSARGS_TAIL_11 , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11 +#define TINYFORMAT_PASSARGS_TAIL_12 \ + , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12 +#define TINYFORMAT_PASSARGS_TAIL_13 \ + , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13 +#define TINYFORMAT_PASSARGS_TAIL_14 \ + , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14 +#define TINYFORMAT_PASSARGS_TAIL_15 \ + , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15 +#define TINYFORMAT_PASSARGS_TAIL_16 \ + , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16 + +#define TINYFORMAT_FOREACH_ARGNUM(m) \ + m(1) m(2) m(3) m(4) m(5) m(6) m(7) m(8) m(9) m(10) m(11) m(12) m(13) m(14) \ + m(15) m(16) +//[[[end]]] + +namespace detail { + +// Type-opaque holder for an argument to format(), with associated actions on +// the type held as explicit function pointers. This allows FormatArg's for +// each argument to be allocated as a homogenous array inside FormatList +// whereas a naive implementation based on inheritance does not. +class FormatArg { +public: + FormatArg() {} + + template + FormatArg(const T &value) + : m_value(static_cast(&value)), + m_formatImpl(&formatImpl), + m_toIntImpl(&toIntImpl) {} + + void format(std::ostream &out, + const char *fmtBegin, + const char *fmtEnd, + int ntrunc) const { + m_formatImpl(out, fmtBegin, fmtEnd, ntrunc, m_value); + } + + int toInt() const { return m_toIntImpl(m_value); } + +private: + template + static void formatImpl(std::ostream &out, + const char *fmtBegin, + const char *fmtEnd, + int ntrunc, + const void *value) { + formatValue(out, fmtBegin, fmtEnd, ntrunc, *static_cast(value)); + } + + template + static int toIntImpl(const void *value) { + return convertToInt::invoke(*static_cast(value)); + } + + const void *m_value; + void (*m_formatImpl)(std::ostream &out, + const char *fmtBegin, + const char *fmtEnd, + int ntrunc, + const void *value); + int (*m_toIntImpl)(const void *value); +}; + +// Parse and return an integer from the string c, as atoi() +// On return, c is set to one past the end of the integer. +inline int parseIntAndAdvance(const char *&c) { + int i = 0; + for (; *c >= '0' && *c <= '9'; ++c) i = 10 * i + (*c - '0'); + return i; +} + +// Print literal part of format string and return next format spec +// position. +// +// Skips over any occurrences of '%%', printing a literal '%' to the +// output. The position of the first % character of the next +// nontrivial format spec is returned, or the end of string. +inline const char *printFormatStringLiteral(std::ostream &out, + const char *fmt) { + const char *c = fmt; + for (;; ++c) { + switch (*c) { + case '\0': + out.write(fmt, c - fmt); + return c; + case '%': + out.write(fmt, c - fmt); + if (*(c + 1) != '%') return c; + // for "%%", tack trailing % onto next literal section. + fmt = ++c; + break; + default: + break; + } + } +} + +// Parse a format string and set the stream state accordingly. +// +// The format mini-language recognized here is meant to be the one from C99, +// with the form "%[flags][width][.precision][length]type". +// +// Formatting options which can't be natively represented using the ostream +// state are returned in spacePadPositive (for space padded positive numbers) +// and ntrunc (for truncating conversions). argIndex is incremented if +// necessary to pull out variable width and precision . The function returns a +// pointer to the character after the end of the current format spec. +inline const char *streamStateFromFormat(std::ostream &out, + bool &spacePadPositive, + int &ntrunc, + const char *fmtStart, + const detail::FormatArg *formatters, + int &argIndex, + int numFormatters) { + if (*fmtStart != '%') { + TINYFORMAT_ERROR( + "tinyformat: Not enough conversion specifiers in format string"); + return fmtStart; + } + // Reset stream state to defaults. + out.width(0); + out.precision(6); + out.fill(' '); + // Reset most flags; ignore irrelevant unitbuf & skipws. + out.unsetf(std::ios::adjustfield | std::ios::basefield | + std::ios::floatfield | std::ios::showbase | std::ios::boolalpha | + std::ios::showpoint | std::ios::showpos | std::ios::uppercase); + bool precisionSet = false; + bool widthSet = false; + int widthExtra = 0; + const char *c = fmtStart + 1; + // 1) Parse flags + for (;; ++c) { + switch (*c) { + case '#': + out.setf(std::ios::showpoint | std::ios::showbase); + continue; + case '0': + // overridden by left alignment ('-' flag) + if (!(out.flags() & std::ios::left)) { + // Use internal padding so that numeric values are + // formatted correctly, eg -00010 rather than 000-10 + out.fill('0'); + out.setf(std::ios::internal, std::ios::adjustfield); + } + continue; + case '-': + out.fill(' '); + out.setf(std::ios::left, std::ios::adjustfield); + continue; + case ' ': + // overridden by show positive sign, '+' flag. + if (!(out.flags() & std::ios::showpos)) spacePadPositive = true; + continue; + case '+': + out.setf(std::ios::showpos); + spacePadPositive = false; + widthExtra = 1; + continue; + default: + break; + } + break; + } + // 2) Parse width + if (*c >= '0' && *c <= '9') { + widthSet = true; + out.width(parseIntAndAdvance(c)); + } + if (*c == '*') { + widthSet = true; + int width = 0; + if (argIndex < numFormatters) + width = formatters[argIndex++].toInt(); + else + TINYFORMAT_ERROR( + "tinyformat: Not enough arguments to read variable width"); + if (width < 0) { + // negative widths correspond to '-' flag set + out.fill(' '); + out.setf(std::ios::left, std::ios::adjustfield); + width = -width; + } + out.width(width); + ++c; + } + // 3) Parse precision + if (*c == '.') { + ++c; + int precision = 0; + if (*c == '*') { + ++c; + if (argIndex < numFormatters) + precision = formatters[argIndex++].toInt(); + else + TINYFORMAT_ERROR( + "tinyformat: Not enough arguments to read variable precision"); + } else { + if (*c >= '0' && *c <= '9') + precision = parseIntAndAdvance(c); + else if (*c == '-') // negative precisions ignored, treated as zero. + parseIntAndAdvance(++c); + } + out.precision(precision); + precisionSet = true; + } + // 4) Ignore any C99 length modifier + while (*c == 'l' || *c == 'h' || *c == 'L' || *c == 'j' || *c == 'z' || + *c == 't') + ++c; + // 5) We're up to the conversion specifier character. + // Set stream flags based on conversion specifier (thanks to the + // boost::format class for forging the way here). + bool intConversion = false; + switch (*c) { + case 'u': + case 'd': + case 'i': + out.setf(std::ios::dec, std::ios::basefield); + intConversion = true; + break; + case 'o': + out.setf(std::ios::oct, std::ios::basefield); + intConversion = true; + break; + case 'X': + out.setf(std::ios::uppercase); + case 'x': + case 'p': + out.setf(std::ios::hex, std::ios::basefield); + intConversion = true; + break; + case 'E': + out.setf(std::ios::uppercase); + case 'e': + out.setf(std::ios::scientific, std::ios::floatfield); + out.setf(std::ios::dec, std::ios::basefield); + break; + case 'F': + out.setf(std::ios::uppercase); + case 'f': + out.setf(std::ios::fixed, std::ios::floatfield); + break; + case 'G': + out.setf(std::ios::uppercase); + case 'g': + out.setf(std::ios::dec, std::ios::basefield); + // As in boost::format, let stream decide float format. + out.flags(out.flags() & ~std::ios::floatfield); + break; + case 'a': + case 'A': + TINYFORMAT_ERROR( + "tinyformat: the %a and %A conversion specs " + "are not supported"); + break; + case 'c': + // Handled as special case inside formatValue() + break; + case 's': + if (precisionSet) ntrunc = static_cast(out.precision()); + // Make %s print booleans as "true" and "false" + out.setf(std::ios::boolalpha); + break; + case 'n': + // Not supported - will cause problems! + TINYFORMAT_ERROR("tinyformat: %n conversion spec not supported"); + break; + case '\0': + TINYFORMAT_ERROR( + "tinyformat: Conversion spec incorrectly " + "terminated by end of string"); + return c; + default: + break; + } + if (intConversion && precisionSet && !widthSet) { + // "precision" for integers gives the minimum number of digits (to be + // padded with zeros on the left). This isn't really supported by the + // iostreams, but we can approximately simulate it with the width if + // the width isn't otherwise used. + out.width(out.precision() + widthExtra); + out.setf(std::ios::internal, std::ios::adjustfield); + out.fill('0'); + } + return c + 1; +} + +//------------------------------------------------------------------------------ +inline void formatImpl(std::ostream &out, + const char *fmt, + const detail::FormatArg *formatters, + int numFormatters) { + // Saved stream state + std::streamsize origWidth = out.width(); + std::streamsize origPrecision = out.precision(); + std::ios::fmtflags origFlags = out.flags(); + char origFill = out.fill(); + + for (int argIndex = 0; argIndex < numFormatters; ++argIndex) { + // Parse the format string + fmt = printFormatStringLiteral(out, fmt); + bool spacePadPositive = false; + int ntrunc = -1; + const char *fmtEnd = streamStateFromFormat(out, + spacePadPositive, + ntrunc, + fmt, + formatters, + argIndex, + numFormatters); + if (argIndex >= numFormatters) { + // Check args remain after reading any variable width/precision + TINYFORMAT_ERROR("tinyformat: Not enough format arguments"); + return; + } + const FormatArg &arg = formatters[argIndex]; + // Format the arg into the stream. + if (!spacePadPositive) + arg.format(out, fmt, fmtEnd, ntrunc); + else { + // The following is a special case with no direct correspondence + // between stream formatting and the printf() behaviour. Simulate + // it crudely by formatting into a temporary string stream and + // munging the resulting string. + std::ostringstream tmpStream; + tmpStream.copyfmt(out); + tmpStream.setf(std::ios::showpos); + arg.format(tmpStream, fmt, fmtEnd, ntrunc); + std::string result = tmpStream.str(); // allocates... yuck. + for (size_t i = 0, iend = result.size(); i < iend; ++i) + if (result[i] == '+') result[i] = ' '; + out << result; + } + fmt = fmtEnd; + } + + // Print remaining part of format string. + fmt = printFormatStringLiteral(out, fmt); + if (*fmt != '\0') + TINYFORMAT_ERROR( + "tinyformat: Too many conversion specifiers in format string"); + + // Restore stream state + out.width(origWidth); + out.precision(origPrecision); + out.flags(origFlags); + out.fill(origFill); +} + +} // namespace detail + +/// List of template arguments format(), held in a type-opaque way. +/// +/// A const reference to FormatList (typedef'd as FormatListRef) may be +/// conveniently used to pass arguments to non-template functions: All type +/// information has been stripped from the arguments, leaving just enough of a +/// common interface to perform formatting as required. +class FormatList { +public: + FormatList(detail::FormatArg *formatters, int N) + : m_formatters(formatters), m_N(N) {} + + friend void vformat(std::ostream &out, + const char *fmt, + const FormatList &list); + +private: + const detail::FormatArg *m_formatters; + int m_N; +}; + +/// Reference to type-opaque format list for passing to vformat() +typedef const FormatList &FormatListRef; + +namespace detail { + +// Format list subclass with fixed storage to avoid dynamic allocation +template +class FormatListN : public FormatList { +public: + template + FormatListN(const Args &... args) + : FormatList(&m_formatterStore[0], N), + m_formatterStore{FormatArg(args)...} { + static_assert(sizeof...(args) == N, "Number of args must be N"); + } + +private: + FormatArg m_formatterStore[N]; +}; + +// Special 0-arg version - MSVC says zero-sized C array in struct is nonstandard +template <> +class FormatListN<0> : public FormatList { +public: + FormatListN() : FormatList(0, 0) {} +}; + +} // namespace detail + +//------------------------------------------------------------------------------ +// Primary API functions + +/// Make type-agnostic format list from list of template arguments. +/// +/// The exact return type of this function is an implementation detail and +/// shouldn't be relied upon. Instead it should be stored as a FormatListRef: +/// +/// FormatListRef formatList = makeFormatList( /*...*/ ); +template +detail::FormatListN makeFormatList(const Args &... args) { + return detail::FormatListN(args...); +} + +/// Format list of arguments to the stream according to the given format string. +/// +/// The name vformat() is chosen for the semantic similarity to vprintf(): the +/// list of format arguments is held in a single function argument. +inline void vformat(std::ostream &out, const char *fmt, FormatListRef list) { + detail::formatImpl(out, fmt, list.m_formatters, list.m_N); +} + +/// Format list of arguments to the stream according to given format string. +template +void format(std::ostream &out, const char *fmt, const Args &... args) { + vformat(out, fmt, makeFormatList(args...)); +} + +/// Format list of arguments according to the given format string and return +/// the result as a string. +template +std::string format(const char *fmt, const Args &... args) { + std::ostringstream oss; + format(oss, fmt, args...); + return oss.str(); +} + +/// Format list of arguments to std::cout, according to the given format string +template +void printf(const char *fmt, const Args &... args) { + format(std::cout, fmt, args...); +} + +template +void printfln(const char *fmt, const Args &... args) { + format(std::cout, fmt, args...); + std::cout << '\n'; +} + +} // namespace tinyformat +} // namespace string +} // namespace paddle diff --git a/paddle/testing/CMakeLists.txt b/paddle/testing/CMakeLists.txt index c47add04b081cbdf78b5a5d3bca3a71025b3d9ac..4245df5ab72bf0fd67261818b307f0babdb5d685 100644 --- a/paddle/testing/CMakeLists.txt +++ b/paddle/testing/CMakeLists.txt @@ -2,7 +2,7 @@ if(WITH_TESTING) add_library(paddle_test_main STATIC TestMain.cpp) - add_dependencies(paddle_test_main gen_proto_cpp) + add_dependencies(paddle_test_main paddle_proto ${external_project_dependencies}) add_library(paddle_test_util STATIC TestUtil.cpp) - add_dependencies(paddle_test_util gen_proto_cpp) + add_dependencies(paddle_test_util paddle_proto ${external_project_dependencies}) endif() diff --git a/paddle/trainer/CMakeLists.txt b/paddle/trainer/CMakeLists.txt index f34d53ae99f913a8aed8767b7271a538efce4778..6414c399561575c13074c41598184a78f84373ee 100644 --- a/paddle/trainer/CMakeLists.txt +++ b/paddle/trainer/CMakeLists.txt @@ -41,7 +41,8 @@ add_style_check_target(paddle_trainer_lib add_style_check_target(paddle_trainer_lib ${TRAINER_HEADERS}) add_dependencies(paddle_trainer_lib - gen_proto_cpp) + paddle_proto + ${external_project_dependencies}) macro(add_paddle_exe TARGET_NAME) add_executable(${TARGET_NAME} ${ARGN}) diff --git a/paddle/utils/CMakeLists.txt b/paddle/utils/CMakeLists.txt index af59951752d1799c95e293d3eae233e6aa26e5f3..7a4977935ede4878c07f4fb6ba0dd76bf50acd42 100644 --- a/paddle/utils/CMakeLists.txt +++ b/paddle/utils/CMakeLists.txt @@ -17,7 +17,7 @@ add_library(paddle_utils STATIC add_style_check_target(paddle_utils ${UTIL_HEADERS}) add_style_check_target(paddle_utils ${UTIL_SOURCES} ${UTIL_ARCH_SOURCES}) -add_dependencies(paddle_utils gen_proto_cpp) +add_dependencies(paddle_utils paddle_proto ${external_project_dependencies}) if(WITH_TESTING) add_subdirectory(tests) endif() diff --git a/proto/CMakeLists.txt b/proto/CMakeLists.txt index c942620990765832f21c887d30f85a2d211a5f32..18584cafe7971bad281b498908c54780250791b7 100644 --- a/proto/CMakeLists.txt +++ b/proto/CMakeLists.txt @@ -1,43 +1,23 @@ -set(proto_filenames - DataConfig.proto - DataFormat.proto - ModelConfig.proto - ParameterConfig.proto - ParameterService.proto - TrainerConfig.proto - OptimizerConfig.proto - ParameterServerConfig.proto) +file(GLOB proto_filenames . *.proto) +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +proto_library(paddle_proto SRCS ${proto_filenames}) set(PROTO_GEN) set(PROTO_GEN_PY) foreach(filename ${proto_filenames}) - get_filename_component(base_filename ${filename} NAME_WE) - set(CUR_PROTO_GEN - ${CMAKE_CURRENT_BINARY_DIR}/${base_filename}.pb.h - ${CMAKE_CURRENT_BINARY_DIR}/${base_filename}.pb.cc) - set(PROTO_GEN - ${PROTO_GEN} - ${CUR_PROTO_GEN}) - add_custom_command(OUTPUT ${CUR_PROTO_GEN} - COMMAND env ${py_env} ${PROTOBUF_PROTOC_EXECUTABLE} - --cpp_out ${CMAKE_CURRENT_BINARY_DIR} - --proto_path ${PROJ_ROOT}/proto ${PROJ_ROOT}/proto/${filename} - DEPENDS ${filename} ${external_project_dependencies}) - + get_filename_component(ABS_FIL ${filename} ABSOLUTE) + get_filename_component(FIL_WE ${filename} NAME_WE) set(CUR_PROTO_GEN_PY - ${PROJ_ROOT}/paddle/python/paddle/proto/${base_filename}_pb2.py) + ${PROJ_ROOT}/paddle/python/paddle/proto/${FIL_WE}_pb2.py) set(PROTO_GEN_PY - ${CUR_PROTO_GEN_PY} - ${PROTO_GEN_PY}) + ${CUR_PROTO_GEN_PY} + ${PROTO_GEN_PY}) add_custom_command(OUTPUT ${CUR_PROTO_GEN_PY} - COMMAND env ${py_env} ${PROTOBUF_PROTOC_EXECUTABLE} --python_out ${PROJ_ROOT}/python/paddle/proto - --proto_path ${PROJ_ROOT}/proto ${PROJ_ROOT}/proto/${filename} - DEPENDS ${filename} ${external_project_dependencies}) + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} + ARGS "--python_out=${PROJ_ROOT}/python/paddle/proto" + "-I" ${CMAKE_CURRENT_SOURCE_DIR} ${ABS_FIL} + DEPENDS ${ABS_FIL} ${external_project_dependencies}) endforeach() -add_custom_target(gen_proto_cpp ALL DEPENDS ${PROTO_GEN}) add_custom_target(gen_proto_py ALL DEPENDS ${PROTO_GEN_PY}) - -add_library(paddle_proto STATIC ${PROTO_GEN}) -target_include_directories(paddle_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 84ed160773065da15fc26bfb5c5882b068874f1c..a601d5c84ad222785e68b9fa81c51b1e120b4f29 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -1149,10 +1149,10 @@ def pooling_layer(input, @layer_support(DROPOUT) def lstmemory(input, name=None, + size=None, reverse=False, act=None, gate_act=None, - size=None, state_act=None, bias_attr=None, param_attr=None, @@ -1194,6 +1194,8 @@ def lstmemory(input, :param name: The lstmemory layer name. :type name: basestring + :param size: DEPRECATED. size of the lstm cell + :type size: int :param input: input layer name. :type input: LayerOutput :param reverse: is sequence process reversed or not. @@ -1220,15 +1222,15 @@ def lstmemory(input, assert state_act.support_hppl assert act.support_hppl assert input.size is not None and input.size % 4 == 0 + if size is not None: if input.size / 4 == size: plog = logger.warning else: plog = logger.fatal - - plog("NOTE: The lstmemory layer[%s]'s size is set by previous input " - "layer. The lstm size should be equal with input layer size/4. The" - " size which is set explicitly will be ignored." % name) + plog("size of lstmemory layer: %s is automatically set to " + "size of input layer / 4. The parameter size passing to " + "this layer is ignored." % (name)) Layer( name=name, @@ -1255,11 +1257,11 @@ def lstmemory(input, @wrap_name_default("gru") @layer_support(DROPOUT) def grumemory(input, + size=None, name=None, reverse=False, act=None, gate_act=None, - size=None, bias_attr=None, param_attr=None, layer_attr=None): @@ -1318,6 +1320,8 @@ def grumemory(input, :type name: None|basestring :param input: input layer. :type input: LayerOutput. + :param size: DEPRECATED. size of the gru cell + :type size: int :param reverse: Whether sequence process is reversed or not. :type reverse: bool :param act: activation type, TanhActivation by default. This activation @@ -1334,9 +1338,6 @@ def grumemory(input, :type param_attr: ParameterAttribute|None|False :param layer_attr: Extra Layer attribute :type layer_attr: ExtraLayerAttribute|None - :param size: Stub parameter of size, but actually not used. If set this size - will get a warning. - :type size: None :return: LayerOutput object. :rtype: LayerOutput """ @@ -1348,9 +1349,9 @@ def grumemory(input, plog = logger.warning else: plog = logger.fatal - plog("NOTE: the gru memory layer's size is set by previous input layer," - " and should be input size / 3. Set size explicitly will be " - "ignored.") + plog("size of grumemory layer: %s is automatically set to " + "size of input layer / 3. The parameter size passing to this " + "layer is ignored." % (name)) Layer( name=name, @@ -2524,8 +2525,8 @@ def img_cmrnorm_layer(input, @wrap_bias_attr_default() -@wrap_param_attr_default(default_factory=lambda _: ParamAttr(initial_mean=1.0, - initial_std=0.)) +@wrap_param_attr_default( + default_factory=lambda _: ParamAttr(initial_mean=1.0, initial_std=0.)) @wrap_act_default(act=ReluActivation()) @wrap_name_default("batch_norm") @layer_support(DROPOUT) @@ -3013,25 +3014,25 @@ def lstm_step_layer(input, bias_attr=None, layer_attr=None): """ - LSTM Step Layer. It used in recurrent_group. The lstm equations are shown - as follow. + LSTM Step Layer. This function is used only in recurrent_group. + The lstm equations are shown as follows. .. math:: - i_t & = \\sigma(W_{xi}x_{t} + W_{hi}h_{t-1} + W_{ci}c_{t-1} + b_i) + i_t & = \\sigma(W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i) - f_t & = \\sigma(W_{xf}x_{t} + W_{hf}h_{t-1} + W_{cf}c_{t-1} + b_f) + f_t & = \\sigma(W_{x_f}x_{t} + W_{h_f}h_{t-1} + W_{c_f}c_{t-1} + b_f) - c_t & = f_tc_{t-1} + i_t tanh (W_{xc}x_t+W_{hc}h_{t-1} + b_c) + c_t & = f_tc_{t-1} + i_t tanh (W_{x_c}x_t+W_{h_c}h_{t-1} + b_c) - o_t & = \\sigma(W_{xo}x_{t} + W_{ho}h_{t-1} + W_{co}c_t + b_o) + o_t & = \\sigma(W_{x_o}x_{t} + W_{h_o}h_{t-1} + W_{c_o}c_t + b_o) h_t & = o_t tanh(c_t) The input of lstm step is :math:`Wx_t + Wh_{t-1}`, and user should use :code:`mixed_layer` and :code:`full_matrix_projection` to calculate these - input vector. + input vectors. The state of lstm step is :math:`c_{t-1}`. And lstm step layer will do @@ -3042,14 +3043,14 @@ def lstm_step_layer(input, ... - This layer contains two outputs. Default output is :math:`h_t`. The other - output is :math:`o_t`, which name is 'state' and can use + This layer has two outputs. Default output is :math:`h_t`. The other + output is :math:`o_t`, whose name is 'state' and can use :code:`get_output_layer` to extract this output. :param name: Layer's name. :type name: basestring - :param size: Layer's size. NOTE: lstm layer's size, should be equal as - :code:`input.size/4`, and should be equal as + :param size: Layer's size. NOTE: lstm layer's size, should be equal to + :code:`input.size/4`, and should be equal to :code:`state.size`. :type size: int :param input: input layer. :math:`Wx_t + Wh_{t-1}` diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index 67154a8d7d366bd983b4426da87e0b33307fced4..b77932ce5f09470329a97cc0a6273942a9155c6a 100755 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -614,6 +614,7 @@ def simple_lstm(input, @wrap_name_default('lstm_unit') def lstmemory_unit(input, + memory_boot=None, name=None, size=None, param_attr=None, @@ -626,9 +627,9 @@ def lstmemory_unit(input, lstm_layer_attr=None, get_output_layer_attr=None): """ - Define calculations that a LSTM unit performs in a single time step. - This function itself is not a recurrent layer, so that it can not be - directly applied to sequence input. This function is always used in + Define calculations that a LSTM unit performs during a single time step. + This function itself is not a recurrent layer, so it can not be + directly used to process sequence inputs. This function is always used in recurrent_group (see layers.py for more details) to implement attention mechanism. @@ -638,13 +639,13 @@ def lstmemory_unit(input, .. math:: - i_t & = \\sigma(W_{xi}x_{t} + W_{hi}h_{t-1} + W_{ci}c_{t-1} + b_i) + i_t & = \\sigma(W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i) - f_t & = \\sigma(W_{xf}x_{t} + W_{hf}h_{t-1} + W_{cf}c_{t-1} + b_f) + f_t & = \\sigma(W_{x_f}x_{t} + W_{h_f}h_{t-1} + W_{c_f}c_{t-1} + b_f) - c_t & = f_tc_{t-1} + i_t tanh (W_{xc}x_t+W_{hc}h_{t-1} + b_c) + c_t & = f_tc_{t-1} + i_t tanh (W_{x_c}x_t+W_{h_c}h_{t-1} + b_c) - o_t & = \\sigma(W_{xo}x_{t} + W_{ho}h_{t-1} + W_{co}c_t + b_o) + o_t & = \\sigma(W_{x_o}x_{t} + W_{h_o}h_{t-1} + W_{c_o}c_t + b_o) h_t & = o_t tanh(c_t) @@ -661,6 +662,8 @@ def lstmemory_unit(input, :param input: input layer name. :type input: LayerOutput + :param memory_boot: the initialization state of the LSTM cell. + :type memory_boot: LayerOutput | None :param name: lstmemory unit name. :type name: basestring :param size: lstmemory unit size. @@ -692,7 +695,8 @@ def lstmemory_unit(input, assert input.size % 4 == 0 size = input.size / 4 out_mem = memory(name=name, size=size) - state_mem = memory(name="%s_state" % name, size=size) + state_mem = memory( + name="%s_state" % name, size=size, boot_layer=memory_boot) with mixed_layer( name="%s_input_recurrent" % name, @@ -726,6 +730,7 @@ def lstmemory_unit(input, def lstmemory_group(input, size=None, name=None, + memory_boot=None, reverse=False, param_attr=None, act=None, @@ -737,7 +742,7 @@ def lstmemory_group(input, lstm_layer_attr=None, get_output_layer_attr=None): """ - lstm_group is a recurrent layer group version of Long Short Term Memory. It + lstm_group is a recurrent_group version of Long Short Term Memory. It does exactly the same calculation as the lstmemory layer (see lstmemory in layers.py for the maths) does. A promising benefit is that LSTM memory cell states, or hidden states in every time step are accessible to the @@ -748,8 +753,8 @@ def lstmemory_group(input, NOTE: In PaddlePaddle's implementation, the following input-to-hidden multiplications: - :math:`W_{xi}x_{t}` , :math:`W_{xf}x_{t}`, - :math:`W_{xc}x_t`, :math:`W_{xo}x_{t}` are not done in lstmemory_unit to + :math:`W_{x_i}x_{t}` , :math:`W_{x_f}x_{t}`, + :math:`W_{x_c}x_t`, :math:`W_{x_o}x_{t}` are not done in lstmemory_unit to speed up the calculations. Consequently, an additional mixed_layer with full_matrix_projection must be included before lstmemory_unit is called. @@ -765,10 +770,12 @@ def lstmemory_group(input, :param input: input layer name. :type input: LayerOutput - :param name: lstmemory group name. - :type name: basestring :param size: lstmemory group size. :type size: int + :param name: name of the lstmemory group. + :type name: basestring + :param memory_boot: the initialization state of LSTM cell. + :type memory_boot: LayerOutput | None :param reverse: is lstm reversed :type reverse: bool :param param_attr: Parameter config, None if use default. @@ -798,6 +805,7 @@ def lstmemory_group(input, def __lstm_step__(ipt): return lstmemory_unit( input=ipt, + memory_boot=memory_boot, name=name, size=size, mixed_bias_attr=mixed_bias_attr, @@ -819,6 +827,7 @@ def lstmemory_group(input, @wrap_name_default('gru_unit') def gru_unit(input, + memory_boot=None, size=None, name=None, gru_bias_attr=None, @@ -829,8 +838,8 @@ def gru_unit(input, naive=False): """ Define calculations that a gated recurrent unit performs in a single time - step. This function itself is not a recurrent layer, so that it can not be - directly applied to sequence input. This function is almost always used in + step. This function itself is not a recurrent layer, so it can not be + directly used to process sequence inputs. This function is always used in the recurrent_group (see layers.py for more details) to implement attention mechanism. @@ -838,6 +847,8 @@ def gru_unit(input, :param input: input layer name. :type input: LayerOutput + :param memory_boot: the initialization state of the LSTM cell. + :type memory_boot: LayerOutput | None :param name: name of the gru group. :type name: basestring :param size: hidden size of the gru. @@ -856,7 +867,7 @@ def gru_unit(input, if size is None: size = input.size / 3 - out_mem = memory(name=name, size=size) + out_mem = memory(name=name, size=size, boot_layer=memory_boot) if naive: __step__ = gru_step_naive_layer @@ -878,6 +889,7 @@ def gru_unit(input, @wrap_name_default('gru_group') def gru_group(input, + memory_boot=None, size=None, name=None, reverse=False, @@ -888,7 +900,7 @@ def gru_group(input, gru_layer_attr=None, naive=False): """ - gru_group is a recurrent layer group version of Gated Recurrent Unit. It + gru_group is a recurrent_group version of Gated Recurrent Unit. It does exactly the same calculation as the grumemory layer does. A promising benefit is that gru hidden states are accessible to the user. This is especially useful in attention model. If you do not need to access @@ -908,6 +920,8 @@ def gru_group(input, :param input: input layer name. :type input: LayerOutput + :param memory_boot: the initialization state of the LSTM cell. + :type memory_boot: LayerOutput | None :param name: name of the gru group. :type name: basestring :param size: hidden size of the gru. @@ -929,6 +943,7 @@ def gru_group(input, def __gru_step__(ipt): return gru_unit( input=ipt, + memory_boot=memory_boot, name=name, size=size, gru_bias_attr=gru_bias_attr, @@ -1083,7 +1098,6 @@ def simple_gru2(input, return grumemory( name=name, - size=size, input=m, reverse=reverse, bias_attr=gru_bias_attr, diff --git a/python/paddle/v2/dataset/__init__.py b/python/paddle/v2/dataset/__init__.py index 26252d5bbd77ddb70b4f03843679e4737e2f96d3..2e4beb6882789249db09705f3f4d6c5c19e492cd 100644 --- a/python/paddle/v2/dataset/__init__.py +++ b/python/paddle/v2/dataset/__init__.py @@ -25,8 +25,9 @@ import uci_housing import sentiment import wmt14 import mq2007 +import flowers __all__ = [ 'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment' - 'uci_housing', 'wmt14', 'mq2007' + 'uci_housing', 'wmt14', 'mq2007', 'flowers' ] diff --git a/python/paddle/v2/dataset/flowers.py b/python/paddle/v2/dataset/flowers.py index 07c13cf719ae0c864c23fef51f0bd7d47f265759..158cfe158c4f1c8d82d157301adcfbe0351c55df 100644 --- a/python/paddle/v2/dataset/flowers.py +++ b/python/paddle/v2/dataset/flowers.py @@ -13,18 +13,18 @@ # limitations under the License. """ This module will download dataset from -http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html +http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html and parse train/test set intopaddle reader creators. -This set contains images of flowers belonging to 102 different categories. +This set contains images of flowers belonging to 102 different categories. The images were acquired by searching the web and taking pictures. There are a minimum of 40 images for each category. The database was used in: Nilsback, M-E. and Zisserman, A. Automated flower classification over a large - number of classes.Proceedings of the Indian Conference on Computer Vision, -Graphics and Image Processing (2008) + number of classes.Proceedings of the Indian Conference on Computer Vision, +Graphics and Image Processing (2008) http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}. """ @@ -34,9 +34,9 @@ from common import download import tarfile import scipy.io as scio from paddle.v2.image import * +from paddle.v2.reader import * import os import numpy as np -import paddle.v2 as paddle from multiprocessing import cpu_count __all__ = ['train', 'test', 'valid'] @@ -46,6 +46,12 @@ SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat' DATA_MD5 = '52808999861908f626f3c1f4e79d11fa' LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' +# In official 'readme', tstid is the flag of test data +# and trnid is the flag of train data. But test data is more than train data. +# So we exchange the train data and test data. +TRAIN_FLAG = 'tstid' +TEST_FLAG = 'trnid' +VALID_FLAG = 'valid' def default_mapper(sample): @@ -53,8 +59,8 @@ def default_mapper(sample): map image bytes data to type needed by model input layer ''' img, label = sample - img = paddle.image.load_image_bytes(img) - img = paddle.image.simple_transform(img, 256, 224, True) + img = load_image_bytes(img) + img = simple_transform(img, 256, 224, True) return img.flatten().astype('float32'), label @@ -63,22 +69,23 @@ def reader_creator(data_file, setid_file, dataset_name, mapper=default_mapper, - buffered_size=1024): + buffered_size=1024, + use_xmap=True): ''' - 1. read images from tar file and + 1. read images from tar file and merge images into batch files in 102flowers.tgz_batch/ 2. get a reader to read sample from batch file - - :param data_file: downloaded data file + + :param data_file: downloaded data file :type data_file: string - :param label_file: downloaded label file + :param label_file: downloaded label file :type label_file: string :param setid_file: downloaded setid file containing information about how to split dataset :type setid_file: string :param dataset_name: data set name (tstid|trnid|valid) :type dataset_name: string - :param mapper: a function to map image bytes data to type + :param mapper: a function to map image bytes data to type needed by model input layer :type mapper: callable :param buffered_size: the size of buffer used to process images @@ -105,15 +112,17 @@ def reader_creator(data_file, for sample, label in itertools.izip(data, batch['label']): yield sample, int(label) - return paddle.reader.xmap_readers(mapper, reader, - cpu_count(), buffered_size) + if use_xmap: + return xmap_readers(mapper, reader, cpu_count(), buffered_size) + else: + return map_readers(mapper, reader) -def train(mapper=default_mapper, buffered_size=1024): +def train(mapper=default_mapper, buffered_size=1024, use_xmap=True): ''' - Create flowers training set reader. - It returns a reader, each sample in the reader is - image pixels in [0, 1] and label in [1, 102] + Create flowers training set reader. + It returns a reader, each sample in the reader is + image pixels in [0, 1] and label in [1, 102] translated from original color image by steps: 1. resize to 256*256 2. random crop to 224*224 @@ -128,15 +137,15 @@ def train(mapper=default_mapper, buffered_size=1024): return reader_creator( download(DATA_URL, 'flowers', DATA_MD5), download(LABEL_URL, 'flowers', LABEL_MD5), - download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper, - buffered_size) + download(SETID_URL, 'flowers', SETID_MD5), TRAIN_FLAG, mapper, + buffered_size, use_xmap) -def test(mapper=default_mapper, buffered_size=1024): +def test(mapper=default_mapper, buffered_size=1024, use_xmap=True): ''' - Create flowers test set reader. - It returns a reader, each sample in the reader is - image pixels in [0, 1] and label in [1, 102] + Create flowers test set reader. + It returns a reader, each sample in the reader is + image pixels in [0, 1] and label in [1, 102] translated from original color image by steps: 1. resize to 256*256 2. random crop to 224*224 @@ -151,15 +160,15 @@ def test(mapper=default_mapper, buffered_size=1024): return reader_creator( download(DATA_URL, 'flowers', DATA_MD5), download(LABEL_URL, 'flowers', LABEL_MD5), - download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper, - buffered_size) + download(SETID_URL, 'flowers', SETID_MD5), TEST_FLAG, mapper, + buffered_size, use_xmap) -def valid(mapper=default_mapper, buffered_size=1024): +def valid(mapper=default_mapper, buffered_size=1024, use_xmap=True): ''' - Create flowers validation set reader. - It returns a reader, each sample in the reader is - image pixels in [0, 1] and label in [1, 102] + Create flowers validation set reader. + It returns a reader, each sample in the reader is + image pixels in [0, 1] and label in [1, 102] translated from original color image by steps: 1. resize to 256*256 2. random crop to 224*224 @@ -174,8 +183,8 @@ def valid(mapper=default_mapper, buffered_size=1024): return reader_creator( download(DATA_URL, 'flowers', DATA_MD5), download(LABEL_URL, 'flowers', LABEL_MD5), - download(SETID_URL, 'flowers', SETID_MD5), 'valid', mapper, - buffered_size) + download(SETID_URL, 'flowers', SETID_MD5), VALID_FLAG, mapper, + buffered_size, use_xmap) def fetch(): diff --git a/python/paddle/v2/dataset/tests/flowers_test.py b/python/paddle/v2/dataset/tests/flowers_test.py index cc0626f4feae287d18dfb227cc69a4174da055da..a8ae9a07acc22eb9d3c0cc5ebb07f8f11ed21233 100644 --- a/python/paddle/v2/dataset/tests/flowers_test.py +++ b/python/paddle/v2/dataset/tests/flowers_test.py @@ -31,13 +31,13 @@ class TestFlowers(unittest.TestCase): def test_train(self): instances, max_label_value = self.check_reader( paddle.v2.dataset.flowers.train()) - self.assertEqual(instances, 1020) + self.assertEqual(instances, 6149) self.assertEqual(max_label_value, 102) def test_test(self): instances, max_label_value = self.check_reader( paddle.v2.dataset.flowers.test()) - self.assertEqual(instances, 6149) + self.assertEqual(instances, 1020) self.assertEqual(max_label_value, 102) def test_valid(self): diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index ad20241b98302f136326ae491c6723a6c12ae284..bbaf8bfa979fbbf460561ebf7077b75b9c41a11a 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -51,7 +51,7 @@ class Parameters(object): def __init__(self): self.__param_conf__ = dict() self.__gradient_machines__ = [] - self.__tmp_params__ = [] + self.__tmp_params__ = dict() def __append_config__(self, param_conf): """ @@ -128,13 +128,10 @@ class Parameters(object): if len(self.__gradient_machines__) == 0: # create new parameter in python numpy. - if len(self.__tmp_params__) != 0: - ret_list = [ - mat for name, mat in self.__tmp_params__ if name == key - ] - if len(ret_list) == 1: - return ret_list[0] - return np.ndarray(shape=shape, dtype=np.float32) + if key in self.__tmp_params__: + return self.__tmp_params__[key] + else: + return np.ndarray(shape=shape, dtype=np.float32) else: for each_gradient_machine in self.__gradient_machines__: param = __get_parameter_in_gradient_machine__( @@ -187,7 +184,7 @@ class Parameters(object): (shape, value.shape)) if len(self.__gradient_machines__) == 0: - self.__tmp_params__.append((key, value)) + self.__tmp_params__[key] = value else: for each_gradient_machine in self.__gradient_machines__: __copy_parameter_to_gradient_machine__(each_gradient_machine, @@ -231,7 +228,7 @@ class Parameters(object): raise ValueError("gradient_machine should be api.GradientMachine") if len(self.__tmp_params__) != 0: - for name, val in self.__tmp_params__: + for name, val in self.__tmp_params__.iteritems(): try: __copy_parameter_to_gradient_machine__(gradient_machine, name, val) @@ -287,6 +284,18 @@ class Parameters(object): @staticmethod def from_tar(f): + """ + Create a `Parameters` object from the given file. And + the `Parameters` only contains the parameters in this + file. It is adapted the parameters are same in the + defined network and the given file. For example, it + can be used in the inference. + + :param f: the initialized model file. + :type f: tar file + :return: A Parameters object. + :rtype: Parameters. + """ params = Parameters() tar = tarfile.TarFile(fileobj=f, mode='r') for finfo in tar: @@ -302,6 +311,21 @@ class Parameters(object): params.deserialize(param_name, f) return params + def init_from_tar(self, f): + """ + Different from `from_tar`, this interface can be used to + init partial network parameters from another saved model. + + :param f: the initialized model file. + :type f: tar file + :return: Nothing. + """ + + tar_param = Parameters.from_tar(f) + for pname in tar_param.names(): + if pname in self.names(): + self.set(pname, tar_param.get(pname)) + def __get_parameter_in_gradient_machine__(gradient_machine, name): """ diff --git a/python/paddle/v2/reader/decorator.py b/python/paddle/v2/reader/decorator.py index e432003129d2b8dea60138d08f13ec5e9d29a7ad..45a4288751e37b99dd1005ec78f30a98044926ff 100644 --- a/python/paddle/v2/reader/decorator.py +++ b/python/paddle/v2/reader/decorator.py @@ -166,12 +166,12 @@ def buffered(reader, size): The buffered data reader will read and save data entries into a buffer. Reading from the buffered data reader will proceed as long as the buffer is not empty. - + :param reader: the data reader to read from. :type reader: callable :param size: max buffer size. :type size: int - + :returns: the buffered data reader. """ @@ -238,7 +238,7 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): :type mapper: callable :param reader: the data reader to read from :type reader: callable - :param process_num: process number to handle original sample + :param process_num: process number to handle original sample :type process_num: int :param buffer_size: max buffer size :type buffer_size: int @@ -248,9 +248,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): :rtype: callable """ end = XmapEndSignal() - in_queue = Queue(buffer_size) - out_queue = Queue(buffer_size) - out_order = [0] # define a worker to read samples from reader to in_queue def read_worker(reader, in_queue): @@ -266,12 +263,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): in_order += 1 in_queue.put(end) - # start a read worker in a thread - target = order_read_worker if order else read_worker - t = Thread(target=target, args=(reader, in_queue)) - t.daemon = True - t.start() - # define a worker to handle samples from in_queue by mapper # and put mapped samples into out_queue def handle_worker(in_queue, out_queue, mapper): @@ -298,19 +289,27 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): in_queue.put(end) out_queue.put(end) - # start several handle_workers - target = order_handle_worker if order else handle_worker - args = (in_queue, out_queue, mapper, out_order) if order else ( - in_queue, out_queue, mapper) - workers = [] - for i in xrange(process_num): - worker = Thread(target=target, args=args) - worker.daemon = True - workers.append(worker) - for w in workers: - w.start() - def xreader(): + in_queue = Queue(buffer_size) + out_queue = Queue(buffer_size) + out_order = [0] + # start a read worker in a thread + target = order_read_worker if order else read_worker + t = Thread(target=target, args=(reader, in_queue)) + t.daemon = True + t.start() + # start several handle_workers + target = order_handle_worker if order else handle_worker + args = (in_queue, out_queue, mapper, out_order) if order else ( + in_queue, out_queue, mapper) + workers = [] + for i in xrange(process_num): + worker = Thread(target=target, args=args) + worker.daemon = True + workers.append(worker) + for w in workers: + w.start() + sample = out_queue.get() while not isinstance(sample, XmapEndSignal): yield sample diff --git a/python/paddle/v2/reader/tests/decorator_test.py b/python/paddle/v2/reader/tests/decorator_test.py index bb3c5d220b9ce1552d2fc429abb1863930cd4d17..5a92951b100fa51ab6df7039d9c6b54d1f9d963e 100644 --- a/python/paddle/v2/reader/tests/decorator_test.py +++ b/python/paddle/v2/reader/tests/decorator_test.py @@ -132,15 +132,17 @@ class TestXmap(unittest.TestCase): for order in orders: for tNum in thread_nums: for size in buffered_size: - result = [] - for i in paddle.v2.reader.xmap_readers(mapper, + reader = paddle.v2.reader.xmap_readers(mapper, reader_creator_10(0), - tNum, size, order)(): - result.append(i) - if not order: - result.sort() - for idx, e in enumerate(result): - self.assertEqual(e, mapper(idx)) + tNum, size, order) + for n in xrange(3): + result = [] + for i in reader(): + result.append(i) + if not order: + result.sort() + for idx, e in enumerate(result): + self.assertEqual(e, mapper(idx)) if __name__ == '__main__': diff --git a/python/paddle/v2/tests/test_parameters.py b/python/paddle/v2/tests/test_parameters.py index 45372e7dd0ec7cbdd6a2eb5c0397ef7e74284cd0..7ba8a939fbd1a949d61a007b40c054e7543c0cbc 100644 --- a/python/paddle/v2/tests/test_parameters.py +++ b/python/paddle/v2/tests/test_parameters.py @@ -20,14 +20,17 @@ import cStringIO import numpy -def __rand_param_config__(name): +def __rand_param_config__(name, psize=None): conf = ParameterConfig() conf.name = name size = 1 - for i in xrange(2): - dim = random.randint(1, 1000) - conf.dims.append(dim) - size *= dim + if psize is None: + for i in xrange(2): + dim = random.randint(1, 1000) + conf.dims.append(dim) + size *= dim + else: + size = psize conf.size = size assert conf.IsInitialized() return conf @@ -77,6 +80,50 @@ class TestParameters(unittest.TestCase): expected = numpy.array([[1, 1], [1, 2], [1, 1]], numpy.float32) assert numpy.logical_and.reduce(numpy.reshape(val == expected, 6)) + def test_init_from_tar(self): + def get_param(names, size): + p = parameters.Parameters() + for k, v in zip(names, size): + p.__append_config__(__rand_param_config__(k, v)) + for name in p.names(): + param = p.get(name) + param[:] = numpy.random.uniform( + -1.0, 1.0, size=p.get_shape(name)) + p.set(name, param) + return p + + def get_parames(): + name1 = ['param_0', 'param_1'] + size1 = [128, 256] + p1 = get_param(name1, size1) + file1 = cStringIO.StringIO() + p1.to_tar(file1) + file1.seek(0) + + name2 = ['param_0', 'param_1', 'param_2'] + size2 = [128, 256, 288] + p2 = get_param(name2, size2) + file2 = cStringIO.StringIO() + p2.to_tar(file2) + file2.seek(0) + return p1, file1, p2, file2 + + p1, file1, p2, file2 = get_parames() + p2.init_from_tar(file1) + for name in p1.names(): + self.assertEqual(p1.get_shape(name), p2.get_shape(name)) + v1 = p1.get(name) + v2 = p2.get(name) + self.assertTrue(numpy.isclose(v1, v2).all()) + + p1, file1, p2, file2 = get_parames() + p1.init_from_tar(file2) + for name in p1.names(): + self.assertEqual(p1.get_shape(name), p2.get_shape(name)) + v1 = p1.get(name) + v2 = p2.get(name) + self.assertTrue(numpy.isclose(v1, v2).all()) + if __name__ == '__main__': unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index 86fc0fc5c0318b03659bf84f8ad9e2a114467c74..aa6771709cad0bb4dd4ce39c81de7e6ab1ad4c73 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -15,7 +15,8 @@ setup_requires=["requests", "protobuf==3.1", "recordio", "matplotlib", - "rarfile"] + "rarfile", + "scipy>=0.19.0"] if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']: setup_requires+=["opencv-python"]