提交 5b2f9939 编写于 作者: G gongweibao

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fixdownloadbug

...@@ -7,8 +7,17 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/eigen3) ...@@ -7,8 +7,17 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/eigen3)
ExternalProject_Add( ExternalProject_Add(
eigen3 eigen3
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz" # for latest version, please get from official website
URL_MD5 "1a47e78efe365a97de0c022d127607c3" # URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz"
# URL_MD5 "1a47e78efe365a97de0c022d127607c3"
# for no-ssl http support, please get from bazel's mirror
# URL "http://mirror.bazel.build/bitbucket.org/eigen/eigen/get/f3a22f35b044.tar.gz"
# URL_MD5 "4645c66075982da6fa0bcf6b20f3e8f7"
# get from github mirror
GIT_REPOSITORY "https://github.com/RLovelett/eigen.git"
GIT_TAG "a46d2e7337c4656f00abe54a8115f6d76153a048"
PREFIX ${EIGEN_SOURCE_DIR} PREFIX ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
......
...@@ -13,6 +13,10 @@ ...@@ -13,6 +13,10 @@
# limitations under the License. # limitations under the License.
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
# Always invoke `FIND_PACKAGE(Protobuf)` for importing function protobuf_generate_cpp
FIND_PACKAGE(Protobuf QUIET)
SET(PROTOBUF_FOUND "OFF")
# Print and set the protobuf library information, # Print and set the protobuf library information,
# finish this cmake process and exit from this file. # finish this cmake process and exit from this file.
...@@ -39,12 +43,19 @@ macro(PROMPT_PROTOBUF_LIB) ...@@ -39,12 +43,19 @@ macro(PROMPT_PROTOBUF_LIB)
ADD_LIBRARY(protobuf_lite ${protobuf_LIBTYPE} IMPORTED GLOBAL) ADD_LIBRARY(protobuf_lite ${protobuf_LIBTYPE} IMPORTED GLOBAL)
SET_PROPERTY(TARGET protobuf_lite PROPERTY IMPORTED_LOCATION ${PROTOBUF_LITE_LIBRARY}) SET_PROPERTY(TARGET protobuf_lite PROPERTY IMPORTED_LOCATION ${PROTOBUF_LITE_LIBRARY})
ADD_LIBRARY(protoc ${protobuf_LIBTYPE} IMPORTED GLOBAL) ADD_LIBRARY(libprotoc ${protobuf_LIBTYPE} IMPORTED GLOBAL)
SET_PROPERTY(TARGET protoc PROPERTY IMPORTED_LOCATION ${PROTOC_LIBRARY}) SET_PROPERTY(TARGET libprotoc PROPERTY IMPORTED_LOCATION ${PROTOC_LIBRARY})
ADD_EXECUTABLE(protoc IMPORTED GLOBAL)
SET_PROPERTY(TARGET protoc PROPERTY IMPORTED_LOCATION ${PROTOBUF_PROTOC_EXECUTABLE})
# FIND_Protobuf.cmake uses `Protobuf_PROTOC_EXECUTABLE`.
# make `protobuf_generate_cpp` happy.
SET(Protobuf_PROTOC_EXECUTABLE ${PROTOBUF_PROTOC_EXECUTABLE})
FOREACH(dep ${protobuf_DEPS}) FOREACH(dep ${protobuf_DEPS})
ADD_DEPENDENCIES(protobuf ${dep}) ADD_DEPENDENCIES(protobuf ${dep})
ADD_DEPENDENCIES(protobuf_lite ${dep}) ADD_DEPENDENCIES(protobuf_lite ${dep})
ADD_DEPENDENCIES(libprotoc ${dep})
ADD_DEPENDENCIES(protoc ${dep}) ADD_DEPENDENCIES(protoc ${dep})
ENDFOREACH() ENDFOREACH()
......
...@@ -87,6 +87,9 @@ ...@@ -87,6 +87,9 @@
# go_library(example SHARED) # go_library(example SHARED)
# #
# including binary directory for generated headers.
include_directories(${CMAKE_BINARY_DIR})
if(NOT APPLE) if(NOT APPLE)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
link_libraries(${CMAKE_THREAD_LIBS_INIT}) link_libraries(${CMAKE_THREAD_LIBS_INIT})
...@@ -331,3 +334,13 @@ function(go_test TARGET_NAME) ...@@ -331,3 +334,13 @@ function(go_test TARGET_NAME)
add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS}) add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS})
add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}) add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME})
endfunction(go_test) endfunction(go_test)
function(proto_library TARGET_NAME)
set(oneValueArgs "")
set(multiValueArgs SRCS)
cmake_parse_arguments(proto_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(proto_srcs)
set(proto_hdrs)
protobuf_generate_cpp(proto_srcs proto_hdrs ${proto_library_SRCS})
cc_library(${TARGET_NAME} SRCS ${proto_srcs} DEPS protobuf)
endfunction()
...@@ -27,10 +27,6 @@ sphinx_add_target(paddle_docs ...@@ -27,10 +27,6 @@ sphinx_add_target(paddle_docs
${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}
${SPHINX_HTML_DIR_EN}) ${SPHINX_HTML_DIR_EN})
add_dependencies(paddle_docs
gen_proto_py)
# configured documentation tools and intermediate build results # configured documentation tools and intermediate build results
set(BINARY_BUILD_DIR_CN "${CMAKE_CURRENT_BINARY_DIR}/cn/_build") set(BINARY_BUILD_DIR_CN "${CMAKE_CURRENT_BINARY_DIR}/cn/_build")
...@@ -51,6 +47,3 @@ sphinx_add_target(paddle_docs_cn ...@@ -51,6 +47,3 @@ sphinx_add_target(paddle_docs_cn
${SPHINX_CACHE_DIR_CN} ${SPHINX_CACHE_DIR_CN}
${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}
${SPHINX_HTML_DIR_CN}) ${SPHINX_HTML_DIR_CN})
add_dependencies(paddle_docs_cn
gen_proto_py)
...@@ -41,7 +41,7 @@ class Scope { ...@@ -41,7 +41,7 @@ class Scope {
const Variable* GetVariable(const std::string& name) const; const Variable* GetVariable(const std::string& name) const;
private: private:
std::unordered_map<std::string, std::unique_ptr<Vairable>> vars_; std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
}; };
``` ```
...@@ -59,9 +59,9 @@ class Scope { ...@@ -59,9 +59,9 @@ class Scope {
Scope(const std::shared_ptr<Scope>& scope): parent_(scope) {} Scope(const std::shared_ptr<Scope>& scope): parent_(scope) {}
Variable* GetVariable(const std::string& name) const { Variable* GetVariable(const std::string& name) const {
Variable* var = GetVarLocally(name); auto it = vars_.find(name);
if (var != nullptr) { if (it != vars_.end()) {
return var; return it->second.get();
} else if (parent_ != nullptr) { } else if (parent_ != nullptr) {
return parent_->GetVariable(name); return parent_->GetVariable(name);
} else { } else {
...@@ -97,8 +97,8 @@ class Scope { ...@@ -97,8 +97,8 @@ class Scope {
// return nullptr if not found. // return nullptr if not found.
Variable* GetVariable(const std::string& name) const; Variable* GetVariable(const std::string& name) const;
// return Error if already contains same name variable. // return if already contains same name variable.
Error CreateVariable(const std::string& name); Variable* CreateVariable(const std::string& name);
private: private:
std::shared_ptr<Scope> parent_; std::shared_ptr<Scope> parent_;
......
...@@ -13,10 +13,13 @@ typedef int paddle_master_client; ...@@ -13,10 +13,13 @@ typedef int paddle_master_client;
import "C" import "C"
import ( import (
"strings"
"sync" "sync"
"time"
"unsafe" "unsafe"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
...@@ -48,16 +51,33 @@ func remove(client C.paddle_master_client) *master.Client { ...@@ -48,16 +51,33 @@ func remove(client C.paddle_master_client) *master.Client {
return h return h
} }
type addresser string //export paddle_new_etcd_master_client
func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client {
func (a addresser) Address() string { p := C.GoString(etcdEndpoints)
return string(a) cli, err := clientv3.New(clientv3.Config{
Endpoints: strings.Split(p, ","),
DialTimeout: time.Second * time.Duration(timeout),
})
if err != nil {
panic(err)
}
ch := make(chan string, 1)
a, err := master.GetKey(cli, master.DefaultAddrPath, timeout)
if err != nil {
panic(err)
}
ch <- a
go master.WatchKey(cli, master.DefaultAddrPath, ch)
c := master.NewClient(ch, bufSize)
return add(c)
} }
//export paddle_new_master_client //export paddle_new_master_client
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client { func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
a := C.GoString(addr) a := C.GoString(addr)
c := master.NewClient(addresser(a), bufSize) ch := make(chan string, 1)
ch <- a
c := master.NewClient(ch, bufSize)
return add(c) return add(c)
} }
......
...@@ -2,18 +2,12 @@ package master ...@@ -2,18 +2,12 @@ package master
import ( import (
"os" "os"
"time"
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// Addresser provide the address of the master server.
type Addresser interface {
Address() string
}
// Client is the client of the master server. // Client is the client of the master server.
type Client struct { type Client struct {
conn *connection.Conn conn *connection.Conn
...@@ -24,11 +18,11 @@ type Client struct { ...@@ -24,11 +18,11 @@ type Client struct {
// //
// bufSize is the record buffer size. NextRecord will read from this // bufSize is the record buffer size. NextRecord will read from this
// buffer. // buffer.
func NewClient(addr Addresser, bufSize int) *Client { func NewClient(addrCh <-chan string, bufSize int) *Client {
c := &Client{} c := &Client{}
c.conn = connection.New() c.conn = connection.New()
c.ch = make(chan []byte, bufSize) c.ch = make(chan []byte, bufSize)
go c.monitorMaster(addr) go c.monitorMaster(addrCh)
go c.getRecords() go c.getRecords()
return c return c
} }
...@@ -72,12 +66,10 @@ func (c *Client) getRecords() { ...@@ -72,12 +66,10 @@ func (c *Client) getRecords() {
} }
} }
func (c *Client) monitorMaster(addr Addresser) { func (c *Client) monitorMaster(addrCh <-chan string) {
lastMaster := "" lastMaster := ""
monitor := func() { for curMaster := range addrCh {
// get the lastest address of the master server,
// connect to the new address once address changed. // connect to the new address once address changed.
curMaster := addr.Address()
if curMaster != lastMaster { if curMaster != lastMaster {
if curMaster == "" { if curMaster == "" {
err := c.conn.Close() err := c.conn.Close()
...@@ -94,18 +86,10 @@ func (c *Client) monitorMaster(addr Addresser) { ...@@ -94,18 +86,10 @@ func (c *Client) monitorMaster(addr Addresser) {
// to retry next time. // to retry next time.
curMaster = lastMaster curMaster = lastMaster
} }
} }
} }
lastMaster = curMaster lastMaster = curMaster
} }
monitor()
ticker := time.NewTicker(10 * time.Second)
for _ = range ticker.C {
monitor()
}
} }
// SetDataset set dataset for the master server to dispatch. // SetDataset set dataset for the master server to dispatch.
......
...@@ -26,12 +26,6 @@ func init() { ...@@ -26,12 +26,6 @@ func init() {
log.SetLevel(log.ErrorLevel) log.SetLevel(log.ErrorLevel)
} }
type TestAddresser string
func (a TestAddresser) Address() string {
return string(a)
}
func TestGetFinishTask(t *testing.T) { func TestGetFinishTask(t *testing.T) {
const path = "/tmp/master_client_test_0" const path = "/tmp/master_client_test_0"
...@@ -45,7 +39,6 @@ func TestGetFinishTask(t *testing.T) { ...@@ -45,7 +39,6 @@ func TestGetFinishTask(t *testing.T) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
go func(l net.Listener) { go func(l net.Listener) {
s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1) s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if err != nil { if err != nil {
...@@ -82,9 +75,11 @@ func TestGetFinishTask(t *testing.T) { ...@@ -82,9 +75,11 @@ func TestGetFinishTask(t *testing.T) {
// Manually intialize client to avoid calling c.getRecords() // Manually intialize client to avoid calling c.getRecords()
c := &Client{} c := &Client{}
c.conn = connection.New() c.conn = connection.New()
go c.monitorMaster(TestAddresser(fmt.Sprintf(":%d", p))) addr := fmt.Sprintf(":%d", p)
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
c.SetDataset([]string{path}) c.SetDataset([]string{path})
checkOnePass := func(i int) { checkOnePass := func(i int) {
var tasks []Task var tasks []Task
for idx := 0; idx < totalTask; idx++ { for idx := 0; idx < totalTask; idx++ {
......
...@@ -20,7 +20,6 @@ func TestNextRecord(t *testing.T) { ...@@ -20,7 +20,6 @@ func TestNextRecord(t *testing.T) {
path = "/tmp/master_client_TestFull" path = "/tmp/master_client_TestFull"
total = 50 total = 50
) )
l, err := net.Listen("tcp", ":0") l, err := net.Listen("tcp", ":0")
if err != nil { if err != nil {
panic(err) panic(err)
...@@ -31,7 +30,6 @@ func TestNextRecord(t *testing.T) { ...@@ -31,7 +30,6 @@ func TestNextRecord(t *testing.T) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
go func(l net.Listener) { go func(l net.Listener) {
s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1) s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1)
if err != nil { if err != nil {
...@@ -63,10 +61,10 @@ func TestNextRecord(t *testing.T) { ...@@ -63,10 +61,10 @@ func TestNextRecord(t *testing.T) {
} }
w.Close() w.Close()
f.Close() f.Close()
curAddr := make(chan string, 1)
c := master.NewClient(master.TestAddresser(fmt.Sprintf(":%d", p)), 10) curAddr <- fmt.Sprintf(":%d", p)
c := master.NewClient(curAddr, 10)
c.SetDataset([]string{path}) c.SetDataset([]string{path})
for pass := 0; pass < 50; pass++ { for pass := 0; pass < 50; pass++ {
received := make(map[byte]bool) received := make(map[byte]bool)
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
......
...@@ -142,3 +142,31 @@ func (e *EtcdClient) Load() ([]byte, error) { ...@@ -142,3 +142,31 @@ func (e *EtcdClient) Load() ([]byte, error) {
state := kvs[0].Value state := kvs[0].Value
return state, nil return state, nil
} }
// GetKey gets the value by the specify key.
func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
resp, err := c.Get(ctx, key)
cancel()
if err != nil {
return "", err
}
kvs := resp.Kvs
if len(kvs) == 0 {
return "", nil
}
v := kvs[0].Value
return string(v), nil
}
// WatchKey watches the specify key and send to valChan if there is some event.
func WatchKey(c *clientv3.Client, key string, valChan chan<- string) {
rch := c.Watch(context.Background(), key)
for wresp := range rch {
for _, ev := range wresp.Events {
// if received event is DELETE, the value will be an empty string
log.Infof("received event %s, %q : %q\n", ev.Type, ev.Kv.Key, ev.Kv.Value)
valChan <- string(ev.Kv.Value)
}
}
}
package pserver package pserver
import ( import (
"errors"
"hash/fnv" "hash/fnv"
"sort" "sort"
"time" "time"
...@@ -123,6 +124,9 @@ func (c *Client) FinishInitParams() error { ...@@ -123,6 +124,9 @@ func (c *Client) FinishInitParams() error {
// SendGrads sends gradients to parameter servers for updating // SendGrads sends gradients to parameter servers for updating
// parameters. // parameters.
func (c *Client) SendGrads(grads []Gradient) error { func (c *Client) SendGrads(grads []Gradient) error {
if len(grads) == 0 {
return errors.New("no gradient received")
}
errCh := make(chan error, len(grads)) errCh := make(chan error, len(grads))
for _, g := range grads { for _, g := range grads {
go func(g Gradient) { go func(g Gradient) {
......
...@@ -16,7 +16,7 @@ set(API_HEADER ...@@ -16,7 +16,7 @@ set(API_HEADER
Internal.h) Internal.h)
add_library(paddle_api STATIC ${API_SOURCES}) add_library(paddle_api STATIC ${API_SOURCES})
add_dependencies(paddle_api gen_proto_cpp paddle_trainer_lib) add_dependencies(paddle_api paddle_proto paddle_trainer_lib)
INCLUDE(${SWIG_USE_FILE}) INCLUDE(${SWIG_USE_FILE})
INCLUDE_DIRECTORIES(${PROJ_ROOT}/paddle) INCLUDE_DIRECTORIES(${PROJ_ROOT}/paddle)
......
...@@ -26,7 +26,7 @@ target_include_directories(paddle_capi PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) ...@@ -26,7 +26,7 @@ target_include_directories(paddle_capi PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
add_style_check_target(paddle_capi ${CAPI_SOURCES} ${CAPI_HEADER} add_style_check_target(paddle_capi ${CAPI_SOURCES} ${CAPI_HEADER}
${CAPI_PRIVATE_HEADER}) ${CAPI_PRIVATE_HEADER})
add_dependencies(paddle_capi gen_proto_cpp) add_dependencies(paddle_capi paddle_proto)
# combine all paddle static libraries together, into libpaddle_capi_whole.a # combine all paddle static libraries together, into libpaddle_capi_whole.a
......
...@@ -83,7 +83,7 @@ else() ...@@ -83,7 +83,7 @@ else()
${CUDA_CXX_SOURCES}) ${CUDA_CXX_SOURCES})
endif() endif()
add_dependencies(paddle_cuda ${external_project_dependencies}) add_dependencies(paddle_cuda paddle_proto ${external_project_dependencies})
add_style_check_target(paddle_cuda add_style_check_target(paddle_cuda
${CUDA_SOURCES} ${CUDA_SOURCES}
......
# ddim lib
cc_library(ddim SRCS ddim.cc) cc_library(ddim SRCS ddim.cc)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
cc_test(scope_test SRCS scope_test.cc)
cc_test(enforce_test SRCS enforce_test.cc)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <paddle/string/printf.h>
#include <exception>
#include <sstream>
namespace paddle {
namespace framework {
/**
* @brief Enforce exception. Inherits std::exception
*
* All enforce condition not met, will throw an EnforceNotMet exception.
*/
class EnforceNotMet : public std::exception {
public:
EnforceNotMet(const std::string& msg, const char* file, int fileline) {
std::ostringstream sout;
sout << msg << " at [" << file << ":" << fileline << "];";
all_msg_ = sout.str();
}
const char* what() const noexcept override { return all_msg_.c_str(); }
private:
std::string all_msg_;
};
// From https://stackoverflow.com/questions/30130930/
// __buildin_expect is in C++ 11 standard. Since the condition which enforced
// should be true in most situation, it will make the compiler generate faster
// code by adding `UNLIKELY` macro.
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
/**
* @brief Throw a EnforceNotMet exception, automatically filled __FILE__ &
* __LINE__
*
* This macro take __VA_ARGS__, user can pass any type if that type can
* serialize to std::ostream
*/
#define PADDLE_THROW(...) \
do { \
throw ::paddle::framework::EnforceNotMet( \
::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \
} while (0)
/**
* @brief Enforce a condition, otherwise throw an EnforceNotMet
*/
#define PADDLE_ENFORCE(condition, ...) \
do { \
if (UNLIKELY(!(condition))) { \
PADDLE_THROW(__VA_ARGS__); \
} \
} while (0)
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <paddle/framework/enforce.h>
TEST(ENFORCE, OK) {
PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345);
size_t val = 1;
const size_t limit = 10;
PADDLE_ENFORCE(val < limit, "Enforce is OK too");
}
TEST(ENFORCE, FAILED) {
bool in_catch = false;
try {
PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123);
} catch (paddle::framework::EnforceNotMet err) {
in_catch = true;
std::string msg = "Enforce is not ok 123 at all";
const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
}
ASSERT_TRUE(in_catch);
}
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/framework/variable.h"
namespace paddle {
namespace framework {
/**
* @brief Scope that manage all variables.
*
* Scope is an association of a name to Variable. All variables belong to
* Scope. You need to specify a scope to run a Net, i.e., `net.Run(&scope)`.
* One net can run in different scopes and update different variable in the
* scope.
*/
class Scope {
public:
/**
* @brief Initialize s Scope without parent.
*/
Scope() {}
/**
* @brief Initialize a Scope with parent.
*/
explicit Scope(const std::shared_ptr<Scope>& parent) : parent_(parent) {}
/**
* @brief Create Variable
*
* Create Variable in this Scope. Return the exist one if Variable already
* been created.
*/
Variable* CreateVariable(const std::string& name) {
auto var = GetVariable(name);
if (var) {
return var;
} else {
vars_[name] = std::unique_ptr<Variable>(new Variable());
return GetVariable(name);
}
}
/**
* @brief Get Variable.
*
* Get Variable from this Scope, this function will recursive find Variable
* from it's parent scope. Return nullptr if not found.
*/
Variable* GetVariable(const std::string& name) const {
auto it = vars_.find(name);
if (it != vars_.end()) {
return it->second.get();
} else if (parent_ != nullptr) {
return parent_->GetVariable(name);
} else {
return nullptr;
}
}
/**
* @brief If this scope has a Var named name.
*
* Find if there is a Variable in this scope and it's parent scope
*/
bool HasVariable(const std::string& name) const {
return (vars_.find(name) != vars_.end() ||
(parent_ && parent_->HasVariable(name)));
}
private:
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
std::shared_ptr<Scope> parent_{nullptr};
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/scope.h"
#include "gtest/gtest.h"
TEST(Scope, Create) {
using paddle::framework::Scope;
using paddle::framework::Variable;
auto scope = std::make_shared<Scope>();
Variable* var0 = scope->CreateVariable("");
EXPECT_NE(var0, nullptr);
/// GetVariable will return nullptr if not exist.
Variable* var1 = scope->GetVariable("a");
EXPECT_EQ(var1, nullptr);
/// CreateVariable will return one.
Variable* var2 = scope->CreateVariable("a");
EXPECT_NE(var2, nullptr);
/// Get the created variable.
Variable* var3 = scope->GetVariable("a");
EXPECT_EQ(var2, var3);
/// CreateVariable will just return the variable if it's
/// already exist.
Variable* var4 = scope->CreateVariable("a");
EXPECT_EQ(var4, var2);
}
TEST(Scope, Parent) {
using paddle::framework::Scope;
using paddle::framework::Variable;
auto parent_scope = std::make_shared<Scope>();
auto scope = std::make_shared<Scope>(parent_scope);
Variable* var0 = parent_scope->CreateVariable("a");
EXPECT_NE(var0, nullptr);
/// GetVariable will get Variable from parent scope if exist.
Variable* var1 = scope->GetVariable("a");
EXPECT_EQ(var0, var1);
}
...@@ -12,7 +12,7 @@ endif() ...@@ -12,7 +12,7 @@ endif()
add_library(paddle_function STATIC ${cpp_files} ${cu_objs}) add_library(paddle_function STATIC ${cpp_files} ${cu_objs})
add_dependencies(paddle_function ${external_project_dependencies}) add_dependencies(paddle_function ${external_project_dependencies})
add_dependencies(paddle_function gen_proto_cpp) add_dependencies(paddle_function paddle_proto)
if(WITH_TESTING) if(WITH_TESTING)
if(WITH_GPU) if(WITH_GPU)
......
...@@ -58,7 +58,7 @@ endif() ...@@ -58,7 +58,7 @@ endif()
add_style_check_target(paddle_gserver ${GSERVER_SOURCES}) add_style_check_target(paddle_gserver ${GSERVER_SOURCES})
add_style_check_target(paddle_gserver ${GSERVER_HEADER}) add_style_check_target(paddle_gserver ${GSERVER_HEADER})
add_dependencies(paddle_gserver gen_proto_cpp) add_dependencies(paddle_gserver paddle_proto ${external_project_dependencies})
if(WITH_TESTING) if(WITH_TESTING)
add_subdirectory(tests) add_subdirectory(tests)
endif() endif()
...@@ -33,7 +33,7 @@ endif() ...@@ -33,7 +33,7 @@ endif()
add_style_check_target(paddle_math ${MATH_SOURCES}) add_style_check_target(paddle_math ${MATH_SOURCES})
add_style_check_target(paddle_math ${MATH_HEADERS}) add_style_check_target(paddle_math ${MATH_HEADERS})
add_dependencies(paddle_math gen_proto_cpp) # depends add_dependencies(paddle_math paddle_proto ${external_project_dependencies}) # depends
if(WITH_TESTING) if(WITH_TESTING)
add_subdirectory(tests) add_subdirectory(tests)
endif() endif()
...@@ -10,7 +10,7 @@ set(OPITMIZER_SRCS ...@@ -10,7 +10,7 @@ set(OPITMIZER_SRCS
) )
add_library(paddle_optimizer STATIC ${OPITMIZER_SRCS}) add_library(paddle_optimizer STATIC ${OPITMIZER_SRCS})
add_dependencies(paddle_optimizer gen_proto_cpp) add_dependencies(paddle_optimizer paddle_proto ${external_project_dependencies})
if(WITH_TESTING) if(WITH_TESTING)
add_simple_unittest(serialization_test) add_simple_unittest(serialization_test)
......
...@@ -7,7 +7,7 @@ add_library(paddle_parameter STATIC ...@@ -7,7 +7,7 @@ add_library(paddle_parameter STATIC
${PARAMETERS_SOURCES}) ${PARAMETERS_SOURCES})
add_style_check_target(paddle_parameter ${PARAMETERS_SOURCES}) add_style_check_target(paddle_parameter ${PARAMETERS_SOURCES})
add_style_check_target(paddle_parameter ${PARAMETERS_HEADERS}) add_style_check_target(paddle_parameter ${PARAMETERS_HEADERS})
add_dependencies(paddle_parameter gen_proto_cpp) add_dependencies(paddle_parameter paddle_proto ${external_project_dependencies})
if(WITH_TESTING) if(WITH_TESTING)
add_subdirectory(tests) add_subdirectory(tests)
endif() endif()
...@@ -17,7 +17,7 @@ add_library(paddle_network STATIC ...@@ -17,7 +17,7 @@ add_library(paddle_network STATIC
add_style_check_target(paddle_network ${NETWORK_SOURCES}) add_style_check_target(paddle_network ${NETWORK_SOURCES})
add_style_check_target(paddle_network ${NETWORK_HEADERS}) add_style_check_target(paddle_network ${NETWORK_HEADERS})
add_dependencies(paddle_network gen_proto_cpp) add_dependencies(paddle_network paddle_proto ${external_project_dependencies})
################### paddle_pserver ###################### ################### paddle_pserver ######################
set(PSERVER_SOURCES set(PSERVER_SOURCES
...@@ -40,7 +40,7 @@ add_library(paddle_pserver STATIC ...@@ -40,7 +40,7 @@ add_library(paddle_pserver STATIC
add_style_check_target(paddle_pserver ${PSERVER_SOURCES}) add_style_check_target(paddle_pserver ${PSERVER_SOURCES})
add_style_check_target(paddle_pserver ${PSERVER_HEADERS}) add_style_check_target(paddle_pserver ${PSERVER_HEADERS})
add_dependencies(paddle_pserver gen_proto_cpp) add_dependencies(paddle_pserver paddle_proto ${external_project_dependencies})
set(PSERVER_MAIN_SOURCES set(PSERVER_MAIN_SOURCES
ParameterServer2Main.cpp) ParameterServer2Main.cpp)
......
...@@ -144,7 +144,7 @@ class DenseScanner(IScanner): ...@@ -144,7 +144,7 @@ class DenseScanner(IScanner):
if len(self.__shape__) > 1: if len(self.__shape__) > 1:
# The last-two dimenstions are the frame height and width. # The last-two dimenstions are the frame height and width.
# For example, the layout is CHW for 3-D feature of image. # For example, the layout is CHW for 3-D feature of image.
# The H and W are the fram height and width. # The H and W are the frame height and width.
h, w = self.__shape__[-2:] h, w = self.__shape__[-2:]
argument.setSlotFrameHeight(self.pos, h) argument.setSlotFrameHeight(self.pos, h)
argument.setSlotFrameWidth(self.pos, w) argument.setSlotFrameWidth(self.pos, w)
......
cc_library(stringpiece SRCS piece.cc) cc_library(stringpiece SRCS piece.cc)
cc_test(stringpiece_test SRCS piece_test.cc DEPS stringpiece glog gflags) cc_test(stringpiece_test SRCS piece_test.cc DEPS stringpiece glog gflags)
cc_test(stringprintf_test SRCS printf_test.cc DEPS glog gflags)
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Compared with std::stringstream, there are primary purpose of
// string::Printf:
//
// 1. Type-safe printing, with why and how explained in
// http://www.drdobbs.com/stringprintf-a-typesafe-printf-family-fo/184401999.
// Implementation includes
//
// https://github.com/c42f/tinyformat
// boost::format
// std::stringstream
//
// std::stringstream is not convenient enough in many cases. For example:
//
// std::cout << std::setprecision(2) << std::fixed << 1.23456 << "\n";
//
// boost::format is the most convenient one. We can have
//
// std::cout << format("%2% %1%") % 36 % 77;
//
// or
//
// format fmter("%2% %1%");
// fmter % 36; fmter % 77;
// std::cout << fmter.c_str();
//
// But the overloading of % might be overkilling and it would be
// more efficient if it can write to std::cout directly.
//
// tinyformat has an interface compatible with the C-printf style,
// and it can writes to a stream or returns a std::string:
//
// std::cout << tfm::printf(
// "%s, %s %d, %.2d:%.2d\n",
// weekday, month, day, hour, min);
//
// or
//
// tfm::format(std::cout,
// "%s, %s %d, %.2d:%.2d\n",
// weekday, month, day, hour, min);
//
// 2. High-performance -- most printed strings are not too long and
// doens't need dynamic memory allocation. Many StringPrintf
// implementations doesn't enforce type-safe, but are
// high-performance, including
//
// https://developers.google.com/optimization/reference/base/stringprintf/
// https://github.com/adobe/chromium/blob/master/base/stringprintf.h
// https://github.com/google/protobuf/blob/master/src/google/protobuf/stubs/stringprintf.h
//
// According to
// https://github.com/c42f/tinyformat#compile-time-and-code-bloat,
// boost::format runs too slow and results in large executable binary
// files. So here we port tinyformat.
#pragma once
#include <iostream>
#include <sstream>
#include "paddle/string/tinyformat/tinyformat.h" // https://github.com/c42f/tinyformat
namespace paddle {
namespace string {
template <typename... Args>
void Fprintf(std::ostream& out, const char* fmt, const Args&... args) {
tinyformat::vformat(out, fmt, tinyformat::makeFormatList(args...));
}
template <typename... Args>
std::string Sprintf(const char* fmt, const Args&... args) {
std::ostringstream oss;
Fprintf(oss, fmt, args...);
return oss.str();
}
template <typename... Args>
void Printf(const char* fmt, const Args&... args) {
Fprintf(std::cout, fmt, args...);
}
} // namespace string
} // namespace paddle
#include "paddle/string/printf.h"
#include <string>
#include "gtest/gtest.h"
TEST(StringPrintf, StringPrintf) {
std::string weekday = "Wednesday";
const char* month = "July";
size_t day = 27;
long hour = 14;
int min = 44;
EXPECT_EQ(std::string("Wednesday, July 27, 14:44"),
paddle::string::Sprintf(
"%s, %s %d, %.2d:%.2d", weekday, month, day, hour, min));
}
// tinyformat.h
// Copyright (C) 2011, Chris Foster [chris42f (at) gmail (d0t) com]
//
// Boost Software License - Version 1.0
//
// Permission is hereby granted, free of charge, to any person or organization
// obtaining a copy of the software and accompanying documentation covered by
// this license (the "Software") to use, reproduce, display, distribute,
// execute, and transmit the Software, and to prepare derivative works of the
// Software, and to permit third-parties to whom the Software is furnished to
// do so, all subject to the following:
//
// The copyright notices in the Software and this entire statement, including
// the above license grant, this restriction and the following disclaimer,
// must be included in all copies of the Software, in whole or in part, and
// all derivative works of the Software, unless such copies or derivative
// works are solely in the form of machine-executable object code generated by
// a source language processor.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
// SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
// FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
//------------------------------------------------------------------------------
// Tinyformat: A minimal type safe printf replacement
//
// tinyformat.h is a type safe printf replacement library in a single C++
// header file. Design goals include:
//
// * Type safety and extensibility for user defined types.
// * C99 printf() compatibility, to the extent possible using std::ostream
// * Simplicity and minimalism. A single header file to include and distribute
// with your projects.
// * Augment rather than replace the standard stream formatting mechanism
// * C++98 support, with optional C++11 niceties
//
//
// Main interface example usage
// ----------------------------
//
// To print a date to std::cout:
//
// std::string weekday = "Wednesday";
// const char* month = "July";
// size_t day = 27;
// long hour = 14;
// int min = 44;
//
// tfm::printf("%s, %s %d, %.2d:%.2d\n", weekday, month, day, hour, min);
//
// The strange types here emphasize the type safety of the interface; it is
// possible to print a std::string using the "%s" conversion, and a
// size_t using the "%d" conversion. A similar result could be achieved
// using either of the tfm::format() functions. One prints on a user provided
// stream:
//
// tfm::format(std::cerr, "%s, %s %d, %.2d:%.2d\n",
// weekday, month, day, hour, min);
//
// The other returns a std::string:
//
// std::string date = tfm::format("%s, %s %d, %.2d:%.2d\n",
// weekday, month, day, hour, min);
// std::cout << date;
//
// These are the three primary interface functions. There is also a
// convenience function printfln() which appends a newline to the usual result
// of printf() for super simple logging.
//
//
// User defined format functions
// -----------------------------
//
// Simulating variadic templates in C++98 is pretty painful since it requires
// writing out the same function for each desired number of arguments. To make
// this bearable tinyformat comes with a set of macros which are used
// internally to generate the API, but which may also be used in user code.
//
// The three macros TINYFORMAT_ARGTYPES(n), TINYFORMAT_VARARGS(n) and
// TINYFORMAT_PASSARGS(n) will generate a list of n argument types,
// type/name pairs and argument names respectively when called with an integer
// n between 1 and 16. We can use these to define a macro which generates the
// desired user defined function with n arguments. To generate all 16 user
// defined function bodies, use the macro TINYFORMAT_FOREACH_ARGNUM. For an
// example, see the implementation of printf() at the end of the source file.
//
// Sometimes it's useful to be able to pass a list of format arguments through
// to a non-template function. The FormatList class is provided as a way to do
// this by storing the argument list in a type-opaque way. Continuing the
// example from above, we construct a FormatList using makeFormatList():
//
// FormatListRef formatList = tfm::makeFormatList(weekday, month, day, hour,
// min);
//
// The format list can now be passed into any non-template function and used
// via a call to the vformat() function:
//
// tfm::vformat(std::cout, "%s, %s %d, %.2d:%.2d\n", formatList);
//
//
// Additional API information
// --------------------------
//
// Error handling: Define TINYFORMAT_ERROR to customize the error handling for
// format strings which are unsupported or have the wrong number of format
// specifiers (calls assert() by default).
//
// User defined types: Uses operator<< for user defined types by default.
// Overload formatValue() for more control.
#pragma once
#include <algorithm>
#include <cassert>
#include <iostream>
#include <sstream>
namespace paddle {
namespace string {
namespace tinyformat {
#ifndef TINYFORMAT_ERROR
#define TINYFORMAT_ERROR(reason) assert(0 && reason)
#endif
//------------------------------------------------------------------------------
namespace detail {
// Test whether type T1 is convertible to type T2
template <typename T1, typename T2>
struct is_convertible {
private:
// two types of different size
struct fail {
char dummy[2];
};
struct succeed {
char dummy;
};
// Try to convert a T1 to a T2 by plugging into tryConvert
static fail tryConvert(...);
static succeed tryConvert(const T2 &);
static const T1 &makeT1();
public:
// Standard trick: the (...) version of tryConvert will be chosen from
// the overload set only if the version taking a T2 doesn't match.
// Then we compare the sizes of the return types to check which
// function matched. Very neat, in a disgusting kind of way :)
static const bool value = sizeof(tryConvert(makeT1())) == sizeof(succeed);
};
// Format the value by casting to type fmtT. This default implementation
// should never be called.
template <typename T,
typename fmtT,
bool convertible = is_convertible<T, fmtT>::value>
struct formatValueAsType {
static void invoke(std::ostream & /*out*/, const T & /*value*/) { assert(0); }
};
// Specialized version for types that can actually be converted to fmtT, as
// indicated by the "convertible" template parameter.
template <typename T, typename fmtT>
struct formatValueAsType<T, fmtT, true> {
static void invoke(std::ostream &out, const T &value) {
out << static_cast<fmtT>(value);
}
};
// Convert an arbitrary type to integer. The version with convertible=false
// throws an error.
template <typename T, bool convertible = is_convertible<T, int>::value>
struct convertToInt {
static int invoke(const T & /*value*/) {
TINYFORMAT_ERROR(
"tinyformat: Cannot convert from argument type to "
"integer for use as variable width or precision");
return 0;
}
};
// Specialization for convertToInt when conversion is possible
template <typename T>
struct convertToInt<T, true> {
static int invoke(const T &value) { return static_cast<int>(value); }
};
// Format at most ntrunc characters to the given stream.
template <typename T>
inline void formatTruncated(std::ostream &out, const T &value, int ntrunc) {
std::ostringstream tmp;
tmp << value;
std::string result = tmp.str();
out.write(result.c_str(),
(std::min)(ntrunc, static_cast<int>(result.size())));
}
#define TINYFORMAT_DEFINE_FORMAT_TRUNCATED_CSTR(type) \
inline void formatTruncated(std::ostream &out, type *value, int ntrunc) { \
std::streamsize len = 0; \
while (len < ntrunc && value[len] != 0) ++len; \
out.write(value, len); \
}
// Overload for const char* and char*. Could overload for signed & unsigned
// char too, but these are technically unneeded for printf compatibility.
TINYFORMAT_DEFINE_FORMAT_TRUNCATED_CSTR(const char)
TINYFORMAT_DEFINE_FORMAT_TRUNCATED_CSTR(char)
#undef TINYFORMAT_DEFINE_FORMAT_TRUNCATED_CSTR
} // namespace detail
//------------------------------------------------------------------------------
// Variable formatting functions. May be overridden for user-defined types if
// desired.
/// Format a value into a stream, delegating to operator<< by default.
///
/// Users may override this for their own types. When this function is called,
/// the stream flags will have been modified according to the format string.
/// The format specification is provided in the range [fmtBegin, fmtEnd). For
/// truncating conversions, ntrunc is set to the desired maximum number of
/// characters, for example "%.7s" calls formatValue with ntrunc = 7.
///
/// By default, formatValue() uses the usual stream insertion operator
/// operator<< to format the type T, with special cases for the %c and %p
/// conversions.
template <typename T>
inline void formatValue(std::ostream &out,
const char * /*fmtBegin*/,
const char *fmtEnd,
int ntrunc,
const T &value) {
// The mess here is to support the %c and %p conversions: if these
// conversions are active we try to convert the type to a char or const
// void* respectively and format that instead of the value itself. For the
// %p conversion it's important to avoid dereferencing the pointer, which
// could otherwise lead to a crash when printing a dangling (const char*).
const bool canConvertToChar = detail::is_convertible<T, char>::value;
const bool canConvertToVoidPtr =
detail::is_convertible<T, const void *>::value;
if (canConvertToChar && *(fmtEnd - 1) == 'c')
detail::formatValueAsType<T, char>::invoke(out, value);
else if (canConvertToVoidPtr && *(fmtEnd - 1) == 'p')
detail::formatValueAsType<T, const void *>::invoke(out, value);
else if (ntrunc >= 0) {
// Take care not to overread C strings in truncating conversions like
// "%.4s" where at most 4 characters may be read.
detail::formatTruncated(out, value, ntrunc);
} else
out << value;
}
// Overloaded version for char types to support printing as an integer
#define TINYFORMAT_DEFINE_FORMATVALUE_CHAR(charType) \
inline void formatValue(std::ostream &out, \
const char * /*fmtBegin*/, \
const char *fmtEnd, \
int /**/, \
charType value) { \
switch (*(fmtEnd - 1)) { \
case 'u': \
case 'd': \
case 'i': \
case 'o': \
case 'X': \
case 'x': \
out << static_cast<int>(value); \
break; \
default: \
out << value; \
break; \
} \
}
// per 3.9.1: char, signed char and unsigned char are all distinct types
TINYFORMAT_DEFINE_FORMATVALUE_CHAR(char)
TINYFORMAT_DEFINE_FORMATVALUE_CHAR(signed char)
TINYFORMAT_DEFINE_FORMATVALUE_CHAR(unsigned char)
#undef TINYFORMAT_DEFINE_FORMATVALUE_CHAR
//------------------------------------------------------------------------------
// Tools for emulating variadic templates in C++98. The basic idea here is
// stolen from the boost preprocessor metaprogramming library and cut down to
// be just general enough for what we need.
#define TINYFORMAT_ARGTYPES(n) TINYFORMAT_ARGTYPES_##n
#define TINYFORMAT_VARARGS(n) TINYFORMAT_VARARGS_##n
#define TINYFORMAT_PASSARGS(n) TINYFORMAT_PASSARGS_##n
#define TINYFORMAT_PASSARGS_TAIL(n) TINYFORMAT_PASSARGS_TAIL_##n
// To keep it as transparent as possible, the macros below have been generated
// using python via the excellent cog.py code generation script. This avoids
// the need for a bunch of complex (but more general) preprocessor tricks as
// used in boost.preprocessor.
//
// To rerun the code generation in place, use `cog.py -r tinyformat.h`
// (see http://nedbatchelder.com/code/cog). Alternatively you can just create
// extra versions by hand.
/*[[[cog
maxParams = 16
def makeCommaSepLists(lineTemplate, elemTemplate, startInd=1):
for j in range(startInd,maxParams+1):
list = ', '.join([elemTemplate % {'i':i} for i in range(startInd,j+1)])
cog.outl(lineTemplate % {'j':j, 'list':list})
makeCommaSepLists('#define TINYFORMAT_ARGTYPES_%(j)d %(list)s',
'class T%(i)d')
cog.outl()
makeCommaSepLists('#define TINYFORMAT_VARARGS_%(j)d %(list)s',
'const T%(i)d& v%(i)d')
cog.outl()
makeCommaSepLists('#define TINYFORMAT_PASSARGS_%(j)d %(list)s', 'v%(i)d')
cog.outl()
cog.outl('#define TINYFORMAT_PASSARGS_TAIL_1')
makeCommaSepLists('#define TINYFORMAT_PASSARGS_TAIL_%(j)d , %(list)s',
'v%(i)d', startInd = 2)
cog.outl()
cog.outl('#define TINYFORMAT_FOREACH_ARGNUM(m) \\\n ' +
' '.join(['m(%d)' % (j,) for j in range(1,maxParams+1)]))
]]]*/
#define TINYFORMAT_ARGTYPES_1 class T1
#define TINYFORMAT_ARGTYPES_2 class T1, class T2
#define TINYFORMAT_ARGTYPES_3 class T1, class T2, class T3
#define TINYFORMAT_ARGTYPES_4 class T1, class T2, class T3, class T4
#define TINYFORMAT_ARGTYPES_5 class T1, class T2, class T3, class T4, class T5
#define TINYFORMAT_ARGTYPES_6 \
class T1, class T2, class T3, class T4, class T5, class T6
#define TINYFORMAT_ARGTYPES_7 \
class T1, class T2, class T3, class T4, class T5, class T6, class T7
#define TINYFORMAT_ARGTYPES_8 \
class T1, class T2, class T3, class T4, class T5, class T6, class T7, class T8
#define TINYFORMAT_ARGTYPES_9 \
class T1, class T2, class T3, class T4, class T5, class T6, class T7, \
class T8, class T9
#define TINYFORMAT_ARGTYPES_10 \
class T1, class T2, class T3, class T4, class T5, class T6, class T7, \
class T8, class T9, class T10
#define TINYFORMAT_ARGTYPES_11 \
class T1, class T2, class T3, class T4, class T5, class T6, class T7, \
class T8, class T9, class T10, class T11
#define TINYFORMAT_ARGTYPES_12 \
class T1, class T2, class T3, class T4, class T5, class T6, class T7, \
class T8, class T9, class T10, class T11, class T12
#define TINYFORMAT_ARGTYPES_13 \
class T1, class T2, class T3, class T4, class T5, class T6, class T7, \
class T8, class T9, class T10, class T11, class T12, class T13
#define TINYFORMAT_ARGTYPES_14 \
class T1, class T2, class T3, class T4, class T5, class T6, class T7, \
class T8, class T9, class T10, class T11, class T12, class T13, \
class T14
#define TINYFORMAT_ARGTYPES_15 \
class T1, class T2, class T3, class T4, class T5, class T6, class T7, \
class T8, class T9, class T10, class T11, class T12, class T13, \
class T14, class T15
#define TINYFORMAT_ARGTYPES_16 \
class T1, class T2, class T3, class T4, class T5, class T6, class T7, \
class T8, class T9, class T10, class T11, class T12, class T13, \
class T14, class T15, class T16
#define TINYFORMAT_VARARGS_1 const T1 &v1
#define TINYFORMAT_VARARGS_2 const T1 &v1, const T2 &v2
#define TINYFORMAT_VARARGS_3 const T1 &v1, const T2 &v2, const T3 &v3
#define TINYFORMAT_VARARGS_4 \
const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4
#define TINYFORMAT_VARARGS_5 \
const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5
#define TINYFORMAT_VARARGS_6 \
const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \
const T6 &v6
#define TINYFORMAT_VARARGS_7 \
const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \
const T6 &v6, const T7 &v7
#define TINYFORMAT_VARARGS_8 \
const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \
const T6 &v6, const T7 &v7, const T8 &v8
#define TINYFORMAT_VARARGS_9 \
const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \
const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9
#define TINYFORMAT_VARARGS_10 \
const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \
const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10
#define TINYFORMAT_VARARGS_11 \
const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \
const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \
const T11 &v11
#define TINYFORMAT_VARARGS_12 \
const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \
const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \
const T11 &v11, const T12 &v12
#define TINYFORMAT_VARARGS_13 \
const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \
const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \
const T11 &v11, const T12 &v12, const T13 &v13
#define TINYFORMAT_VARARGS_14 \
const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \
const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \
const T11 &v11, const T12 &v12, const T13 &v13, const T14 &v14
#define TINYFORMAT_VARARGS_15 \
const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \
const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \
const T11 &v11, const T12 &v12, const T13 &v13, const T14 &v14, \
const T15 &v15
#define TINYFORMAT_VARARGS_16 \
const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \
const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \
const T11 &v11, const T12 &v12, const T13 &v13, const T14 &v14, \
const T15 &v15, const T16 &v16
#define TINYFORMAT_PASSARGS_1 v1
#define TINYFORMAT_PASSARGS_2 v1, v2
#define TINYFORMAT_PASSARGS_3 v1, v2, v3
#define TINYFORMAT_PASSARGS_4 v1, v2, v3, v4
#define TINYFORMAT_PASSARGS_5 v1, v2, v3, v4, v5
#define TINYFORMAT_PASSARGS_6 v1, v2, v3, v4, v5, v6
#define TINYFORMAT_PASSARGS_7 v1, v2, v3, v4, v5, v6, v7
#define TINYFORMAT_PASSARGS_8 v1, v2, v3, v4, v5, v6, v7, v8
#define TINYFORMAT_PASSARGS_9 v1, v2, v3, v4, v5, v6, v7, v8, v9
#define TINYFORMAT_PASSARGS_10 v1, v2, v3, v4, v5, v6, v7, v8, v9, v10
#define TINYFORMAT_PASSARGS_11 v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11
#define TINYFORMAT_PASSARGS_12 v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12
#define TINYFORMAT_PASSARGS_13 \
v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13
#define TINYFORMAT_PASSARGS_14 \
v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14
#define TINYFORMAT_PASSARGS_15 \
v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15
#define TINYFORMAT_PASSARGS_16 \
v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16
#define TINYFORMAT_PASSARGS_TAIL_1
#define TINYFORMAT_PASSARGS_TAIL_2 , v2
#define TINYFORMAT_PASSARGS_TAIL_3 , v2, v3
#define TINYFORMAT_PASSARGS_TAIL_4 , v2, v3, v4
#define TINYFORMAT_PASSARGS_TAIL_5 , v2, v3, v4, v5
#define TINYFORMAT_PASSARGS_TAIL_6 , v2, v3, v4, v5, v6
#define TINYFORMAT_PASSARGS_TAIL_7 , v2, v3, v4, v5, v6, v7
#define TINYFORMAT_PASSARGS_TAIL_8 , v2, v3, v4, v5, v6, v7, v8
#define TINYFORMAT_PASSARGS_TAIL_9 , v2, v3, v4, v5, v6, v7, v8, v9
#define TINYFORMAT_PASSARGS_TAIL_10 , v2, v3, v4, v5, v6, v7, v8, v9, v10
#define TINYFORMAT_PASSARGS_TAIL_11 , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11
#define TINYFORMAT_PASSARGS_TAIL_12 \
, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12
#define TINYFORMAT_PASSARGS_TAIL_13 \
, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13
#define TINYFORMAT_PASSARGS_TAIL_14 \
, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14
#define TINYFORMAT_PASSARGS_TAIL_15 \
, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15
#define TINYFORMAT_PASSARGS_TAIL_16 \
, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16
#define TINYFORMAT_FOREACH_ARGNUM(m) \
m(1) m(2) m(3) m(4) m(5) m(6) m(7) m(8) m(9) m(10) m(11) m(12) m(13) m(14) \
m(15) m(16)
//[[[end]]]
namespace detail {
// Type-opaque holder for an argument to format(), with associated actions on
// the type held as explicit function pointers. This allows FormatArg's for
// each argument to be allocated as a homogenous array inside FormatList
// whereas a naive implementation based on inheritance does not.
class FormatArg {
public:
FormatArg() {}
template <typename T>
FormatArg(const T &value)
: m_value(static_cast<const void *>(&value)),
m_formatImpl(&formatImpl<T>),
m_toIntImpl(&toIntImpl<T>) {}
void format(std::ostream &out,
const char *fmtBegin,
const char *fmtEnd,
int ntrunc) const {
m_formatImpl(out, fmtBegin, fmtEnd, ntrunc, m_value);
}
int toInt() const { return m_toIntImpl(m_value); }
private:
template <typename T>
static void formatImpl(std::ostream &out,
const char *fmtBegin,
const char *fmtEnd,
int ntrunc,
const void *value) {
formatValue(out, fmtBegin, fmtEnd, ntrunc, *static_cast<const T *>(value));
}
template <typename T>
static int toIntImpl(const void *value) {
return convertToInt<T>::invoke(*static_cast<const T *>(value));
}
const void *m_value;
void (*m_formatImpl)(std::ostream &out,
const char *fmtBegin,
const char *fmtEnd,
int ntrunc,
const void *value);
int (*m_toIntImpl)(const void *value);
};
// Parse and return an integer from the string c, as atoi()
// On return, c is set to one past the end of the integer.
inline int parseIntAndAdvance(const char *&c) {
int i = 0;
for (; *c >= '0' && *c <= '9'; ++c) i = 10 * i + (*c - '0');
return i;
}
// Print literal part of format string and return next format spec
// position.
//
// Skips over any occurrences of '%%', printing a literal '%' to the
// output. The position of the first % character of the next
// nontrivial format spec is returned, or the end of string.
inline const char *printFormatStringLiteral(std::ostream &out,
const char *fmt) {
const char *c = fmt;
for (;; ++c) {
switch (*c) {
case '\0':
out.write(fmt, c - fmt);
return c;
case '%':
out.write(fmt, c - fmt);
if (*(c + 1) != '%') return c;
// for "%%", tack trailing % onto next literal section.
fmt = ++c;
break;
default:
break;
}
}
}
// Parse a format string and set the stream state accordingly.
//
// The format mini-language recognized here is meant to be the one from C99,
// with the form "%[flags][width][.precision][length]type".
//
// Formatting options which can't be natively represented using the ostream
// state are returned in spacePadPositive (for space padded positive numbers)
// and ntrunc (for truncating conversions). argIndex is incremented if
// necessary to pull out variable width and precision . The function returns a
// pointer to the character after the end of the current format spec.
inline const char *streamStateFromFormat(std::ostream &out,
bool &spacePadPositive,
int &ntrunc,
const char *fmtStart,
const detail::FormatArg *formatters,
int &argIndex,
int numFormatters) {
if (*fmtStart != '%') {
TINYFORMAT_ERROR(
"tinyformat: Not enough conversion specifiers in format string");
return fmtStart;
}
// Reset stream state to defaults.
out.width(0);
out.precision(6);
out.fill(' ');
// Reset most flags; ignore irrelevant unitbuf & skipws.
out.unsetf(std::ios::adjustfield | std::ios::basefield |
std::ios::floatfield | std::ios::showbase | std::ios::boolalpha |
std::ios::showpoint | std::ios::showpos | std::ios::uppercase);
bool precisionSet = false;
bool widthSet = false;
int widthExtra = 0;
const char *c = fmtStart + 1;
// 1) Parse flags
for (;; ++c) {
switch (*c) {
case '#':
out.setf(std::ios::showpoint | std::ios::showbase);
continue;
case '0':
// overridden by left alignment ('-' flag)
if (!(out.flags() & std::ios::left)) {
// Use internal padding so that numeric values are
// formatted correctly, eg -00010 rather than 000-10
out.fill('0');
out.setf(std::ios::internal, std::ios::adjustfield);
}
continue;
case '-':
out.fill(' ');
out.setf(std::ios::left, std::ios::adjustfield);
continue;
case ' ':
// overridden by show positive sign, '+' flag.
if (!(out.flags() & std::ios::showpos)) spacePadPositive = true;
continue;
case '+':
out.setf(std::ios::showpos);
spacePadPositive = false;
widthExtra = 1;
continue;
default:
break;
}
break;
}
// 2) Parse width
if (*c >= '0' && *c <= '9') {
widthSet = true;
out.width(parseIntAndAdvance(c));
}
if (*c == '*') {
widthSet = true;
int width = 0;
if (argIndex < numFormatters)
width = formatters[argIndex++].toInt();
else
TINYFORMAT_ERROR(
"tinyformat: Not enough arguments to read variable width");
if (width < 0) {
// negative widths correspond to '-' flag set
out.fill(' ');
out.setf(std::ios::left, std::ios::adjustfield);
width = -width;
}
out.width(width);
++c;
}
// 3) Parse precision
if (*c == '.') {
++c;
int precision = 0;
if (*c == '*') {
++c;
if (argIndex < numFormatters)
precision = formatters[argIndex++].toInt();
else
TINYFORMAT_ERROR(
"tinyformat: Not enough arguments to read variable precision");
} else {
if (*c >= '0' && *c <= '9')
precision = parseIntAndAdvance(c);
else if (*c == '-') // negative precisions ignored, treated as zero.
parseIntAndAdvance(++c);
}
out.precision(precision);
precisionSet = true;
}
// 4) Ignore any C99 length modifier
while (*c == 'l' || *c == 'h' || *c == 'L' || *c == 'j' || *c == 'z' ||
*c == 't')
++c;
// 5) We're up to the conversion specifier character.
// Set stream flags based on conversion specifier (thanks to the
// boost::format class for forging the way here).
bool intConversion = false;
switch (*c) {
case 'u':
case 'd':
case 'i':
out.setf(std::ios::dec, std::ios::basefield);
intConversion = true;
break;
case 'o':
out.setf(std::ios::oct, std::ios::basefield);
intConversion = true;
break;
case 'X':
out.setf(std::ios::uppercase);
case 'x':
case 'p':
out.setf(std::ios::hex, std::ios::basefield);
intConversion = true;
break;
case 'E':
out.setf(std::ios::uppercase);
case 'e':
out.setf(std::ios::scientific, std::ios::floatfield);
out.setf(std::ios::dec, std::ios::basefield);
break;
case 'F':
out.setf(std::ios::uppercase);
case 'f':
out.setf(std::ios::fixed, std::ios::floatfield);
break;
case 'G':
out.setf(std::ios::uppercase);
case 'g':
out.setf(std::ios::dec, std::ios::basefield);
// As in boost::format, let stream decide float format.
out.flags(out.flags() & ~std::ios::floatfield);
break;
case 'a':
case 'A':
TINYFORMAT_ERROR(
"tinyformat: the %a and %A conversion specs "
"are not supported");
break;
case 'c':
// Handled as special case inside formatValue()
break;
case 's':
if (precisionSet) ntrunc = static_cast<int>(out.precision());
// Make %s print booleans as "true" and "false"
out.setf(std::ios::boolalpha);
break;
case 'n':
// Not supported - will cause problems!
TINYFORMAT_ERROR("tinyformat: %n conversion spec not supported");
break;
case '\0':
TINYFORMAT_ERROR(
"tinyformat: Conversion spec incorrectly "
"terminated by end of string");
return c;
default:
break;
}
if (intConversion && precisionSet && !widthSet) {
// "precision" for integers gives the minimum number of digits (to be
// padded with zeros on the left). This isn't really supported by the
// iostreams, but we can approximately simulate it with the width if
// the width isn't otherwise used.
out.width(out.precision() + widthExtra);
out.setf(std::ios::internal, std::ios::adjustfield);
out.fill('0');
}
return c + 1;
}
//------------------------------------------------------------------------------
inline void formatImpl(std::ostream &out,
const char *fmt,
const detail::FormatArg *formatters,
int numFormatters) {
// Saved stream state
std::streamsize origWidth = out.width();
std::streamsize origPrecision = out.precision();
std::ios::fmtflags origFlags = out.flags();
char origFill = out.fill();
for (int argIndex = 0; argIndex < numFormatters; ++argIndex) {
// Parse the format string
fmt = printFormatStringLiteral(out, fmt);
bool spacePadPositive = false;
int ntrunc = -1;
const char *fmtEnd = streamStateFromFormat(out,
spacePadPositive,
ntrunc,
fmt,
formatters,
argIndex,
numFormatters);
if (argIndex >= numFormatters) {
// Check args remain after reading any variable width/precision
TINYFORMAT_ERROR("tinyformat: Not enough format arguments");
return;
}
const FormatArg &arg = formatters[argIndex];
// Format the arg into the stream.
if (!spacePadPositive)
arg.format(out, fmt, fmtEnd, ntrunc);
else {
// The following is a special case with no direct correspondence
// between stream formatting and the printf() behaviour. Simulate
// it crudely by formatting into a temporary string stream and
// munging the resulting string.
std::ostringstream tmpStream;
tmpStream.copyfmt(out);
tmpStream.setf(std::ios::showpos);
arg.format(tmpStream, fmt, fmtEnd, ntrunc);
std::string result = tmpStream.str(); // allocates... yuck.
for (size_t i = 0, iend = result.size(); i < iend; ++i)
if (result[i] == '+') result[i] = ' ';
out << result;
}
fmt = fmtEnd;
}
// Print remaining part of format string.
fmt = printFormatStringLiteral(out, fmt);
if (*fmt != '\0')
TINYFORMAT_ERROR(
"tinyformat: Too many conversion specifiers in format string");
// Restore stream state
out.width(origWidth);
out.precision(origPrecision);
out.flags(origFlags);
out.fill(origFill);
}
} // namespace detail
/// List of template arguments format(), held in a type-opaque way.
///
/// A const reference to FormatList (typedef'd as FormatListRef) may be
/// conveniently used to pass arguments to non-template functions: All type
/// information has been stripped from the arguments, leaving just enough of a
/// common interface to perform formatting as required.
class FormatList {
public:
FormatList(detail::FormatArg *formatters, int N)
: m_formatters(formatters), m_N(N) {}
friend void vformat(std::ostream &out,
const char *fmt,
const FormatList &list);
private:
const detail::FormatArg *m_formatters;
int m_N;
};
/// Reference to type-opaque format list for passing to vformat()
typedef const FormatList &FormatListRef;
namespace detail {
// Format list subclass with fixed storage to avoid dynamic allocation
template <int N>
class FormatListN : public FormatList {
public:
template <typename... Args>
FormatListN(const Args &... args)
: FormatList(&m_formatterStore[0], N),
m_formatterStore{FormatArg(args)...} {
static_assert(sizeof...(args) == N, "Number of args must be N");
}
private:
FormatArg m_formatterStore[N];
};
// Special 0-arg version - MSVC says zero-sized C array in struct is nonstandard
template <>
class FormatListN<0> : public FormatList {
public:
FormatListN() : FormatList(0, 0) {}
};
} // namespace detail
//------------------------------------------------------------------------------
// Primary API functions
/// Make type-agnostic format list from list of template arguments.
///
/// The exact return type of this function is an implementation detail and
/// shouldn't be relied upon. Instead it should be stored as a FormatListRef:
///
/// FormatListRef formatList = makeFormatList( /*...*/ );
template <typename... Args>
detail::FormatListN<sizeof...(Args)> makeFormatList(const Args &... args) {
return detail::FormatListN<sizeof...(args)>(args...);
}
/// Format list of arguments to the stream according to the given format string.
///
/// The name vformat() is chosen for the semantic similarity to vprintf(): the
/// list of format arguments is held in a single function argument.
inline void vformat(std::ostream &out, const char *fmt, FormatListRef list) {
detail::formatImpl(out, fmt, list.m_formatters, list.m_N);
}
/// Format list of arguments to the stream according to given format string.
template <typename... Args>
void format(std::ostream &out, const char *fmt, const Args &... args) {
vformat(out, fmt, makeFormatList(args...));
}
/// Format list of arguments according to the given format string and return
/// the result as a string.
template <typename... Args>
std::string format(const char *fmt, const Args &... args) {
std::ostringstream oss;
format(oss, fmt, args...);
return oss.str();
}
/// Format list of arguments to std::cout, according to the given format string
template <typename... Args>
void printf(const char *fmt, const Args &... args) {
format(std::cout, fmt, args...);
}
template <typename... Args>
void printfln(const char *fmt, const Args &... args) {
format(std::cout, fmt, args...);
std::cout << '\n';
}
} // namespace tinyformat
} // namespace string
} // namespace paddle
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
if(WITH_TESTING) if(WITH_TESTING)
add_library(paddle_test_main STATIC TestMain.cpp) add_library(paddle_test_main STATIC TestMain.cpp)
add_dependencies(paddle_test_main gen_proto_cpp) add_dependencies(paddle_test_main paddle_proto ${external_project_dependencies})
add_library(paddle_test_util STATIC TestUtil.cpp) add_library(paddle_test_util STATIC TestUtil.cpp)
add_dependencies(paddle_test_util gen_proto_cpp) add_dependencies(paddle_test_util paddle_proto ${external_project_dependencies})
endif() endif()
...@@ -41,7 +41,8 @@ add_style_check_target(paddle_trainer_lib ...@@ -41,7 +41,8 @@ add_style_check_target(paddle_trainer_lib
add_style_check_target(paddle_trainer_lib add_style_check_target(paddle_trainer_lib
${TRAINER_HEADERS}) ${TRAINER_HEADERS})
add_dependencies(paddle_trainer_lib add_dependencies(paddle_trainer_lib
gen_proto_cpp) paddle_proto
${external_project_dependencies})
macro(add_paddle_exe TARGET_NAME) macro(add_paddle_exe TARGET_NAME)
add_executable(${TARGET_NAME} ${ARGN}) add_executable(${TARGET_NAME} ${ARGN})
......
...@@ -17,7 +17,7 @@ add_library(paddle_utils STATIC ...@@ -17,7 +17,7 @@ add_library(paddle_utils STATIC
add_style_check_target(paddle_utils ${UTIL_HEADERS}) add_style_check_target(paddle_utils ${UTIL_HEADERS})
add_style_check_target(paddle_utils ${UTIL_SOURCES} add_style_check_target(paddle_utils ${UTIL_SOURCES}
${UTIL_ARCH_SOURCES}) ${UTIL_ARCH_SOURCES})
add_dependencies(paddle_utils gen_proto_cpp) add_dependencies(paddle_utils paddle_proto ${external_project_dependencies})
if(WITH_TESTING) if(WITH_TESTING)
add_subdirectory(tests) add_subdirectory(tests)
endif() endif()
set(proto_filenames file(GLOB proto_filenames . *.proto)
DataConfig.proto include_directories(${CMAKE_CURRENT_BINARY_DIR})
DataFormat.proto proto_library(paddle_proto SRCS ${proto_filenames})
ModelConfig.proto
ParameterConfig.proto
ParameterService.proto
TrainerConfig.proto
OptimizerConfig.proto
ParameterServerConfig.proto)
set(PROTO_GEN) set(PROTO_GEN)
set(PROTO_GEN_PY) set(PROTO_GEN_PY)
foreach(filename ${proto_filenames}) foreach(filename ${proto_filenames})
get_filename_component(base_filename ${filename} NAME_WE) get_filename_component(ABS_FIL ${filename} ABSOLUTE)
set(CUR_PROTO_GEN get_filename_component(FIL_WE ${filename} NAME_WE)
${CMAKE_CURRENT_BINARY_DIR}/${base_filename}.pb.h
${CMAKE_CURRENT_BINARY_DIR}/${base_filename}.pb.cc)
set(PROTO_GEN
${PROTO_GEN}
${CUR_PROTO_GEN})
add_custom_command(OUTPUT ${CUR_PROTO_GEN}
COMMAND env ${py_env} ${PROTOBUF_PROTOC_EXECUTABLE}
--cpp_out ${CMAKE_CURRENT_BINARY_DIR}
--proto_path ${PROJ_ROOT}/proto ${PROJ_ROOT}/proto/${filename}
DEPENDS ${filename} ${external_project_dependencies})
set(CUR_PROTO_GEN_PY set(CUR_PROTO_GEN_PY
${PROJ_ROOT}/paddle/python/paddle/proto/${base_filename}_pb2.py) ${PROJ_ROOT}/paddle/python/paddle/proto/${FIL_WE}_pb2.py)
set(PROTO_GEN_PY set(PROTO_GEN_PY
${CUR_PROTO_GEN_PY} ${CUR_PROTO_GEN_PY}
${PROTO_GEN_PY}) ${PROTO_GEN_PY})
add_custom_command(OUTPUT ${CUR_PROTO_GEN_PY} add_custom_command(OUTPUT ${CUR_PROTO_GEN_PY}
COMMAND env ${py_env} ${PROTOBUF_PROTOC_EXECUTABLE} --python_out ${PROJ_ROOT}/python/paddle/proto COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
--proto_path ${PROJ_ROOT}/proto ${PROJ_ROOT}/proto/${filename} ARGS "--python_out=${PROJ_ROOT}/python/paddle/proto"
DEPENDS ${filename} ${external_project_dependencies}) "-I" ${CMAKE_CURRENT_SOURCE_DIR} ${ABS_FIL}
DEPENDS ${ABS_FIL} ${external_project_dependencies})
endforeach() endforeach()
add_custom_target(gen_proto_cpp ALL DEPENDS ${PROTO_GEN})
add_custom_target(gen_proto_py ALL DEPENDS ${PROTO_GEN_PY}) add_custom_target(gen_proto_py ALL DEPENDS ${PROTO_GEN_PY})
add_library(paddle_proto STATIC ${PROTO_GEN})
target_include_directories(paddle_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
...@@ -1149,10 +1149,10 @@ def pooling_layer(input, ...@@ -1149,10 +1149,10 @@ def pooling_layer(input,
@layer_support(DROPOUT) @layer_support(DROPOUT)
def lstmemory(input, def lstmemory(input,
name=None, name=None,
size=None,
reverse=False, reverse=False,
act=None, act=None,
gate_act=None, gate_act=None,
size=None,
state_act=None, state_act=None,
bias_attr=None, bias_attr=None,
param_attr=None, param_attr=None,
...@@ -1194,6 +1194,8 @@ def lstmemory(input, ...@@ -1194,6 +1194,8 @@ def lstmemory(input,
:param name: The lstmemory layer name. :param name: The lstmemory layer name.
:type name: basestring :type name: basestring
:param size: DEPRECATED. size of the lstm cell
:type size: int
:param input: input layer name. :param input: input layer name.
:type input: LayerOutput :type input: LayerOutput
:param reverse: is sequence process reversed or not. :param reverse: is sequence process reversed or not.
...@@ -1220,15 +1222,15 @@ def lstmemory(input, ...@@ -1220,15 +1222,15 @@ def lstmemory(input,
assert state_act.support_hppl assert state_act.support_hppl
assert act.support_hppl assert act.support_hppl
assert input.size is not None and input.size % 4 == 0 assert input.size is not None and input.size % 4 == 0
if size is not None: if size is not None:
if input.size / 4 == size: if input.size / 4 == size:
plog = logger.warning plog = logger.warning
else: else:
plog = logger.fatal plog = logger.fatal
plog("size of lstmemory layer: %s is automatically set to "
plog("NOTE: The lstmemory layer[%s]'s size is set by previous input " "size of input layer / 4. The parameter size passing to "
"layer. The lstm size should be equal with input layer size/4. The" "this layer is ignored." % (name))
" size which is set explicitly will be ignored." % name)
Layer( Layer(
name=name, name=name,
...@@ -1255,11 +1257,11 @@ def lstmemory(input, ...@@ -1255,11 +1257,11 @@ def lstmemory(input,
@wrap_name_default("gru") @wrap_name_default("gru")
@layer_support(DROPOUT) @layer_support(DROPOUT)
def grumemory(input, def grumemory(input,
size=None,
name=None, name=None,
reverse=False, reverse=False,
act=None, act=None,
gate_act=None, gate_act=None,
size=None,
bias_attr=None, bias_attr=None,
param_attr=None, param_attr=None,
layer_attr=None): layer_attr=None):
...@@ -1318,6 +1320,8 @@ def grumemory(input, ...@@ -1318,6 +1320,8 @@ def grumemory(input,
:type name: None|basestring :type name: None|basestring
:param input: input layer. :param input: input layer.
:type input: LayerOutput. :type input: LayerOutput.
:param size: DEPRECATED. size of the gru cell
:type size: int
:param reverse: Whether sequence process is reversed or not. :param reverse: Whether sequence process is reversed or not.
:type reverse: bool :type reverse: bool
:param act: activation type, TanhActivation by default. This activation :param act: activation type, TanhActivation by default. This activation
...@@ -1334,9 +1338,6 @@ def grumemory(input, ...@@ -1334,9 +1338,6 @@ def grumemory(input,
:type param_attr: ParameterAttribute|None|False :type param_attr: ParameterAttribute|None|False
:param layer_attr: Extra Layer attribute :param layer_attr: Extra Layer attribute
:type layer_attr: ExtraLayerAttribute|None :type layer_attr: ExtraLayerAttribute|None
:param size: Stub parameter of size, but actually not used. If set this size
will get a warning.
:type size: None
:return: LayerOutput object. :return: LayerOutput object.
:rtype: LayerOutput :rtype: LayerOutput
""" """
...@@ -1348,9 +1349,9 @@ def grumemory(input, ...@@ -1348,9 +1349,9 @@ def grumemory(input,
plog = logger.warning plog = logger.warning
else: else:
plog = logger.fatal plog = logger.fatal
plog("NOTE: the gru memory layer's size is set by previous input layer," plog("size of grumemory layer: %s is automatically set to "
" and should be input size / 3. Set size explicitly will be " "size of input layer / 3. The parameter size passing to this "
"ignored.") "layer is ignored." % (name))
Layer( Layer(
name=name, name=name,
...@@ -2524,8 +2525,8 @@ def img_cmrnorm_layer(input, ...@@ -2524,8 +2525,8 @@ def img_cmrnorm_layer(input,
@wrap_bias_attr_default() @wrap_bias_attr_default()
@wrap_param_attr_default(default_factory=lambda _: ParamAttr(initial_mean=1.0, @wrap_param_attr_default(
initial_std=0.)) default_factory=lambda _: ParamAttr(initial_mean=1.0, initial_std=0.))
@wrap_act_default(act=ReluActivation()) @wrap_act_default(act=ReluActivation())
@wrap_name_default("batch_norm") @wrap_name_default("batch_norm")
@layer_support(DROPOUT) @layer_support(DROPOUT)
...@@ -3013,25 +3014,25 @@ def lstm_step_layer(input, ...@@ -3013,25 +3014,25 @@ def lstm_step_layer(input,
bias_attr=None, bias_attr=None,
layer_attr=None): layer_attr=None):
""" """
LSTM Step Layer. It used in recurrent_group. The lstm equations are shown LSTM Step Layer. This function is used only in recurrent_group.
as follow. The lstm equations are shown as follows.
.. math:: .. math::
i_t & = \\sigma(W_{xi}x_{t} + W_{hi}h_{t-1} + W_{ci}c_{t-1} + b_i) i_t & = \\sigma(W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i)
f_t & = \\sigma(W_{xf}x_{t} + W_{hf}h_{t-1} + W_{cf}c_{t-1} + b_f) f_t & = \\sigma(W_{x_f}x_{t} + W_{h_f}h_{t-1} + W_{c_f}c_{t-1} + b_f)
c_t & = f_tc_{t-1} + i_t tanh (W_{xc}x_t+W_{hc}h_{t-1} + b_c) c_t & = f_tc_{t-1} + i_t tanh (W_{x_c}x_t+W_{h_c}h_{t-1} + b_c)
o_t & = \\sigma(W_{xo}x_{t} + W_{ho}h_{t-1} + W_{co}c_t + b_o) o_t & = \\sigma(W_{x_o}x_{t} + W_{h_o}h_{t-1} + W_{c_o}c_t + b_o)
h_t & = o_t tanh(c_t) h_t & = o_t tanh(c_t)
The input of lstm step is :math:`Wx_t + Wh_{t-1}`, and user should use The input of lstm step is :math:`Wx_t + Wh_{t-1}`, and user should use
:code:`mixed_layer` and :code:`full_matrix_projection` to calculate these :code:`mixed_layer` and :code:`full_matrix_projection` to calculate these
input vector. input vectors.
The state of lstm step is :math:`c_{t-1}`. And lstm step layer will do The state of lstm step is :math:`c_{t-1}`. And lstm step layer will do
...@@ -3042,14 +3043,14 @@ def lstm_step_layer(input, ...@@ -3042,14 +3043,14 @@ def lstm_step_layer(input,
... ...
This layer contains two outputs. Default output is :math:`h_t`. The other This layer has two outputs. Default output is :math:`h_t`. The other
output is :math:`o_t`, which name is 'state' and can use output is :math:`o_t`, whose name is 'state' and can use
:code:`get_output_layer` to extract this output. :code:`get_output_layer` to extract this output.
:param name: Layer's name. :param name: Layer's name.
:type name: basestring :type name: basestring
:param size: Layer's size. NOTE: lstm layer's size, should be equal as :param size: Layer's size. NOTE: lstm layer's size, should be equal to
:code:`input.size/4`, and should be equal as :code:`input.size/4`, and should be equal to
:code:`state.size`. :code:`state.size`.
:type size: int :type size: int
:param input: input layer. :math:`Wx_t + Wh_{t-1}` :param input: input layer. :math:`Wx_t + Wh_{t-1}`
......
...@@ -614,6 +614,7 @@ def simple_lstm(input, ...@@ -614,6 +614,7 @@ def simple_lstm(input,
@wrap_name_default('lstm_unit') @wrap_name_default('lstm_unit')
def lstmemory_unit(input, def lstmemory_unit(input,
memory_boot=None,
name=None, name=None,
size=None, size=None,
param_attr=None, param_attr=None,
...@@ -626,9 +627,9 @@ def lstmemory_unit(input, ...@@ -626,9 +627,9 @@ def lstmemory_unit(input,
lstm_layer_attr=None, lstm_layer_attr=None,
get_output_layer_attr=None): get_output_layer_attr=None):
""" """
Define calculations that a LSTM unit performs in a single time step. Define calculations that a LSTM unit performs during a single time step.
This function itself is not a recurrent layer, so that it can not be This function itself is not a recurrent layer, so it can not be
directly applied to sequence input. This function is always used in directly used to process sequence inputs. This function is always used in
recurrent_group (see layers.py for more details) to implement attention recurrent_group (see layers.py for more details) to implement attention
mechanism. mechanism.
...@@ -638,13 +639,13 @@ def lstmemory_unit(input, ...@@ -638,13 +639,13 @@ def lstmemory_unit(input,
.. math:: .. math::
i_t & = \\sigma(W_{xi}x_{t} + W_{hi}h_{t-1} + W_{ci}c_{t-1} + b_i) i_t & = \\sigma(W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i)
f_t & = \\sigma(W_{xf}x_{t} + W_{hf}h_{t-1} + W_{cf}c_{t-1} + b_f) f_t & = \\sigma(W_{x_f}x_{t} + W_{h_f}h_{t-1} + W_{c_f}c_{t-1} + b_f)
c_t & = f_tc_{t-1} + i_t tanh (W_{xc}x_t+W_{hc}h_{t-1} + b_c) c_t & = f_tc_{t-1} + i_t tanh (W_{x_c}x_t+W_{h_c}h_{t-1} + b_c)
o_t & = \\sigma(W_{xo}x_{t} + W_{ho}h_{t-1} + W_{co}c_t + b_o) o_t & = \\sigma(W_{x_o}x_{t} + W_{h_o}h_{t-1} + W_{c_o}c_t + b_o)
h_t & = o_t tanh(c_t) h_t & = o_t tanh(c_t)
...@@ -661,6 +662,8 @@ def lstmemory_unit(input, ...@@ -661,6 +662,8 @@ def lstmemory_unit(input,
:param input: input layer name. :param input: input layer name.
:type input: LayerOutput :type input: LayerOutput
:param memory_boot: the initialization state of the LSTM cell.
:type memory_boot: LayerOutput | None
:param name: lstmemory unit name. :param name: lstmemory unit name.
:type name: basestring :type name: basestring
:param size: lstmemory unit size. :param size: lstmemory unit size.
...@@ -692,7 +695,8 @@ def lstmemory_unit(input, ...@@ -692,7 +695,8 @@ def lstmemory_unit(input,
assert input.size % 4 == 0 assert input.size % 4 == 0
size = input.size / 4 size = input.size / 4
out_mem = memory(name=name, size=size) out_mem = memory(name=name, size=size)
state_mem = memory(name="%s_state" % name, size=size) state_mem = memory(
name="%s_state" % name, size=size, boot_layer=memory_boot)
with mixed_layer( with mixed_layer(
name="%s_input_recurrent" % name, name="%s_input_recurrent" % name,
...@@ -726,6 +730,7 @@ def lstmemory_unit(input, ...@@ -726,6 +730,7 @@ def lstmemory_unit(input,
def lstmemory_group(input, def lstmemory_group(input,
size=None, size=None,
name=None, name=None,
memory_boot=None,
reverse=False, reverse=False,
param_attr=None, param_attr=None,
act=None, act=None,
...@@ -737,7 +742,7 @@ def lstmemory_group(input, ...@@ -737,7 +742,7 @@ def lstmemory_group(input,
lstm_layer_attr=None, lstm_layer_attr=None,
get_output_layer_attr=None): get_output_layer_attr=None):
""" """
lstm_group is a recurrent layer group version of Long Short Term Memory. It lstm_group is a recurrent_group version of Long Short Term Memory. It
does exactly the same calculation as the lstmemory layer (see lstmemory in does exactly the same calculation as the lstmemory layer (see lstmemory in
layers.py for the maths) does. A promising benefit is that LSTM memory layers.py for the maths) does. A promising benefit is that LSTM memory
cell states, or hidden states in every time step are accessible to the cell states, or hidden states in every time step are accessible to the
...@@ -748,8 +753,8 @@ def lstmemory_group(input, ...@@ -748,8 +753,8 @@ def lstmemory_group(input,
NOTE: In PaddlePaddle's implementation, the following input-to-hidden NOTE: In PaddlePaddle's implementation, the following input-to-hidden
multiplications: multiplications:
:math:`W_{xi}x_{t}` , :math:`W_{xf}x_{t}`, :math:`W_{x_i}x_{t}` , :math:`W_{x_f}x_{t}`,
:math:`W_{xc}x_t`, :math:`W_{xo}x_{t}` are not done in lstmemory_unit to :math:`W_{x_c}x_t`, :math:`W_{x_o}x_{t}` are not done in lstmemory_unit to
speed up the calculations. Consequently, an additional mixed_layer with speed up the calculations. Consequently, an additional mixed_layer with
full_matrix_projection must be included before lstmemory_unit is called. full_matrix_projection must be included before lstmemory_unit is called.
...@@ -765,10 +770,12 @@ def lstmemory_group(input, ...@@ -765,10 +770,12 @@ def lstmemory_group(input,
:param input: input layer name. :param input: input layer name.
:type input: LayerOutput :type input: LayerOutput
:param name: lstmemory group name.
:type name: basestring
:param size: lstmemory group size. :param size: lstmemory group size.
:type size: int :type size: int
:param name: name of the lstmemory group.
:type name: basestring
:param memory_boot: the initialization state of LSTM cell.
:type memory_boot: LayerOutput | None
:param reverse: is lstm reversed :param reverse: is lstm reversed
:type reverse: bool :type reverse: bool
:param param_attr: Parameter config, None if use default. :param param_attr: Parameter config, None if use default.
...@@ -798,6 +805,7 @@ def lstmemory_group(input, ...@@ -798,6 +805,7 @@ def lstmemory_group(input,
def __lstm_step__(ipt): def __lstm_step__(ipt):
return lstmemory_unit( return lstmemory_unit(
input=ipt, input=ipt,
memory_boot=memory_boot,
name=name, name=name,
size=size, size=size,
mixed_bias_attr=mixed_bias_attr, mixed_bias_attr=mixed_bias_attr,
...@@ -819,6 +827,7 @@ def lstmemory_group(input, ...@@ -819,6 +827,7 @@ def lstmemory_group(input,
@wrap_name_default('gru_unit') @wrap_name_default('gru_unit')
def gru_unit(input, def gru_unit(input,
memory_boot=None,
size=None, size=None,
name=None, name=None,
gru_bias_attr=None, gru_bias_attr=None,
...@@ -829,8 +838,8 @@ def gru_unit(input, ...@@ -829,8 +838,8 @@ def gru_unit(input,
naive=False): naive=False):
""" """
Define calculations that a gated recurrent unit performs in a single time Define calculations that a gated recurrent unit performs in a single time
step. This function itself is not a recurrent layer, so that it can not be step. This function itself is not a recurrent layer, so it can not be
directly applied to sequence input. This function is almost always used in directly used to process sequence inputs. This function is always used in
the recurrent_group (see layers.py for more details) to implement attention the recurrent_group (see layers.py for more details) to implement attention
mechanism. mechanism.
...@@ -838,6 +847,8 @@ def gru_unit(input, ...@@ -838,6 +847,8 @@ def gru_unit(input,
:param input: input layer name. :param input: input layer name.
:type input: LayerOutput :type input: LayerOutput
:param memory_boot: the initialization state of the LSTM cell.
:type memory_boot: LayerOutput | None
:param name: name of the gru group. :param name: name of the gru group.
:type name: basestring :type name: basestring
:param size: hidden size of the gru. :param size: hidden size of the gru.
...@@ -856,7 +867,7 @@ def gru_unit(input, ...@@ -856,7 +867,7 @@ def gru_unit(input,
if size is None: if size is None:
size = input.size / 3 size = input.size / 3
out_mem = memory(name=name, size=size) out_mem = memory(name=name, size=size, boot_layer=memory_boot)
if naive: if naive:
__step__ = gru_step_naive_layer __step__ = gru_step_naive_layer
...@@ -878,6 +889,7 @@ def gru_unit(input, ...@@ -878,6 +889,7 @@ def gru_unit(input,
@wrap_name_default('gru_group') @wrap_name_default('gru_group')
def gru_group(input, def gru_group(input,
memory_boot=None,
size=None, size=None,
name=None, name=None,
reverse=False, reverse=False,
...@@ -888,7 +900,7 @@ def gru_group(input, ...@@ -888,7 +900,7 @@ def gru_group(input,
gru_layer_attr=None, gru_layer_attr=None,
naive=False): naive=False):
""" """
gru_group is a recurrent layer group version of Gated Recurrent Unit. It gru_group is a recurrent_group version of Gated Recurrent Unit. It
does exactly the same calculation as the grumemory layer does. A promising does exactly the same calculation as the grumemory layer does. A promising
benefit is that gru hidden states are accessible to the user. This is benefit is that gru hidden states are accessible to the user. This is
especially useful in attention model. If you do not need to access especially useful in attention model. If you do not need to access
...@@ -908,6 +920,8 @@ def gru_group(input, ...@@ -908,6 +920,8 @@ def gru_group(input,
:param input: input layer name. :param input: input layer name.
:type input: LayerOutput :type input: LayerOutput
:param memory_boot: the initialization state of the LSTM cell.
:type memory_boot: LayerOutput | None
:param name: name of the gru group. :param name: name of the gru group.
:type name: basestring :type name: basestring
:param size: hidden size of the gru. :param size: hidden size of the gru.
...@@ -929,6 +943,7 @@ def gru_group(input, ...@@ -929,6 +943,7 @@ def gru_group(input,
def __gru_step__(ipt): def __gru_step__(ipt):
return gru_unit( return gru_unit(
input=ipt, input=ipt,
memory_boot=memory_boot,
name=name, name=name,
size=size, size=size,
gru_bias_attr=gru_bias_attr, gru_bias_attr=gru_bias_attr,
...@@ -1083,7 +1098,6 @@ def simple_gru2(input, ...@@ -1083,7 +1098,6 @@ def simple_gru2(input,
return grumemory( return grumemory(
name=name, name=name,
size=size,
input=m, input=m,
reverse=reverse, reverse=reverse,
bias_attr=gru_bias_attr, bias_attr=gru_bias_attr,
......
...@@ -25,8 +25,9 @@ import uci_housing ...@@ -25,8 +25,9 @@ import uci_housing
import sentiment import sentiment
import wmt14 import wmt14
import mq2007 import mq2007
import flowers
__all__ = [ __all__ = [
'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment' 'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment'
'uci_housing', 'wmt14', 'mq2007' 'uci_housing', 'wmt14', 'mq2007', 'flowers'
] ]
...@@ -34,9 +34,9 @@ from common import download ...@@ -34,9 +34,9 @@ from common import download
import tarfile import tarfile
import scipy.io as scio import scipy.io as scio
from paddle.v2.image import * from paddle.v2.image import *
from paddle.v2.reader import *
import os import os
import numpy as np import numpy as np
import paddle.v2 as paddle
from multiprocessing import cpu_count from multiprocessing import cpu_count
__all__ = ['train', 'test', 'valid'] __all__ = ['train', 'test', 'valid']
...@@ -46,6 +46,12 @@ SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat' ...@@ -46,6 +46,12 @@ SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa' DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
# In official 'readme', tstid is the flag of test data
# and trnid is the flag of train data. But test data is more than train data.
# So we exchange the train data and test data.
TRAIN_FLAG = 'tstid'
TEST_FLAG = 'trnid'
VALID_FLAG = 'valid'
def default_mapper(sample): def default_mapper(sample):
...@@ -53,8 +59,8 @@ def default_mapper(sample): ...@@ -53,8 +59,8 @@ def default_mapper(sample):
map image bytes data to type needed by model input layer map image bytes data to type needed by model input layer
''' '''
img, label = sample img, label = sample
img = paddle.image.load_image_bytes(img) img = load_image_bytes(img)
img = paddle.image.simple_transform(img, 256, 224, True) img = simple_transform(img, 256, 224, True)
return img.flatten().astype('float32'), label return img.flatten().astype('float32'), label
...@@ -63,7 +69,8 @@ def reader_creator(data_file, ...@@ -63,7 +69,8 @@ def reader_creator(data_file,
setid_file, setid_file,
dataset_name, dataset_name,
mapper=default_mapper, mapper=default_mapper,
buffered_size=1024): buffered_size=1024,
use_xmap=True):
''' '''
1. read images from tar file and 1. read images from tar file and
merge images into batch files in 102flowers.tgz_batch/ merge images into batch files in 102flowers.tgz_batch/
...@@ -105,11 +112,13 @@ def reader_creator(data_file, ...@@ -105,11 +112,13 @@ def reader_creator(data_file,
for sample, label in itertools.izip(data, batch['label']): for sample, label in itertools.izip(data, batch['label']):
yield sample, int(label) yield sample, int(label)
return paddle.reader.xmap_readers(mapper, reader, if use_xmap:
cpu_count(), buffered_size) return xmap_readers(mapper, reader, cpu_count(), buffered_size)
else:
return map_readers(mapper, reader)
def train(mapper=default_mapper, buffered_size=1024): def train(mapper=default_mapper, buffered_size=1024, use_xmap=True):
''' '''
Create flowers training set reader. Create flowers training set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
...@@ -128,11 +137,11 @@ def train(mapper=default_mapper, buffered_size=1024): ...@@ -128,11 +137,11 @@ def train(mapper=default_mapper, buffered_size=1024):
return reader_creator( return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5), download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper, download(SETID_URL, 'flowers', SETID_MD5), TRAIN_FLAG, mapper,
buffered_size) buffered_size, use_xmap)
def test(mapper=default_mapper, buffered_size=1024): def test(mapper=default_mapper, buffered_size=1024, use_xmap=True):
''' '''
Create flowers test set reader. Create flowers test set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
...@@ -151,11 +160,11 @@ def test(mapper=default_mapper, buffered_size=1024): ...@@ -151,11 +160,11 @@ def test(mapper=default_mapper, buffered_size=1024):
return reader_creator( return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5), download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper, download(SETID_URL, 'flowers', SETID_MD5), TEST_FLAG, mapper,
buffered_size) buffered_size, use_xmap)
def valid(mapper=default_mapper, buffered_size=1024): def valid(mapper=default_mapper, buffered_size=1024, use_xmap=True):
''' '''
Create flowers validation set reader. Create flowers validation set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
...@@ -174,8 +183,8 @@ def valid(mapper=default_mapper, buffered_size=1024): ...@@ -174,8 +183,8 @@ def valid(mapper=default_mapper, buffered_size=1024):
return reader_creator( return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5), download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'valid', mapper, download(SETID_URL, 'flowers', SETID_MD5), VALID_FLAG, mapper,
buffered_size) buffered_size, use_xmap)
def fetch(): def fetch():
......
...@@ -31,13 +31,13 @@ class TestFlowers(unittest.TestCase): ...@@ -31,13 +31,13 @@ class TestFlowers(unittest.TestCase):
def test_train(self): def test_train(self):
instances, max_label_value = self.check_reader( instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.train()) paddle.v2.dataset.flowers.train())
self.assertEqual(instances, 1020) self.assertEqual(instances, 6149)
self.assertEqual(max_label_value, 102) self.assertEqual(max_label_value, 102)
def test_test(self): def test_test(self):
instances, max_label_value = self.check_reader( instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.test()) paddle.v2.dataset.flowers.test())
self.assertEqual(instances, 6149) self.assertEqual(instances, 1020)
self.assertEqual(max_label_value, 102) self.assertEqual(max_label_value, 102)
def test_valid(self): def test_valid(self):
......
...@@ -51,7 +51,7 @@ class Parameters(object): ...@@ -51,7 +51,7 @@ class Parameters(object):
def __init__(self): def __init__(self):
self.__param_conf__ = dict() self.__param_conf__ = dict()
self.__gradient_machines__ = [] self.__gradient_machines__ = []
self.__tmp_params__ = [] self.__tmp_params__ = dict()
def __append_config__(self, param_conf): def __append_config__(self, param_conf):
""" """
...@@ -128,12 +128,9 @@ class Parameters(object): ...@@ -128,12 +128,9 @@ class Parameters(object):
if len(self.__gradient_machines__) == 0: if len(self.__gradient_machines__) == 0:
# create new parameter in python numpy. # create new parameter in python numpy.
if len(self.__tmp_params__) != 0: if key in self.__tmp_params__:
ret_list = [ return self.__tmp_params__[key]
mat for name, mat in self.__tmp_params__ if name == key else:
]
if len(ret_list) == 1:
return ret_list[0]
return np.ndarray(shape=shape, dtype=np.float32) return np.ndarray(shape=shape, dtype=np.float32)
else: else:
for each_gradient_machine in self.__gradient_machines__: for each_gradient_machine in self.__gradient_machines__:
...@@ -187,7 +184,7 @@ class Parameters(object): ...@@ -187,7 +184,7 @@ class Parameters(object):
(shape, value.shape)) (shape, value.shape))
if len(self.__gradient_machines__) == 0: if len(self.__gradient_machines__) == 0:
self.__tmp_params__.append((key, value)) self.__tmp_params__[key] = value
else: else:
for each_gradient_machine in self.__gradient_machines__: for each_gradient_machine in self.__gradient_machines__:
__copy_parameter_to_gradient_machine__(each_gradient_machine, __copy_parameter_to_gradient_machine__(each_gradient_machine,
...@@ -231,7 +228,7 @@ class Parameters(object): ...@@ -231,7 +228,7 @@ class Parameters(object):
raise ValueError("gradient_machine should be api.GradientMachine") raise ValueError("gradient_machine should be api.GradientMachine")
if len(self.__tmp_params__) != 0: if len(self.__tmp_params__) != 0:
for name, val in self.__tmp_params__: for name, val in self.__tmp_params__.iteritems():
try: try:
__copy_parameter_to_gradient_machine__(gradient_machine, __copy_parameter_to_gradient_machine__(gradient_machine,
name, val) name, val)
...@@ -287,6 +284,18 @@ class Parameters(object): ...@@ -287,6 +284,18 @@ class Parameters(object):
@staticmethod @staticmethod
def from_tar(f): def from_tar(f):
"""
Create a `Parameters` object from the given file. And
the `Parameters` only contains the parameters in this
file. It is adapted the parameters are same in the
defined network and the given file. For example, it
can be used in the inference.
:param f: the initialized model file.
:type f: tar file
:return: A Parameters object.
:rtype: Parameters.
"""
params = Parameters() params = Parameters()
tar = tarfile.TarFile(fileobj=f, mode='r') tar = tarfile.TarFile(fileobj=f, mode='r')
for finfo in tar: for finfo in tar:
...@@ -302,6 +311,21 @@ class Parameters(object): ...@@ -302,6 +311,21 @@ class Parameters(object):
params.deserialize(param_name, f) params.deserialize(param_name, f)
return params return params
def init_from_tar(self, f):
"""
Different from `from_tar`, this interface can be used to
init partial network parameters from another saved model.
:param f: the initialized model file.
:type f: tar file
:return: Nothing.
"""
tar_param = Parameters.from_tar(f)
for pname in tar_param.names():
if pname in self.names():
self.set(pname, tar_param.get(pname))
def __get_parameter_in_gradient_machine__(gradient_machine, name): def __get_parameter_in_gradient_machine__(gradient_machine, name):
""" """
......
...@@ -248,9 +248,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): ...@@ -248,9 +248,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
:rtype: callable :rtype: callable
""" """
end = XmapEndSignal() end = XmapEndSignal()
in_queue = Queue(buffer_size)
out_queue = Queue(buffer_size)
out_order = [0]
# define a worker to read samples from reader to in_queue # define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue): def read_worker(reader, in_queue):
...@@ -266,12 +263,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): ...@@ -266,12 +263,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
in_order += 1 in_order += 1
in_queue.put(end) in_queue.put(end)
# start a read worker in a thread
target = order_read_worker if order else read_worker
t = Thread(target=target, args=(reader, in_queue))
t.daemon = True
t.start()
# define a worker to handle samples from in_queue by mapper # define a worker to handle samples from in_queue by mapper
# and put mapped samples into out_queue # and put mapped samples into out_queue
def handle_worker(in_queue, out_queue, mapper): def handle_worker(in_queue, out_queue, mapper):
...@@ -298,6 +289,15 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): ...@@ -298,6 +289,15 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
in_queue.put(end) in_queue.put(end)
out_queue.put(end) out_queue.put(end)
def xreader():
in_queue = Queue(buffer_size)
out_queue = Queue(buffer_size)
out_order = [0]
# start a read worker in a thread
target = order_read_worker if order else read_worker
t = Thread(target=target, args=(reader, in_queue))
t.daemon = True
t.start()
# start several handle_workers # start several handle_workers
target = order_handle_worker if order else handle_worker target = order_handle_worker if order else handle_worker
args = (in_queue, out_queue, mapper, out_order) if order else ( args = (in_queue, out_queue, mapper, out_order) if order else (
...@@ -310,7 +310,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): ...@@ -310,7 +310,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
for w in workers: for w in workers:
w.start() w.start()
def xreader():
sample = out_queue.get() sample = out_queue.get()
while not isinstance(sample, XmapEndSignal): while not isinstance(sample, XmapEndSignal):
yield sample yield sample
......
...@@ -132,10 +132,12 @@ class TestXmap(unittest.TestCase): ...@@ -132,10 +132,12 @@ class TestXmap(unittest.TestCase):
for order in orders: for order in orders:
for tNum in thread_nums: for tNum in thread_nums:
for size in buffered_size: for size in buffered_size:
result = [] reader = paddle.v2.reader.xmap_readers(mapper,
for i in paddle.v2.reader.xmap_readers(mapper,
reader_creator_10(0), reader_creator_10(0),
tNum, size, order)(): tNum, size, order)
for n in xrange(3):
result = []
for i in reader():
result.append(i) result.append(i)
if not order: if not order:
result.sort() result.sort()
......
...@@ -20,14 +20,17 @@ import cStringIO ...@@ -20,14 +20,17 @@ import cStringIO
import numpy import numpy
def __rand_param_config__(name): def __rand_param_config__(name, psize=None):
conf = ParameterConfig() conf = ParameterConfig()
conf.name = name conf.name = name
size = 1 size = 1
if psize is None:
for i in xrange(2): for i in xrange(2):
dim = random.randint(1, 1000) dim = random.randint(1, 1000)
conf.dims.append(dim) conf.dims.append(dim)
size *= dim size *= dim
else:
size = psize
conf.size = size conf.size = size
assert conf.IsInitialized() assert conf.IsInitialized()
return conf return conf
...@@ -77,6 +80,50 @@ class TestParameters(unittest.TestCase): ...@@ -77,6 +80,50 @@ class TestParameters(unittest.TestCase):
expected = numpy.array([[1, 1], [1, 2], [1, 1]], numpy.float32) expected = numpy.array([[1, 1], [1, 2], [1, 1]], numpy.float32)
assert numpy.logical_and.reduce(numpy.reshape(val == expected, 6)) assert numpy.logical_and.reduce(numpy.reshape(val == expected, 6))
def test_init_from_tar(self):
def get_param(names, size):
p = parameters.Parameters()
for k, v in zip(names, size):
p.__append_config__(__rand_param_config__(k, v))
for name in p.names():
param = p.get(name)
param[:] = numpy.random.uniform(
-1.0, 1.0, size=p.get_shape(name))
p.set(name, param)
return p
def get_parames():
name1 = ['param_0', 'param_1']
size1 = [128, 256]
p1 = get_param(name1, size1)
file1 = cStringIO.StringIO()
p1.to_tar(file1)
file1.seek(0)
name2 = ['param_0', 'param_1', 'param_2']
size2 = [128, 256, 288]
p2 = get_param(name2, size2)
file2 = cStringIO.StringIO()
p2.to_tar(file2)
file2.seek(0)
return p1, file1, p2, file2
p1, file1, p2, file2 = get_parames()
p2.init_from_tar(file1)
for name in p1.names():
self.assertEqual(p1.get_shape(name), p2.get_shape(name))
v1 = p1.get(name)
v2 = p2.get(name)
self.assertTrue(numpy.isclose(v1, v2).all())
p1, file1, p2, file2 = get_parames()
p1.init_from_tar(file2)
for name in p1.names():
self.assertEqual(p1.get_shape(name), p2.get_shape(name))
v1 = p1.get(name)
v2 = p2.get(name)
self.assertTrue(numpy.isclose(v1, v2).all())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -15,7 +15,8 @@ setup_requires=["requests", ...@@ -15,7 +15,8 @@ setup_requires=["requests",
"protobuf==3.1", "protobuf==3.1",
"recordio", "recordio",
"matplotlib", "matplotlib",
"rarfile"] "rarfile",
"scipy>=0.19.0"]
if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']: if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']:
setup_requires+=["opencv-python"] setup_requires+=["opencv-python"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册