提交 c3705602 编写于 作者: L Liu Yiqun

Merge branch 'develop' into support_scalar_kernels

......@@ -70,7 +70,7 @@ before looking into the
We provide [English](http://www.paddlepaddle.org/develop/doc/) and
[Chinese](http://www.paddlepaddle.org/doc_cn/) documentation.
- [Deep Learning 101](http://book.paddlepaddle.org/index.en.html)
- [Deep Learning 101](http://book.paddlepaddle.org/index.html)
You might want to start from the this online interactive book that can run in Jupyter Notebook.
......
......@@ -11,11 +11,23 @@ find_path(CUDNN_INCLUDE_DIR cudnn.h
get_filename_component(__libpath_hist ${CUDA_CUDART_LIBRARY} PATH)
if(NOT ${CMAKE_HOST_SYSTEM_PROCESSOR})
execute_process(
COMMAND uname -m COMMAND tr -d '\n'
OUTPUT_VARIABLE HOST_ARCH
RESULT_VARIABLE UNAME_RESULT)
if(${UNAME_RESULT})
set(HOST_ARCH "x86_64")
endif(${UNAME_RESULT})
else(NOT ${CMAKE_HOST_SYSTEM_PROCESSOR})
set(HOST_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR})
endif(NOT ${CMAKE_HOST_SYSTEM_PROCESSOR})
list(APPEND CUDNN_CHECK_LIBRARY_DIRS
${CUDNN_ROOT}
${CUDNN_ROOT}/lib64
${CUDNN_ROOT}/lib
${CUDNN_ROOT}/lib/x86_64-linux-gnu
${CUDNN_ROOT}/lib/${HOST_ARCH}-linux-gnu
$ENV{CUDNN_ROOT}
$ENV{CUDNN_ROOT}/lib64
$ENV{CUDNN_ROOT}/lib
......
......@@ -37,7 +37,7 @@ IF(NOT ${CBLAS_FOUND})
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 USE_THREAD=0 libs)
ELSE()
SET(OPENBLAS_COMMIT "v0.2.19")
SET(OPENBLAS_ARGS DYNAMIC_ARCH=1 libs)
SET(OPTIONAL_ARGS DYNAMIC_ARCH=1 libs NUM_THREADS=64)
ENDIF()
ExternalProject_Add(
......
......@@ -136,6 +136,9 @@ int paddle_send_grads(paddle_pserver_client* client, const paddle_gradient* grad
/**
* @brief paddle_get_params gets parameters from parameter servers.
*
* paddle_get_params will block until parameters are initialized on
* the parameter servers.
*
* @param names the array of names of the parameters to get.
* @param dst the destination array of parameters to save to.
* @param len the length of the names array and the paddle_parameter
......
......@@ -38,7 +38,7 @@ endif()
mark_as_advanced(CMAKE_Go_COMPILER)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/CMakeGoCompiler.cmake.in
configure_file(${CMAKE_MODULE_PATH}/CMakeGoCompiler.cmake.in
${CMAKE_PLATFORM_INFO_DIR}/CMakeGoCompiler.cmake @ONLY)
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER")
set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go")
file(MAKE_DIRECTORY ${GOPATH})
set(PADDLE_IN_GOPATH "${GOPATH}/src/github.com/PaddlePaddle")
file(MAKE_DIRECTORY ${PADDLE_IN_GOPATH})
function(ExternalGoProject_Add TARG)
add_custom_target(${TARG} env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get ${ARGN})
endfunction(ExternalGoProject_Add)
function(add_go_executable NAME)
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build
-o "${CMAKE_CURRENT_BINARY_DIR}/${NAME}"
${CMAKE_GO_FLAGS} ${GO_SOURCE}
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR})
add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN})
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${NAME} DESTINATION bin)
endfunction(add_go_executable)
function(ADD_GO_LIBRARY NAME BUILD_TYPE)
function(GO_LIBRARY NAME BUILD_TYPE)
if(BUILD_TYPE STREQUAL "STATIC")
set(BUILD_MODE -buildmode=c-archive)
set(LIB_NAME "lib${NAME}.a")
......@@ -32,6 +17,24 @@ function(ADD_GO_LIBRARY NAME BUILD_TYPE)
endif()
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
file(RELATIVE_PATH rel ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR})
# find Paddle directory.
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
get_filename_component(PADDLE_DIR ${PARENT_DIR} DIRECTORY)
# automatically get all dependencies specified in the source code
# for given target.
add_custom_target(goGet env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ${rel}/...)
# make a symlink that references Paddle inside $GOPATH, so go get
# will use the local changes in Paddle rather than checkout Paddle
# in github.
add_custom_target(copyPaddle
COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH})
add_dependencies(goGet copyPaddle)
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
......@@ -39,8 +42,9 @@ function(ADD_GO_LIBRARY NAME BUILD_TYPE)
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR})
add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN})
add_dependencies(${NAME} goGet)
if(NOT BUILD_TYPE STREQUAL "STATIC")
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME} DESTINATION bin)
endif()
endfunction(ADD_GO_LIBRARY)
endfunction(GO_LIBRARY)
......@@ -13,8 +13,8 @@ import (
"github.com/namsral/flag"
"github.com/PaddlePaddle/Paddle/paddle/go/master"
"github.com/PaddlePaddle/Paddle/paddle/go/recordio"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/go/recordio"
)
func main() {
......
......@@ -8,7 +8,7 @@ import (
"github.com/namsral/flag"
"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver"
)
func main() {
......
......@@ -6,7 +6,7 @@ import (
"sync"
"time"
"github.com/PaddlePaddle/Paddle/paddle/go/recordio"
"github.com/PaddlePaddle/Paddle/go/recordio"
)
const (
......
cmake_minimum_required(VERSION 3.0)
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake")
project(cxx_go C Go)
include(golang)
include(flags)
go_library(client STATIC)
add_subdirectory(test)
......@@ -39,10 +39,11 @@ import "C"
import (
"log"
"strings"
"sync"
"unsafe"
"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver"
)
var nullPtr = unsafe.Pointer(uintptr(0))
......@@ -78,34 +79,54 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
return nil
}
// create a Go clice backed by a C array,
// reference: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
// create a Go clice backed by a C array, reference:
// https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
//
// Go garbage collector will not interact with this data, need
// to be freed properly.
return (*[1 << 30]byte)(p)[:len:len]
}
type selector bool
func (s selector) Select() bool {
return bool(s)
}
type lister []pserver.Server
func (l lister) List() []pserver.Server {
return l
}
//export paddle_new_pserver_client
func paddle_new_pserver_client(addr *C.char) C.client {
c := pserver.NewClient(C.GoString(addr))
func paddle_new_pserver_client(addrs *C.char, selected int) C.client {
a := C.GoString(addrs)
as := strings.Split(a, ",")
servers := make([]pserver.Server, len(as))
for i := range as {
servers[i].Index = i
servers[i].Addr = as[i]
}
c := pserver.NewClient(lister(servers), len(as), selector(selected != 0))
return add(c)
}
//export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.client {
// TODO(helin): fault tolerant pserver client using etcd.
panic("not implemented.")
}
//export paddle_pserver_client_release
func paddle_pserver_client_release(client C.client) {
c := remove(client)
c.Cleanup()
remove(client)
}
//export paddle_begin_init_params
func paddle_begin_init_params(client C.client, pserver_config unsafe.Pointer, config_len C.int) C.int {
func paddle_begin_init_params(client C.client) C.int {
c := get(client)
b := cArrayToSlice(pserver_config, int(config_len))
selected, err := c.BeginInitParams(b)
if err != nil {
log.Println(err)
return -1
}
if selected {
if selected := c.BeginInitParams(); selected {
return 1
}
return 0
......@@ -227,7 +248,7 @@ func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter
func paddle_save_model(client C.client, path *C.char) C.int {
p := C.GoString(path)
c := get(client)
err := c.SaveModel(p)
err := c.Save(p)
if err != nil {
log.Println(err)
return -1
......
#include "libclient.h"
#include <stdio.h>
//#include "gtest/gtest.h"
#include "libclient.h"
void panic() {
void fail() {
// TODO(helin): fix: gtest using cmake is not working, using this
// hacky way for now.
*(void*)0;
printf("test failed.\n");
exit(-1);
}
int main() {
char addr[] = "localhost:3000";
client c = paddle_new_pserver_client(addr);
client c = paddle_new_pserver_client(addr, 1);
retry:
if (paddle_begin_init_params(c, NULL, 0)) {
if (paddle_begin_init_params(c)) {
paddle_parameter param;
char name_a[] = "param_a";
char name_b[] = "param_b";
......@@ -35,7 +36,7 @@ retry:
goto retry;
}
} else {
panic();
fail();
}
char content[] = {0x00, 0x11, 0x22};
......@@ -44,25 +45,25 @@ retry:
{"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}};
if (!paddle_send_grads(c, grads, 2)) {
panic();
fail();
}
paddle_parameter* params[2] = {NULL, NULL};
char* names[] = {"param_a", "param_b"};
if (!paddle_get_params(c, names, params, 2)) {
panic();
fail();
}
// get parameters again by reusing the allocated parameter buffers.
if (!paddle_get_params(c, names, params, 2)) {
panic();
fail();
}
paddle_release_param(params[0]);
paddle_release_param(params[1]);
if (!paddle_save_model(c, "/tmp/")) {
panic();
fail();
}
return 0;
......
package pserver
import (
"hash/fnv"
"log"
"sort"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver/internal/connection"
)
// TODO(helin): add RPC call retry logic
// Selector selects if the client should initialize parameter servers.
type Selector interface {
Select() bool
}
// Server is the identification of a parameter Server.
type Server struct {
Index int
Addr string
}
// Lister lists currently available parameter servers.
type Lister interface {
List() []Server
}
// Client is the client to parameter servers.
type Client struct {
sel Selector
pservers []*connection.Conn
}
// NewClient creates a new client.
func NewClient(l Lister, pserverNum int, sel Selector) *Client {
c := &Client{sel: sel}
c.pservers = make([]*connection.Conn, pserverNum)
for i := 0; i < pserverNum; i++ {
c.pservers[i] = connection.New()
}
go c.monitorPservers(l, pserverNum)
return c
}
// monitorPservers monitors pserver addresses, and updates connection
// when the address changes.
func (c *Client) monitorPservers(l Lister, pserverNum int) {
knownServers := make([]Server, pserverNum)
ticker := time.NewTicker(10 * time.Second)
monitor := func() {
curServers := make([]Server, pserverNum)
list := l.List()
for _, l := range list {
curServers[l.Index] = l
}
for i := range knownServers {
if knownServers[i].Addr != curServers[i].Addr {
err := c.pservers[i].Connect(curServers[i].Addr)
if err != nil {
log.Println(err)
// connect to addr failed, set
// to last known addr in order
// to retry next time.
curServers[i].Addr = knownServers[i].Addr
}
}
}
knownServers = curServers
}
monitor()
for _ = range ticker.C {
monitor()
}
}
// BeginInitParams begins to initialize parameters on parameter
// servers.
//
// BeginInitParams will be called from multiple trainers, only one
// trainer will be selected to initialize the parameters on parameter
// servers. Other trainers will be blocked until the initialization is
// done, and they need to get the initialized parameters from
// parameter servers using GetParams.
func (c *Client) BeginInitParams() bool {
return c.sel.Select()
}
// InitParam initializes the parameter on parameter servers.
func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error {
var dummy int
return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, &dummy)
}
// FinishInitParams tells parameter servers client has sent all
// parameters to parameter servers as initialization.
func (c *Client) FinishInitParams() error {
for _, p := range c.pservers {
var dummy int
err := p.Call("Service.FinishInitParams", dummy, &dummy)
if err != nil {
return err
}
}
return nil
}
// SendGrads sends gradients to parameter servers for updating
// parameters.
func (c *Client) SendGrads(grads []Gradient) error {
errCh := make(chan error, len(grads))
for _, g := range grads {
go func(g Gradient) {
var dummy int
err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, &dummy)
errCh <- err
}(g)
}
recv := 0
for err := range errCh {
if err != nil {
return err
}
recv++
if recv == len(grads) {
break
}
}
return nil
}
type result struct {
idx int
param Parameter
err error
}
type results []result
func (r results) Len() int {
return len(r)
}
func (r results) Less(i int, j int) bool {
return r[i].idx < r[j].idx
}
func (r results) Swap(i int, j int) {
r[i], r[j] = r[j], r[i]
}
// GetParams gets parameters from parameter servers.
func (c *Client) GetParams(names []string) ([]Parameter, error) {
rCh := make(chan result, len(names))
for idx, name := range names {
go func(name string, idx int) {
var parameter Parameter
err := c.pservers[c.partition(name)].Call("Service.GetParam", name, &parameter)
rCh <- result{idx: idx, param: parameter, err: err}
}(name, idx)
}
var rs results
recv := 0
for r := range rCh {
if r.err != nil {
return nil, r.err
}
rs = append(rs, r)
recv++
if recv == len(names) {
break
}
}
sort.Sort(rs)
ps := make([]Parameter, len(rs))
for i := range rs {
ps[i] = rs[i].param
}
return ps, nil
}
// Save indicates parameters to save the parameter to the given path.
func (c *Client) Save(path string) error {
errCh := make(chan error, len(c.pservers))
for _, p := range c.pservers {
var dummy int
err := p.Call("Service.Save", path, &dummy)
errCh <- err
}
recv := 0
for err := range errCh {
if err != nil {
return err
}
recv++
if recv == len(c.pservers) {
break
}
}
// TODO(helin): there will be many files under path, need to
// merge them into a single file.
return nil
}
func strHash(s string) uint32 {
h := fnv.New32a()
h.Write([]byte(s))
return h.Sum32()
}
// TODO(helin): now partition only select which parameter server to
// send the entire parameter. We need to partition a parameter into
// small blocks and send to different parameter servers.
func (c *Client) partition(key string) int {
return int(strHash(key) % uint32(len(c.pservers)))
}
package pserver_test
import (
"net"
"net/http"
"net/rpc"
"strconv"
"strings"
"testing"
"github.com/PaddlePaddle/Paddle/go/pserver"
)
const numPserver = 10
var port [numPserver]int
func init() {
for i := 0; i < numPserver; i++ {
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
ss := strings.Split(l.Addr().String(), ":")
p, err := strconv.Atoi(ss[len(ss)-1])
if err != nil {
panic(err)
}
port[i] = p
go func(l net.Listener) {
s := pserver.NewService()
server := rpc.NewServer()
err := server.Register(s)
if err != nil {
panic(err)
}
mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux)
if err != nil {
panic(err)
}
}(l)
}
}
type selector bool
func (s selector) Select() bool {
return bool(s)
}
type lister []pserver.Server
func (l lister) List() []pserver.Server {
return l
}
func TestClientFull(t *testing.T) {
servers := make([]pserver.Server, numPserver)
for i := 0; i < numPserver; i++ {
servers[i] = pserver.Server{Index: i, Addr: ":" + strconv.Itoa(port[i])}
}
c := pserver.NewClient(lister(servers), len(servers), selector(true))
selected := c.BeginInitParams()
if !selected {
t.Fatal("should be selected.")
}
const numParameter = 100
for i := 0; i < numParameter; i++ {
var p pserver.Parameter
p.Name = "p_" + strconv.Itoa(i)
p.ElementType = pserver.Float32
p.Content = make([]byte, (i+1)*100)
err := c.InitParam(pserver.ParameterWithConfig{Param: p})
if err != nil {
t.Fatal(err)
}
}
err := c.FinishInitParams()
if err != nil {
t.Fatal(err)
}
var grads []pserver.Gradient
for i := 0; i < numParameter/2; i++ {
var g pserver.Gradient
g.Name = "p_" + strconv.Itoa(i)
g.ElementType = pserver.Float32
g.Content = make([]byte, (i+1)*100)
grads = append(grads, g)
}
err = c.SendGrads(grads)
if err != nil {
t.Fatal(err)
}
names := make([]string, numParameter)
for i := 0; i < numParameter; i++ {
names[i] = "p_" + strconv.Itoa(i)
}
params, err := c.GetParams(names)
if err != nil {
t.Fatal(err)
}
if len(names) != len(params) {
t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
}
for i := range params {
if names[i] != params[i].Name {
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i])
}
}
}
package connection
import (
"errors"
"net/rpc"
"sync"
)
// TODO(helin): add TCP re-connect logic
// Conn is a connection to a parameter server
type Conn struct {
mu sync.Mutex
client *rpc.Client
waitConn chan struct{}
}
// New creates a new connection.
func New() *Conn {
c := &Conn{}
return c
}
// Connect connects the connection to a address.
func (c *Conn) Connect(addr string) error {
c.mu.Lock()
if c.client != nil {
err := c.client.Close()
if err != nil {
c.mu.Unlock()
return err
}
c.client = nil
}
c.mu.Unlock()
client, err := rpc.DialHTTP("tcp", addr)
if err != nil {
return err
}
c.mu.Lock()
defer c.mu.Unlock()
if c.client == nil {
c.client = client
if c.waitConn != nil {
close(c.waitConn)
c.waitConn = nil
}
} else {
return errors.New("client already set from a concurrent goroutine")
}
return nil
}
// Call make a RPC call.
//
// Call will be blocked until the connection to remote RPC service
// being established.
func (c *Conn) Call(serviceMethod string, args interface{}, reply interface{}) error {
c.mu.Lock()
client := c.client
var waitCh chan struct{}
if client == nil {
if c.waitConn != nil {
waitCh = c.waitConn
} else {
waitCh = make(chan struct{})
c.waitConn = waitCh
}
}
c.mu.Unlock()
if waitCh != nil {
// wait until new connection being established
<-waitCh
return c.Call(serviceMethod, args, reply)
}
return client.Call(serviceMethod, args, reply)
}
......@@ -29,11 +29,11 @@ func newOptimizer(t optimizerType, learning_rate float64) *optimizer {
func (o *optimizer) UpdateParameter(p Parameter, g Gradient) error {
if len(p.Content) != len(g.Content) {
return fmt.Errorf("parameter and gradient length not match, parameter: %d, gradient: %d", len(p.Content), len(g.Content))
return fmt.Errorf("Name: %s, parameter and gradient length not match, parameter: %d, gradient: %d", p.Name, len(p.Content), len(g.Content))
}
if p.ElementType != g.ElementType {
return fmt.Errorf("parameter and gradient element type not match, parameter: %v, gradient: %v", p.ElementType, g.ElementType)
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", p.Name, p.ElementType, g.ElementType)
}
r := C.paddle_update_parameter(o.opt, unsafe.Pointer(&p.Content[0]), C.paddle_element_type(p.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content)))
......
......@@ -49,33 +49,12 @@ type Service struct {
// NewService creates a new service.
func NewService() *Service {
s := &Service{}
s := &Service{opt: newOptimizer(sgd, 0.01)}
s.paramMap = make(map[string]Parameter)
s.initialized = make(chan struct{})
return s
}
// BeginInitParams tells the parameter server that the parameter
// initialization has begun.
func (s *Service) BeginInitParams(config []byte, dummy *int) error {
select {
case <-s.initialized:
return ErrAlreadyInitialized
default:
}
s.mu.Lock()
defer s.mu.Unlock()
if s.opt != nil {
s.opt.Cleanup()
}
// TODO(helin): parse learning rate from config
s.opt = newOptimizer(sgd, 0.01)
return nil
}
// InitParam initializes a parameter.
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
select {
......@@ -109,65 +88,37 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
return nil
}
// SendGrads sends gradients to parameter servers for parameter
// SendGrad sends gradient to parameter servers for parameter
// optimization.
func (s *Service) SendGrads(grads []Gradient, dummy *int) error {
func (s *Service) SendGrad(g Gradient, dummy *int) error {
select {
case <-s.initialized:
default:
return ErrUninitialized
}
count := len(grads)
if count == 0 {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
for _, g := range grads {
if _, ok := s.paramMap[g.Name]; !ok {
p, ok := s.paramMap[g.Name]
if !ok {
return fmt.Errorf("parameter: %s does not exist", g.Name)
}
}
errCh := make(chan error, count)
for _, g := range grads {
go func(p Parameter, g Gradient) {
err := s.opt.UpdateParameter(p, g)
errCh <- err
}(s.paramMap[g.Name], g)
}
recv := 0
for err := range errCh {
if err != nil {
return err
}
recv++
if recv == count {
break
}
}
return nil
return s.opt.UpdateParameter(p, g)
}
// GetParams gets parameters from the parameter server.
func (s *Service) GetParams(names []string, parameters *[]Parameter) error {
// GetParam gets parameters from the parameter server.
func (s *Service) GetParam(name string, parameter *Parameter) error {
<-s.initialized
s.mu.Lock()
defer s.mu.Unlock()
for _, n := range names {
if _, ok := s.paramMap[n]; !ok {
return fmt.Errorf("parameter: %s does not exist", n)
}
p, ok := s.paramMap[name]
if !ok {
return fmt.Errorf("parameter: %s does not exist", name)
}
*parameters = make([]Parameter, len(names))
for i, n := range names {
// The parameter content (a byte slice) may change
// during RPC serialization due to write from other
// goroutine, we allow it since mini-batch based deep
......@@ -175,9 +126,7 @@ func (s *Service) GetParams(names []string, parameters *[]Parameter) error {
// nature. This race condition is allowed deliberately
// to save the program from making a copy of the
// paramter content.
(*parameters)[i] = s.paramMap[n]
}
*parameter = p
return nil
}
......
......@@ -4,23 +4,19 @@ import (
"reflect"
"sync"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver"
)
func TestFull(t *testing.T) {
s := pserver.NewService()
var dummy int
err := s.BeginInitParams(nil, &dummy)
if err != nil {
t.FailNow()
}
var p pserver.Parameter
p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
err = s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy)
var dummy int
err := s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy)
if err != nil {
t.FailNow()
}
......@@ -39,40 +35,39 @@ func TestFull(t *testing.T) {
t.FailNow()
}
var params []pserver.Parameter
err = s.GetParams([]string{"param_b", "param_a"}, &params)
var param pserver.Parameter
err = s.GetParam("param_b", &param)
if err != nil {
t.FailNow()
}
if len(params) != 2 || !reflect.DeepEqual(params[0], p1) || !reflect.DeepEqual(params[0], p1) {
if !reflect.DeepEqual(param, p1) {
t.FailNow()
}
grads := []pserver.Gradient{pserver.Gradient(p1), pserver.Gradient(p)}
err = s.SendGrads(grads, &dummy)
g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
err = s.SendGrad(g1, &dummy)
if err != nil {
t.FailNow()
}
err = s.SendGrad(g2, &dummy)
var params1 []pserver.Parameter
err = s.GetParams([]string{"param_b", "param_a"}, &params1)
if err != nil {
t.FailNow()
}
if len(params) != 2 {
var param1 pserver.Parameter
err = s.GetParam("param_a", &param1)
if err != nil {
t.FailNow()
}
// don't compare content, since it's already changed by
// gradient update.
params1[0].Content = nil
params1[0].Content = nil
param1.Content = nil
p.Content = nil
p1.Content = nil
if !reflect.DeepEqual(params1[0], p1) || !reflect.DeepEqual(params1[0], p1) {
if !reflect.DeepEqual(param1, p) {
t.FailNow()
}
}
......@@ -80,19 +75,7 @@ func TestFull(t *testing.T) {
func TestMultipleInit(t *testing.T) {
s := pserver.NewService()
var dummy int
err := s.BeginInitParams(nil, &dummy)
if err != nil {
t.FailNow()
}
// this is fine, it's possible for client to call init
// multiple times.
err = s.BeginInitParams(nil, &dummy)
if err != nil {
t.FailNow()
}
err = s.FinishInitParams(0, &dummy)
err := s.FinishInitParams(0, &dummy)
if err != nil {
t.FailNow()
}
......@@ -101,17 +84,12 @@ func TestMultipleInit(t *testing.T) {
if err != pserver.ErrAlreadyInitialized {
t.FailNow()
}
err = s.BeginInitParams(nil, &dummy)
if err != pserver.ErrAlreadyInitialized {
t.FailNow()
}
}
func TestUninitialized(t *testing.T) {
s := pserver.NewService()
var dummy int
err := s.SendGrads(nil, &dummy)
err := s.SendGrad(pserver.Gradient{}, &dummy)
if err != pserver.ErrUninitialized {
t.FailNow()
}
......@@ -123,8 +101,8 @@ func TestBlockUntilInitialized(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
var params []pserver.Parameter
err := s.GetParams(nil, &params)
var param pserver.Parameter
err := s.GetParam("param_a", &param)
if err != nil {
t.FailNow()
}
......@@ -143,11 +121,7 @@ func TestBlockUntilInitialized(t *testing.T) {
ch <- struct{}{}
}()
var dummy int
err := s.BeginInitParams(nil, &dummy)
if err != nil {
t.FailNow()
}
time.Sleep(50 * time.Millisecond)
select {
case <-ch:
......@@ -156,6 +130,16 @@ func TestBlockUntilInitialized(t *testing.T) {
default:
}
var p pserver.Parameter
p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
var dummy int
err := s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy)
if err != nil {
t.FailNow()
}
err = s.FinishInitParams(0, &dummy)
if err != nil {
t.FailNow()
......
......@@ -8,6 +8,7 @@ w := recordio.NewWriter(f)
w.Write([]byte("Hello"))
w.Write([]byte("World!"))
w.Close()
f.Close()
```
## Read
......@@ -18,6 +19,7 @@ w.Close()
f, e := os.Open("a_file.recordio")
idx, e := recordio.LoadIndex(f)
fmt.Println("Total records: ", idx.Len())
f.Close()
```
2. Create one or more scanner to read a range of records. The
......@@ -30,7 +32,8 @@ w.Close()
for s.Scan() {
fmt.Println(string(s.Record()))
}
if s.Err() != nil && s.Err() != io.EOF {
if s.Err() != nil {
log.Fatalf("Something wrong with scanning: %v", e)
}
f.Close()
```
cmake_minimum_required(VERSION 3.0)
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake")
project(cxx_go C Go)
include(golang)
include(flags)
go_library(recordio STATIC)
add_subdirectory(test)
package main
/*
#include <string.h>
typedef int reader;
typedef int writer;
*/
import "C"
import (
"log"
"os"
"strings"
"unsafe"
"github.com/PaddlePaddle/Paddle/go/recordio"
)
var nullPtr = unsafe.Pointer(uintptr(0))
type writer struct {
w *recordio.Writer
f *os.File
}
type reader struct {
scanner *recordio.Scanner
}
func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nullPtr {
return nil
}
// create a Go clice backed by a C array, reference:
// https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
//
// Go garbage collector will not interact with this data, need
// to be freed properly.
return (*[1 << 30]byte)(p)[:len:len]
}
//export create_recordio_writer
func create_recordio_writer(path *C.char) C.writer {
p := C.GoString(path)
f, err := os.Create(p)
if err != nil {
log.Println(err)
return -1
}
w := recordio.NewWriter(f, -1, -1)
writer := &writer{f: f, w: w}
return addWriter(writer)
}
//export recordio_write
func recordio_write(writer C.writer, buf *C.uchar, size C.int) C.int {
w := getWriter(writer)
b := cArrayToSlice(unsafe.Pointer(buf), int(size))
c, err := w.w.Write(b)
if err != nil {
log.Println(err)
return -1
}
return C.int(c)
}
//export release_recordio_writer
func release_recordio_writer(writer C.writer) {
w := removeWriter(writer)
w.w.Close()
w.f.Close()
}
//export create_recordio_reader
func create_recordio_reader(path *C.char) C.reader {
p := C.GoString(path)
s, err := recordio.NewScanner(strings.Split(p, ",")...)
if err != nil {
log.Println(err)
return -1
}
r := &reader{scanner: s}
return addReader(r)
}
//export recordio_read
func recordio_read(reader C.reader, record **C.uchar) C.int {
r := getReader(reader)
if r.scanner.Scan() {
buf := r.scanner.Record()
if len(buf) == 0 {
*record = (*C.uchar)(nullPtr)
return 0
}
size := C.int(len(buf))
*record = (*C.uchar)(C.malloc(C.size_t(len(buf))))
C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
return size
}
return -1
}
//export release_recordio_reader
func release_recordio_reader(reader C.reader) {
r := removeReader(reader)
r.scanner.Close()
}
func main() {} // Required but ignored
package main
/*
typedef int reader;
typedef int writer;
*/
import "C"
import "sync"
var mu sync.Mutex
var handleMap = make(map[C.reader]*reader)
var curHandle C.reader
var writerMap = make(map[C.writer]*writer)
var curWriterHandle C.writer
func addReader(r *reader) C.reader {
mu.Lock()
defer mu.Unlock()
reader := curHandle
curHandle++
handleMap[reader] = r
return reader
}
func getReader(reader C.reader) *reader {
mu.Lock()
defer mu.Unlock()
return handleMap[reader]
}
func removeReader(reader C.reader) *reader {
mu.Lock()
defer mu.Unlock()
r := handleMap[reader]
delete(handleMap, reader)
return r
}
func addWriter(w *writer) C.writer {
mu.Lock()
defer mu.Unlock()
writer := curWriterHandle
curWriterHandle++
writerMap[writer] = w
return writer
}
func getWriter(writer C.writer) *writer {
mu.Lock()
defer mu.Unlock()
return writerMap[writer]
}
func removeWriter(writer C.writer) *writer {
mu.Lock()
defer mu.Unlock()
w := writerMap[writer]
delete(writerMap, writer)
return w
}
cmake_minimum_required(VERSION 3.0)
include_directories(${CMAKE_BINARY_DIR})
add_executable(recordio_test test.c)
add_dependencies(recordio_test recordio)
set (CMAKE_EXE_LINKER_FLAGS "-pthread")
target_link_libraries(recordio_test ${CMAKE_BINARY_DIR}/librecordio.a)
#include <stdio.h>
#include <stdlib.h>
#include "librecordio.h"
void fail() {
// TODO(helin): fix: gtest using cmake is not working, using this
// hacky way for now.
printf("test failed.\n");
exit(-1);
}
int main() {
writer w = create_recordio_writer("/tmp/test_recordio_0");
recordio_write(w, "hello", 6);
recordio_write(w, "hi", 3);
release_recordio_writer(w);
w = create_recordio_writer("/tmp/test_recordio_1");
recordio_write(w, "dog", 4);
recordio_write(w, "cat", 4);
release_recordio_writer(w);
reader r = create_recordio_reader("/tmp/test_recordio_*");
unsigned char* item = NULL;
int size = recordio_read(r, &item);
if (strcmp(item, "hello") || size != 6) {
fail();
}
free(item);
size = recordio_read(r, &item);
if (strcmp(item, "hi") || size != 3) {
fail();
}
free(item);
size = recordio_read(r, &item);
if (strcmp(item, "dog") || size != 4) {
fail();
}
free(item);
size = recordio_read(r, &item);
if (strcmp(item, "cat") || size != 4) {
fail();
}
free(item);
size = recordio_read(r, &item);
if (size != -1) {
fail();
}
release_recordio_reader(r);
}
......@@ -74,8 +74,8 @@ func (r *Index) Locate(recordIndex int) (int, int) {
return -1, -1
}
// Scanner scans records in a specified range within [0, numRecords).
type Scanner struct {
// RangeScanner scans records in a specified range within [0, numRecords).
type RangeScanner struct {
reader io.ReadSeeker
index *Index
start, end, cur int
......@@ -84,10 +84,10 @@ type Scanner struct {
err error
}
// NewScanner creates a scanner that sequencially reads records in the
// NewRangeScanner creates a scanner that sequencially reads records in the
// range [start, start+len). If start < 0, it scans from the
// beginning. If len < 0, it scans till the end of file.
func NewScanner(r io.ReadSeeker, index *Index, start, len int) *Scanner {
func NewRangeScanner(r io.ReadSeeker, index *Index, start, len int) *RangeScanner {
if start < 0 {
start = 0
}
......@@ -95,7 +95,7 @@ func NewScanner(r io.ReadSeeker, index *Index, start, len int) *Scanner {
len = index.NumRecords() - start
}
return &Scanner{
return &RangeScanner{
reader: r,
index: index,
start: start,
......@@ -108,7 +108,7 @@ func NewScanner(r io.ReadSeeker, index *Index, start, len int) *Scanner {
// Scan moves the cursor forward for one record and loads the chunk
// containing the record if not yet.
func (s *Scanner) Scan() bool {
func (s *RangeScanner) Scan() bool {
s.cur++
if s.cur >= s.end {
......@@ -124,12 +124,17 @@ func (s *Scanner) Scan() bool {
}
// Record returns the record under the current cursor.
func (s *Scanner) Record() []byte {
func (s *RangeScanner) Record() []byte {
_, ri := s.index.Locate(s.cur)
return s.chunk.records[ri]
}
// Error returns the error that stopped Scan.
func (s *Scanner) Error() error {
// Err returns the first non-EOF error that was encountered by the
// Scanner.
func (s *RangeScanner) Err() error {
if s.err == io.EOF {
return nil
}
return s.err
}
......@@ -68,7 +68,7 @@ func TestWriteAndRead(t *testing.T) {
2*4)}, // two record legnths
idx.chunkOffsets)
s := NewScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1)
s := NewRangeScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1)
i := 0
for s.Scan() {
assert.Equal(data[i], string(s.Record()))
......
......@@ -5,7 +5,7 @@ import (
"reflect"
"testing"
"github.com/PaddlePaddle/Paddle/paddle/go/recordio"
"github.com/PaddlePaddle/Paddle/go/recordio"
)
func TestWriteRead(t *testing.T) {
......@@ -29,7 +29,7 @@ func TestWriteRead(t *testing.T) {
t.Fatal("num record does not match:", idx.NumRecords(), total)
}
s := recordio.NewScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1)
s := recordio.NewRangeScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1)
i := 0
for s.Scan() {
if !reflect.DeepEqual(s.Record(), make([]byte, i)) {
......@@ -66,7 +66,7 @@ func TestChunkIndex(t *testing.T) {
for i := 0; i < total; i++ {
newIdx := idx.ChunkIndex(i)
s := recordio.NewScanner(bytes.NewReader(buf.Bytes()), newIdx, -1, -1)
s := recordio.NewRangeScanner(bytes.NewReader(buf.Bytes()), newIdx, -1, -1)
j := 0
for s.Scan() {
if !reflect.DeepEqual(s.Record(), make([]byte, i)) {
......
package recordio
import (
"fmt"
"os"
"path/filepath"
)
// Scanner is a scanner for multiple recordio files.
type Scanner struct {
paths []string
curFile *os.File
curScanner *RangeScanner
pathIdx int
end bool
err error
}
// NewScanner creates a new Scanner.
func NewScanner(paths ...string) (*Scanner, error) {
var ps []string
for _, s := range paths {
match, err := filepath.Glob(s)
if err != nil {
return nil, err
}
ps = append(ps, match...)
}
if len(ps) == 0 {
return nil, fmt.Errorf("no valid path provided: %v", paths)
}
return &Scanner{paths: ps}, nil
}
// Scan moves the cursor forward for one record and loads the chunk
// containing the record if not yet.
func (s *Scanner) Scan() bool {
if s.err != nil {
return false
}
if s.end {
return false
}
if s.curScanner == nil {
more, err := s.nextFile()
if err != nil {
s.err = err
return false
}
if !more {
s.end = true
return false
}
}
curMore := s.curScanner.Scan()
s.err = s.curScanner.Err()
if s.err != nil {
return curMore
}
if !curMore {
err := s.curFile.Close()
if err != nil {
s.err = err
return false
}
s.curFile = nil
more, err := s.nextFile()
if err != nil {
s.err = err
return false
}
if !more {
s.end = true
return false
}
return s.Scan()
}
return true
}
// Err returns the first non-EOF error that was encountered by the
// Scanner.
func (s *Scanner) Err() error {
return s.err
}
// Record returns the record under the current cursor.
func (s *Scanner) Record() []byte {
if s.curScanner == nil {
return nil
}
return s.curScanner.Record()
}
// Close release the resources.
func (s *Scanner) Close() error {
s.curScanner = nil
if s.curFile != nil {
err := s.curFile.Close()
s.curFile = nil
return err
}
return nil
}
func (s *Scanner) nextFile() (bool, error) {
if s.pathIdx >= len(s.paths) {
return false, nil
}
path := s.paths[s.pathIdx]
s.pathIdx++
f, err := os.Open(path)
if err != nil {
return false, err
}
idx, err := LoadIndex(f)
if err != nil {
f.Close()
return false, err
}
s.curFile = f
s.curScanner = NewRangeScanner(f, idx, 0, -1)
return true, nil
}
......@@ -9,9 +9,10 @@ add_subdirectory(pserver)
add_subdirectory(trainer)
add_subdirectory(scripts)
if(CMAKE_Go_COMPILER)
add_subdirectory(go)
endif()
# Do not build go directory until go cmake is working smoothly.
# if(CMAKE_Go_COMPILER)
# add_subdirectory(go)
# endif()
find_package(Boost QUIET)
......
include_directories(${CMAKE_CURRENT_BINARY_DIR})
go_library(adder SRCS adder.go)
if (WITH_TESTING)
cc_test(cgo_test
SRCS
cgo_test.cc
DEPS
adder)
endif()
package main
import "C"
//export GoAdder
func GoAdder(x, y int) int {
return x + y
}
func main() {} // Required but ignored
cmake_minimum_required(VERSION 3.0)
if(GTEST_INCLUDE_DIR AND GTEST_LIBRARIES)
message("-- Found gtest (include: ${GTEST_INCLUDE_DIR}, library: ${GTEST_LIBRARIES})")
else()
# find cmake directory modules
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake")
# enable c++11
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
# enable gtest
set(THIRD_PARTY_PATH ./third_party)
set(WITH_TESTING ON)
include(external/gtest)
endif()
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
project(cxx_go CXX C Go)
include(cmake/golang.cmake)
include(cmake/flags.cmake)
ExternalGoProject_Add(pserver github.com/PaddlePaddle/Paddle/paddle/go/pserver)
add_go_library(client STATIC pserver)
add_subdirectory(test)
#include <iostream>
#include "gtest/gtest.h"
#include "libadder.h"
TEST(Cgo, Invoke) { EXPECT_EQ(GoAdder(30, 12), 42); }
package pserver
// Client is the client to parameter servers.
type Client struct {
}
// NewClient creates a new client.
func NewClient(addr string) *Client {
return &Client{}
}
// BeginInitParams begins to initialize parameters on parameter
// servers.
//
// BeginInitParams will be called from multiple trainers, only one
// trainer will be selected to initialize the parameters on parameter
// servers. Other trainers will be blocked until the initialization is
// done, and they need to get the initialized parameters from
// parameter servers using GetParams.
func (c *Client) BeginInitParams(pserverConfigProto []byte) (selected bool, err error) {
return true, nil
}
// InitParam initializes the parameter on parameter servers.
func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error {
return nil
}
// FinishInitParams tells parameter servers client has sent all
// parameters to parameter servers as initialization.
func (c *Client) FinishInitParams() error {
return nil
}
// SendGrads sends gradients to parameter servers for updating
// parameters.
func (c *Client) SendGrads(grads []Gradient) error {
return nil
}
// GetParams gets parameters from parameter servers.
func (c *Client) GetParams(names []string) ([]Parameter, error) {
return nil, nil
}
// SaveModel indicates parameters to save the parameter to the given
// path.
func (c *Client) SaveModel(path string) error {
return nil
}
// Cleanup cleans up the client states.
func (c *Client) Cleanup() {
}
......@@ -324,6 +324,7 @@ protected:
std::vector<std::shared_ptr<IParameterUpdaterHook>> updaterHooks_;
public:
void setSharedCount(int cnt) { sharedCount_ = cnt; }
int getSharedCount() { return sharedCount_; }
bool isSparse() { return config_.is_sparse(); }
......
......@@ -21,7 +21,7 @@ limitations under the License. */
#else
#if !defined(__arm__)
#if !defined(__arm__) && !defined(__aarch64__)
#include <cpuid.h>
/// for GCC/Clang
#define CPUID(info, x) __cpuid_count(x, 0, info[0], info[1], info[2], info[3])
......@@ -32,7 +32,7 @@ limitations under the License. */
namespace paddle {
SIMDFlags::SIMDFlags() {
#if defined(__arm__)
#if defined(__arm__) || defined(__aarch64__)
simd_flags_ = SIMD_NEON;
#else
unsigned int cpuInfo[4];
......
......@@ -19,7 +19,7 @@ using namespace paddle; // NOLINT
TEST(SIMDFlags, gccTest) {
#if (defined(__GNUC__) || defined(__GNUG__)) && !(defined(__clang__)) && \
!defined(__arm__)
!defined(__arm__) && !defined(__aarch64__)
// clang-format off
CHECK(!__builtin_cpu_supports("sse") != HAS_SSE);
CHECK(!__builtin_cpu_supports("sse2") != HAS_SSE2);
......
......@@ -3371,7 +3371,7 @@ def make_importer(config_dir, config_args):
return Import
settings = dict(
DEFAULT_SETTING = dict(
batch_size=None,
mini_batch_size=None,
algorithm='async_sgd',
......@@ -3404,6 +3404,8 @@ settings = dict(
adam_beta2=0.999,
adam_epsilon=1e-8, )
settings = copy.deepcopy(DEFAULT_SETTING)
settings_deprecated = dict(usage_ratio=1., )
trainer_settings = dict(
......@@ -3544,10 +3546,8 @@ def update_g_config():
return g_config
def parse_config(trainer_config, config_arg_str):
def begin_parse(config_arg_str=''):
'''
@param trainer_config: can be a string of config file name or a function name
with config logic
@param config_arg_str: a string of the form var1=val1,var2=val2. It will be
passed to config script as a dictionary CONFIG_ARGS
'''
......@@ -3555,12 +3555,23 @@ def parse_config(trainer_config, config_arg_str):
for hook in _parse_config_hooks:
hook()
config_args = {}
logger.findCaller = find_caller
logger.fatal = my_fatal
g_config.model_config.type = "nn"
global g_current_submodel, g_root_submodel
g_root_submodel = g_config.model_config.sub_models.add()
g_root_submodel.name = 'root'
g_root_submodel.is_recurrent_layer_group = False
g_current_submodel = g_root_submodel
def parse_config(trainer_config, config_arg_str):
begin_parse(config_arg_str)
config_args = {}
if config_arg_str:
config_args = dict([f.split('=') for f in config_arg_str.split(',')])
......@@ -3573,14 +3584,6 @@ def parse_config(trainer_config, config_arg_str):
extension_module = importlib(extension_module_name)
g_extended_config_funcs = extension_module.get_config_funcs(g_config)
g_config.model_config.type = 'nn'
global g_current_submodel, g_root_submodel
g_root_submodel = g_config.model_config.sub_models.add()
g_root_submodel.name = 'root'
g_root_submodel.is_recurrent_layer_group = False
g_current_submodel = g_root_submodel
if hasattr(trainer_config, '__call__'):
trainer_config.func_globals.update(
make_config_environment("", config_args))
......
......@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import paddle.trainer.config_parser as config_parser
from paddle.proto.TrainerConfig_pb2 import OptimizationConfig
'''
This file is a wrapper of formal config_parser. The main idea of this file is to
separete different config logic into different function, such as network configuration
......@@ -20,7 +22,8 @@ separete different config logic into different function, such as network configu
'''
__all__ = [
"parse_trainer_config", "parse_network_config", "parse_optimizer_config"
"parse_trainer_config", "parse_network_config", "parse_optimizer_config",
"reset_parser"
]
......@@ -34,5 +37,15 @@ def parse_network_config(network_conf, config_arg_str=''):
def parse_optimizer_config(optimizer_conf, config_arg_str=''):
config = config_parser.parse_config(optimizer_conf, config_arg_str)
return config.opt_config
config_parser.settings = copy.deepcopy(config_parser.DEFAULT_SETTING)
optimizer_conf()
opt_config = OptimizationConfig()
for k, v in config_parser.settings.iteritems():
if v is None:
continue
opt_config.__setattr__(k, v)
return opt_config
def reset_parser():
config_parser.begin_parse()
......@@ -119,6 +119,7 @@ __all__ = [
'eos_layer',
'smooth_l1_cost',
'layer_support',
'multiplex_layer',
]
......@@ -185,6 +186,7 @@ class LayerType(object):
MAXOUT = "maxout"
SPP_LAYER = "spp"
PAD_LAYER = "pad"
MULTIPLEX_LAYER = "multiplex"
PRINT_LAYER = "print"
PRIORBOX_LAYER = "priorbox"
......@@ -285,6 +287,7 @@ class LayerOutput(object):
assert size is not None
assert LayerType.is_layer_type(layer_type)
self.name = name
self.full_name = MakeLayerNameInSubmodel(name)
self.layer_type = layer_type
if parents is not None and type(parents) != list:
parents = [parents]
......@@ -3489,6 +3492,11 @@ def recurrent_group(step,
RecurrentLayerGroupEnd(name=name)
for layer_out in layer_outs:
# Thee previous full_name is the name is the rnn group
# We need a full_name outside the rnn group
layer_out.full_name = MakeLayerNameInSubmodel(layer_out.name)
if len(layer_outs) == 1:
return layer_outs[0]
else:
......@@ -5465,3 +5473,54 @@ def smooth_l1_cost(input, label, name=None, coeff=1.0, layer_attr=None):
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name, LayerType.SMOOTH_L1, parents=[input, label], size=1)
@wrap_name_default()
def multiplex_layer(input, name=None, layer_attr=None):
"""
This layer multiplex multiple layers according to the index,
which is provided by the first input layer.
inputs[0]: the index of the layer to output of size batchSize.
inputs[1:N]; the candidate output data.
For each index i from 0 to batchSize -1, the output is the i-th row of the
(index[i] + 1)-th layer.
For each i-th row of output:
.. math::
y[i][j] = x_{x_{0}[i] + 1}[i][j], j = 0,1, ... , (x_{1}.width - 1)
where, y is output. :math:`x_{k}` is the k-th input layer and
:math:`k = x_{0}[i] + 1`.
.. code-block:: python
maxid = multiplex_layer(input=layers)
:param input: Input layers.
:type input: list of LayerOutput
:param name: Layer name.
:type name: basestring
:param layer_attr: extra layer attributes.
:type layer_attr: ExtraLayerAttribute.
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert isinstance(input, collections.Sequence)
assert len(input) > 2, 'multiplex_layer should have more than 2 inputs'
for i in range(1, len(input)):
assert isinstance(input[i], LayerOutput)
assert input[i].size == input[1].size, \
"All the input layers except the first one should have the same size"
l = Layer(
name=name,
type='multiplex',
inputs=[x.name for x in input],
size=input[1].size,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name=name,
layer_type=LayerType.MULTIPLEX_LAYER,
parents=input,
size=l.config.size)
......@@ -5,6 +5,6 @@ last_first_seq test_expand_layer test_ntm_layers test_hsigmoid
img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cost_layers
test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight
test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops
test_seq_concat_reshape test_pad test_smooth_l1)
test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer)
export whole_configs=(test_split_datasource)
type: "nn"
layers {
name: "index"
type: "data"
size: 1
active_type: ""
}
layers {
name: "data1"
type: "data"
size: 30
active_type: ""
}
layers {
name: "data2"
type: "data"
size: 30
active_type: ""
}
layers {
name: "data3"
type: "data"
size: 30
active_type: ""
}
layers {
name: "__multiplex_layer_0__"
type: "multiplex"
size: 30
active_type: ""
inputs {
input_layer_name: "index"
}
inputs {
input_layer_name: "data1"
}
inputs {
input_layer_name: "data2"
}
inputs {
input_layer_name: "data3"
}
}
input_layer_names: "index"
input_layer_names: "data1"
input_layer_names: "data2"
input_layer_names: "data3"
output_layer_names: "__multiplex_layer_0__"
sub_models {
name: "root"
layer_names: "index"
layer_names: "data1"
layer_names: "data2"
layer_names: "data3"
layer_names: "__multiplex_layer_0__"
input_layer_names: "index"
input_layer_names: "data1"
input_layer_names: "data2"
input_layer_names: "data3"
output_layer_names: "__multiplex_layer_0__"
is_recurrent_layer_group: false
}
from paddle.trainer_config_helpers import *
settings(batch_size=1000, learning_rate=1e-5)
index = data_layer(name='index', size=1)
din1 = data_layer(name='data1', size=30)
din2 = data_layer(name='data2', size=30)
din3 = data_layer(name='data3', size=30)
dout = multiplex_layer([index, din1, din2, din3])
outputs(dout)
......@@ -27,6 +27,7 @@ from . import dataset
from . import reader
from . import plot
import attr
import op
import pooling
import inference
import networks
......
......@@ -14,206 +14,55 @@
import collections
import re
from paddle.trainer_config_helpers.default_decorators import wrap_name_default
import paddle.trainer_config_helpers as conf_helps
from topology import Topology
class LayerType(type):
def __new__(cls, name, bases, attrs):
method_name = attrs.get('METHOD_NAME', None)
if method_name is not None:
method = getattr(conf_helps, method_name)
if method.__doc__ is not None:
mapper = attrs.get("__map_docstr__", None)
if mapper is not None:
attrs['__doc__'] = LayerType.__map_docstr__(
mapper(method.__doc__),
method_name=method_name,
name=name)
else:
attrs['__doc__'] = LayerType.__map_docstr__(
method.__doc__, method_name=method_name, name=name)
return super(LayerType, cls).__new__(cls, name, bases, attrs)
@staticmethod
def __map_docstr__(doc, name, method_name):
__layer_map__ = {}
def __map_docstr__(doc, name):
if doc is None:
return doc
assert isinstance(doc, basestring)
# replace LayerOutput to paddle.v2.config_base.Layer
doc = doc.replace("LayerOutput", "paddle.v2.config_base.Layer")
doc = doc.replace('ParameterAttribute',
'paddle.v2.attr.ParameterAttribute')
doc = doc.replace('ParameterAttribute', 'paddle.v2.attr.ParameterAttribute')
doc = re.sub(r'ExtraLayerAttribute[^\s]?',
'paddle.v2.attr.ExtraAttribute', doc)
doc = re.sub(r'ExtraLayerAttribute[^\s]?', 'paddle.v2.attr.ExtraAttribute',
doc)
# xxx_layer to xxx
doc = re.sub(r"(?P<name>[a-z]+)_layer", r"\g<name>", doc)
# XxxxActivation to paddle.v2.Activation.Xxxx
# XxxxActivation to paddle.v2.activation.Xxxx
doc = re.sub(r"(?P<name>[A-Z][a-zA-Z]+)Activation",
r"paddle.v2.Activation.\g<name>", doc)
r"paddle.v2.activation.\g<name>", doc)
# xxx_evaluator to paddle.v2.evaluator.xxx
doc = re.sub(r"(?P<name>[a-z]+)_evaluator", r"evaluator.\g<name>", doc)
# TODO(yuyang18): Add more rules if needed.
return doc
class Layer(object):
__metaclass__ = LayerType
def __init__(self, name=None, parent_layers=None):
assert isinstance(parent_layers, dict)
self.name = name
self.__context__ = {}
self.__parent_layers__ = parent_layers
# some layer may have some extra parent layer
self.__extra_parent__ = []
# used for evaluator.
self.__children_layers__ = []
def extra_parent(self):
return self.__extra_parent__
def append_extra_parent(self, parent):
self.__extra_parent__.append(parent)
def append_child(self, layer, parent_names):
self.__children_layers__.append((layer, parent_names))
def to_proto(self, context):
"""
function to set proto attribute
"""
self.__context__ = context
# STEP: short cut if this layer is parsed before.
if self.context_name() in context:
if self.use_context_name():
return context[self.context_name()]
else:
return context[self.name]
# STEP: parse extra_parent that is not used by this layer but must
# be parsed before this layer.
for p in self.__extra_parent__:
p.to_proto(context=context)
# STEP: parse parent that is used by this layer, get the result and
# insert into kwargs of the next layer's to_proto_impl method.
kwargs = dict()
for layer_name in self.__parent_layers__:
if not isinstance(self.__parent_layers__[layer_name],
collections.Sequence):
v1_layer = self.__parent_layers__[layer_name].to_proto(
context=context)
else:
v1_layer = map(lambda x: x.to_proto(context=context),
self.__parent_layers__[layer_name])
kwargs[layer_name] = v1_layer
# STEP: parse myself and add myself into context.
ret_val = self.to_proto_impl(**kwargs)
if self.context_name() is not None \
and self.context_name() not in context:
context[self.context_name()] = ret_val
# STEP: parse children that should be pased after this layer.
for layer, pnames in self.__children_layers__:
drop = False
# child will only be parsed if all parents are in context.
for pname in pnames:
if pname not in context:
drop = True
break
if drop:
continue
layer.to_proto(context=context)
# STEP: return v1 layer result
if self.context_name() is None:
return ret_val
elif self.use_context_name():
return context[self.context_name()]
else:
return context[self.name]
def to_proto_impl(self, **kwargs):
raise NotImplementedError()
def context_name(self):
"""
Context name means the context which stores `to_proto_impl` result.
If multiple layer share same context_name, the `to_proto_impl` of them
will be invoked only once.
"""
return self.name
def use_context_name(self):
return False
def calculate_size(self):
"""
lazy calculate size of the layer, should be called when to_proto_impl of
this layer is called.
:return:
"""
return self.__context__[self.context_name()].size
def attr(self):
topo = Topology(self)
return topo.get_layer_proto(self.name)
def __convert_to_v2__(method_name,
parent_names,
is_default_name=True,
attach_parent=False):
if is_default_name:
wrapper = wrap_name_default(name_prefix=method_name)
else:
wrapper = None
class V2LayerImpl(Layer):
METHOD_NAME = method_name
def __init__(self, **kwargs):
parent_layers = dict()
other_kwargs = dict()
for pname in parent_names:
if pname in kwargs:
parent_layers[pname] = kwargs[pname]
if attach_parent:
pnames = [x.context_name() for x in parent_layers.values()]
for pname in parent_layers:
layers = kwargs[pname]
if not isinstance(layers, collections.Sequence):
layers = [layers]
for layer in layers:
layer.append_child(self, pnames)
for key in kwargs.keys():
if key not in parent_names:
other_kwargs[key] = kwargs[key]
name = kwargs.get('name', None)
super(V2LayerImpl, self).__init__(name, parent_layers)
self.__other_kwargs__ = other_kwargs
if wrapper is not None:
__init__ = wrapper(__init__)
def to_proto_impl(self, **kwargs):
args = dict()
for each in kwargs:
args[each] = kwargs[each]
for each in self.__other_kwargs__:
args[each] = self.__other_kwargs__[each]
return getattr(conf_helps, method_name)(**args)
return V2LayerImpl
def __convert_to_v2__(f, name, module):
def wrapped(*args, **xargs):
out = f(*args, **xargs)
outs = out
if not isinstance(out, collections.Sequence):
outs = [out]
for l in outs:
if isinstance(l, conf_helps.LayerOutput):
__layer_map__[l.full_name] = l
return out
wrapped.__doc__ = __map_docstr__(f.__doc__, name)
wrapped.__name__ = name
wrapped.__module__ = module
return wrapped
Layer = conf_helps.LayerOutput
......@@ -13,8 +13,8 @@
# limitations under the License.
import paddle.trainer_config_helpers.evaluators as evs
import inspect
from config_base import __convert_to_v2__
import inspect
__all__ = []
......@@ -25,21 +25,10 @@ def initialize():
for __ev_name__ in filter(lambda x: x.endswith('_evaluator'), evs.__all__):
__ev__ = getattr(evs, __ev_name__)
if hasattr(__ev__, 'argspec'):
argspec = __ev__.argspec
else:
argspec = inspect.getargspec(__ev__)
parent_names = filter(lambda x: x in ['input', 'label', 'weight'],
argspec.args)
v2_ev = __convert_to_v2__(
__ev_name__,
parent_names=parent_names,
is_default_name='name' in argspec.args,
attach_parent=True)
__new_name__ = convert_to_new_name(__ev_name__)
globals()[__new_name__] = v2_ev
globals()[__new_name__] = __convert_to_v2__(__ev__, __new_name__,
__name__)
globals()[__new_name__].__name__ = __new_name__
__all__.append(__new_name__)
......
此差异已折叠。
......@@ -24,20 +24,7 @@ def __initialize__():
if each_subnetwork in ['inputs', 'outputs']:
continue
func = getattr(conf_nw, each_subnetwork)
if hasattr(func, 'argspec'):
argspec = func.argspec
else:
argspec = inspect.getargspec(func)
if each_subnetwork == 'simple_attention':
parents = ['encoded_sequence', 'encoded_proj', 'decoder_state']
else:
parents = filter(lambda x: x.startswith('input'), argspec.args)
assert len(parents) != 0, each_subnetwork
v2_subnet = __convert_to_v2__(
each_subnetwork,
parent_names=parents,
is_default_name='name' in argspec.args)
globals()[each_subnetwork] = v2_subnet
globals()[each_subnetwork] = func
globals()[each_subnetwork].__name__ = each_subnetwork
global __all__
__all__.append(each_subnetwork)
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import layer
import activation as act
from config_base import Layer
from paddle.trainer_config_helpers.attrs import is_compatible_with
from paddle.trainer_config_helpers.default_decorators import wrap_name_default
__all__ = []
def __register_unary_math_op__(op_name, act):
def op(input, name=None):
return layer.mixed(
input=[layer.identity_projection(input=input)], name=name, act=act)
op = wrap_name_default(op_name)(op)
op.__doc__ = type(act).__doc__
globals()[op_name] = op
__all__.append(op_name)
__register_unary_math_op__('exp', act.Exp())
__register_unary_math_op__('log', act.Log())
__register_unary_math_op__('abs', act.Abs())
__register_unary_math_op__('sigmoid', act.Sigmoid())
__register_unary_math_op__('tanh', act.Tanh())
__register_unary_math_op__('square', act.Square())
__register_unary_math_op__('relu', act.Relu())
__register_unary_math_op__('sqrt', act.Sqrt())
__register_unary_math_op__('reciprocal', act.Reciprocal())
__register_unary_math_op__('softmax', act.Softmax())
def __add__(layeroutput, other):
if is_compatible_with(other, float):
return layer.slope_intercept(input=layeroutput, intercept=other)
if not isinstance(other, Layer):
raise TypeError("Layer can only be added with"
" another Layer or a number")
if layeroutput.size == other.size:
return layer.mixed(input=[
layer.identity_projection(input=layeroutput),
layer.identity_projection(input=other)
])
if other.size != 1 and layeroutput.size != 1:
raise TypeError("Two Layer can be added only if they have equal size"
" or one of their sizes is 1. sizes are %s and %s" %
(layeroutput.size, other.size))
elif layeroutput.size == 1:
tmp = layeroutput
layeroutput = other
other = tmp
other = layer.repeat(other, layeroutput.size)
return layer.mixed(input=[
layer.identity_projection(input=layeroutput),
layer.identity_projection(input=other)
])
Layer.__radd__ = __add__
Layer.__add__ = __add__
def __neg__(layeroutput):
return layer.slope_intercept(input=layeroutput, slope=-1.0)
Layer.__neg__ = __neg__
def __sub__(layeroutput, other):
if is_compatible_with(other, float):
return layer.slope_intercept(input=layeroutput, intercept=other)
if not isinstance(other, Layer):
raise TypeError("Layer can only be subtracted with"
" another Layeroutput or a number")
return __add__(layeroutput, -other)
Layer.__sub__ = __sub__
def __rsub__(layeroutput, other):
neg = layer.slope_intercept(input=layeroutput, slope=-1.0)
return __add__(neg, other)
Layer.__rsub__ = __rsub__
def __mul__(layeroutput, other):
if is_compatible_with(other, float):
return layer.slope_intercept(input=layeroutput, slope=other)
if not isinstance(other, Layer):
raise TypeError("Layer can only be multiplied with"
" another Layer or a number")
elif layeroutput.size == 1:
return layer.scaling(input=other, weight=layeroutput)
elif other.size == 1:
return layer.scaling(input=layeroutput, weight=other)
else:
raise TypeError("At least one of the operand of '*' must be a number"
" or a Layer with size=1")
Layer.__mul__ = __mul__
Layer.__rmul__ = __mul__
add_python_test(test_v2_api test_data_feeder.py test_parameters.py
add_python_test(test_v2_api test_data_feeder.py test_op.py test_parameters.py
test_layer.py test_rnn_layer.py test_topology.py test_image.py)
......@@ -173,9 +173,9 @@ class OtherLayerTest(unittest.TestCase):
class ProjOpTest(unittest.TestCase):
def test_projection(self):
input = layer.data(name='data', type=data_type.dense_vector(784))
input = layer.data(name='data2', type=data_type.dense_vector(784))
word = layer.data(
name='word', type=data_type.integer_value_sequence(10000))
name='word2', type=data_type.integer_value_sequence(10000))
fc0 = layer.fc(input=input, size=100, act=activation.Sigmoid())
fc1 = layer.fc(input=input, size=200, act=activation.Sigmoid())
mixed0 = layer.mixed(
......@@ -204,8 +204,8 @@ class ProjOpTest(unittest.TestCase):
dotmul1 += dotmul
context = layer.context_projection(input=fc0, context_len=5)
context0 = layer.mixed(size=100, input=context)
with layer.mixed(size=100) as context1:
context0 = layer.mixed(size=500, input=context)
with layer.mixed(size=500) as context1:
context1 += context
conv = layer.conv_projection(
......@@ -231,8 +231,8 @@ class ProjOpTest(unittest.TestCase):
print layer.parse_network(conv1)
def test_operator(self):
ipt0 = layer.data(name='data', type=data_type.dense_vector(784))
ipt1 = layer.data(name='word', type=data_type.dense_vector(128))
ipt0 = layer.data(name='data1', type=data_type.dense_vector(784))
ipt1 = layer.data(name='word1', type=data_type.dense_vector(128))
fc0 = layer.fc(input=ipt0, size=100, act=activation.Sigmoid())
fc1 = layer.fc(input=ipt0, size=100, act=activation.Sigmoid())
......@@ -261,7 +261,7 @@ class ProjOpTest(unittest.TestCase):
class NetworkTests(unittest.TestCase):
def test_vgg(self):
img = layer.data(name='pixel', type=data_type.dense_vector(784))
img = layer.data(name='pixel1', type=data_type.dense_vector(784))
vgg_out = networks.small_vgg(
input_image=img, num_channels=1, num_classes=2)
print layer.parse_network(vgg_out)
......@@ -269,12 +269,12 @@ class NetworkTests(unittest.TestCase):
class EvaluatorTest(unittest.TestCase):
def test_evaluator(self):
img = layer.data(name='pixel', type=data_type.dense_vector(784))
img = layer.data(name='pixel2', type=data_type.dense_vector(784))
output = layer.fc(input=img,
size=10,
act=activation.Softmax(),
name='fc_here')
lbl = layer.data(name='label', type=data_type.integer_value(10))
lbl = layer.data(name='label2', type=data_type.integer_value(10))
cost = layer.cross_entropy_cost(input=output, label=lbl)
evaluator.classification_error(input=output, label=lbl)
......
# Copyright PaddlePaddle contributors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.v2.data_type as data_type
import paddle.v2.layer as layer
import paddle.v2.op as op
class OpTest(unittest.TestCase):
def test_op(self):
x = layer.data(name='data', type=data_type.dense_vector(128))
x = op.exp(x)
x = op.sqrt(x)
x = op.reciprocal(x)
x = op.log(x)
x = op.abs(x)
x = op.sigmoid(x)
x = op.tanh(x)
x = op.square(x)
x = op.relu(x)
y = 1 + x
y = y + 1
y = x + y
y = y - x
y = y - 2
y = 2 - y
y = 2 * y
y = y * 3
z = layer.data(name='data_2', type=data_type.dense_vector(1))
y = y * z
y = z * y
y = y + z
y = z + y
print layer.parse_network(y)
if __name__ == '__main__':
unittest.main()
......@@ -20,6 +20,8 @@ import paddle.v2.data_type as data_type
import paddle.v2.layer as layer
from paddle.trainer_config_helpers.config_parser_utils import \
parse_network_config as parse_network
from paddle.trainer_config_helpers.config_parser_utils import \
reset_parser
class RNNTest(unittest.TestCase):
......@@ -29,6 +31,8 @@ class RNNTest(unittest.TestCase):
hidden_dim = 8
def parse_old_rnn():
reset_parser()
def step(y):
mem = conf_helps.memory(name="rnn_state", size=hidden_dim)
out = conf_helps.fc_layer(
......@@ -48,6 +52,8 @@ class RNNTest(unittest.TestCase):
return str(parse_network(test))
def parse_new_rnn():
reset_parser()
def new_step(y):
mem = layer.memory(name="rnn_state", size=hidden_dim)
out = layer.fc(input=[y, mem],
......@@ -75,6 +81,8 @@ class RNNTest(unittest.TestCase):
label_dim = 3
def parse_old_rnn():
reset_parser()
def test():
data = conf_helps.data_layer(name="word", size=dict_dim)
label = conf_helps.data_layer(name="label", size=label_dim)
......@@ -114,6 +122,7 @@ class RNNTest(unittest.TestCase):
return str(parse_network(test))
def parse_new_rnn():
reset_parser()
data = layer.data(
name="word", type=data_type.dense_vector(dict_dim))
label = layer.data(
......
......@@ -46,8 +46,8 @@ class TestTopology(unittest.TestCase):
self.assertEqual(label_data_type[1].dim, 10)
def test_get_layer(self):
pixel = layer.data(name='pixel', type=data_type.dense_vector(784))
label = layer.data(name='label', type=data_type.integer_value(10))
pixel = layer.data(name='pixel2', type=data_type.dense_vector(784))
label = layer.data(name='label2', type=data_type.integer_value(10))
hidden = layer.fc(input=pixel,
size=100,
act=conf_helps.SigmoidActivation())
......@@ -56,14 +56,14 @@ class TestTopology(unittest.TestCase):
act=conf_helps.SoftmaxActivation())
cost = layer.classification_cost(input=inference, label=label)
topo = topology.Topology(cost)
pixel_layer = topo.get_layer("pixel")
label_layer = topo.get_layer("label")
pixel_layer = topo.get_layer("pixel2")
label_layer = topo.get_layer("label2")
self.assertEqual(pixel_layer, pixel)
self.assertEqual(label_layer, label)
def test_parse(self):
pixel = layer.data(name='pixel', type=data_type.dense_vector(784))
label = layer.data(name='label', type=data_type.integer_value(10))
pixel = layer.data(name='pixel3', type=data_type.dense_vector(784))
label = layer.data(name='label3', type=data_type.integer_value(10))
hidden = layer.fc(input=pixel,
size=100,
act=conf_helps.SigmoidActivation())
......
......@@ -15,36 +15,13 @@
import collections
from paddle.proto.ModelConfig_pb2 import ModelConfig
import paddle.trainer_config_helpers as conf_helps
import layer as v2_layer
import config_base
__all__ = ['Topology']
def __flatten__(lis):
"""
Given a list, possibly nested to any level, return it flattened.
"""
new_lis = []
for item in lis:
if isinstance(item, collections.Sequence):
new_lis.extend(__flatten__(item))
else:
new_lis.append(item)
return new_lis
def __bfs_travel__(callback, *layers):
layers = __flatten__(layers)
for each_layer in layers:
__break__ = callback(each_layer)
if __break__:
return
__layers__ = each_layer.__parent_layers__.values() + \
each_layer.extra_parent()
__bfs_travel__(callback, *__layers__)
class Topology(object):
"""
Topology is used to store the information about all layers
......@@ -94,31 +71,18 @@ class Topology(object):
:param name:
:return:
"""
result_layer = [None]
def __impl__(l):
if l.name == name:
result_layer[0] = l
return True # break
return False
__bfs_travel__(__impl__, *self.layers)
if result_layer[0] is None:
raise ValueError("No such layer %s" % name)
return result_layer[0]
return v2_layer.get_layer(name)
def data_layers(self):
"""
get all data layer
:return:
"""
data_layers = dict()
def __impl__(l):
if isinstance(l, v2_layer.DataLayerV2):
data_layers[l.name] = l
__bfs_travel__(__impl__, *self.layers)
data_layers = {}
for layer in self.proto().layers:
l = v2_layer.get_layer(layer.name)
if l and l.layer_type == conf_helps.LayerType.DATA:
data_layers[layer.name] = l
return data_layers
def data_type(self):
......@@ -127,7 +91,7 @@ class Topology(object):
[('image', dense_vector(768)), ('label', integer_value(10))]
"""
data_layers = self.data_layers()
return [(nm, data_layers[nm].type)
return [(nm, data_layers[nm].data_type)
for nm in self.proto().input_layer_names]
def get_layer_proto(self, name):
......@@ -138,5 +102,5 @@ class Topology(object):
def __check_layer_type__(layer):
if not isinstance(layer, v2_layer.LayerV2):
raise ValueError('layer should have type paddle.layer.Layer')
if not isinstance(layer, config_base.Layer):
raise ValueError('layer should have type paddle.v2.config_base.Layer')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册