提交 c7e8c1aa 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #6 from reyoung/feature/refactorize_framework_proto

Feature/refactorize framework proto
...@@ -28,7 +28,7 @@ RUN apt-get update && \ ...@@ -28,7 +28,7 @@ RUN apt-get update && \
wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \ wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \
curl sed grep graphviz libjpeg-dev zlib1g-dev \ curl sed grep graphviz libjpeg-dev zlib1g-dev \
python-matplotlib gcc-4.8 g++-4.8 \ python-matplotlib gcc-4.8 g++-4.8 \
automake locales clang-format-3.8 swig doxygen cmake \ automake locales clang-format swig doxygen cmake \
liblapack-dev liblapacke-dev libboost-dev \ liblapack-dev liblapacke-dev libboost-dev \
clang-3.8 llvm-3.8 libclang-3.8-dev \ clang-3.8 llvm-3.8 libclang-3.8-dev \
net-tools && \ net-tools && \
......
hash: 1b9b07408ca7fac27a374dc2ccd2433e4bff090484008a037df967284949a582 hash: 1b9b07408ca7fac27a374dc2ccd2433e4bff090484008a037df967284949a582
updated: 2017-08-03T21:46:51.744995189Z updated: 2017-08-07T23:37:48.867469328Z
imports: imports:
- name: github.com/beorn7/perks - name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9 version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
...@@ -10,7 +10,7 @@ imports: ...@@ -10,7 +10,7 @@ imports:
- name: github.com/cockroachdb/cmux - name: github.com/cockroachdb/cmux
version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92 version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92
- name: github.com/coreos/etcd - name: github.com/coreos/etcd
version: c31bec0f29facff13f7c3e3d948e55dd6689ed42 version: d0d1a87aa96ae14914751d42264262cb69eda170
subpackages: subpackages:
- alarm - alarm
- auth - auth
...@@ -24,6 +24,7 @@ imports: ...@@ -24,6 +24,7 @@ imports:
- error - error
- etcdserver - etcdserver
- etcdserver/api - etcdserver/api
- etcdserver/api/etcdhttp
- etcdserver/api/v2http - etcdserver/api/v2http
- etcdserver/api/v2http/httptypes - etcdserver/api/v2http/httptypes
- etcdserver/api/v3client - etcdserver/api/v3client
...@@ -210,11 +211,6 @@ testImports: ...@@ -210,11 +211,6 @@ testImports:
version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9 version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9
subpackages: subpackages:
- spew - spew
- name: github.com/docker/docker
version: b6d164e6c46d8115b146e4c3ac93784e9ef8b49e
subpackages:
- pkg/ioutils
- pkg/longpath
- name: github.com/pmezard/go-difflib - name: github.com/pmezard/go-difflib
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
subpackages: subpackages:
......
package master_test package master_test
import ( import (
"io/ioutil"
"net/url"
"os" "os"
"strings"
"testing" "testing"
"time" "time"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/embed" "github.com/coreos/etcd/embed"
"github.com/docker/docker/pkg/ioutils"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestNewServiceWithEtcd(t *testing.T) { func TestNewServiceWithEtcd(t *testing.T) {
// setup an embed etcd server // setup an embed etcd server
etcdDir, err := ioutils.TempDir("", "") etcdDir, err := ioutil.TempDir("", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
cfg := embed.NewConfig() cfg := embed.NewConfig()
lpurl, _ := url.Parse("http://localhost:0")
lcurl, _ := url.Parse("http://localhost:0")
cfg.LPUrls = []url.URL{*lpurl}
cfg.LCUrls = []url.URL{*lcurl}
cfg.Dir = etcdDir cfg.Dir = etcdDir
e, err := embed.StartEtcd(cfg) e, err := embed.StartEtcd(cfg)
if err != nil { if err != nil {
...@@ -30,15 +36,13 @@ func TestNewServiceWithEtcd(t *testing.T) { ...@@ -30,15 +36,13 @@ func TestNewServiceWithEtcd(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
}() }()
select {
case <-e.Server.ReadyNotify():
t.Log("Server is ready!")
case <-time.After(60 * time.Second):
e.Server.Stop() // trigger a shutdown
t.Fatal("Server took too long to start!")
}
ep := []string{"127.0.0.1:2379"} <-e.Server.ReadyNotify()
port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
endpoint := "127.0.0.1:" + port
ep := []string{endpoint}
masterAddr := "127.0.0.1:3306" masterAddr := "127.0.0.1:3306"
store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30) store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30)
if err != nil { if err != nil {
......
...@@ -90,8 +90,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte { ...@@ -90,8 +90,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
type selector bool type selector bool
func (s selector) Select() bool { func (s selector) Select() (bool, error) {
return bool(s) return bool(s), nil
}
func (s selector) Done() error {
return nil
} }
type lister []client.Server type lister []client.Server
...@@ -114,11 +118,10 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_cli ...@@ -114,11 +118,10 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_cli
} }
//export paddle_new_etcd_pserver_client //export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcdEndpoints *C.char, selected int) C.paddle_pserver_client { func paddle_new_etcd_pserver_client(etcdEndpoints *C.char) C.paddle_pserver_client {
// TODO(Longfei: use etcd lock to decide which trainer to initialize the parameters)
addr := C.GoString(etcdEndpoints) addr := C.GoString(etcdEndpoints)
etcdClient := client.NewEtcd(addr) etcdClient := client.NewEtcd(addr)
c := client.NewClient(etcdClient, etcdClient.Desired(), selector(selected != 0)) c := client.NewClient(etcdClient, etcdClient.Desired(), etcdClient)
return add(c) return add(c)
} }
...@@ -136,7 +139,12 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) { ...@@ -136,7 +139,12 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) {
//export paddle_begin_init_params //export paddle_begin_init_params
func paddle_begin_init_params(client C.paddle_pserver_client) C.int { func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
c := get(client) c := get(client)
if selected := c.BeginInitParams(); selected { selected, err := c.BeginInitParams()
if err != nil {
panic(err)
}
if selected {
return 1 return 1
} }
return 0 return 0
......
...@@ -27,9 +27,13 @@ import ( ...@@ -27,9 +27,13 @@ import (
// TODO(helin): add RPC call retry logic // TODO(helin): add RPC call retry logic
// Selector selects if the client should initialize parameter servers. // Selector selects if the client should initialize parameters and
// reports the initialization process done.
type Selector interface { type Selector interface {
Select() bool // Select selects if the client should initialize parameter servers.
Select() (bool, error)
// Done indicates the initialization process is done.
Done() error
} }
// Server is the identification of a parameter Server. // Server is the identification of a parameter Server.
...@@ -115,7 +119,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { ...@@ -115,7 +119,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
// servers. Other trainers will be blocked until the initialization is // servers. Other trainers will be blocked until the initialization is
// done, and they need to get the initialized parameters from // done, and they need to get the initialized parameters from
// parameter servers using GetParams. // parameter servers using GetParams.
func (c *Client) BeginInitParams() bool { func (c *Client) BeginInitParams() (bool, error) {
return c.sel.Select() return c.sel.Select()
} }
......
...@@ -124,8 +124,12 @@ func initEtcdClient() { ...@@ -124,8 +124,12 @@ func initEtcdClient() {
type selector bool type selector bool
func (s selector) Select() bool { func (s selector) Select() (bool, error) {
return bool(s) return bool(s), nil
}
func (s selector) Done() error {
return nil
} }
type lister []client.Server type lister []client.Server
...@@ -135,7 +139,11 @@ func (l lister) List() []client.Server { ...@@ -135,7 +139,11 @@ func (l lister) List() []client.Server {
} }
func testClient(t *testing.T, c *client.Client) { func testClient(t *testing.T, c *client.Client) {
selected := c.BeginInitParams() selected, err := c.BeginInitParams()
if err != nil {
t.Fatal(err)
}
if !selected { if !selected {
t.Fatal("should be selected.") t.Fatal("should be selected.")
} }
......
...@@ -16,53 +16,60 @@ package client ...@@ -16,53 +16,60 @@ package client
import ( import (
"context" "context"
"errors"
"fmt"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const ( const (
defaultEtcdTimeout time.Duration = 5 * time.Second defaultEtcdTimeout time.Duration = 5 * time.Second
initLockPath = "/init_ps/lock"
initDonePath = "/init_ps/done"
initDoneVal = "1"
) )
// EtcdClient is used by pserver client that is a part of trainer process. // Etcd is used by pserver client that is a part of trainer process.
// TODO: // TODO:
// 1. add watcher to watch the change state of pservers) // 1. add watcher to watch the change state of pservers.
// 1. add etcd lock) type Etcd struct {
type EtcdClient struct {
client *clientv3.Client client *clientv3.Client
timeout time.Duration timeout time.Duration
endpoints []string endpoints []string
lock *concurrency.Mutex
} }
// Desired read ps desired number from etcd. // Desired read ps desired number from etcd.
func (p *EtcdClient) Desired() int { func (e *Etcd) Desired() int {
var psDesired int var psDesired int
for { for {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout) ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
resp, err := p.client.Get(ctx, pserver.PsDesired) resp, err := e.client.Get(ctx, pserver.PsDesired)
cancel() cancel()
if err != nil { if err != nil {
log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err) log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err)
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
kvs := resp.Kvs kvs := resp.Kvs
if len(kvs) == 0 { if len(kvs) == 0 {
log.Infoln("Waiting for ps desired registered ...") log.Infoln("Waiting for ps desired registered ...")
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value)) psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil { if err != nil {
log.Errorf("psDesired %d invalid %v", psDesired, err) log.Errorf("psDesired %d invalid %v", psDesired, err)
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
...@@ -73,26 +80,26 @@ func (p *EtcdClient) Desired() int { ...@@ -73,26 +80,26 @@ func (p *EtcdClient) Desired() int {
} }
// List return the pserver list read from etcd. // List return the pserver list read from etcd.
func (p *EtcdClient) List() []Server { func (e *Etcd) List() []Server {
psDesired := p.Desired() psDesired := e.Desired()
servers := make([]Server, psDesired) servers := make([]Server, psDesired)
for { for {
for i := 0; i < psDesired; i++ { for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout) ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
psKey := pserver.PsPath + strconv.Itoa(i) psKey := pserver.PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey) log.Debugf("checking %s", psKey)
resp, err := p.client.Get(ctx, psKey) resp, err := e.client.Get(ctx, psKey)
cancel() cancel()
if err != nil { if err != nil {
log.Infof("Get psKey= %s error, %v", psKey, err) log.Infof("Get psKey= %s error, %v", psKey, err)
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
kvs := resp.Kvs kvs := resp.Kvs
if len(kvs) == 0 { if len(kvs) == 0 {
log.Infof("Waiting for ps addr registered ...") log.Infof("Waiting for ps addr registered ...")
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
...@@ -100,7 +107,7 @@ func (p *EtcdClient) List() []Server { ...@@ -100,7 +107,7 @@ func (p *EtcdClient) List() []Server {
// TODO(Longfei) check the ps address // TODO(Longfei) check the ps address
if psAddr == "" { if psAddr == "" {
log.Infof("Get psKey = %s, psAddr is empty", psKey) log.Infof("Get psKey = %s, psAddr is empty", psKey)
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
log.Debugf("got value (%s) for key: %s", psAddr, psKey) log.Debugf("got value (%s) for key: %s", psAddr, psKey)
...@@ -113,7 +120,7 @@ func (p *EtcdClient) List() []Server { ...@@ -113,7 +120,7 @@ func (p *EtcdClient) List() []Server {
} }
// NewEtcd create a etcd client to return the state of pserver on etcd. // NewEtcd create a etcd client to return the state of pserver on etcd.
func NewEtcd(endpoints string) *EtcdClient { func NewEtcd(endpoints string) *Etcd {
ep := strings.Split(endpoints, ",") ep := strings.Split(endpoints, ",")
var cli *clientv3.Client var cli *clientv3.Client
var err error var err error
...@@ -130,10 +137,118 @@ func NewEtcd(endpoints string) *EtcdClient { ...@@ -130,10 +137,118 @@ func NewEtcd(endpoints string) *EtcdClient {
break break
} }
log.Infof("Connected to etcd: %s\n", endpoints) log.Infof("Connected to etcd: %s\n", endpoints)
client := &EtcdClient{ client := &Etcd{
client: cli, client: cli,
timeout: defaultEtcdTimeout, timeout: defaultEtcdTimeout,
endpoints: ep, endpoints: ep,
} }
return client return client
} }
// Select indicates if the current trainer is selected to initialize
// the pserver parameters.
func (e *Etcd) Select() (bool, error) {
sess, err := concurrency.NewSession(e.client, concurrency.WithTTL(5))
if err != nil {
return false, err
}
lock := concurrency.NewMutex(sess, initLockPath)
log.Infof("Trying to acquire lock at %s.", initLockPath)
// Do not use timeout context here, since we don't know how
// long does it take for other trainers to initialize the
// parameters.
err = lock.Lock(context.Background())
if err != nil {
return false, err
}
log.Infof("Successfully acquired lock at %s.", initLockPath)
get := clientv3.OpGet(initDonePath)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
tresp, err := e.client.Txn(ctx).If(lock.IsOwner()).Then(get).Commit()
cancel()
if err != nil {
return false, err
}
if !tresp.Succeeded {
return false, errors.New("no longer the owner of the lock")
}
resp := tresp.Responses[0].GetResponseRange()
if len(resp.Kvs) == 0 {
// Key value not set, select current trainer.
e.lock = lock
log.Infoln("Trainer selected.")
return true, nil
}
if string(resp.Kvs[0].Value) == initDoneVal {
log.Infoln("Initialization is already done.")
ctx, cancel = context.WithTimeout(context.Background(), e.timeout)
err = lock.Unlock(ctx)
cancel()
if err != nil {
log.Errorln(err)
}
return false, nil
}
return false, fmt.Errorf("key %s have unexpected value: %v", initDonePath, resp.Kvs[0].Value)
}
// Done indicates the parameter initialization process is done.
func (e *Etcd) Done() error {
if e.lock == nil {
return errors.New("lock is nil, Done called unexpectedly")
}
put := clientv3.OpPut(initDonePath, initDoneVal)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
tresp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit()
cancel()
if err != nil {
return err
}
if !tresp.Succeeded {
return errors.New("no longer the owner of the lock")
}
ctx, cancel = context.WithTimeout(context.Background(), e.timeout)
err = e.lock.Unlock(ctx)
cancel()
if err != nil {
log.Errorln(err)
} else {
e.lock = nil
}
return nil
}
// Close closes the etcd client.
func (e *Etcd) Close() error {
var err error
if e.lock != nil {
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
err = e.lock.Unlock(ctx)
cancel()
if err == nil {
e.lock = nil
}
}
cErr := e.client.Close()
if cErr != nil {
if err != nil {
log.Errorln(cErr)
return err
}
return cErr
}
return err
}
package client_test
import (
"io/ioutil"
"net/url"
"os"
"strings"
"sync"
"testing"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
"github.com/coreos/etcd/embed"
)
func TestSelector(t *testing.T) {
etcdDir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
cfg := embed.NewConfig()
lpurl, _ := url.Parse("http://localhost:0")
lcurl, _ := url.Parse("http://localhost:0")
cfg.LPUrls = []url.URL{*lpurl}
cfg.LCUrls = []url.URL{*lcurl}
cfg.Dir = etcdDir
e, err := embed.StartEtcd(cfg)
if err != nil {
t.Fatal(err)
}
defer func() {
e.Close()
if err := os.RemoveAll(etcdDir); err != nil {
t.Fatal(err)
}
}()
<-e.Server.ReadyNotify()
port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
endpoint := "127.0.0.1:" + port
var mu sync.Mutex
selectedCount := 0
var wg sync.WaitGroup
selectAndDone := func(c *client.Etcd) {
defer wg.Done()
selected, err := c.Select()
if err != nil {
panic(err)
}
if selected {
mu.Lock()
selectedCount++
mu.Unlock()
err = c.Done()
if err != nil {
t.Fatal(err)
}
}
}
c0 := client.NewEtcd(endpoint)
c1 := client.NewEtcd(endpoint)
c2 := client.NewEtcd(endpoint)
c3 := client.NewEtcd(endpoint)
wg.Add(3)
go selectAndDone(c0)
go selectAndDone(c1)
go selectAndDone(c2)
wg.Wait()
// simulate trainer crashed and restarted after the
// initialization process.
wg.Add(1)
go selectAndDone(c3)
wg.Wait()
mu.Lock()
if selectedCount != 1 {
t.Fatal("selected count wrong:", selectedCount)
}
mu.Unlock()
err = c0.Close()
if err != nil {
t.Fatal(err)
}
err = c1.Close()
if err != nil {
t.Fatal(err)
}
err = c2.Close()
if err != nil {
t.Fatal(err)
}
err = c3.Close()
if err != nil {
t.Fatal(err)
}
}
...@@ -7,6 +7,9 @@ cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context) ...@@ -7,6 +7,9 @@ cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context)
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
cc_library(lod_tensor SRCS lod_tensor.cc details/lod_tensor.cc DEPS ddim place tensor)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
cc_library(scope SRCS scope.cc) cc_library(scope SRCS scope.cc)
......
...@@ -147,8 +147,9 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -147,8 +147,9 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
ForEachVarName(grad_op->inputs_, [&no_grad_names, ForEachVarName(grad_op->inputs_, [&no_grad_names,
&net](std::string& grad_input) { &net](std::string& grad_input) {
if (no_grad_names.count(grad_input)) { if (no_grad_names.count(grad_input)) {
std::string prefix = // +1 for \0
grad_input.substr(0, grad_input.size() - kGradVarSuffix.size()); std::string prefix = grad_input.substr(
0, grad_input.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
grad_input = prefix + kZeroVarSuffix; grad_input = prefix + kZeroVarSuffix;
// If part of input gradient of that operator is not calculated, fill // If part of input gradient of that operator is not calculated, fill
...@@ -184,7 +185,7 @@ std::shared_ptr<OperatorBase> Backward( ...@@ -184,7 +185,7 @@ std::shared_ptr<OperatorBase> Backward(
std::unordered_set<std::string> no_grad_names; std::unordered_set<std::string> no_grad_names;
no_grad_names.reserve(no_grad_vars.size()); no_grad_names.reserve(no_grad_vars.size());
no_grad_names.insert(kEmptyVarName + kGradVarSuffix); no_grad_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
for (auto& name : no_grad_vars) { for (auto& name : no_grad_vars) {
no_grad_names.insert(name + kGradVarSuffix); no_grad_names.insert(name + kGradVarSuffix);
......
...@@ -166,21 +166,17 @@ REGISTER_OP(fc, f::FcOp, f::FcOpMaker); ...@@ -166,21 +166,17 @@ REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker); REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp); REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
TEST(Backward, need_to_be_removed) {}
//
// TEST(Backward, simple_op_grad) { // TEST(Backward, simple_op_grad) {
// auto fwd = f::OpRegistry::CreateOp( // auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
// "rowwise_add", {{"X", {"X"}}, {"b", {"b"}}}, {{"Out", {"Out"}}}, {});
// ASSERT_NE(fwd, nullptr); // ASSERT_NE(fwd, nullptr);
// auto gop = f::OpRegistry::CreateGradOp(*fwd); // auto gop = f::OpRegistry::CreateGradOp(*fwd);
// ASSERT_EQ(4UL, gop->inputs_.size()); // ASSERT_EQ(4UL, gop->inputs_.size());
// ASSERT_EQ(f::kEmptyVarName, gop->inputs_[0]); // ASSERT_EQ(f::kEmptyVarName, gop->inputs_[0]);
// ASSERT_EQ("rowwise_add_grad", gop->type_); // ASSERT_EQ("rowwise_add_grad", gop->type_);
// ASSERT_EQ("X" + f::kGradVarSuffix, gop->outputs_[0]); // ASSERT_EQ(f::GradVarName("X"), gop->outputs_[0]);
// ASSERT_EQ("b" + f::kGradVarSuffix, gop->outputs_[1]); // ASSERT_EQ(f::GradVarName("b"), gop->outputs_[1]);
// //
// ASSERT_EQ("X" + f::kGradVarSuffix, gop->Output("X" + f::kGradVarSuffix)); // ASSERT_EQ(f::GradVarName("X"), gop->Output(f::GradVarName("X")));
//} //}
// //
// TEST(Backward, simple_op_not_need_grad) { // TEST(Backward, simple_op_not_need_grad) {
...@@ -188,7 +184,7 @@ TEST(Backward, need_to_be_removed) {} ...@@ -188,7 +184,7 @@ TEST(Backward, need_to_be_removed) {}
// ASSERT_NE(fwd, nullptr); // ASSERT_NE(fwd, nullptr);
// auto gop = f::Backward(*fwd, {"X"}); // auto gop = f::Backward(*fwd, {"X"});
// ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(), // ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(),
// "X" + f::kGradVarSuffix), // f::GradVarName("X")),
// gop->outputs_.end()); // gop->outputs_.end());
// //
// auto no_input_gop = f::Backward(*fwd, {"X", "b"}); // auto no_input_gop = f::Backward(*fwd, {"X", "b"});
...@@ -259,18 +255,18 @@ TEST(Backward, need_to_be_removed) {} ...@@ -259,18 +255,18 @@ TEST(Backward, need_to_be_removed) {}
// all_output.erase(f::kEmptyVarName); // all_output.erase(f::kEmptyVarName);
// //
// for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) { // for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) {
// ASSERT_NE(all_output.find(out + f::kGradVarSuffix), all_output.end()); // ASSERT_NE(all_output.find(f::GradVarName(out)), all_output.end());
// } // }
// //
// // Not Generated X // // Not Generated X
// ASSERT_EQ(all_output.find("X" + f::kGradVarSuffix), all_output.end()); // ASSERT_EQ(all_output.find(f::GradVarName("X")), all_output.end());
// //
// ASSERT_EQ(2UL, bwd_net->ops_.size()); // ASSERT_EQ(2UL, bwd_net->ops_.size());
// ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); // ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
// auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get()); // auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get());
// ASSERT_EQ(3UL, first_fc_grad->ops_.size()); // ASSERT_EQ(3UL, first_fc_grad->ops_.size());
// ASSERT_EQ(f::kEmptyVarName, // ASSERT_EQ(f::kEmptyVarName,
// first_fc_grad->ops_[2]->Output("A" + f::kGradVarSuffix)); // first_fc_grad->ops_[2]->Output(f::GradVarName("A")));
//} //}
// //
// TEST(Backward, net_shared_weight) { // TEST(Backward, net_shared_weight) {
...@@ -322,17 +318,15 @@ TEST(Backward, need_to_be_removed) {} ...@@ -322,17 +318,15 @@ TEST(Backward, need_to_be_removed) {}
// ASSERT_EQ(1UL, fill_zero.inputs_.size()); // ASSERT_EQ(1UL, fill_zero.inputs_.size());
// ASSERT_EQ("Z", fill_zero.inputs_[0]); // ASSERT_EQ("Z", fill_zero.inputs_[0]);
// ASSERT_EQ(1UL, fill_zero.outputs_.size()); // ASSERT_EQ(1UL, fill_zero.outputs_.size());
// ASSERT_EQ("Z" + f::kZeroVarSuffix, fill_zero.outputs_[0]); // ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, fill_zero.outputs_[0]);
// //
// auto &d_many_out = *net->ops_[1]; // auto &d_many_out = *net->ops_[1];
// ASSERT_EQ("many_output_op_grad", d_many_out.type_); // ASSERT_EQ("many_output_op_grad", d_many_out.type_);
// ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG // ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG
// ASSERT_EQ("Z" + f::kZeroVarSuffix, d_many_out.Input("z" + // ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix,
// f::kGradVarSuffix)); // d_many_out.Input(f::GradVarName("z")));
// ASSERT_EQ("Y" + f::kGradVarSuffix, d_many_out.Input("y" + // ASSERT_EQ(f::GradVarName("Y"), d_many_out.Input(f::GradVarName("y")));
// f::kGradVarSuffix)); // ASSERT_EQ(f::GradVarName("X"), d_many_out.Output(f::GradVarName("x")));
// ASSERT_EQ("X" + f::kGradVarSuffix,
// d_many_out.Output("x" + f::kGradVarSuffix));
//} //}
// //
// TEST(Backward, op_part_of_input_are_not_need) { // TEST(Backward, op_part_of_input_are_not_need) {
...@@ -342,11 +336,9 @@ TEST(Backward, need_to_be_removed) {} ...@@ -342,11 +336,9 @@ TEST(Backward, need_to_be_removed) {}
// ASSERT_EQ(grad_mul.type_, "mul_grad"); // ASSERT_EQ(grad_mul.type_, "mul_grad");
// ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); // ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
// ASSERT_EQ(grad_mul.outputs_.size(), 2UL); // ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
// ASSERT_EQ(grad_mul.Output("A" + f::kGradVarSuffix), f::kEmptyVarName); // ASSERT_EQ(grad_mul.Output(f::GradVarName("A")), f::kEmptyVarName);
// ASSERT_EQ(grad_mul.Output("B" + f::kGradVarSuffix), "b" + // ASSERT_EQ(grad_mul.Output(f::GradVarName("B")), f::GradVarName("b"));
// f::kGradVarSuffix); // ASSERT_EQ(grad_mul.Input(f::GradVarName("Out")), f::GradVarName("out"));
// ASSERT_EQ(grad_mul.Input("Out" + f::kGradVarSuffix),
// "out" + f::kGradVarSuffix);
// ASSERT_EQ(grad_mul.Input("A"), "a"); // ASSERT_EQ(grad_mul.Input("A"), "a");
// ASSERT_EQ(grad_mul.Input("B"), "b"); // ASSERT_EQ(grad_mul.Input("B"), "b");
// ASSERT_EQ(grad_mul.Input("Out"), "out"); // ASSERT_EQ(grad_mul.Input("Out"), "out");
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include <memory>
namespace paddle {
namespace framework {
namespace details {
using LOD = LODTensor::LOD;
std::shared_ptr<LOD> SliceLOD(const LOD &lod, size_t level_begin,
size_t level_end) {
auto new_lod = std::make_shared<LOD>();
new_lod->reserve(level_end - level_begin);
for (size_t i = level_begin; i < level_end; i++) {
new_lod->emplace_back(lod[i]);
}
return new_lod;
}
std::shared_ptr<LOD> SliceLOD(const LOD &lod, size_t level, size_t elem_begin,
size_t elem_end, bool tensor_shared) {
// slice the lod.
auto new_lod = std::make_shared<LOD>();
new_lod->reserve(lod.size() - level);
auto start = lod.at(level)[elem_begin];
auto end = lod.at(level)[elem_end];
for (auto it = lod.begin() + level; it != lod.end(); it++) {
auto it_begin = std::find(it->begin(), it->end(), start);
auto it_end = std::find(it_begin, it->end(), end);
PADDLE_ENFORCE(it_begin != it->end(), "error in parsing lod info");
PADDLE_ENFORCE(it_end != it->end(), "error in parsing lod info");
new_lod->emplace_back(it_begin, it_end + 1);
if (!tensor_shared) {
// reset offset if tensor is copyed and sliced.
std::transform(new_lod->back().begin(), new_lod->back().end(),
new_lod->back().begin(),
[start](int v) { return v - start; });
PADDLE_ENFORCE(new_lod->back().front() == 0, "error in slice LOD");
}
}
return new_lod;
}
} // namespace details
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
namespace paddle {
namespace framework {
namespace details {
/*
* Slice levels from LOD.
*
* @lod: LOD to slice.
* @level_begin: level to begin slice.
* @level_end: level to end slice.
*/
std::shared_ptr<LODTensor::LOD> SliceLOD(const LODTensor::LOD &lod,
size_t level_begin, size_t level_end);
/*
* Slice elements from a level of LOD.
*
* @lod: LOD to slice.
* @level: which level to slice.
* @elem_begin: element's index to begin slice.
* @elem_end: element's index to end slice.
*/
std::shared_ptr<LODTensor::LOD> SliceLOD(const LODTensor::LOD &lod,
size_t level, size_t elem_begin,
size_t elem_end, bool tensor_shared);
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -40,8 +40,8 @@ message OpDesc { ...@@ -40,8 +40,8 @@ message OpDesc {
}; };
message Var { message Var {
required string op_proto_name = 1; required string parameter = 1;
repeated string var_names = 2; repeated string arguments = 2;
}; };
required string type = 3; required string type = 3;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include <glog/logging.h>
namespace paddle {
namespace framework {
LODTensor LODTensor::SliceShared(size_t level_begin, size_t level_end) const {
PADDLE_ENFORCE(HasLOD(), "has no LOD info, can't be sliced.");
auto new_lod = details::SliceLOD(*lod_start_pos_, level_begin, level_end);
// slice levels just need to update LOD info, each level will contains the
// whole tensor_, so no need to modify tensor_.
return LODTensor(tensor_, new_lod);
}
LODTensor LODTensor::SliceShared(size_t level, size_t elem_begin,
size_t elem_end) const {
PADDLE_ENFORCE(HasLOD(), "has no LOD info, can't be sliced.");
PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level,
NumLevels());
PADDLE_ENFORCE(elem_begin < NumElements(level),
"element begin [%d] out of range [%d]", elem_begin,
NumElements(level));
PADDLE_ENFORCE(elem_end < NumElements(level) + 1,
"element end [%d] out of range [%d]", elem_end,
NumElements(level));
auto new_lod = details::SliceLOD(*lod_start_pos_, level, elem_begin, elem_end,
true /*tensor_shared*/);
// slice elements just need to update LOD info, because offsets are not
// changed, so the original tensor_ can be reused.
return LODTensor(tensor_, new_lod);
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#if (!PADDLE_ONLY_CPU)
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#endif
#include "paddle/framework/ddim.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace framework {
/*
* LODTensor (Level of details Tensor)
* see https://en.wikipedia.org/wiki/Level_of_details for reference.
*/
class LODTensor {
public:
// Level save offsets of each unit.
#ifdef PADDLE_ONLY_CPU
using Level = std::vector<size_t>;
#else
using Level = thrust::device_vector<size_t>;
#endif
// LOD stores offsets of each level of units, the largest units level first,
// then the smaller units level. Each Level stores the offsets of units in
// Tesor.
typedef std::vector<Level> LOD;
LODTensor() {}
LODTensor(const std::shared_ptr<Tensor> &tensor,
const std::shared_ptr<LOD> &lod) {
Reset(tensor, lod);
}
void Reset(const std::shared_ptr<Tensor> &tensor,
const std::shared_ptr<LOD> &lod) {
tensor_ = tensor;
lod_start_pos_ = lod;
}
/*
* Get a element from LOD.
*/
size_t lod_element(size_t level, size_t elem) const {
PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level,
NumLevels());
PADDLE_ENFORCE(elem < NumElements(level),
"element begin [%d] out of range [%d]", elem,
NumElements(level));
return (*lod_start_pos_)[level][elem];
}
/*
* Number of LODTensor's levels, each level has units of data, for example,
* in the sentence's view, article, paragraph, sentence are 3 levels.
*/
size_t NumLevels() const {
return lod_start_pos_ ? lod_start_pos_->size() : 0UL;
}
/*
* Number of elements in a level.
*/
size_t NumElements(size_t level = 0) const {
PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level,
NumLevels());
// the last offset is the end of last element
return lod_start_pos_->at(level).size() - 1;
}
/*
* Slice of levels[level_begin:level_end], with tensor copied.
*/
template <typename T>
LODTensor SliceCopied(size_t level_begin, size_t level_end,
const platform::Place &dst_place) const;
/*
* Slice of levels[level_begin:level_end], with tensor shared.
*/
LODTensor SliceShared(size_t level_begin, size_t level_end) const;
/*
* Slice of elements of a level, [elem_begin: elem_end], with tensor copied.
* @note: low performance in slice lod_start_pos_.
*/
template <typename T>
LODTensor SliceCopied(size_t level, size_t elem_begin, size_t elem_end,
const platform::Place &dst_place) const;
/*
* Slice of elements of a level, [elem_begin: elem_end], with tensor shared.
* @note: low performance in slice lod_start_pos_.
*/
LODTensor SliceShared(size_t level, size_t elem_begin, size_t elem_end) const;
/*
* Copy other's lod_start_pos_, to share LOD info.
* @note: the LOD info should not be changed.
*/
void ShareLOD(const LODTensor &other) {
lod_start_pos_ = other.lod_start_pos_;
}
/*
* Copy other's lod_start_pos_'s content, free to mutate.
*/
void CopyLOD(const LODTensor &other) {
lod_start_pos_ = std::make_shared<LOD>(*other.lod_start_pos_);
}
/*
* Determine whether LODTensor has a valid LOD info.
*/
bool HasLOD() const { return bool(lod_start_pos_); }
LOD *lod() const { return lod_start_pos_.get(); }
std::shared_ptr<Tensor> &tensor() { return tensor_; }
Tensor *raw_tensor() { return tensor_.get(); }
private:
std::shared_ptr<LOD> lod_start_pos_;
std::shared_ptr<Tensor> tensor_;
};
} // namespace framework
} // namespace paddle
#include "paddle/framework/lod_tensor_impl.h"
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/details/lod_tensor.h"
namespace paddle {
namespace framework {
template <typename T>
LODTensor LODTensor::SliceCopied(size_t level_begin, size_t level_end,
const platform::Place &dst_place) const {
PADDLE_ENFORCE(HasLOD(), "has no LOD info, can't be sliced.");
auto new_lod = details::SliceLOD(*lod_start_pos_, level_begin, level_end);
auto new_tensor = std::make_shared<Tensor>();
new_tensor->CopyFrom<T>(*tensor_, dst_place);
return LODTensor(new_tensor, new_lod);
}
template <typename T>
LODTensor LODTensor::SliceCopied(size_t level, size_t elem_begin,
size_t elem_end,
const platform::Place &dst_place) const {
PADDLE_ENFORCE(HasLOD(), "has no LOD info, can't be sliced.");
PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level,
NumLevels());
PADDLE_ENFORCE(elem_begin < NumElements(level),
"element begin [%d] out of range [%d]", elem_begin,
NumElements(level));
PADDLE_ENFORCE(elem_end < NumElements(level) + 1,
"element end [%d] out of range [%d]", elem_end,
NumElements(level));
auto new_lod = details::SliceLOD(*lod_start_pos_, level, elem_begin, elem_end,
false /*tensor_shared*/);
auto start_idx = new_lod->front().front();
auto end_idx = new_lod->front().back() - 1 /*the next element's start*/;
auto sliced_tensor = tensor_->Slice<T>(start_idx, end_idx);
auto new_tensor = std::make_shared<Tensor>();
new_tensor->CopyFrom<T>(sliced_tensor, dst_place);
return LODTensor(new_tensor, new_lod);
}
} // namespace framework
} // namespace paddle
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "paddle/framework/lod_tensor.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <memory>
namespace paddle {
namespace framework {
class LODTensorTester : public ::testing::Test {
public:
virtual void SetUp() override {
lod_tensor.reset(new LODTensor);
// tensor's batch_size: 30
// 3 levels
// 0 10 20
// 0 5 10 15 20
// 0 2 5 7 10 12 15 20
auto lod = std::make_shared<LODTensor::LOD>();
lod->push_back(std::vector<size_t>{0, 10, 20});
lod->push_back(std::vector<size_t>{0, 5, 10, 15, 20});
lod->push_back(std::vector<size_t>{0, 2, 5, 7, 10, 12, 15, 17, 20});
auto tensor = std::make_shared<Tensor>();
tensor->Resize({20 /*batch size*/, 128 /*dim*/});
// malloc memory
tensor->mutable_data<float>(place);
lod_tensor->Reset(tensor, lod);
}
protected:
std::unique_ptr<LODTensor> lod_tensor;
platform::CPUPlace place;
};
TEST_F(LODTensorTester, NumLevels) { ASSERT_EQ(lod_tensor->NumLevels(), 3UL); }
TEST_F(LODTensorTester, NumElements) {
ASSERT_EQ(lod_tensor->NumElements(0), 2UL);
ASSERT_EQ(lod_tensor->NumElements(1), 4UL);
ASSERT_EQ(lod_tensor->NumElements(2), 8UL);
}
TEST_F(LODTensorTester, SliceShared_Level) {
// slice 1 level
for (size_t level = 0; level < 3UL; ++level) {
auto new_lod_tensor = lod_tensor->SliceShared(level, level + 1);
ASSERT_EQ(new_lod_tensor.NumLevels(), 1UL);
ASSERT_EQ(new_lod_tensor.NumElements(0UL), lod_tensor->NumElements(level));
ASSERT_EQ(new_lod_tensor.tensor(), lod_tensor->tensor());
}
// slice 2 level
for (size_t level = 0; level < 2UL; ++level) {
auto new_lod_tensor = lod_tensor->SliceShared(level, level + 2);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor->NumElements(level));
ASSERT_EQ(new_lod_tensor.NumElements(1),
lod_tensor->NumElements(level + 1));
ASSERT_EQ(new_lod_tensor.tensor(), lod_tensor->tensor());
}
}
TEST_F(LODTensorTester, SliceCopied_Level) {
// slice 1 level
for (size_t level = 0; level < 3UL; ++level) {
auto new_lod_tensor =
lod_tensor->SliceCopied<float>(level, level + 1, place);
ASSERT_EQ(new_lod_tensor.NumLevels(), 1UL);
ASSERT_EQ(new_lod_tensor.NumElements(0UL), lod_tensor->NumElements(level));
// ASSERT_EQ(new_lod_tensor.tensor(), lod_tensor->tensor());
// TODO(superjom) add tensor comparation here.
}
// slice 2 level
for (size_t level = 0; level < 2UL; ++level) {
auto new_lod_tensor =
lod_tensor->SliceCopied<float>(level, level + 2, place);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor->NumElements(level));
ASSERT_EQ(new_lod_tensor.NumElements(1),
lod_tensor->NumElements(level + 1));
// ASSERT_EQ(new_lod_tensor.tensor(), lod_tensor->tensor());
// TODO(superjom) add tensor comparation here.
}
}
TEST_F(LODTensorTester, SliceShared_Element) {
size_t level = 0;
auto new_lod_tensor = lod_tensor->SliceShared(level, 0, 2);
ASSERT_EQ(new_lod_tensor.NumLevels(), 3UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_EQ(new_lod_tensor.NumElements(2), 8UL);
ASSERT_EQ(new_lod_tensor.raw_tensor(), lod_tensor->raw_tensor());
level = 1;
new_lod_tensor = lod_tensor->SliceShared(level, 0, 2);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_EQ(new_lod_tensor.raw_tensor(), lod_tensor->raw_tensor());
}
TEST_F(LODTensorTester, SliceCopied_Element) {
size_t level = 0;
auto new_lod_tensor = lod_tensor->SliceCopied<float>(level, 0, 2, place);
ASSERT_EQ(new_lod_tensor.NumLevels(), 3UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_EQ(new_lod_tensor.NumElements(2), 8UL);
ASSERT_NE(new_lod_tensor.raw_tensor(), lod_tensor->raw_tensor());
level = 1;
new_lod_tensor = lod_tensor->SliceCopied<float>(level, 0, 2, place);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_NE(new_lod_tensor.raw_tensor(), lod_tensor->raw_tensor());
level = 1;
// LOD is
// 0 5 10
// 0 2 5 7 10
new_lod_tensor = lod_tensor->SliceCopied<float>(level, 1, 3, place);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_EQ(new_lod_tensor.lod_element(0, 0), 0UL);
ASSERT_EQ(new_lod_tensor.lod_element(0, 1), 5UL);
ASSERT_EQ(new_lod_tensor.lod_element(1, 0), 0UL);
ASSERT_EQ(new_lod_tensor.lod_element(1, 1), 2UL);
ASSERT_EQ(new_lod_tensor.lod_element(1, 2), 5UL);
ASSERT_EQ(new_lod_tensor.lod_element(1, 3), 7UL);
// TODO(superjom) compare the content of these tensors
}
TEST_F(LODTensorTester, ShareLOD) {
LODTensor new_lod_tensor;
new_lod_tensor.ShareLOD(*lod_tensor);
ASSERT_EQ(new_lod_tensor.lod(), lod_tensor->lod());
}
TEST_F(LODTensorTester, CopyLOD) {
LODTensor new_lod_tensor;
new_lod_tensor.CopyLOD(*lod_tensor);
ASSERT_NE(new_lod_tensor.lod(), lod_tensor->lod());
}
} // namespace framework
} // namespace paddle
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h" #include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
namespace paddle { namespace paddle {
...@@ -127,7 +128,7 @@ class OpRegistry { ...@@ -127,7 +128,7 @@ class OpRegistry {
static void RegisterOp(const std::string& op_type) { static void RegisterOp(const std::string& op_type) {
op_creators()[op_type] = [] { return new OpType; }; op_creators()[op_type] = [] { return new OpType; };
OpAttrChecker& op_checker = op_checkers()[op_type]; OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = protos()[op_type]; OpProto& op_proto = OpProtos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker); auto maker = ProtoMakerType(&op_proto, &op_checker);
maker.Validate(); maker.Validate();
*op_proto.mutable_type() = op_type; *op_proto.mutable_type() = op_type;
...@@ -135,17 +136,6 @@ class OpRegistry { ...@@ -135,17 +136,6 @@ class OpRegistry {
op_proto.IsInitialized(), op_proto.IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized", "Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_proto.InitializationErrorString()); op_type, op_proto.InitializationErrorString());
VarIndexMaps()[op_type].reset(new VarIndexMap());
auto& varmap = *VarIndexMaps()[op_type];
int idx = 0;
for (auto& var : op_proto.inputs()) {
varmap[var.name()] = idx++;
}
idx = 0;
for (auto& var : op_proto.outputs()) {
varmap[var.name()] = idx++;
}
} }
template <typename GradOpType> template <typename GradOpType>
...@@ -180,8 +170,8 @@ class OpRegistry { ...@@ -180,8 +170,8 @@ class OpRegistry {
static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) { static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) {
VarNameMap inputs; VarNameMap inputs;
for (auto& input : op_desc.inputs()) { for (auto& input : op_desc.inputs()) {
auto& var_names = inputs[input.op_proto_name()]; auto& var_names = inputs[input.parameter()];
auto& var_names_in_proto = input.var_names(); auto& var_names_in_proto = input.arguments();
var_names.reserve(static_cast<size_t>(var_names_in_proto.size())); var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
std::copy(var_names_in_proto.begin(), var_names_in_proto.end(), std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
std::back_inserter(var_names)); std::back_inserter(var_names));
...@@ -189,8 +179,8 @@ class OpRegistry { ...@@ -189,8 +179,8 @@ class OpRegistry {
VarNameMap outputs; VarNameMap outputs;
for (auto& output : op_desc.outputs()) { for (auto& output : op_desc.outputs()) {
auto& var_names = outputs[output.op_proto_name()]; auto& var_names = outputs[output.parameter()];
auto& var_names_in_proto = output.var_names(); auto& var_names_in_proto = output.arguments();
var_names.reserve(static_cast<size_t>(var_names_in_proto.size())); var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
std::copy(var_names_in_proto.begin(), var_names_in_proto.end(), std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
std::back_inserter(var_names)); std::back_inserter(var_names));
...@@ -212,22 +202,11 @@ class OpRegistry { ...@@ -212,22 +202,11 @@ class OpRegistry {
return grad_op; return grad_op;
} }
static std::unordered_map<std::string, OpProto>& protos() {
static std::unordered_map<std::string, OpProto> protos_;
return protos_;
}
static std::unordered_map<std::string, std::string>& grad_ops() { static std::unordered_map<std::string, std::string>& grad_ops() {
static std::unordered_map<std::string, std::string> grad_ops_; static std::unordered_map<std::string, std::string> grad_ops_;
return grad_ops_; return grad_ops_;
} }
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
VarIndexMaps() {
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>> maps_;
return maps_;
}
static std::unordered_map<std::string, OpCreator>& op_creators() { static std::unordered_map<std::string, OpCreator>& op_creators() {
static std::unordered_map<std::string, OpCreator> op_creators_; static std::unordered_map<std::string, OpCreator> op_creators_;
return op_creators_; return op_creators_;
......
...@@ -58,12 +58,12 @@ TEST(OpRegistry, CreateOp) { ...@@ -58,12 +58,12 @@ TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
auto input = op_desc.add_inputs(); auto input = op_desc.add_inputs();
input->set_op_proto_name("input"); input->set_parameter("input");
*input->mutable_var_names()->Add() = "aa"; *input->mutable_arguments()->Add() = "aa";
auto output = op_desc.add_outputs(); auto output = op_desc.add_outputs();
output->set_op_proto_name("output"); output->set_parameter("output");
*output->mutable_var_names()->Add() = "bb"; *output->mutable_arguments()->Add() = "bb";
float scale = 3.3; float scale = 3.3;
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
...@@ -84,12 +84,12 @@ TEST(OpRegistry, IllegalAttr) { ...@@ -84,12 +84,12 @@ TEST(OpRegistry, IllegalAttr) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
auto input = op_desc.add_inputs(); auto input = op_desc.add_inputs();
input->set_op_proto_name("input"); input->set_parameter("input");
*input->mutable_var_names()->Add() = "aa"; *input->mutable_arguments()->Add() = "aa";
auto output = op_desc.add_outputs(); auto output = op_desc.add_outputs();
output->set_op_proto_name("output"); output->set_parameter("output");
*output->mutable_var_names()->Add() = "bb"; *output->mutable_arguments()->Add() = "bb";
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
...@@ -114,12 +114,12 @@ TEST(OpRegistry, DefaultValue) { ...@@ -114,12 +114,12 @@ TEST(OpRegistry, DefaultValue) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
auto input = op_desc.add_inputs(); auto input = op_desc.add_inputs();
input->set_op_proto_name("input"); input->set_parameter("input");
*input->mutable_var_names()->Add() = "aa"; *input->mutable_arguments()->Add() = "aa";
auto output = op_desc.add_outputs(); auto output = op_desc.add_outputs();
output->set_op_proto_name("output"); output->set_parameter("output");
*output->mutable_var_names()->Add() = "bb"; *output->mutable_arguments()->Add() = "bb";
ASSERT_TRUE(op_desc.IsInitialized()); ASSERT_TRUE(op_desc.IsInitialized());
...@@ -135,12 +135,12 @@ TEST(OpRegistry, CustomChecker) { ...@@ -135,12 +135,12 @@ TEST(OpRegistry, CustomChecker) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("my_test_op"); op_desc.set_type("my_test_op");
auto input = op_desc.add_inputs(); auto input = op_desc.add_inputs();
input->set_op_proto_name("input"); input->set_parameter("input");
*input->mutable_var_names()->Add() = "ii"; *input->mutable_arguments()->Add() = "ii";
auto output = op_desc.add_outputs(); auto output = op_desc.add_outputs();
output->set_op_proto_name("output"); output->set_parameter("output");
*output->mutable_var_names()->Add() = "oo"; *output->mutable_arguments()->Add() = "oo";
// attr 'test_attr' is not set // attr 'test_attr' is not set
bool caught = false; bool caught = false;
......
...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm>
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include <algorithm>
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -33,6 +33,14 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ...@@ -33,6 +33,14 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
} }
#endif #endif
static std::unordered_map<std::string, OpProto>* g_op_protos = nullptr;
std::unordered_map<std::string, OpProto>& OpProtos() {
if (g_op_protos == nullptr) {
g_op_protos = new std::unordered_map<std::string, OpProto>();
}
return *g_op_protos;
}
const std::string& OperatorBase::Input(const std::string& name) const { const std::string& OperatorBase::Input(const std::string& name) const {
auto it = inputs_.find(name); auto it = inputs_.find(name);
PADDLE_ENFORCE(it != inputs_.end(), "Op %s does not have output %s", type_, PADDLE_ENFORCE(it != inputs_.end(), "Op %s does not have output %s", type_,
......
...@@ -32,24 +32,26 @@ namespace paddle { ...@@ -32,24 +32,26 @@ namespace paddle {
namespace framework { namespace framework {
/// If a variable is a empty variable, that name will be used. /// If a variable is a empty variable, that name will be used.
const std::string kEmptyVarName = "@EMPTY@"; constexpr char kEmptyVarName[] = "@EMPTY@";
/// If a variable is a temporary variable, that name will be set in Python, /// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator. /// but it will be convert to a unique name in scope after OpCreator.
const std::string kTempVarName = "@TEMP@"; constexpr char kTempVarName[] = "@TEMP@";
/// If a variable's name has a certain suffix, it means that the /// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale. /// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x". /// e.g. Variable "x@GRAD" is the gradient of varibale "x".
const std::string kGradVarSuffix = "@GRAD"; constexpr char kGradVarSuffix[] = "@GRAD";
/// Variables with this suffix are supposed to be filled up with zeros. /// Variables with this suffix are supposed to be filled up with zeros.
const std::string kZeroVarSuffix = "@ZERO"; constexpr char kZeroVarSuffix[] = "@ZERO";
inline std::string GradVarName(const std::string& var_name) { inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix; return var_name + kGradVarSuffix;
} }
extern std::unordered_map<std::string, OpProto>& OpProtos();
class OperatorBase; class OperatorBase;
class InferShapeContext; class InferShapeContext;
class ExecutionContext; class ExecutionContext;
...@@ -103,6 +105,35 @@ class OperatorBase { ...@@ -103,6 +105,35 @@ class OperatorBase {
//! TODO add a vector_view to prevent memory copy. //! TODO add a vector_view to prevent memory copy.
const std::vector<std::string>& Outputs(const std::string& name) const; const std::vector<std::string>& Outputs(const std::string& name) const;
virtual std::vector<std::string> OutputVars(bool has_intermediate) const {
std::vector<std::string> ret_val;
if (has_intermediate) {
// push all outputs into ret_val
for (auto& o : outputs_) {
ret_val.reserve(ret_val.size() + o.second.size());
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
}
return ret_val;
}
auto it = OpProtos().find(type_);
PADDLE_ENFORCE(
it != OpProtos().end(),
"Operator %s not registered, cannot figure out intermediate outputs",
type_);
// get all OpProto::Var for outputs
for (auto& o : it->second.outputs()) {
// ignore all intermediate output
if (o.intermediate()) continue;
auto out = outputs_.find(o.name());
if (out != outputs_.end()) {
ret_val.reserve(ret_val.size() + out->second.size());
ret_val.insert(ret_val.end(), out->second.begin(), out->second.end());
}
}
return ret_val;
}
public: public:
std::string type_; std::string type_;
// NOTE: in case of OpGrad, inputs_ contains: // NOTE: in case of OpGrad, inputs_ contains:
...@@ -117,10 +148,10 @@ class OperatorBase { ...@@ -117,10 +148,10 @@ class OperatorBase {
AttributeMap attrs_; AttributeMap attrs_;
}; };
class OperatorContext { class InferShapeContext {
public: public:
OperatorContext(const OperatorBase* op, const Scope& scope) InferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(*op), scope_(scope) {} : op_(op), scope_(scope) {}
size_t InputSize(const std::string& name) const { size_t InputSize(const std::string& name) const {
return op_.inputs_.at(name).size(); return op_.inputs_.at(name).size();
...@@ -209,12 +240,6 @@ class OperatorContext { ...@@ -209,12 +240,6 @@ class OperatorContext {
const Scope& scope_; const Scope& scope_;
}; };
class InferShapeContext : public OperatorContext {
public:
InferShapeContext(const OperatorBase* op, const Scope& scope)
: OperatorContext(op, scope) {}
};
template <typename T> template <typename T>
struct EigenDeviceConverter; struct EigenDeviceConverter;
...@@ -230,11 +255,11 @@ struct EigenDeviceConverter<platform::GPUPlace> { ...@@ -230,11 +255,11 @@ struct EigenDeviceConverter<platform::GPUPlace> {
}; };
#endif #endif
class ExecutionContext : public OperatorContext { class ExecutionContext : public InferShapeContext {
public: public:
ExecutionContext(const OperatorBase* op, const Scope& scope, ExecutionContext(const OperatorBase& op, const Scope& scope,
const platform::DeviceContext* device_context) const platform::DeviceContext* device_context)
: OperatorContext(op, scope), device_context_(device_context) {} : InferShapeContext(op, scope), device_context_(device_context) {}
template <typename PlaceType, template <typename PlaceType,
typename DeviceType = typename DeviceType =
...@@ -286,13 +311,13 @@ class OperatorWithKernel : public OperatorBase { ...@@ -286,13 +311,13 @@ class OperatorWithKernel : public OperatorBase {
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>; std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void InferShape(const Scope& scope) const override { void InferShape(const Scope& scope) const override {
InferShape(InferShapeContext(this, scope)); InferShape(InferShapeContext(*this, scope));
} }
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const final { const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(ExecutionContext(this, scope, &dev_ctx)); opKernel->Compute(ExecutionContext(*this, scope, &dev_ctx));
} }
static std::unordered_map<std::string /* op_type */, OpKernelMap>& static std::unordered_map<std::string /* op_type */, OpKernelMap>&
......
...@@ -61,12 +61,12 @@ TEST(OperatorBase, all) { ...@@ -61,12 +61,12 @@ TEST(OperatorBase, all) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("test_operator"); op_desc.set_type("test_operator");
auto* ipt = op_desc.mutable_inputs()->Add(); auto* ipt = op_desc.mutable_inputs()->Add();
*ipt->mutable_var_names()->Add() = "IN1"; *ipt->mutable_arguments()->Add() = "IN1";
ipt->set_op_proto_name("input"); ipt->set_parameter("input");
auto* output = op_desc.mutable_outputs()->Add(); auto* output = op_desc.mutable_outputs()->Add();
*output->mutable_var_names()->Add() = "OUT1"; *output->mutable_arguments()->Add() = "OUT1";
output->set_op_proto_name("output"); output->set_parameter("output");
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::AttrType::FLOAT);
...@@ -184,12 +184,12 @@ TEST(OpKernel, all) { ...@@ -184,12 +184,12 @@ TEST(OpKernel, all) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("op_with_kernel"); op_desc.set_type("op_with_kernel");
auto* ipt = op_desc.mutable_inputs()->Add(); auto* ipt = op_desc.mutable_inputs()->Add();
*ipt->mutable_var_names()->Add() = "IN1"; *ipt->mutable_arguments()->Add() = "IN1";
ipt->set_op_proto_name("x"); ipt->set_parameter("x");
auto* output = op_desc.mutable_outputs()->Add(); auto* output = op_desc.mutable_outputs()->Add();
*output->mutable_var_names()->Add() = "OUT1"; *output->mutable_arguments()->Add() = "OUT1";
output->set_op_proto_name("y"); output->set_parameter("y");
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
...@@ -217,17 +217,17 @@ TEST(OpKernel, multi_inputs) { ...@@ -217,17 +217,17 @@ TEST(OpKernel, multi_inputs) {
OpDesc op_desc; OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel"); op_desc.set_type("op_multi_inputs_with_kernel");
auto x = op_desc.mutable_inputs()->Add(); auto x = op_desc.mutable_inputs()->Add();
x->set_op_proto_name("xs"); x->set_parameter("xs");
*x->mutable_var_names()->Add() = "x0"; *x->mutable_arguments()->Add() = "x0";
*x->mutable_var_names()->Add() = "x1"; *x->mutable_arguments()->Add() = "x1";
*x->mutable_var_names()->Add() = "x2"; *x->mutable_arguments()->Add() = "x2";
auto k = op_desc.mutable_inputs()->Add(); auto k = op_desc.mutable_inputs()->Add();
k->set_op_proto_name("k"); k->set_parameter("k");
*k->mutable_var_names()->Add() = "k0"; *k->mutable_arguments()->Add() = "k0";
auto y = op_desc.mutable_outputs()->Add(); auto y = op_desc.mutable_outputs()->Add();
y->set_op_proto_name("ys"); y->set_parameter("ys");
*y->mutable_var_names()->Add() = "y0"; *y->mutable_arguments()->Add() = "y0";
*y->mutable_var_names()->Add() = "y1"; *y->mutable_arguments()->Add() = "y1";
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "paddle/string/to_string.h"
#include "pybind11/numpy.h" #include "pybind11/numpy.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
...@@ -186,9 +187,13 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -186,9 +187,13 @@ All parameter, weight, gradient are variables in Paddle.
}); });
// clang-format on // clang-format on
py::class_<paddle::platform::GPUPlace>(m, "GPUPlace").def(py::init<int>()); py::class_<platform::GPUPlace>(m, "GPUPlace")
.def(py::init<int>())
.def("__str__", string::to_string<const platform::GPUPlace &>);
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace").def(py::init<>()); py::class_<paddle::platform::CPUPlace>(m, "CPUPlace")
.def(py::init<>())
.def("__str__", string::to_string<const platform::CPUPlace &>);
py::class_<OperatorBase, std::shared_ptr<OperatorBase>> operator_base( py::class_<OperatorBase, std::shared_ptr<OperatorBase>> operator_base(
m, "Operator"); m, "Operator");
......
...@@ -18,6 +18,8 @@ limitations under the License. */ ...@@ -18,6 +18,8 @@ limitations under the License. */
#include <cstring> #include <cstring>
#include <memory> #include <memory>
#include <typeindex> #include <typeindex>
#include <vector>
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
#include "paddle/memory/memory.h" #include "paddle/memory/memory.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
......
...@@ -19,7 +19,7 @@ TEST(Tensor, Dims) { ...@@ -19,7 +19,7 @@ TEST(Tensor, Dims) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
Tensor tt; Tensor tt;
tt.Resize(make_ddim({2, 3, 4})); tt.Resize({2, 3, 4});
DDim dims = tt.dims(); DDim dims = tt.dims();
ASSERT_EQ(arity(dims), 3); ASSERT_EQ(arity(dims), 3);
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
......
...@@ -93,8 +93,8 @@ TEST(Arguments, Matrix) { ...@@ -93,8 +93,8 @@ TEST(Arguments, Matrix) {
MatrixPtr matrix = Matrix::create(100, 200); MatrixPtr matrix = Matrix::create(100, 200);
CheckBufferArg check = [=](const BufferArg& arg) { CheckBufferArg check = [=](const BufferArg& arg) {
EXPECT_EQ(arg.shape().ndims(), 2U); EXPECT_EQ(arg.shape().ndims(), 2U);
EXPECT_EQ(arg.shape()[0], 100); EXPECT_EQ(arg.shape()[0], 100U);
EXPECT_EQ(arg.shape()[1], 200); EXPECT_EQ(arg.shape()[1], 200U);
EXPECT_EQ(arg.data(), matrix->getData()); EXPECT_EQ(arg.data(), matrix->getData());
EXPECT_EQ(arg.matrix<DEVICE_TYPE_CPU>().getHeight(), matrix->getHeight()); EXPECT_EQ(arg.matrix<DEVICE_TYPE_CPU>().getHeight(), matrix->getHeight());
...@@ -112,8 +112,8 @@ TEST(Arguments, Matrix) { ...@@ -112,8 +112,8 @@ TEST(Arguments, Matrix) {
TEST(Arguments, Vector) { TEST(Arguments, Vector) {
VectorPtr vector = Vector::create(100, false); VectorPtr vector = Vector::create(100, false);
CheckBufferArg check = [=](const BufferArg& arg) { CheckBufferArg check = [=](const BufferArg& arg) {
EXPECT_EQ(arg.shape().ndims(), 1); EXPECT_EQ(arg.shape().ndims(), 1U);
EXPECT_EQ(arg.shape()[0], 100); EXPECT_EQ(arg.shape()[0], 100U);
EXPECT_EQ(arg.data(), vector->getData()); EXPECT_EQ(arg.data(), vector->getData());
CpuVector inVector = arg.vector<real, DEVICE_TYPE_CPU>(); CpuVector inVector = arg.vector<real, DEVICE_TYPE_CPU>();
...@@ -131,9 +131,9 @@ TEST(Arguments, Vector) { ...@@ -131,9 +131,9 @@ TEST(Arguments, Vector) {
TEST(Arguments, CpuSparseMatrix) { TEST(Arguments, CpuSparseMatrix) {
CpuSparseMatrix sparse(200, 300, 50); CpuSparseMatrix sparse(200, 300, 50);
CheckBufferArg check = [=](const BufferArg& arg) { CheckBufferArg check = [=](const BufferArg& arg) {
EXPECT_EQ(arg.shape().ndims(), 2); EXPECT_EQ(arg.shape().ndims(), 2U);
EXPECT_EQ(arg.shape()[0], 200); EXPECT_EQ(arg.shape()[0], 200U);
EXPECT_EQ(arg.shape()[1], 300); EXPECT_EQ(arg.shape()[1], 300U);
EXPECT_EQ(arg.data(), sparse.getData()); EXPECT_EQ(arg.data(), sparse.getData());
// CHECK_EQ(arg.sparse().nnz(), 50); // CHECK_EQ(arg.sparse().nnz(), 50);
// CHECK_EQ(arg.sparse().dataFormat(), SPARSE_CSR_FORMAT); // CHECK_EQ(arg.sparse().dataFormat(), SPARSE_CSR_FORMAT);
...@@ -152,10 +152,10 @@ TEST(Arguments, CpuSparseMatrix) { ...@@ -152,10 +152,10 @@ TEST(Arguments, CpuSparseMatrix) {
TEST(Arguments, BufferArg) { TEST(Arguments, BufferArg) {
BufferArg arg(nullptr, VALUE_TYPE_FLOAT, {1, 2, 3}); BufferArg arg(nullptr, VALUE_TYPE_FLOAT, {1, 2, 3});
CheckBufferArg check = [=](const BufferArg& arg) { CheckBufferArg check = [=](const BufferArg& arg) {
EXPECT_EQ(arg.shape().ndims(), 3); EXPECT_EQ(arg.shape().ndims(), 3U);
EXPECT_EQ(arg.shape()[0], 1); EXPECT_EQ(arg.shape()[0], 1U);
EXPECT_EQ(arg.shape()[1], 2); EXPECT_EQ(arg.shape()[1], 2U);
EXPECT_EQ(arg.shape()[2], 3); EXPECT_EQ(arg.shape()[2], 3U);
}; };
BufferArgs argments; BufferArgs argments;
......
...@@ -44,7 +44,7 @@ TEST(TensorShape, GetAndSet) { ...@@ -44,7 +44,7 @@ TEST(TensorShape, GetAndSet) {
EXPECT_EQ(t.ndims(), 3U); EXPECT_EQ(t.ndims(), 3U);
EXPECT_EQ(t.getElements(), 6U); EXPECT_EQ(t.getElements(), 6U);
EXPECT_EQ(t[1], 2); EXPECT_EQ(t[1], 2U);
t.setDim(1, 100); t.setDim(1, 100);
EXPECT_EQ(t.getElements(), 300U); EXPECT_EQ(t.getElements(), 300U);
EXPECT_EQ(t[1], 100U); EXPECT_EQ(t[1], 100U);
......
...@@ -96,7 +96,7 @@ void SubNestedSequenceLayer::calSelectedCols( ...@@ -96,7 +96,7 @@ void SubNestedSequenceLayer::calSelectedCols(
for (size_t i = 0; i < seqNum; ++i) { for (size_t i = 0; i < seqNum; ++i) {
for (size_t j = 0; j < beamSize; ++j) { for (size_t j = 0; j < beamSize; ++j) {
if (selectedIndices->getElement(i, j) == -1.) break; if (selectedIndices->getElement(i, j) == -1.) break;
int selSubSeqIdx = selectedIndices->getElement(i, j); size_t selSubSeqIdx = selectedIndices->getElement(i, j);
CHECK_GT(inputSeqInfoVec_[i].size() - 1, selSubSeqIdx); CHECK_GT(inputSeqInfoVec_[i].size() - 1, selSubSeqIdx);
size_t subSeqLen = inputSeqInfoVec_[i][selSubSeqIdx + 1] - size_t subSeqLen = inputSeqInfoVec_[i][selSubSeqIdx + 1] -
...@@ -135,7 +135,7 @@ void SubNestedSequenceLayer::forward(PassType passType) { ...@@ -135,7 +135,7 @@ void SubNestedSequenceLayer::forward(PassType passType) {
CHECK(inputSeq.hasSubseq()) << "The first input of SubNestSequence layer " CHECK(inputSeq.hasSubseq()) << "The first input of SubNestSequence layer "
<< "must be a nested sequence."; << "must be a nested sequence.";
const MatrixPtr selectedIndices = getInputValue(1); const MatrixPtr selectedIndices = getInputValue(1);
CHECK_EQ(inputSeq.getNumSequences(), selectedIndices->getHeight()); CHECK_EQ(size_t(inputSeq.getNumSequences()), selectedIndices->getHeight());
if (dynamic_cast<GpuMatrix*>(selectedIndices.get())) { if (dynamic_cast<GpuMatrix*>(selectedIndices.get())) {
/* /*
......
...@@ -88,7 +88,7 @@ void checkLayerOut(vector<vector<int>> groundTruth, ...@@ -88,7 +88,7 @@ void checkLayerOut(vector<vector<int>> groundTruth,
TEST(Layer, kmaxSeqScoreLayer) { TEST(Layer, kmaxSeqScoreLayer) {
const size_t maxBeamSize = 100; const size_t maxBeamSize = 100;
int beamSize = 1 + (rand() % maxBeamSize); size_t beamSize = 1 + (rand() % maxBeamSize);
vector<int> seqStartPosition; vector<int> seqStartPosition;
vector<int> subSeqStartPosition; vector<int> subSeqStartPosition;
......
...@@ -45,10 +45,8 @@ cc_library(net_op SRCS net_op.cc DEPS op_registry) ...@@ -45,10 +45,8 @@ cc_library(net_op SRCS net_op.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
op_library(add_op SRCS add_op.cc add_op.cu) op_library(add_op SRCS add_op.cc add_op.cu)
cc_test(add_op_test SRCS add_op_test.cc DEPS add_op)
op_library(mean_op SRCS mean_op.cc mean_op.cu) op_library(mean_op SRCS mean_op.cc mean_op.cu)
cc_test(mean_op_test SRCS mean_op_test.cc DEPS mean_op)
op_library(mul_op SRCS mul_op.cc mul_op.cu) op_library(mul_op SRCS mul_op.cc mul_op.cu)
op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
...@@ -59,7 +57,6 @@ op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) ...@@ -59,7 +57,6 @@ op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu)
op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu)
op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
cc_test(sgd_op_test SRCS sgd_op_test.cc DEPS sgd_op)
op_library(fc_op op_library(fc_op
SRCS fc_op.cc SRCS fc_op.cc
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#define private public
#include "paddle/framework/op_registry.h"
USE_OP(add_two);
TEST(AddOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos();
auto it = protos.find("add_two");
ASSERT_NE(it, protos.end());
auto& op_creators = paddle::framework::OpRegistry::op_creators();
auto it1 = op_creators.find("add_two_grad");
ASSERT_NE(it1, op_creators.end());
}
...@@ -39,7 +39,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -39,7 +39,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
class MeanGradOp : public framework::OperatorWithKernel { class MeanGradOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>("X" + framework::kGradVarSuffix) ctx.Output<Tensor>(framework::GradVarName("X"))
->Resize(ctx.Input<Tensor>("X")->dims()); ->Resize(ctx.Input<Tensor>("X")->dims());
} }
}; };
......
...@@ -48,10 +48,10 @@ template <typename Place, typename T> ...@@ -48,10 +48,10 @@ template <typename Place, typename T>
class MeanGradKernel : public framework::OpKernel { class MeanGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto OG = context.Input<Tensor>("Out" + framework::kGradVarSuffix); auto OG = context.Input<Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(framework::product(OG->dims()) == 1, PADDLE_ENFORCE(framework::product(OG->dims()) == 1,
"Mean Gradient should be scalar"); "Mean Gradient should be scalar");
auto IG = context.Output<Tensor>("X" + framework::kGradVarSuffix); auto IG = context.Output<Tensor>(framework::GradVarName("X"));
IG->mutable_data<T>(context.GetPlace()); IG->mutable_data<T>(context.GetPlace());
T ig_size = (T)framework::product(IG->dims()); T ig_size = (T)framework::product(IG->dims());
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <paddle/framework/op_registry.h>
USE_OP(mean);
TEST(MeanOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos();
auto it = protos.find("mean");
ASSERT_NE(it, protos.end());
}
...@@ -21,19 +21,20 @@ ...@@ -21,19 +21,20 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
const char NetOp::kAll[] = "all";
void NetOp::CompleteAddOp(bool calc) { void NetOp::CompleteAddOp(bool calc) {
add_op_done_ = true; add_op_done_ = true;
if (!calc) return; if (!calc) return;
std::set<std::string> input_set; std::set<std::string> input_set;
std::set<std::string> output_set; std::set<std::string> output_set;
std::set<std::string> temp_output;
for (auto& op : ops_) { for (auto& op : ops_) {
for (auto& ipt : op->inputs_) { for (auto& ipt : op->inputs_) {
for (auto& var_name : ipt.second) { for (auto& var_name : ipt.second) {
if (!Contains(output_set, var_name)) { // Not other op's output if (!Contains(output_set, var_name)) { // Not other op's output
input_set.insert(var_name); input_set.insert(var_name);
} else { } else {
temp_output.insert(var_name); intermediate_outputs_.insert(var_name);
} }
} }
} }
...@@ -44,24 +45,12 @@ void NetOp::CompleteAddOp(bool calc) { ...@@ -44,24 +45,12 @@ void NetOp::CompleteAddOp(bool calc) {
} }
} }
} }
auto& inputs = inputs_["all"]; auto& inputs = inputs_[kAll];
inputs.reserve(input_set.size()); inputs.reserve(input_set.size());
std::copy(input_set.begin(), input_set.end(), std::back_inserter(inputs)); std::copy(input_set.begin(), input_set.end(), std::back_inserter(inputs));
auto& outputs = outputs_["all"]; auto& outputs = outputs_[kAll];
outputs.reserve(output_set.size()); outputs.reserve(output_set.size());
std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs)); std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs));
//! TODO figure out how to generate temporary_index in Network.
std::vector<int> tmp_index;
tmp_index.reserve(temp_output.size());
int output_len = static_cast<int>(outputs.size());
for (int i = 0; i < output_len; ++i) {
if (Contains(temp_output, outputs[i])) {
tmp_index.push_back(i);
}
}
attrs_["temporary_index"] = tmp_index;
} }
std::string NetOp::DebugString() const { std::string NetOp::DebugString() const {
...@@ -78,5 +67,19 @@ std::string NetOp::DebugString() const { ...@@ -78,5 +67,19 @@ std::string NetOp::DebugString() const {
bool NetOp::IsNetOp() const { return true; } bool NetOp::IsNetOp() const { return true; }
std::vector<std::string> NetOp::OutputVars(bool has_intermediate) const {
if (has_intermediate) {
return this->outputs_.at(kAll);
}
auto& all = this->outputs_.at(kAll);
std::vector<std::string> ret_val;
for (auto& each : all) {
if (!Contains(intermediate_outputs_, each)) {
ret_val.push_back(each);
}
}
return ret_val;
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -36,6 +36,8 @@ namespace operators { ...@@ -36,6 +36,8 @@ namespace operators {
*/ */
class NetOp : public framework::OperatorBase { class NetOp : public framework::OperatorBase {
public: public:
static const char kAll[];
/** /**
* Infer all the operators' input and output variables' shapes, will be called * Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch * before every mini-batch
...@@ -91,11 +93,13 @@ class NetOp : public framework::OperatorBase { ...@@ -91,11 +93,13 @@ class NetOp : public framework::OperatorBase {
std::string DebugString() const override; std::string DebugString() const override;
bool IsNetOp() const override; bool IsNetOp() const override;
std::vector<std::string> OutputVars(bool has_intermediate) const override;
std::vector<std::shared_ptr<OperatorBase>> ops_; std::vector<std::shared_ptr<OperatorBase>> ops_;
private: private:
bool add_op_done_{false}; bool add_op_done_{false};
std::set<std::string> intermediate_outputs_;
template <typename T, typename KeyType> template <typename T, typename KeyType>
static bool Contains(T container, KeyType key) { static bool Contains(T container, KeyType key) {
......
...@@ -54,22 +54,13 @@ TEST(OpKernel, all) { ...@@ -54,22 +54,13 @@ TEST(OpKernel, all) {
net->CompleteAddOp(); net->CompleteAddOp();
AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"},
net->inputs_.at("__all__")); net->inputs_.at(NetOp::kAll));
AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_.at("__all__")); AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_.at(NetOp::kAll));
auto tmp_idx_iter = net->attrs_.find("temporary_index");
ASSERT_NE(net->attrs_.end(), tmp_idx_iter);
auto& tmp_idx = boost::get<std::vector<int>>(tmp_idx_iter->second);
ASSERT_EQ(1UL, tmp_idx.size());
ASSERT_EQ("y", net->outputs_.at("__all__")[tmp_idx[0]]);
Scope scope; auto final_outs = net->OutputVars(false);
platform::CPUDeviceContext dev_ctx;
net->InferShape(scope); ASSERT_EQ(final_outs.size(), 1UL);
net->Run(scope, dev_ctx); ASSERT_EQ(final_outs[0], "z");
ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), platform::EnforceNotMet);
} }
TEST(NetOp, insert_op) { TEST(NetOp, insert_op) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <paddle/framework/op_registry.h>
USE_OP(sgd);
TEST(SGDOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos();
auto it = protos.find("sgd");
ASSERT_NE(it, protos.end());
}
...@@ -8,7 +8,7 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags) ...@@ -8,7 +8,7 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
add_subdirectory(dynload) add_subdirectory(dynload)
cc_test(enforce_test SRCS enforce_test.cc) cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece)
IF(WITH_GPU) IF(WITH_GPU)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader) set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
......
...@@ -15,11 +15,12 @@ limitations under the License. */ ...@@ -15,11 +15,12 @@ limitations under the License. */
#pragma once #pragma once
#include <execinfo.h> #include <execinfo.h>
#include <paddle/string/printf.h>
#include <iomanip> #include <iomanip>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include "paddle/string/printf.h"
#include "paddle/string/to_string.h"
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
...@@ -191,27 +192,11 @@ inline void throw_on_error(T e) { ...@@ -191,27 +192,11 @@ inline void throw_on_error(T e) {
PADDLE_ENFORCE(nullptr != (__VAL), #__VAL " should not be null\n%s", \ PADDLE_ENFORCE(nullptr != (__VAL), #__VAL " should not be null\n%s", \
paddle::string::Sprintf("" __VA_ARGS__)); paddle::string::Sprintf("" __VA_ARGS__));
template <typename T>
inline std::string enforce_to_string(const T& val) {
std::ostringstream sout;
sout << val;
return sout.str();
}
template <>
inline std::string enforce_to_string(const std::string& val) {
return val;
}
template <>
inline std::string enforce_to_string(const char* const& val) {
return std::string(val);
}
#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \ #define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \
PADDLE_ENFORCE(__VAL0 __CMP __VAL1, \ PADDLE_ENFORCE(__VAL0 __CMP __VAL1, \
"enforce %s " #__CMP " %s failed, %s " #__INV_CMP " %s\n%s", \ "enforce %s " #__CMP " %s failed, %s " #__INV_CMP " %s\n%s", \
#__VAL0, #__VAL1, \ #__VAL0, #__VAL1, paddle::string::to_string(__VAL0), \
paddle::platform::enforce_to_string(__VAL0), \ paddle::string::to_string(__VAL1), \
paddle::platform::enforce_to_string(__VAL1), \
paddle::string::Sprintf("" __VA_ARGS__)); paddle::string::Sprintf("" __VA_ARGS__));
} // namespace platform } // namespace platform
......
...@@ -9,10 +9,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,10 +9,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <array>
#include <iostream>
#include <memory> #include <memory>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/string/piece.h"
using StringPiece = paddle::string::Piece;
using paddle::string::HasPrefix;
TEST(ENFORCE, OK) { TEST(ENFORCE, OK) {
PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345); PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345);
...@@ -22,19 +28,15 @@ TEST(ENFORCE, OK) { ...@@ -22,19 +28,15 @@ TEST(ENFORCE, OK) {
} }
TEST(ENFORCE, FAILED) { TEST(ENFORCE, FAILED) {
bool in_catch = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123); PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123);
} catch (paddle::platform::EnforceNotMet error) { } catch (paddle::platform::EnforceNotMet error) {
// your error handling code here caught_exception = true;
in_catch = true; EXPECT_TRUE(
std::string msg = "Enforce is not ok 123 at all"; HasPrefix(StringPiece(error.what()), "Enforce is not ok 123 at all"));
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
} }
ASSERT_TRUE(in_catch); EXPECT_TRUE(caught_exception);
} }
TEST(ENFORCE, NO_ARG_OK) { TEST(ENFORCE, NO_ARG_OK) {
...@@ -47,41 +49,27 @@ TEST(ENFORCE, NO_ARG_OK) { ...@@ -47,41 +49,27 @@ TEST(ENFORCE, NO_ARG_OK) {
TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) { TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) {
int a = 2; int a = 2;
bool in_catch = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE_EQ(a, 1 + 3); PADDLE_ENFORCE_EQ(a, 1 + 3);
} catch (paddle::platform::EnforceNotMet error) { } catch (paddle::platform::EnforceNotMet error) {
in_catch = true; caught_exception = true;
const std::string msg = "enforce a == 1 + 3 failed, 2 != 4"; HasPrefix(StringPiece(error.what()), "enforce a == 1 + 3 failed, 2 != 4");
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
} }
EXPECT_TRUE(caught_exception);
ASSERT_TRUE(in_catch);
} }
TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) { TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) {
int a = 2; int a = 2;
bool in_catch = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE_EQ(a, 1 + 3, "%s size not match", "their"); PADDLE_ENFORCE_EQ(a, 1 + 3, "%s size not match", "their");
} catch (paddle::platform::EnforceNotMet error) { } catch (paddle::platform::EnforceNotMet error) {
in_catch = true; caught_exception = true;
const std::string msg = HasPrefix(StringPiece(error.what()),
"enforce a == 1 + 3 failed, 2 != 4\ntheir size not match"; "enforce a == 1 + 3 failed, 2 != 4\ntheir size not match");
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
} }
EXPECT_TRUE(caught_exception);
ASSERT_TRUE(in_catch);
} }
TEST(ENFORCE_NE, OK) { TEST(ENFORCE_NE, OK) {
...@@ -89,42 +77,32 @@ TEST(ENFORCE_NE, OK) { ...@@ -89,42 +77,32 @@ TEST(ENFORCE_NE, OK) {
PADDLE_ENFORCE_NE(1.0, 2UL); PADDLE_ENFORCE_NE(1.0, 2UL);
} }
TEST(ENFORCE_NE, FAIL) { TEST(ENFORCE_NE, FAIL) {
bool in_catch = false; bool caught_exception = false;
try { try {
// 2UL here to check data type compatible // 2UL here to check data type compatible
PADDLE_ENFORCE_NE(1.0, 1UL); PADDLE_ENFORCE_NE(1.0, 1UL);
} catch (paddle::platform::EnforceNotMet error) { } catch (paddle::platform::EnforceNotMet error) {
in_catch = true; caught_exception = true;
const std::string msg = "enforce 1.0 != 1UL failed, 1.000000 == 1"; EXPECT_TRUE(HasPrefix(StringPiece(error.what()),
const char* what = error.what(); "enforce 1.0 != 1UL failed, 1 == 1"))
for (size_t i = 0; i < msg.length(); ++i) { << error.what() << " does not have expected prefix";
ASSERT_EQ(what[i], msg[i]);
}
} }
EXPECT_TRUE(caught_exception);
ASSERT_TRUE(in_catch);
} }
TEST(ENFORCE_GT, OK) { PADDLE_ENFORCE_GT(2, 1); } TEST(ENFORCE_GT, OK) { PADDLE_ENFORCE_GT(2, 1); }
TEST(ENFORCE_GT, FAIL) { TEST(ENFORCE_GT, FAIL) {
bool in_catch = false; bool caught_exception = false;
try { try {
// 2UL here to check data type compatible
PADDLE_ENFORCE_GT(1, 2UL); PADDLE_ENFORCE_GT(1, 2UL);
} catch (paddle::platform::EnforceNotMet error) { } catch (paddle::platform::EnforceNotMet error) {
in_catch = true; caught_exception = true;
const std::string msg = "enforce 1 > 2UL failed, 1 <= 2"; EXPECT_TRUE(
const char* what = error.what(); HasPrefix(StringPiece(error.what()), "enforce 1 > 2UL failed, 1 <= 2"));
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
} }
EXPECT_TRUE(caught_exception);
ASSERT_TRUE(in_catch);
} }
TEST(ENFORCE_GE, OK) { TEST(ENFORCE_GE, OK) {
...@@ -134,21 +112,16 @@ TEST(ENFORCE_GE, OK) { ...@@ -134,21 +112,16 @@ TEST(ENFORCE_GE, OK) {
PADDLE_ENFORCE_GE(3.21, 2UL); PADDLE_ENFORCE_GE(3.21, 2UL);
} }
TEST(ENFORCE_GE, FAIL) { TEST(ENFORCE_GE, FAIL) {
bool in_catch = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE_GE(1, 2UL); PADDLE_ENFORCE_GE(1, 2UL);
} catch (paddle::platform::EnforceNotMet error) { } catch (paddle::platform::EnforceNotMet error) {
in_catch = true; caught_exception = true;
const std::string msg = "enforce 1 >= 2UL failed, 1 < 2"; EXPECT_TRUE(
const char* what = error.what(); HasPrefix(StringPiece(error.what()), "enforce 1 >= 2UL failed, 1 < 2"));
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
} }
EXPECT_TRUE(caught_exception);
ASSERT_TRUE(in_catch);
} }
TEST(ENFORCE_LE, OK) { TEST(ENFORCE_LE, OK) {
...@@ -159,21 +132,16 @@ TEST(ENFORCE_LE, OK) { ...@@ -159,21 +132,16 @@ TEST(ENFORCE_LE, OK) {
PADDLE_ENFORCE_LE(2UL, 3.2); PADDLE_ENFORCE_LE(2UL, 3.2);
} }
TEST(ENFORCE_LE, FAIL) { TEST(ENFORCE_LE, FAIL) {
bool in_catch = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE_GT(1, 2UL); PADDLE_ENFORCE_GT(1, 2UL);
} catch (paddle::platform::EnforceNotMet error) { } catch (paddle::platform::EnforceNotMet error) {
in_catch = true; caught_exception = true;
const std::string msg = "enforce 1 > 2UL failed, 1 <= 2"; EXPECT_TRUE(
const char* what = error.what(); HasPrefix(StringPiece(error.what()), "enforce 1 > 2UL failed, 1 <= 2"));
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
} }
EXPECT_TRUE(caught_exception);
ASSERT_TRUE(in_catch);
} }
TEST(ENFORCE_LT, OK) { TEST(ENFORCE_LT, OK) {
...@@ -182,21 +150,15 @@ TEST(ENFORCE_LT, OK) { ...@@ -182,21 +150,15 @@ TEST(ENFORCE_LT, OK) {
PADDLE_ENFORCE_LT(2UL, 3); PADDLE_ENFORCE_LT(2UL, 3);
} }
TEST(ENFORCE_LT, FAIL) { TEST(ENFORCE_LT, FAIL) {
bool in_catch = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE_LT(1UL, 0.12); PADDLE_ENFORCE_LT(1UL, 0.12);
} catch (paddle::platform::EnforceNotMet error) { } catch (paddle::platform::EnforceNotMet error) {
in_catch = true; caught_exception = true;
const std::string msg = "enforce 1UL < 0.12 failed, 1 >= 0.12"; EXPECT_TRUE(HasPrefix(StringPiece(error.what()),
const char* what = error.what(); "enforce 1UL < 0.12 failed, 1 >= 0.12"));
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
} }
EXPECT_TRUE(caught_exception);
ASSERT_TRUE(in_catch);
} }
TEST(ENFORCE_NOT_NULL, OK) { TEST(ENFORCE_NOT_NULL, OK) {
...@@ -205,20 +167,50 @@ TEST(ENFORCE_NOT_NULL, OK) { ...@@ -205,20 +167,50 @@ TEST(ENFORCE_NOT_NULL, OK) {
delete a; delete a;
} }
TEST(ENFORCE_NOT_NULL, FAIL) { TEST(ENFORCE_NOT_NULL, FAIL) {
bool in_catch = false; bool caught_exception = false;
int* a{nullptr};
try { try {
int* a = nullptr;
PADDLE_ENFORCE_NOT_NULL(a); PADDLE_ENFORCE_NOT_NULL(a);
} catch (paddle::platform::EnforceNotMet error) { } catch (paddle::platform::EnforceNotMet error) {
in_catch = true; caught_exception = true;
const std::string msg = "a should not be null"; EXPECT_TRUE(HasPrefix(StringPiece(error.what()), "a should not be null"));
const char* what = error.what(); }
for (size_t i = 0; i < msg.length(); ++i) { EXPECT_TRUE(caught_exception);
ASSERT_EQ(what[i], msg[i]); }
struct Dims {
size_t dims_[4];
bool operator==(const Dims& o) const {
for (size_t i = 0; i < 4; ++i) {
if (dims_[i] != o.dims_[i]) return false;
} }
return true;
} }
};
ASSERT_TRUE(in_catch); std::ostream& operator<<(std::ostream& os, const Dims& d) {
for (size_t i = 0; i < 4; ++i) {
if (i == 0) {
os << "[";
}
os << d.dims_[i];
if (i == 4 - 1) {
os << "]";
} else {
os << ", ";
}
}
return os;
} }
TEST(ENFORCE_USER_DEFINED_CLASS, EQ) {
Dims a{{1, 2, 3, 4}}, b{{1, 2, 3, 4}};
PADDLE_ENFORCE_EQ(a, b);
}
TEST(ENFORCE_USER_DEFINED_CLASS, NE) {
Dims a{{1, 2, 3, 4}}, b{{5, 6, 7, 8}};
ASSERT_THROW(PADDLE_ENFORCE_EQ(a, b), paddle::platform::EnforceNotMet);
}
\ No newline at end of file
...@@ -2,3 +2,4 @@ cc_library(stringpiece SRCS piece.cc) ...@@ -2,3 +2,4 @@ cc_library(stringpiece SRCS piece.cc)
cc_test(stringpiece_test SRCS piece_test.cc DEPS stringpiece glog gflags) cc_test(stringpiece_test SRCS piece_test.cc DEPS stringpiece glog gflags)
cc_test(stringprintf_test SRCS printf_test.cc DEPS glog gflags) cc_test(stringprintf_test SRCS printf_test.cc DEPS glog gflags)
cc_test(to_string_test SRCS to_string_test.cc)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sstream>
#include <string>
namespace paddle {
namespace string {
template <typename T>
inline std::string to_string(T v) {
std::ostringstream sout;
sout << v;
return sout.str();
}
// Faster std::string/const char* type
template <>
inline std::string to_string(std::string v) {
return v;
}
template <>
inline std::string to_string(const char* v) {
return std::string(v);
}
} // namespace string
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/string/to_string.h"
#include <gtest/gtest.h>
constexpr char kOutputString[] = "User Defined Output";
class UserDefinedClass {
public:
};
std::ostream& operator<<(std::ostream& s, const UserDefinedClass& ins) {
s << kOutputString;
return s;
}
TEST(to_string, normal) {
using namespace paddle::string;
ASSERT_EQ("10", to_string(10));
ASSERT_EQ("abc", to_string("abc"));
ASSERT_EQ("1.2", to_string(1.2));
}
TEST(to_string, user_defined) {
using namespace paddle::string;
UserDefinedClass instance;
ASSERT_EQ(kOutputString, to_string(instance));
}
\ No newline at end of file
...@@ -50,8 +50,8 @@ void NewRemoteParameterUpdater::init( ...@@ -50,8 +50,8 @@ void NewRemoteParameterUpdater::init(
// create parameter server client. // create parameter server client.
if (useEtcd_) { if (useEtcd_) {
parameterClient_ = paddle_new_etcd_pserver_client( parameterClient_ =
(char *)pserverSpec_.c_str(), FLAGS_trainer_id == 0); paddle_new_etcd_pserver_client((char *)pserverSpec_.c_str());
} else { } else {
parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(), parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(),
FLAGS_trainer_id == 0); FLAGS_trainer_id == 0);
......
...@@ -92,15 +92,27 @@ def get_numeric_gradient(op, ...@@ -92,15 +92,27 @@ def get_numeric_gradient(op,
class GradientChecker(unittest.TestCase): class GradientChecker(unittest.TestCase):
def __is_close(self, numeric_grads, scope, max_relative_error): def assert_is_close(self, numeric_grads, scope, max_relative_error,
msg_prefix):
for name in numeric_grads: for name in numeric_grads:
op_grad = numpy.array( b = numpy.array(scope.find_var(grad_var_name(name)).get_tensor())
scope.find_var(grad_var_name(name)).get_tensor()) a = numeric_grads[name]
is_close = numpy.allclose(
numeric_grads[name], op_grad, rtol=max_relative_error, atol=100) abs_a = numpy.abs(a)
if not is_close: # if abs_a is nearly zero, then use abs error for a, not relative
return False # error.
return True abs_a[abs_a < 1e-3] = 1
diff_mat = numpy.abs(a - b) / abs_a
max_diff = numpy.max(diff_mat)
def err_msg():
offset = numpy.argmax(diff_mat > max_relative_error)
return "%s Variable %s max gradient diff %f over limit %f, the first " \
"error element is %d" % (
msg_prefix, name, max_diff, max_relative_error, offset)
self.assertLessEqual(max_diff, max_relative_error, err_msg())
def check_grad(self, def check_grad(self,
forward_op, forward_op,
...@@ -145,7 +157,8 @@ class GradientChecker(unittest.TestCase): ...@@ -145,7 +157,8 @@ class GradientChecker(unittest.TestCase):
# get numeric gradient # get numeric gradient
for check_name in inputs_to_check: for check_name in inputs_to_check:
numeric_grad[check_name] = \ numeric_grad[check_name] = \
get_numeric_gradient(forward_op, input_vars, output_name, check_name) get_numeric_gradient(forward_op, input_vars, output_name,
check_name)
# get operator gradient according to different device # get operator gradient according to different device
for place in places: for place in places:
...@@ -187,15 +200,8 @@ class GradientChecker(unittest.TestCase): ...@@ -187,15 +200,8 @@ class GradientChecker(unittest.TestCase):
backward_op.infer_shape(scope) backward_op.infer_shape(scope)
backward_op.run(scope, ctx) backward_op.run(scope, ctx)
if isinstance(place, core.CPUPlace): self.assert_is_close(numeric_grad, scope, max_relative_error,
msg = "CPU kernel gradient is not close to numeric gradient" "Gradient Check On %s" % str(place))
else:
if isinstance(place, core.GPUPlace):
msg = "GPU kernel gradient is not close to numeric gradient"
else:
raise ValueError("unknown place " + type(place))
self.assertTrue(
self.__is_close(numeric_grad, scope, max_relative_error), msg)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册