提交 0467cd2d 编写于 作者: Y Yu Yang

Merge branch 'develop' into feature/middle_level_net_api

...@@ -75,10 +75,11 @@ snapshot to a model will be a TODO for future. ...@@ -75,10 +75,11 @@ snapshot to a model will be a TODO for future.
### Trainer Election ### Trainer Election
One trainer will be elected as the one to save the model. When using One trainer will be elected as the one to save the model. When using
etcd, trainer ID is a randomly generated UUID, we will utilize etcd to etcd, trainer ID is a randomly generated UUID, the trainer will
elect one trainer. When not using etcd, unique trainer IDs will be contact the master server requesting to save the model, and find out
given by the administrator, the trainer whose ID is "0" is elected to if itself is elected. When the master server is not used, unique
save the model. trainer IDs will be given by the administrator, the trainer whose ID
is "0" is elected to save the model.
### Model Save Path ### Model Save Path
......
...@@ -49,6 +49,7 @@ message AttrProto { ...@@ -49,6 +49,7 @@ message AttrProto {
message VarProto { message VarProto {
required string name = 1; required string name = 1;
required string comment = 2; required string comment = 2;
required bool is_tensor = 3;
}; };
message OpProto { message OpProto {
......
...@@ -311,3 +311,13 @@ Paddle二进制在运行时捕获了浮点数异常,只要出现浮点数异 ...@@ -311,3 +311,13 @@ Paddle二进制在运行时捕获了浮点数异常,只要出现浮点数异
* 训练数据有问题,导致参数收敛到了一些奇异的情况。或者输入数据尺度过大,有些特征的取值达到数百万,这时进行矩阵乘法运算就可能导致浮点数溢出。 * 训练数据有问题,导致参数收敛到了一些奇异的情况。或者输入数据尺度过大,有些特征的取值达到数百万,这时进行矩阵乘法运算就可能导致浮点数溢出。
主要的解决办法是减小学习律或者对数据进行归一化处理。 主要的解决办法是减小学习律或者对数据进行归一化处理。
15. 编译安装后执行 import paddle.v2 as paddle 报ImportError: No module named v2
------------------------------------------------------------------------
先查看一下是否曾经安装过paddle v1版本,有的话需要先卸载:
pip uninstall py_paddle paddle
然后安装paddle的python环境, 在build目录下执行
pip install python/dist/paddle*.whl && pip install ../paddle/dist/py_paddle*.whl
...@@ -59,9 +59,13 @@ func main() { ...@@ -59,9 +59,13 @@ func main() {
cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e) cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e)
if err != nil { if err != nil {
if err == pserver.ErrCheckpointNotFound {
log.Infof("Could not find the pserver checkpoint.")
} else {
log.Errorf("Fetch checkpoint failed, %s", err) log.Errorf("Fetch checkpoint failed, %s", err)
} }
} }
}
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp) s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
candy.Must(err) candy.Must(err)
......
...@@ -22,6 +22,9 @@ package main ...@@ -22,6 +22,9 @@ package main
#define PADDLE_MASTER_OK 0 #define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1 #define PADDLE_MASTER_ERROR -1
#define PADDLE_SAVE_MODEL_OK 1
#define PADDLE_SAVE_MODEL_SKIP 0
typedef int paddle_master_client; typedef int paddle_master_client;
*/ */
import "C" import "C"
...@@ -33,7 +36,6 @@ import ( ...@@ -33,7 +36,6 @@ import (
"unsafe" "unsafe"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
...@@ -65,32 +67,32 @@ func remove(client C.paddle_master_client) *master.Client { ...@@ -65,32 +67,32 @@ func remove(client C.paddle_master_client) *master.Client {
} }
//export paddle_new_etcd_master_client //export paddle_new_etcd_master_client
//
// bufSize is the record buffer size.
func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client { func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client {
p := C.GoString(etcdEndpoints) p := C.GoString(etcdEndpoints)
cli, err := clientv3.New(clientv3.Config{ endpoints := strings.Split(p, ",")
Endpoints: strings.Split(p, ","), c, err := master.NewClient(
DialTimeout: time.Second * time.Duration(timeout), master.WithEtcd(endpoints, time.Duration(timeout)*time.Second),
}) master.WithBuffer(bufSize),
)
if err != nil { if err != nil {
panic(err) panic(err)
} }
ch := make(chan string, 1)
a, err := master.GetKey(cli, master.DefaultAddrPath, timeout)
if err != nil {
panic(err)
}
ch <- a
go master.WatchKey(cli, master.DefaultAddrPath, ch)
c := master.NewClient(ch, bufSize)
return add(c) return add(c)
} }
//export paddle_new_master_client //export paddle_new_master_client
//
// bufSize is the record buffer size.
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client { func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
a := C.GoString(addr) a := C.GoString(addr)
ch := make(chan string, 1) c, err := master.NewClient(master.WithAddr(a), master.WithBuffer(bufSize))
ch <- a if err != nil {
c := master.NewClient(ch, bufSize) panic(err)
}
return add(c) return add(c)
} }
...@@ -117,9 +119,10 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int ...@@ -117,9 +119,10 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
return C.PADDLE_MASTER_OK return C.PADDLE_MASTER_OK
} }
// return value: // paddle_next_record gets the nexts training record.
// 0:ok //
// -1:error // returns number of bytes of the records if success, -1 if failed.
//
//export paddle_next_record //export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client) c := get(client)
...@@ -143,6 +146,29 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { ...@@ -143,6 +146,29 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
return C.int(size) return C.int(size)
} }
// paddle_request_save_model requests the master server to approve the
// caller to save the model.
//
// returns 1 if the save the model request is approved, 0 if the
// request is rejected because other trainer is saving the model, -1
// if error happened.
//
//export paddle_request_save_model
func paddle_request_save_model(client C.paddle_master_client, trainerID string, blockMS int) C.int {
c := get(client)
need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond)
if err != nil {
log.Errorln(err)
return C.PADDLE_MASTER_ERROR
}
if need {
return C.PADDLE_SAVE_MODEL_OK
}
return C.PADDLE_SAVE_MODEL_SKIP
}
//export mem_free //export mem_free
func mem_free(p unsafe.Pointer) { func mem_free(p unsafe.Pointer) {
// "free" may be a better name for this function, but doing so // "free" may be a better name for this function, but doing so
......
...@@ -16,10 +16,12 @@ package master ...@@ -16,10 +16,12 @@ package master
import ( import (
"os" "os"
"sync"
"time" "time"
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
...@@ -27,6 +29,7 @@ import ( ...@@ -27,6 +29,7 @@ import (
type Client struct { type Client struct {
conn *connection.Conn conn *connection.Conn
ch chan record ch chan record
initChOnce sync.Once
} }
type record struct { type record struct {
...@@ -34,24 +37,83 @@ type record struct { ...@@ -34,24 +37,83 @@ type record struct {
err error err error
} }
// NewClient creates a new Client. // WithBuffer sets the client to buffer the training record.
// //
// bufSize is the record buffer size. NextRecord will read from this // bufSize is the record buffer size. NextRecord will read from this
// buffer. // buffer.
func NewClient(addrCh <-chan string, bufSize int) *Client { func WithBuffer(bufSize int) func(*Client) error {
c := &Client{} return func(c *Client) error {
c.conn = connection.New() if bufSize <= 0 {
return nil
}
c.initChOnce.Do(func() {
c.ch = make(chan record, bufSize) c.ch = make(chan record, bufSize)
go c.monitorMaster(addrCh)
go c.getRecords() go c.getRecords()
return c })
return nil
}
}
// WithAddr sets the client to use fixed master address.
func WithAddr(addr string) func(c *Client) error {
return func(c *Client) error {
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
return nil
}
}
// WithEtcd sets the client to use etcd for master discovery.
func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error {
return func(c *Client) error {
cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints,
DialTimeout: timeout,
})
if err != nil {
return err
}
ch := make(chan string, 1)
a, err := GetKey(cli, DefaultAddrPath, timeout)
if err != nil {
return err
}
if a != "" {
// Master is registered, send to the master address
// channel.
ch <- a
}
go watchKey(cli, DefaultAddrPath, ch)
go c.monitorMaster(ch)
return nil
}
}
// NewClient creates a new Client.
func NewClient(opts ...func(*Client) error) (*Client, error) {
c := &Client{}
c.conn = connection.New()
for _, opt := range opts {
err := opt(c)
if err != nil {
return nil, err
}
}
return c, nil
} }
func (c *Client) getRecords() { func (c *Client) getRecords() {
for { for {
t, err := c.getTask() t, err := c.getTask()
if err != nil { if err != nil {
// getTask call.
log.Errorf("Get task failed, sleep 3 seconds and continue, %s", err) log.Errorf("Get task failed, sleep 3 seconds and continue, %s", err)
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
continue continue
...@@ -146,6 +208,20 @@ func (c *Client) taskFailed(meta TaskMeta) error { ...@@ -146,6 +208,20 @@ func (c *Client) taskFailed(meta TaskMeta) error {
// NextRecord will block until the next record is available. It is // NextRecord will block until the next record is available. It is
// thread-safe. // thread-safe.
func (c *Client) NextRecord() ([]byte, error) { func (c *Client) NextRecord() ([]byte, error) {
c.initChOnce.Do(func() {
// initialize with in case WithBuffer is not used.
c.ch = make(chan record, 0)
go c.getRecords()
})
r := <-c.ch r := <-c.ch
return r.r, r.err return r.r, r.err
} }
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (c *Client) RequestSaveModel(trainerID string, blockDur time.Duration) (bool, error) {
var need bool
err := c.conn.Call("Service.RequestSaveModel", SaveModelRequest{TrainerID: trainerID, BlockDur: blockDur}, &need)
return need, err
}
...@@ -87,9 +87,11 @@ func TestNextRecord(t *testing.T) { ...@@ -87,9 +87,11 @@ func TestNextRecord(t *testing.T) {
panic(err) panic(err)
} }
curAddr := make(chan string, 1) c, err := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(10))
curAddr <- fmt.Sprintf(":%d", p) if err != nil {
c := master.NewClient(curAddr, 10) panic(err)
}
err = c.SetDataset([]string{path}) err = c.SetDataset([]string{path})
if err != nil { if err != nil {
panic(err) panic(err)
......
...@@ -158,8 +158,8 @@ func (e *EtcdClient) Load() ([]byte, error) { ...@@ -158,8 +158,8 @@ func (e *EtcdClient) Load() ([]byte, error) {
} }
// GetKey gets the value by the specify key. // GetKey gets the value by the specify key.
func GetKey(c *clientv3.Client, key string, timeout int) (string, error) { func GetKey(c *clientv3.Client, key string, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := c.Get(ctx, key) resp, err := c.Get(ctx, key)
cancel() cancel()
if err != nil { if err != nil {
...@@ -173,8 +173,8 @@ func GetKey(c *clientv3.Client, key string, timeout int) (string, error) { ...@@ -173,8 +173,8 @@ func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
return string(v), nil return string(v), nil
} }
// WatchKey watches the specify key and send to valChan if there is some event. // watchKey watches the specify key and send to valChan if there is some event.
func WatchKey(c *clientv3.Client, key string, valChan chan<- string) { func watchKey(c *clientv3.Client, key string, valChan chan<- string) {
rch := c.Watch(context.Background(), key) rch := c.Watch(context.Background(), key)
for wresp := range rch { for wresp := range rch {
for _, ev := range wresp.Events { for _, ev := range wresp.Events {
......
...@@ -81,6 +81,7 @@ type Service struct { ...@@ -81,6 +81,7 @@ type Service struct {
mu sync.Mutex mu sync.Mutex
initDone bool initDone bool
taskQueues taskQueues taskQueues taskQueues
savingTrainer string
} }
func partition(chunks []Chunk, chunksPerTask int) []taskEntry { func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
...@@ -246,7 +247,7 @@ func readChunks(globPaths []string) ([]Chunk, error) { ...@@ -246,7 +247,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
// //
// SetDataset can be call multiple times. But only the first call will // SetDataset can be call multiple times. But only the first call will
// be honored. // be honored.
func (s *Service) SetDataset(globPaths []string, dummy *int) error { func (s *Service) SetDataset(globPaths []string, _ *int) error {
if len(globPaths) == 0 { if len(globPaths) == 0 {
return errors.New("no dataset specified") return errors.New("no dataset specified")
} }
...@@ -330,7 +331,7 @@ func (s *Service) logFields() log.Fields { ...@@ -330,7 +331,7 @@ func (s *Service) logFields() log.Fields {
} }
// GetTask gets a new task from the service. // GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error { func (s *Service) GetTask(_ int, task *Task) error {
select { select {
case <-s.ready: case <-s.ready:
} }
...@@ -380,7 +381,7 @@ func (s *Service) GetTask(dummy int, task *Task) error { ...@@ -380,7 +381,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
} }
// TaskFinished tell the service that a task is finished. // TaskFinished tell the service that a task is finished.
func (s *Service) TaskFinished(taskID int, dummy *int) error { func (s *Service) TaskFinished(taskID int, _ *int) error {
select { select {
case <-s.ready: case <-s.ready:
} }
...@@ -415,7 +416,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -415,7 +416,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
} }
// TaskFailed tells the service that a task is failed. // TaskFailed tells the service that a task is failed.
func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error { func (s *Service) TaskFailed(meta TaskMeta, _ *int) error {
select { select {
case <-s.ready: case <-s.ready:
} }
...@@ -432,3 +433,42 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error { ...@@ -432,3 +433,42 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
s.processFailedTask(t, meta.Epoch) s.processFailedTask(t, meta.Epoch)
return nil return nil
} }
// SaveModelRequest is the request for saving model
type SaveModelRequest struct {
TrainerID string
BlockDur time.Duration
}
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (s *Service) RequestSaveModel(req SaveModelRequest, need *bool) error {
s.mu.Lock()
defer s.mu.Unlock()
if req.TrainerID == "" {
return errors.New("trainer id is empty")
}
if s.savingTrainer == "" {
*need = true
} else {
if req.TrainerID == s.savingTrainer {
// save trainer asked to save model again
*need = true
} else {
*need = false
}
}
if *need {
s.savingTrainer = req.TrainerID
time.AfterFunc(req.BlockDur, func() {
s.mu.Lock()
s.savingTrainer = ""
s.mu.Unlock()
})
}
return nil
}
...@@ -127,13 +127,19 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) { ...@@ -127,13 +127,19 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) {
remove(client) remove(client)
} }
// paddle_begin_init_params tells trainer if it needs to init the
// parameters.
//
// returns 1 if the trainer needs to init the parameters. 0 if the
// trainer does not need to init the parameters.
//
//export paddle_begin_init_params //export paddle_begin_init_params
func paddle_begin_init_params(client C.paddle_pserver_client) C.int { func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
c := get(client) c := get(client)
if selected := c.BeginInitParams(); selected { if selected := c.BeginInitParams(); selected {
return 1 return 1
} }
return C.PSERVER_OK return 0
} }
//export paddle_init_param //export paddle_init_param
...@@ -256,17 +262,4 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -256,17 +262,4 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
return C.PSERVER_OK return C.PSERVER_OK
} }
//export paddle_save_model
func paddle_save_model(client C.paddle_pserver_client, path *C.char) C.int {
p := C.GoString(path)
c := get(client)
err := c.Save(p)
if err != nil {
log.Errorln(err)
return C.PSERVER_ERROR
}
return C.PSERVER_OK
}
func main() {} // Required but ignored func main() {} // Required but ignored
...@@ -111,9 +111,5 @@ retry: ...@@ -111,9 +111,5 @@ retry:
getParams(c); getParams(c);
} }
if (paddle_save_model(c, "/tmp/")) {
fail();
}
return 0; return 0;
} }
...@@ -219,32 +219,6 @@ func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) { ...@@ -219,32 +219,6 @@ func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) {
return ps, nil return ps, nil
} }
// Save indicates parameters to save the parameter to the given path.
func (c *Client) Save(path string) error {
errCh := make(chan error, len(c.pservers))
for _, p := range c.pservers {
err := p.Call("Service.Save", path, nil)
errCh <- err
}
recv := 0
for err := range errCh {
if err != nil {
return err
}
recv++
if recv == len(c.pservers) {
break
}
}
// TODO(helin): there will be many files under path, need to
// merge them into a single file.
return nil
}
func strHash(s string) uint32 { func strHash(s string) uint32 {
h := fnv.New32a() h := fnv.New32a()
_, _ = h.Write([]byte(s)) _, _ = h.Write([]byte(s))
......
...@@ -36,6 +36,10 @@ import ( ...@@ -36,6 +36,10 @@ import (
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
type ElementType int type ElementType int
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
var ErrCheckpointNotFound = errors.New("checkpoint not found")
// RPC error message. // RPC error message.
const ( const (
AlreadyInitialized = "pserver already initialized" AlreadyInitialized = "pserver already initialized"
...@@ -103,6 +107,10 @@ func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, e ...@@ -103,6 +107,10 @@ func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, e
return nil, err return nil, err
} }
if len(v) == 0 {
return nil, ErrCheckpointNotFound
}
var cpMeta checkpointMeta var cpMeta checkpointMeta
if err = json.Unmarshal(v, &cpMeta); err != nil { if err = json.Unmarshal(v, &cpMeta); err != nil {
return nil, err return nil, err
...@@ -156,7 +164,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient ...@@ -156,7 +164,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
} }
// InitParam initializes a parameter. // InitParam initializes a parameter.
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error { func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
return errors.New(AlreadyInitialized) return errors.New(AlreadyInitialized)
...@@ -177,7 +185,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er ...@@ -177,7 +185,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// FinishInitParams tells the parameter server that the parameter // FinishInitParams tells the parameter server that the parameter
// initialization has finished. // initialization has finished.
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error { func (s *Service) FinishInitParams(_ int, _ *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
return errors.New(AlreadyInitialized) return errors.New(AlreadyInitialized)
...@@ -190,7 +198,7 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error { ...@@ -190,7 +198,7 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
// SendGrad sends gradient to parameter servers for parameter // SendGrad sends gradient to parameter servers for parameter
// optimization. // optimization.
func (s *Service) SendGrad(g Gradient, dummy *int) error { func (s *Service) SendGrad(g Gradient, _ *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
default: default:
......
...@@ -19,8 +19,10 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) ...@@ -19,8 +19,10 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor) cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) cc_library(grad_op_creator SRCS grad_op_creator.cc DEPS op_proto operator)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_creator)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_creator_test SRCS grad_op_creator_test.cc DEPS grad_op_creator op_registry add_op)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module. # Generate an empty __init__.py to make framework_py_proto as a valid python module.
...@@ -28,5 +30,6 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch ...@@ -28,5 +30,6 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch
add_dependencies(framework_py_proto framework_py_proto_init) add_dependencies(framework_py_proto framework_py_proto_init)
proto_library(net_proto SRCS net_proto.proto DEPS op_proto) proto_library(net_proto SRCS net_proto.proto DEPS op_proto)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library(net SRCS net.cc DEPS operator net_proto op_registry) cc_library(net SRCS net.cc DEPS operator net_proto op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net) cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/grad_op_creator.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
OperatorBase* GradOpCreator::Create() {
BuildOpInOutArgList();
OperatorBase* grad_op = OpRegistry::grad_creators().at(op_->type_)();
CompleteGradOp(grad_op);
return grad_op;
}
OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
const VarIndexMap& var_map,
const std::vector<int>& format,
InOutType type) {
int idx = var_map.at(var.name());
int begin_idx = format.empty() ? idx : format.at(idx);
int end_idx = format.empty() ? idx + 1 : format.at(idx + 1);
return new OpInOutArg(var.name(), type, !var.ignore_gradient(), begin_idx,
end_idx);
}
void GradOpCreator::BuildOpInOutArgList() {
const OpProto& op_proto = OpRegistry::protos().at(op_->type_);
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_));
const std::vector<int>& in_format =
op_->attrs_.count("input_format")
? op_->GetAttr<std::vector<int>>("input_format")
: std::vector<int>();
const std::vector<int>& out_format =
op_->attrs_.count("output_format")
? op_->GetAttr<std::vector<int>>("output_format")
: std::vector<int>();
for (const auto& var : op_proto.inputs()) {
arg_list_.emplace_back(
std::shared_ptr<OpInOutArg>(BuildArg(var, var_map, in_format, IN)));
}
for (const auto& var : op_proto.outputs()) {
arg_list_.emplace_back(
std::shared_ptr<OpInOutArg>(BuildArg(var, var_map, out_format, OUT)));
}
}
void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg,
std::vector<std::string>& in_out,
std::vector<int>& format,
VarIndexMap* varmap, int& idx,
bool is_grad) const {
std::string var_name = arg->proto_name_;
if (is_grad) {
var_name += OperatorBase::GRAD_VAR_SUFFIX();
}
(*varmap)[var_name] = idx++;
size_t pre_sz = in_out.size();
auto base_it =
arg->type_ == IN ? op_->inputs_.begin() : op_->outputs_.begin();
std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_,
std::back_inserter(in_out));
if (is_grad) {
for (size_t i = pre_sz; i < in_out.size(); ++i) {
in_out[i] += OperatorBase::GRAD_VAR_SUFFIX();
}
}
format.push_back(in_out.size());
}
void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const {
grad_op->type_ = op_->type_ + "@GRAD"; // not necessary
grad_op->attrs_ = op_->attrs_;
grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format");
VarIndexMap* grad_varmap = new VarIndexMap();
int in_idx = 0;
int out_idx = 0;
std::vector<int> in_format({0});
std::vector<int> out_format({0});
for (const auto& arg : arg_list_) {
// op_'s inputs_ and outputs_
if (arg->needed_in_grad_) {
AddArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap,
in_idx, false);
}
if (arg->type_ == IN) {
// gradients of op_'s inputs_
AddArgIntoGradOp(arg.get(), grad_op->outputs_, out_format, grad_varmap,
out_idx, true);
} else {
// gradients of op_'s outputs_
AddArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap,
in_idx, true);
}
}
grad_op->attrs_["input_format"] = in_format;
grad_op->attrs_["output_format"] = out_format;
grad_op->in_out_idxs_.reset(grad_varmap);
}
} // namespace framework
} // namespace paddle
#pragma once
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
class OpRegistry;
enum InOutType { IN, OUT };
struct OpInOutArg {
OpInOutArg(const std::string& proto_name, const InOutType& type,
bool needed_in_grad, size_t begin_idx, size_t end_idx)
: proto_name_(proto_name),
type_(type),
needed_in_grad_(needed_in_grad),
begin_idx_(begin_idx),
end_idx_(end_idx) {}
std::string proto_name_;
InOutType type_;
bool needed_in_grad_;
size_t begin_idx_;
size_t end_idx_;
};
class GradOpCreator {
using VarIndexMap = std::unordered_map<std::string, int>;
public:
GradOpCreator(const OperatorBase* op) : op_(op) {}
OperatorBase* Create();
private:
OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map,
const std::vector<int>& format, InOutType type);
void BuildOpInOutArgList();
void AddArgIntoGradOp(const OpInOutArg* arg, std::vector<std::string>& in_out,
std::vector<int>& format, VarIndexMap* varmap, int& idx,
bool is_grad) const;
void CompleteGradOp(OperatorBase* grad_op) const;
const OperatorBase* op_;
std::vector<std::shared_ptr<OpInOutArg>> arg_list_;
};
} // namespace framework
} // namespace paddle
#include "paddle/framework/grad_op_creator.h"
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
USE_OP(add_two);
namespace paddle {
namespace framework {
TEST(GradOpCreator, AddTwo) {
std::shared_ptr<OperatorBase> add_op(
OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {}));
std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(add_op);
EXPECT_EQ(static_cast<int>(grad_add_op->inputs_.size()), 4);
EXPECT_EQ(static_cast<int>(grad_add_op->outputs_.size()), 2);
EXPECT_EQ(grad_add_op->Input("X"), "x");
EXPECT_EQ(grad_add_op->Input("Y"), "y");
EXPECT_EQ(grad_add_op->Input("Out"), "out");
EXPECT_EQ(grad_add_op->Input("Out@GRAD"), "out@GRAD");
EXPECT_EQ(grad_add_op->Output("X@GRAD"), "x@GRAD");
EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD");
}
} // namespace framework
} // namespace paddle
\ No newline at end of file
...@@ -15,14 +15,24 @@ ...@@ -15,14 +15,24 @@
*/ */
#include "paddle/framework/net.h" #include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps) {
auto grad_ops = std::make_shared<PlainNet>();
for (auto& op : ForwardOps->ops_) {
auto op_grad = OpRegistry::CreateGradOp(op);
grad_ops->AddOp(op_grad);
}
grad_ops->CompleteAddOp();
return grad_ops;
}
void PlainNet::CompleteAddOp(bool calc) { void PlainNet::CompleteAddOp(bool calc) {
add_op_done_ = true; add_op_done_ = true;
if (!calc) return; if (!calc) return;
std::unordered_set<std::string> input_set; std::unordered_set<std::string> input_set;
std::unordered_set<std::string> output_set; std::unordered_set<std::string> output_set;
std::unordered_set<std::string> temp_output; std::unordered_set<std::string> temp_output;
......
...@@ -100,5 +100,7 @@ class PlainNet : public Net { ...@@ -100,5 +100,7 @@ class PlainNet : public Net {
} }
}; };
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -3,17 +3,24 @@ ...@@ -3,17 +3,24 @@
#include <paddle/framework/op_registry.h> #include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h> #include <paddle/framework/operator.h>
namespace pd = paddle::framework; USE_OP(add_two);
USE_OP(mul);
USE_OP(sigmoid);
USE_OP(softmax);
namespace paddle {
namespace framework {
static int infer_shape_cnt = 0; static int infer_shape_cnt = 0;
static int run_cnt = 0; static int run_cnt = 0;
class TestOp : public pd::OperatorBase { class TestOp : public OperatorBase {
public: public:
void InferShape(const std::shared_ptr<pd::Scope>& scope) const override { void InferShape(
const std::shared_ptr<framework::Scope>& scope) const override {
++infer_shape_cnt; ++infer_shape_cnt;
} }
void Run(const std::shared_ptr<pd::Scope>& scope, void Run(const std::shared_ptr<framework::Scope>& scope,
const paddle::platform::DeviceContext& dev_ctx) const override { const paddle::platform::DeviceContext& dev_ctx) const override {
++run_cnt; ++run_cnt;
} }
...@@ -33,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected, ...@@ -33,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
} }
TEST(OpKernel, all) { TEST(OpKernel, all) {
auto net = std::make_shared<paddle::framework::PlainNet>(); auto net = std::make_shared<PlainNet>();
ASSERT_NE(net, nullptr); ASSERT_NE(net, nullptr);
auto op1 = std::make_shared<TestOp>(); auto op1 = std::make_shared<TestOp>();
...@@ -55,13 +62,37 @@ TEST(OpKernel, all) { ...@@ -55,13 +62,37 @@ TEST(OpKernel, all) {
ASSERT_EQ(1UL, tmp_idx.size()); ASSERT_EQ(1UL, tmp_idx.size());
ASSERT_EQ("y", net->outputs_[tmp_idx[0]]); ASSERT_EQ("y", net->outputs_[tmp_idx[0]]);
auto scope = std::make_shared<pd::Scope>(); auto scope = std::make_shared<Scope>();
paddle::platform::CPUDeviceContext dev_ctx; platform::CPUDeviceContext dev_ctx;
net->InferShape(scope); net->InferShape(scope);
net->Run(scope, dev_ctx); net->Run(scope, dev_ctx);
ASSERT_EQ(2, infer_shape_cnt); ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt); ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), std::runtime_error); ASSERT_THROW(net->AddOp(op2), std::runtime_error);
} }
TEST(AddBackwardOp, TestGradOp) {
auto net = std::make_shared<PlainNet>();
ASSERT_NE(net, nullptr);
net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
net->AddOp(
framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, {}));
auto grad_ops = AddBackwardOp(net);
for (auto& op : grad_ops->ops_) {
op->DebugString();
}
}
// TODO(zhihong): add fc grad without registering.
// TEST(AddBackwardOp, TestNoGradOp) {
// auto net = std::make_shared<PlainNet>();
// ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"},
// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) {
// op->DebugString();
// }
// }
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
class FakeFC : public Operator {}
} // namespace framework
} // namespace paddle
...@@ -84,6 +84,11 @@ message VarProto { ...@@ -84,6 +84,11 @@ message VarProto {
// "temporary_index": [1] // "temporary_index": [1]
// } // }
optional bool temporary = 4 [default=false]; optional bool temporary = 4 [default=false];
// The gradient of operator can be ignored immediately
// e.g. operator AddOp, y = x1 + x2, the gradient of dy/dx1, dy/dx2
// can be ignored for the future optimized on graph.
optional bool ignore_gradient = 6;
} }
// Op protocol message for 3rd-party language binding. // Op protocol message for 3rd-party language binding.
...@@ -105,4 +110,5 @@ message OpProto { ...@@ -105,4 +110,5 @@ message OpProto {
// The type of that Op. // The type of that Op.
required string type = 5; required string type = 5;
} }
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
...@@ -6,9 +20,9 @@ ...@@ -6,9 +20,9 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/framework/attr_checker.h" #include "paddle/framework/attr_checker.h"
#include "paddle/framework/grad_op_creator.h"
#include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/scope.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -73,25 +87,29 @@ class OpProtoAndCheckerMaker { ...@@ -73,25 +87,29 @@ class OpProtoAndCheckerMaker {
protected: protected:
void AddInput(const std::string& name, const std::string& comment, void AddInput(const std::string& name, const std::string& comment,
bool multiple = false) { bool multiple = false, bool ignore_gradient = false) {
auto input = proto_->mutable_inputs()->Add(); auto input = proto_->mutable_inputs()->Add();
*input->mutable_name() = name; *input->mutable_name() = name;
*input->mutable_comment() = comment; *input->mutable_comment() = comment;
input->set_ignore_gradient(ignore_gradient);
input->set_multiple(multiple); input->set_multiple(multiple);
if (multiple) { if (multiple) {
SetHasMultipleInput(); SetHasMultipleInput();
} }
} }
void AddInputs(const std::string& name, const std::string& comment) { void AddInputs(const std::string& name, const std::string& comment,
AddInput(name, comment, true); bool ignore_gradient = false) {
AddInput(name, comment, true, ignore_gradient);
} }
void AddOutput(const std::string& name, const std::string& comment, void AddOutput(const std::string& name, const std::string& comment,
bool temporary = false, bool multiple = false) { bool temporary = false, bool multiple = false,
bool ignore_gradient = false) {
auto output = proto_->mutable_outputs()->Add(); auto output = proto_->mutable_outputs()->Add();
*output->mutable_name() = name; *output->mutable_name() = name;
*output->mutable_comment() = comment; *output->mutable_comment() = comment;
output->set_ignore_gradient(ignore_gradient);
output->set_multiple(multiple); output->set_multiple(multiple);
if (multiple) { if (multiple) {
SetHasMultipleOutput(); SetHasMultipleOutput();
...@@ -103,8 +121,8 @@ class OpProtoAndCheckerMaker { ...@@ -103,8 +121,8 @@ class OpProtoAndCheckerMaker {
} }
void AddOutputs(const std::string& name, const std::string& comment, void AddOutputs(const std::string& name, const std::string& comment,
bool temporary = false) { bool temporary = false, bool ignore_gradient = false) {
AddOutput(name, comment, temporary, true); AddOutput(name, comment, temporary, true, ignore_gradient);
} }
template <typename T> template <typename T>
...@@ -205,8 +223,8 @@ class OpRegistry { ...@@ -205,8 +223,8 @@ class OpRegistry {
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) { static void RegisterOp(const std::string& op_type) {
creators()[op_type] = [] { return new OpType; }; creators()[op_type] = [] { return new OpType; };
OpProto& op_proto = protos()[op_type];
OpAttrChecker& op_checker = op_checkers()[op_type]; OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = protos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker); auto maker = ProtoMakerType(&op_proto, &op_checker);
maker.Validate(); maker.Validate();
*op_proto.mutable_type() = op_type; *op_proto.mutable_type() = op_type;
...@@ -227,18 +245,24 @@ class OpRegistry { ...@@ -227,18 +245,24 @@ class OpRegistry {
} }
} }
template <typename OpType>
static void RegisterGradOp(const std::string& op_type) {
grad_creators()[op_type] = [] { return new OpType; };
}
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type, static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameList& inputs, const VarNameList& inputs,
const VarNameList& outputs, const VarNameList& outputs,
const AttributeMap& attrs) { const AttributeMap& attrs) {
auto op_create_it = creators().find(type); auto op_create_it = creators().find(type);
PADDLE_ENFORCE(op_create_it != creators().end(), PADDLE_ENFORCE(op_create_it != creators().end(),
"Operator %s cannot be found", type); "Operator %s cannot be found.", type);
auto op = op_create_it->second(); auto op = op_create_it->second();
op->type_ = type; op->type_ = type;
op->inputs_ = inputs; op->inputs_ = inputs;
op->outputs_ = outputs; op->outputs_ = outputs;
op->attrs_ = attrs; op->attrs_ = attrs;
op_checkers().at(type).Check(op->attrs_); op_checkers().at(type).Check(op->attrs_);
...@@ -274,18 +298,41 @@ class OpRegistry { ...@@ -274,18 +298,41 @@ class OpRegistry {
return CreateOp(op_desc.type(), inputs, outputs, attrs); return CreateOp(op_desc.type(), inputs, outputs, attrs);
} }
static std::shared_ptr<OperatorBase> CreateGradOp(
std::shared_ptr<OperatorBase> op) {
GradOpCreator creator(op.get());
std::shared_ptr<OperatorBase> grad_op(creator.Create());
grad_op->Init();
return grad_op;
}
static std::unordered_map<std::string, OpProto>& protos() { static std::unordered_map<std::string, OpProto>& protos() {
static std::unordered_map<std::string, OpProto> protos_; static std::unordered_map<std::string, OpProto> protos_;
return protos_; return protos_;
}; };
private: static std::unordered_map<std::string, OpCreator>& grad_creators() {
static std::unordered_map<std::string, OpCreator> grad_creators_;
return grad_creators_;
}
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>& static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
VarIndexMaps() { VarIndexMaps() {
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>> maps_; static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>> maps_;
return maps_; return maps_;
} }
private:
static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
}
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
};
static void GenerateTempVariableName(OperatorBase* op) { static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL); static std::atomic<size_t> gUniqId(0UL);
for (auto& outname : op->outputs_) { for (auto& outname : op->outputs_) {
...@@ -296,16 +343,6 @@ class OpRegistry { ...@@ -296,16 +343,6 @@ class OpRegistry {
} }
} }
} }
static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
}
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
};
}; };
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
...@@ -316,6 +353,14 @@ class OpRegisterHelper { ...@@ -316,6 +353,14 @@ class OpRegisterHelper {
} }
}; };
template <typename OpType>
class GradOpRegisterHelper {
public:
GradOpRegisterHelper(const char* op_type) {
OpRegistry::RegisterGradOp<OpType>(op_type);
}
};
/** /**
* check if MACRO is used in GLOBAL NAMESPACE. * check if MACRO is used in GLOBAL NAMESPACE.
*/ */
...@@ -335,6 +380,17 @@ class OpRegisterHelper { ...@@ -335,6 +380,17 @@ class OpRegisterHelper {
__op_register_##__op_type##__(#__op_type); \ __op_register_##__op_type##__(#__op_type); \
int __op_register_##__op_type##_handle__() { return 0; } int __op_register_##__op_type##_handle__() { return 0; }
/**
* Macro to Register Gradient Operator.
*/
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__op_class> \
__op_gradient_register_##__op_type##__(#__op_type); \
int __op_gradient_register_##__op_type##_handle__() { return 0; }
/** /**
* Macro to Register OperatorKernel. * Macro to Register OperatorKernel.
*/ */
......
...@@ -62,6 +62,11 @@ class OperatorBase { ...@@ -62,6 +62,11 @@ class OperatorBase {
/// but it will be convert to a unique name in scope after OpCreator. /// but it will be convert to a unique name in scope after OpCreator.
static std::string TMP_VAR_NAME() { return "@TEMP@"; } static std::string TMP_VAR_NAME() { return "@TEMP@"; }
/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; }
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
template <typename T> template <typename T>
......
...@@ -49,9 +49,22 @@ The equation is: Out = X + Y ...@@ -49,9 +49,22 @@ The equation is: Out = X + Y
)DOC"); )DOC");
} }
}; };
class AddOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {}
std::string DebugString() const override {
LOG(INFO) << "AddOpGrad";
return "";
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, paddle::operators::AddOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
add_two, paddle::operators::AddKernel<paddle::platform::CPUPlace, float>); add_two, paddle::operators::AddKernel<paddle::platform::CPUPlace, float>);
...@@ -16,8 +16,13 @@ limitations under the License. */ ...@@ -16,8 +16,13 @@ limitations under the License. */
#define private public #define private public
#include <paddle/framework/op_registry.h> #include <paddle/framework/op_registry.h>
USE_OP(add_two); USE_OP(add_two);
// USE_OP(add_two_grad);
TEST(AddOp, GetOpProto) { TEST(AddOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos(); auto& protos = paddle::framework::OpRegistry::protos();
auto it = protos.find("add_two"); auto it = protos.find("add_two");
ASSERT_NE(it, protos.end()); ASSERT_NE(it, protos.end());
auto& grad_creators = paddle::framework::OpRegistry::grad_creators();
auto it1 = grad_creators.find("add_two");
ASSERT_NE(it1, grad_creators.end());
} }
...@@ -52,9 +52,22 @@ The equation is: Out = X * Y ...@@ -52,9 +52,22 @@ The equation is: Out = X * Y
} }
}; };
class MulOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {}
std::string DebugString() const override {
LOG(INFO) << "MulGrad";
return "";
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker); REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker);
REGISTER_GRADIENT_OP(mul, paddle::operators::MulOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mul, paddle::operators::MulKernel<paddle::platform::CPUPlace, float>); mul, paddle::operators::MulKernel<paddle::platform::CPUPlace, float>);
...@@ -39,12 +39,25 @@ public: ...@@ -39,12 +39,25 @@ public:
} }
}; };
class SigmoidOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {}
std::string DebugString() const override {
LOG(INFO) << "SigmoidGrad";
return "";
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(sigmoid, REGISTER_OP(sigmoid,
paddle::operators::SigmoidOp, paddle::operators::SigmoidOp,
paddle::operators::SigmoidOpMaker); paddle::operators::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, paddle::operators::SigmoidOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sigmoid, sigmoid,
paddle::operators::SigmoidKernel<paddle::platform::CPUPlace, float>); paddle::operators::SigmoidKernel<paddle::platform::CPUPlace, float>);
...@@ -42,11 +42,23 @@ public: ...@@ -42,11 +42,23 @@ public:
} }
}; };
class SoftmaxOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {}
std::string DebugString() const override {
LOG(INFO) << "SoftmaxOpGrad";
return "";
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker); REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker);
REGISTER_GRADIENT_OP(softmax, paddle::operators::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax, REGISTER_OP_CPU_KERNEL(softmax,
ops::SoftmaxKernel<paddle::platform::CPUPlace, float>); ops::SoftmaxKernel<paddle::platform::CPUPlace, float>);
...@@ -33,6 +33,7 @@ import networks ...@@ -33,6 +33,7 @@ import networks
import minibatch import minibatch
import plot import plot
import image import image
import model
__all__ = [ __all__ = [
'optimizer', 'optimizer',
...@@ -54,6 +55,7 @@ __all__ = [ ...@@ -54,6 +55,7 @@ __all__ = [
'evaluator', 'evaluator',
'image', 'image',
'master', 'master',
'model',
] ]
......
...@@ -10,11 +10,31 @@ class client(object): ...@@ -10,11 +10,31 @@ class client(object):
client is a client to the master server. client is a client to the master server.
""" """
def __init__(self, etcd_endpoints, timeout, buf_size): def __init__(self, etcd_endpoints, timeout_sec, buf_size=0):
self.c = lib.paddle_new_etcd_master_client(etcd_endpoints, timeout, self.c = lib.paddle_new_etcd_master_client(etcd_endpoints, timeout_sec,
buf_size) buf_size)
def close(self): def request_save_model(self, trainer_id, block_ms):
"""request to save model
Conventionally the 0-th trainer will save model. But in
distributed training, any trainer could be killed. This
function asks the master server if the trainer should proceed
with saving model.
:param trainer_id: trainer id.
:param block_ms: number of millisecond that other save model
will be blocked if this save model request succeeded.
Returns:
int: 1 if the save the model request is approved, 0 if
does the request is rejected because other trainer is
saving the model, -1 if error happened.
"""
return lib.paddle_request_save_model(self.c, trainer_id, block_ms)
def release(self):
lib.paddle_release_master_client(self.c) lib.paddle_release_master_client(self.c)
self.c = None self.c = None
...@@ -27,10 +47,13 @@ class client(object): ...@@ -27,10 +47,13 @@ class client(object):
holder[idx] = c_ptr holder[idx] = c_ptr
lib.paddle_set_dataset(self.c, holder, len(paths)) lib.paddle_set_dataset(self.c, holder, len(paths))
# return format: (record, errno)
# errno = 0: ok
# < 0: error
def next_record(self): def next_record(self):
"""gets next record for training
Returns:
string: the record.
int: error code, 0 if successful, < 0 otherwise.
"""
p = ctypes.c_char_p() p = ctypes.c_char_p()
ret = ctypes.pointer(p) ret = ctypes.pointer(p)
size = lib.paddle_next_record(self.c, ret) size = lib.paddle_next_record(self.c, ret)
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import errno
import uuid
import paddle.v2.master
__all__ = ["save_model", "load_model"]
trainer_id = str(uuid.uuid4())
def mkdir_p(path):
try:
os.makedirs(path)
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
def save_model(parameters, path):
need_request = "KUBERNETES_SERVICE_HOST" in os.environ.keys()
if need_request:
# TODO(helin): figure out how MPI trains, since MPI only save
# model when trainer_id == "0", we can consolidate the logic
# here.
# TODO(helin): change this environment variable name from
# MASTER_IP to ETCD_IP
etcd_name = "MASTER_IP"
if etcd_name not in os.environ.keys():
raise Exception('not find ' + etcd_name +
' in environment variable.')
etcd_ip = os.environ.get(etcd_name)
client = master.client("http://" + etcd_ip + ":2379", 5, 0)
r = client.request_save_model(trainer_id, 5000)
if r == 0:
# do not need to save
return
elif r < 0:
# error
return
else:
# save model
path = os.path.join(path, trainer_id)
path = os.path.join(path, "model.tar")
mkdir_p(path)
with open(path, 'wb') as f:
parameters.to_tar(f)
def load_model(parameters, path):
with open(path, 'rb') as f:
parameters.from_tar(f)
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Creator package contains some simple reader creator, which could be used in user Creator package contains some simple reader creator, which could
program. be used in user program.
""" """
__all__ = ['np_array', 'text_file', "recordio"] __all__ = ['np_array', 'text_file', "recordio"]
...@@ -96,7 +96,7 @@ def recordio(paths, buf_size=100): ...@@ -96,7 +96,7 @@ def recordio(paths, buf_size=100):
host_name = "MASTER_SERVICE_HOST" host_name = "MASTER_SERVICE_HOST"
if host_name not in os.environ.keys(): if host_name not in os.environ.keys():
raise Exception('not find ' + host_name + ' in environ.') raise Exception('not find ' + host_name + ' in environment variable.')
addr = os.environ(host) addr = os.environ(host)
...@@ -110,6 +110,6 @@ def recordio(paths, buf_size=100): ...@@ -110,6 +110,6 @@ def recordio(paths, buf_size=100):
break break
yield r yield r
c.close() c.release()
return reader return reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册