提交 a0e195a1 编写于 作者: D dongzhihong

Merge remote-tracking branch 'origin/develop' into add_op_gradient

...@@ -28,7 +28,7 @@ RUN apt-get update && \ ...@@ -28,7 +28,7 @@ RUN apt-get update && \
wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \ wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \
curl sed grep graphviz libjpeg-dev zlib1g-dev \ curl sed grep graphviz libjpeg-dev zlib1g-dev \
python-matplotlib gcc-4.8 g++-4.8 \ python-matplotlib gcc-4.8 g++-4.8 \
automake locales clang-format-3.8 swig doxygen cmake \ automake locales clang-format swig doxygen cmake \
liblapack-dev liblapacke-dev libboost-dev \ liblapack-dev liblapacke-dev libboost-dev \
clang-3.8 llvm-3.8 libclang-3.8-dev \ clang-3.8 llvm-3.8 libclang-3.8-dev \
net-tools && \ net-tools && \
......
hash: 1b9b07408ca7fac27a374dc2ccd2433e4bff090484008a037df967284949a582 hash: 1b9b07408ca7fac27a374dc2ccd2433e4bff090484008a037df967284949a582
updated: 2017-08-03T21:46:51.744995189Z updated: 2017-08-07T23:37:48.867469328Z
imports: imports:
- name: github.com/beorn7/perks - name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9 version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
...@@ -10,7 +10,7 @@ imports: ...@@ -10,7 +10,7 @@ imports:
- name: github.com/cockroachdb/cmux - name: github.com/cockroachdb/cmux
version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92 version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92
- name: github.com/coreos/etcd - name: github.com/coreos/etcd
version: c31bec0f29facff13f7c3e3d948e55dd6689ed42 version: d0d1a87aa96ae14914751d42264262cb69eda170
subpackages: subpackages:
- alarm - alarm
- auth - auth
...@@ -24,6 +24,7 @@ imports: ...@@ -24,6 +24,7 @@ imports:
- error - error
- etcdserver - etcdserver
- etcdserver/api - etcdserver/api
- etcdserver/api/etcdhttp
- etcdserver/api/v2http - etcdserver/api/v2http
- etcdserver/api/v2http/httptypes - etcdserver/api/v2http/httptypes
- etcdserver/api/v3client - etcdserver/api/v3client
...@@ -210,11 +211,6 @@ testImports: ...@@ -210,11 +211,6 @@ testImports:
version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9 version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9
subpackages: subpackages:
- spew - spew
- name: github.com/docker/docker
version: b6d164e6c46d8115b146e4c3ac93784e9ef8b49e
subpackages:
- pkg/ioutils
- pkg/longpath
- name: github.com/pmezard/go-difflib - name: github.com/pmezard/go-difflib
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
subpackages: subpackages:
......
package master_test package master_test
import ( import (
"io/ioutil"
"net/url"
"os" "os"
"strings"
"testing" "testing"
"time" "time"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/embed" "github.com/coreos/etcd/embed"
"github.com/docker/docker/pkg/ioutils"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestNewServiceWithEtcd(t *testing.T) { func TestNewServiceWithEtcd(t *testing.T) {
// setup an embed etcd server // setup an embed etcd server
etcdDir, err := ioutils.TempDir("", "") etcdDir, err := ioutil.TempDir("", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
cfg := embed.NewConfig() cfg := embed.NewConfig()
lpurl, _ := url.Parse("http://localhost:0")
lcurl, _ := url.Parse("http://localhost:0")
cfg.LPUrls = []url.URL{*lpurl}
cfg.LCUrls = []url.URL{*lcurl}
cfg.Dir = etcdDir cfg.Dir = etcdDir
e, err := embed.StartEtcd(cfg) e, err := embed.StartEtcd(cfg)
if err != nil { if err != nil {
...@@ -30,15 +36,13 @@ func TestNewServiceWithEtcd(t *testing.T) { ...@@ -30,15 +36,13 @@ func TestNewServiceWithEtcd(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
}() }()
select {
case <-e.Server.ReadyNotify():
t.Log("Server is ready!")
case <-time.After(60 * time.Second):
e.Server.Stop() // trigger a shutdown
t.Fatal("Server took too long to start!")
}
ep := []string{"127.0.0.1:2379"} <-e.Server.ReadyNotify()
port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
endpoint := "127.0.0.1:" + port
ep := []string{endpoint}
masterAddr := "127.0.0.1:3306" masterAddr := "127.0.0.1:3306"
store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30) store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30)
if err != nil { if err != nil {
......
...@@ -90,8 +90,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte { ...@@ -90,8 +90,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
type selector bool type selector bool
func (s selector) Select() bool { func (s selector) Select() (bool, error) {
return bool(s) return bool(s), nil
}
func (s selector) Done() error {
return nil
} }
type lister []client.Server type lister []client.Server
...@@ -114,11 +118,10 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_cli ...@@ -114,11 +118,10 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_cli
} }
//export paddle_new_etcd_pserver_client //export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcdEndpoints *C.char, selected int) C.paddle_pserver_client { func paddle_new_etcd_pserver_client(etcdEndpoints *C.char) C.paddle_pserver_client {
// TODO(Longfei: use etcd lock to decide which trainer to initialize the parameters)
addr := C.GoString(etcdEndpoints) addr := C.GoString(etcdEndpoints)
etcdClient := client.NewEtcd(addr) etcdClient := client.NewEtcd(addr)
c := client.NewClient(etcdClient, etcdClient.Desired(), selector(selected != 0)) c := client.NewClient(etcdClient, etcdClient.Desired(), etcdClient)
return add(c) return add(c)
} }
...@@ -136,7 +139,12 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) { ...@@ -136,7 +139,12 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) {
//export paddle_begin_init_params //export paddle_begin_init_params
func paddle_begin_init_params(client C.paddle_pserver_client) C.int { func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
c := get(client) c := get(client)
if selected := c.BeginInitParams(); selected { selected, err := c.BeginInitParams()
if err != nil {
panic(err)
}
if selected {
return 1 return 1
} }
return 0 return 0
......
...@@ -27,9 +27,13 @@ import ( ...@@ -27,9 +27,13 @@ import (
// TODO(helin): add RPC call retry logic // TODO(helin): add RPC call retry logic
// Selector selects if the client should initialize parameter servers. // Selector selects if the client should initialize parameters and
// reports the initialization process done.
type Selector interface { type Selector interface {
Select() bool // Select selects if the client should initialize parameter servers.
Select() (bool, error)
// Done indicates the initialization process is done.
Done() error
} }
// Server is the identification of a parameter Server. // Server is the identification of a parameter Server.
...@@ -115,7 +119,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { ...@@ -115,7 +119,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
// servers. Other trainers will be blocked until the initialization is // servers. Other trainers will be blocked until the initialization is
// done, and they need to get the initialized parameters from // done, and they need to get the initialized parameters from
// parameter servers using GetParams. // parameter servers using GetParams.
func (c *Client) BeginInitParams() bool { func (c *Client) BeginInitParams() (bool, error) {
return c.sel.Select() return c.sel.Select()
} }
......
...@@ -124,8 +124,12 @@ func initEtcdClient() { ...@@ -124,8 +124,12 @@ func initEtcdClient() {
type selector bool type selector bool
func (s selector) Select() bool { func (s selector) Select() (bool, error) {
return bool(s) return bool(s), nil
}
func (s selector) Done() error {
return nil
} }
type lister []client.Server type lister []client.Server
...@@ -135,7 +139,11 @@ func (l lister) List() []client.Server { ...@@ -135,7 +139,11 @@ func (l lister) List() []client.Server {
} }
func testClient(t *testing.T, c *client.Client) { func testClient(t *testing.T, c *client.Client) {
selected := c.BeginInitParams() selected, err := c.BeginInitParams()
if err != nil {
t.Fatal(err)
}
if !selected { if !selected {
t.Fatal("should be selected.") t.Fatal("should be selected.")
} }
......
...@@ -16,53 +16,60 @@ package client ...@@ -16,53 +16,60 @@ package client
import ( import (
"context" "context"
"errors"
"fmt"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const ( const (
defaultEtcdTimeout time.Duration = 5 * time.Second defaultEtcdTimeout time.Duration = 5 * time.Second
initLockPath = "/init_ps/lock"
initDonePath = "/init_ps/done"
initDoneVal = "1"
) )
// EtcdClient is used by pserver client that is a part of trainer process. // Etcd is used by pserver client that is a part of trainer process.
// TODO: // TODO:
// 1. add watcher to watch the change state of pservers) // 1. add watcher to watch the change state of pservers.
// 1. add etcd lock) type Etcd struct {
type EtcdClient struct {
client *clientv3.Client client *clientv3.Client
timeout time.Duration timeout time.Duration
endpoints []string endpoints []string
lock *concurrency.Mutex
} }
// Desired read ps desired number from etcd. // Desired read ps desired number from etcd.
func (p *EtcdClient) Desired() int { func (e *Etcd) Desired() int {
var psDesired int var psDesired int
for { for {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout) ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
resp, err := p.client.Get(ctx, pserver.PsDesired) resp, err := e.client.Get(ctx, pserver.PsDesired)
cancel() cancel()
if err != nil { if err != nil {
log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err) log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err)
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
kvs := resp.Kvs kvs := resp.Kvs
if len(kvs) == 0 { if len(kvs) == 0 {
log.Infoln("Waiting for ps desired registered ...") log.Infoln("Waiting for ps desired registered ...")
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value)) psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil { if err != nil {
log.Errorf("psDesired %d invalid %v", psDesired, err) log.Errorf("psDesired %d invalid %v", psDesired, err)
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
...@@ -73,26 +80,26 @@ func (p *EtcdClient) Desired() int { ...@@ -73,26 +80,26 @@ func (p *EtcdClient) Desired() int {
} }
// List return the pserver list read from etcd. // List return the pserver list read from etcd.
func (p *EtcdClient) List() []Server { func (e *Etcd) List() []Server {
psDesired := p.Desired() psDesired := e.Desired()
servers := make([]Server, psDesired) servers := make([]Server, psDesired)
for { for {
for i := 0; i < psDesired; i++ { for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout) ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
psKey := pserver.PsPath + strconv.Itoa(i) psKey := pserver.PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey) log.Debugf("checking %s", psKey)
resp, err := p.client.Get(ctx, psKey) resp, err := e.client.Get(ctx, psKey)
cancel() cancel()
if err != nil { if err != nil {
log.Infof("Get psKey= %s error, %v", psKey, err) log.Infof("Get psKey= %s error, %v", psKey, err)
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
kvs := resp.Kvs kvs := resp.Kvs
if len(kvs) == 0 { if len(kvs) == 0 {
log.Infof("Waiting for ps addr registered ...") log.Infof("Waiting for ps addr registered ...")
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
...@@ -100,7 +107,7 @@ func (p *EtcdClient) List() []Server { ...@@ -100,7 +107,7 @@ func (p *EtcdClient) List() []Server {
// TODO(Longfei) check the ps address // TODO(Longfei) check the ps address
if psAddr == "" { if psAddr == "" {
log.Infof("Get psKey = %s, psAddr is empty", psKey) log.Infof("Get psKey = %s, psAddr is empty", psKey)
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
log.Debugf("got value (%s) for key: %s", psAddr, psKey) log.Debugf("got value (%s) for key: %s", psAddr, psKey)
...@@ -113,7 +120,7 @@ func (p *EtcdClient) List() []Server { ...@@ -113,7 +120,7 @@ func (p *EtcdClient) List() []Server {
} }
// NewEtcd create a etcd client to return the state of pserver on etcd. // NewEtcd create a etcd client to return the state of pserver on etcd.
func NewEtcd(endpoints string) *EtcdClient { func NewEtcd(endpoints string) *Etcd {
ep := strings.Split(endpoints, ",") ep := strings.Split(endpoints, ",")
var cli *clientv3.Client var cli *clientv3.Client
var err error var err error
...@@ -130,10 +137,118 @@ func NewEtcd(endpoints string) *EtcdClient { ...@@ -130,10 +137,118 @@ func NewEtcd(endpoints string) *EtcdClient {
break break
} }
log.Infof("Connected to etcd: %s\n", endpoints) log.Infof("Connected to etcd: %s\n", endpoints)
client := &EtcdClient{ client := &Etcd{
client: cli, client: cli,
timeout: defaultEtcdTimeout, timeout: defaultEtcdTimeout,
endpoints: ep, endpoints: ep,
} }
return client return client
} }
// Select indicates if the current trainer is selected to initialize
// the pserver parameters.
func (e *Etcd) Select() (bool, error) {
sess, err := concurrency.NewSession(e.client, concurrency.WithTTL(5))
if err != nil {
return false, err
}
lock := concurrency.NewMutex(sess, initLockPath)
log.Infof("Trying to acquire lock at %s.", initLockPath)
// Do not use timeout context here, since we don't know how
// long does it take for other trainers to initialize the
// parameters.
err = lock.Lock(context.Background())
if err != nil {
return false, err
}
log.Infof("Successfully acquired lock at %s.", initLockPath)
get := clientv3.OpGet(initDonePath)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
tresp, err := e.client.Txn(ctx).If(lock.IsOwner()).Then(get).Commit()
cancel()
if err != nil {
return false, err
}
if !tresp.Succeeded {
return false, errors.New("no longer the owner of the lock")
}
resp := tresp.Responses[0].GetResponseRange()
if len(resp.Kvs) == 0 {
// Key value not set, select current trainer.
e.lock = lock
log.Infoln("Trainer selected.")
return true, nil
}
if string(resp.Kvs[0].Value) == initDoneVal {
log.Infoln("Initialization is already done.")
ctx, cancel = context.WithTimeout(context.Background(), e.timeout)
err = lock.Unlock(ctx)
cancel()
if err != nil {
log.Errorln(err)
}
return false, nil
}
return false, fmt.Errorf("key %s have unexpected value: %v", initDonePath, resp.Kvs[0].Value)
}
// Done indicates the parameter initialization process is done.
func (e *Etcd) Done() error {
if e.lock == nil {
return errors.New("lock is nil, Done called unexpectedly")
}
put := clientv3.OpPut(initDonePath, initDoneVal)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
tresp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit()
cancel()
if err != nil {
return err
}
if !tresp.Succeeded {
return errors.New("no longer the owner of the lock")
}
ctx, cancel = context.WithTimeout(context.Background(), e.timeout)
err = e.lock.Unlock(ctx)
cancel()
if err != nil {
log.Errorln(err)
} else {
e.lock = nil
}
return nil
}
// Close closes the etcd client.
func (e *Etcd) Close() error {
var err error
if e.lock != nil {
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
err = e.lock.Unlock(ctx)
cancel()
if err == nil {
e.lock = nil
}
}
cErr := e.client.Close()
if cErr != nil {
if err != nil {
log.Errorln(cErr)
return err
}
return cErr
}
return err
}
package client_test
import (
"io/ioutil"
"net/url"
"os"
"strings"
"sync"
"testing"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
"github.com/coreos/etcd/embed"
)
func TestSelector(t *testing.T) {
etcdDir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
cfg := embed.NewConfig()
lpurl, _ := url.Parse("http://localhost:0")
lcurl, _ := url.Parse("http://localhost:0")
cfg.LPUrls = []url.URL{*lpurl}
cfg.LCUrls = []url.URL{*lcurl}
cfg.Dir = etcdDir
e, err := embed.StartEtcd(cfg)
if err != nil {
t.Fatal(err)
}
defer func() {
e.Close()
if err := os.RemoveAll(etcdDir); err != nil {
t.Fatal(err)
}
}()
<-e.Server.ReadyNotify()
port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
endpoint := "127.0.0.1:" + port
var mu sync.Mutex
selectedCount := 0
var wg sync.WaitGroup
selectAndDone := func(c *client.Etcd) {
defer wg.Done()
selected, err := c.Select()
if err != nil {
panic(err)
}
if selected {
mu.Lock()
selectedCount++
mu.Unlock()
err = c.Done()
if err != nil {
t.Fatal(err)
}
}
}
c0 := client.NewEtcd(endpoint)
c1 := client.NewEtcd(endpoint)
c2 := client.NewEtcd(endpoint)
c3 := client.NewEtcd(endpoint)
wg.Add(3)
go selectAndDone(c0)
go selectAndDone(c1)
go selectAndDone(c2)
wg.Wait()
// simulate trainer crashed and restarted after the
// initialization process.
wg.Add(1)
go selectAndDone(c3)
wg.Wait()
mu.Lock()
if selectedCount != 1 {
t.Fatal("selected count wrong:", selectedCount)
}
mu.Unlock()
err = c0.Close()
if err != nil {
t.Fatal(err)
}
err = c1.Close()
if err != nil {
t.Fatal(err)
}
err = c2.Close()
if err != nil {
t.Fatal(err)
}
err = c3.Close()
if err != nil {
t.Fatal(err)
}
}
...@@ -7,6 +7,9 @@ cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context) ...@@ -7,6 +7,9 @@ cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context)
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
cc_library(lod_tensor SRCS lod_tensor.cc details/lod_tensor.cc DEPS ddim place tensor)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
cc_library(scope SRCS scope.cc) cc_library(scope SRCS scope.cc)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include <list> #include <list>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
......
...@@ -17,16 +17,21 @@ ...@@ -17,16 +17,21 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
#include "paddle/operators/type_alias.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using OperatorBase = framework::OperatorBase;
using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker;
using OpProto = framework::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
using Scope = framework::Scope;
using DeviceContext = platform::DeviceContext;
class EmptyOp : public OperatorBase { class EmptyOp : public OperatorBase {
public: public:
void InferShape(const Scope &scope) const override {} void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope, void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
const platform::DeviceContext &dev_ctx) const override {}
}; };
class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
...@@ -71,7 +76,7 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker { ...@@ -71,7 +76,7 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
} }
}; };
class FcOp : public ops::NetOp { class FcOp : public operators::NetOp {
public: public:
void Init() override { void Init() override {
AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")}, AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")},
...@@ -143,6 +148,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker { ...@@ -143,6 +148,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
} // namespace paddle } // namespace paddle
namespace f = paddle::framework; namespace f = paddle::framework;
namespace ops = paddle::operators;
using EnforceNotMet = paddle::platform::EnforceNotMet; using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker); REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker);
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp); REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include <memory>
namespace paddle {
namespace framework {
namespace details {
using LOD = LODTensor::LOD;
std::shared_ptr<LOD> SliceLOD(const LOD &lod, size_t level_begin,
size_t level_end) {
auto new_lod = std::make_shared<LOD>();
new_lod->reserve(level_end - level_begin);
for (size_t i = level_begin; i < level_end; i++) {
new_lod->emplace_back(lod[i]);
}
return new_lod;
}
std::shared_ptr<LOD> SliceLOD(const LOD &lod, size_t level, size_t elem_begin,
size_t elem_end, bool tensor_shared) {
// slice the lod.
auto new_lod = std::make_shared<LOD>();
new_lod->reserve(lod.size() - level);
auto start = lod.at(level)[elem_begin];
auto end = lod.at(level)[elem_end];
for (auto it = lod.begin() + level; it != lod.end(); it++) {
auto it_begin = std::find(it->begin(), it->end(), start);
auto it_end = std::find(it_begin, it->end(), end);
PADDLE_ENFORCE(it_begin != it->end(), "error in parsing lod info");
PADDLE_ENFORCE(it_end != it->end(), "error in parsing lod info");
new_lod->emplace_back(it_begin, it_end + 1);
if (!tensor_shared) {
// reset offset if tensor is copyed and sliced.
std::transform(new_lod->back().begin(), new_lod->back().end(),
new_lod->back().begin(),
[start](int v) { return v - start; });
PADDLE_ENFORCE(new_lod->back().front() == 0, "error in slice LOD");
}
}
return new_lod;
}
} // namespace details
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
namespace paddle {
namespace framework {
namespace details {
/*
* Slice levels from LOD.
*
* @lod: LOD to slice.
* @level_begin: level to begin slice.
* @level_end: level to end slice.
*/
std::shared_ptr<LODTensor::LOD> SliceLOD(const LODTensor::LOD &lod,
size_t level_begin, size_t level_end);
/*
* Slice elements from a level of LOD.
*
* @lod: LOD to slice.
* @level: which level to slice.
* @elem_begin: element's index to begin slice.
* @elem_end: element's index to end slice.
*/
std::shared_ptr<LODTensor::LOD> SliceLOD(const LODTensor::LOD &lod,
size_t level, size_t elem_begin,
size_t elem_end, bool tensor_shared);
} // namespace details
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include <glog/logging.h>
namespace paddle {
namespace framework {
LODTensor LODTensor::SliceShared(size_t level_begin, size_t level_end) const {
PADDLE_ENFORCE(HasLOD(), "has no LOD info, can't be sliced.");
auto new_lod = details::SliceLOD(*lod_start_pos_, level_begin, level_end);
// slice levels just need to update LOD info, each level will contains the
// whole tensor_, so no need to modify tensor_.
return LODTensor(tensor_, new_lod);
}
LODTensor LODTensor::SliceShared(size_t level, size_t elem_begin,
size_t elem_end) const {
PADDLE_ENFORCE(HasLOD(), "has no LOD info, can't be sliced.");
PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level,
NumLevels());
PADDLE_ENFORCE(elem_begin < NumElements(level),
"element begin [%d] out of range [%d]", elem_begin,
NumElements(level));
PADDLE_ENFORCE(elem_end < NumElements(level) + 1,
"element end [%d] out of range [%d]", elem_end,
NumElements(level));
auto new_lod = details::SliceLOD(*lod_start_pos_, level, elem_begin, elem_end,
true /*tensor_shared*/);
// slice elements just need to update LOD info, because offsets are not
// changed, so the original tensor_ can be reused.
return LODTensor(tensor_, new_lod);
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#if (!PADDLE_ONLY_CPU)
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#endif
#include "paddle/framework/ddim.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace framework {
/*
* LODTensor (Level of details Tensor)
* see https://en.wikipedia.org/wiki/Level_of_details for reference.
*/
class LODTensor {
public:
// Level save offsets of each unit.
#ifdef PADDLE_ONLY_CPU
using Level = std::vector<size_t>;
#else
using Level = thrust::device_vector<size_t>;
#endif
// LOD stores offsets of each level of units, the largest units level first,
// then the smaller units level. Each Level stores the offsets of units in
// Tesor.
typedef std::vector<Level> LOD;
LODTensor() {}
LODTensor(const std::shared_ptr<Tensor> &tensor,
const std::shared_ptr<LOD> &lod) {
Reset(tensor, lod);
}
void Reset(const std::shared_ptr<Tensor> &tensor,
const std::shared_ptr<LOD> &lod) {
tensor_ = tensor;
lod_start_pos_ = lod;
}
/*
* Get a element from LOD.
*/
size_t lod_element(size_t level, size_t elem) const {
PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level,
NumLevels());
PADDLE_ENFORCE(elem < NumElements(level),
"element begin [%d] out of range [%d]", elem,
NumElements(level));
return (*lod_start_pos_)[level][elem];
}
/*
* Number of LODTensor's levels, each level has units of data, for example,
* in the sentence's view, article, paragraph, sentence are 3 levels.
*/
size_t NumLevels() const {
return lod_start_pos_ ? lod_start_pos_->size() : 0UL;
}
/*
* Number of elements in a level.
*/
size_t NumElements(size_t level = 0) const {
PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level,
NumLevels());
// the last offset is the end of last element
return lod_start_pos_->at(level).size() - 1;
}
/*
* Slice of levels[level_begin:level_end], with tensor copied.
*/
template <typename T>
LODTensor SliceCopied(size_t level_begin, size_t level_end,
const platform::Place &dst_place) const;
/*
* Slice of levels[level_begin:level_end], with tensor shared.
*/
LODTensor SliceShared(size_t level_begin, size_t level_end) const;
/*
* Slice of elements of a level, [elem_begin: elem_end], with tensor copied.
* @note: low performance in slice lod_start_pos_.
*/
template <typename T>
LODTensor SliceCopied(size_t level, size_t elem_begin, size_t elem_end,
const platform::Place &dst_place) const;
/*
* Slice of elements of a level, [elem_begin: elem_end], with tensor shared.
* @note: low performance in slice lod_start_pos_.
*/
LODTensor SliceShared(size_t level, size_t elem_begin, size_t elem_end) const;
/*
* Copy other's lod_start_pos_, to share LOD info.
* @note: the LOD info should not be changed.
*/
void ShareLOD(const LODTensor &other) {
lod_start_pos_ = other.lod_start_pos_;
}
/*
* Copy other's lod_start_pos_'s content, free to mutate.
*/
void CopyLOD(const LODTensor &other) {
lod_start_pos_ = std::make_shared<LOD>(*other.lod_start_pos_);
}
/*
* Determine whether LODTensor has a valid LOD info.
*/
bool HasLOD() const { return bool(lod_start_pos_); }
LOD *lod() const { return lod_start_pos_.get(); }
std::shared_ptr<Tensor> &tensor() { return tensor_; }
Tensor *raw_tensor() { return tensor_.get(); }
private:
std::shared_ptr<LOD> lod_start_pos_;
std::shared_ptr<Tensor> tensor_;
};
} // namespace framework
} // namespace paddle
#include "paddle/framework/lod_tensor_impl.h"
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/details/lod_tensor.h"
namespace paddle {
namespace framework {
template <typename T>
LODTensor LODTensor::SliceCopied(size_t level_begin, size_t level_end,
const platform::Place &dst_place) const {
PADDLE_ENFORCE(HasLOD(), "has no LOD info, can't be sliced.");
auto new_lod = details::SliceLOD(*lod_start_pos_, level_begin, level_end);
auto new_tensor = std::make_shared<Tensor>();
new_tensor->CopyFrom<T>(*tensor_, dst_place);
return LODTensor(new_tensor, new_lod);
}
template <typename T>
LODTensor LODTensor::SliceCopied(size_t level, size_t elem_begin,
size_t elem_end,
const platform::Place &dst_place) const {
PADDLE_ENFORCE(HasLOD(), "has no LOD info, can't be sliced.");
PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level,
NumLevels());
PADDLE_ENFORCE(elem_begin < NumElements(level),
"element begin [%d] out of range [%d]", elem_begin,
NumElements(level));
PADDLE_ENFORCE(elem_end < NumElements(level) + 1,
"element end [%d] out of range [%d]", elem_end,
NumElements(level));
auto new_lod = details::SliceLOD(*lod_start_pos_, level, elem_begin, elem_end,
false /*tensor_shared*/);
auto start_idx = new_lod->front().front();
auto end_idx = new_lod->front().back() - 1 /*the next element's start*/;
auto sliced_tensor = tensor_->Slice<T>(start_idx, end_idx);
auto new_tensor = std::make_shared<Tensor>();
new_tensor->CopyFrom<T>(sliced_tensor, dst_place);
return LODTensor(new_tensor, new_lod);
}
} // namespace framework
} // namespace paddle
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "paddle/framework/lod_tensor.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <memory>
namespace paddle {
namespace framework {
class LODTensorTester : public ::testing::Test {
public:
virtual void SetUp() override {
lod_tensor.reset(new LODTensor);
// tensor's batch_size: 30
// 3 levels
// 0 10 20
// 0 5 10 15 20
// 0 2 5 7 10 12 15 20
auto lod = std::make_shared<LODTensor::LOD>();
lod->push_back(std::vector<size_t>{0, 10, 20});
lod->push_back(std::vector<size_t>{0, 5, 10, 15, 20});
lod->push_back(std::vector<size_t>{0, 2, 5, 7, 10, 12, 15, 17, 20});
auto tensor = std::make_shared<Tensor>();
tensor->Resize({20 /*batch size*/, 128 /*dim*/});
// malloc memory
tensor->mutable_data<float>(place);
lod_tensor->Reset(tensor, lod);
}
protected:
std::unique_ptr<LODTensor> lod_tensor;
platform::CPUPlace place;
};
TEST_F(LODTensorTester, NumLevels) { ASSERT_EQ(lod_tensor->NumLevels(), 3UL); }
TEST_F(LODTensorTester, NumElements) {
ASSERT_EQ(lod_tensor->NumElements(0), 2UL);
ASSERT_EQ(lod_tensor->NumElements(1), 4UL);
ASSERT_EQ(lod_tensor->NumElements(2), 8UL);
}
TEST_F(LODTensorTester, SliceShared_Level) {
// slice 1 level
for (size_t level = 0; level < 3UL; ++level) {
auto new_lod_tensor = lod_tensor->SliceShared(level, level + 1);
ASSERT_EQ(new_lod_tensor.NumLevels(), 1UL);
ASSERT_EQ(new_lod_tensor.NumElements(0UL), lod_tensor->NumElements(level));
ASSERT_EQ(new_lod_tensor.tensor(), lod_tensor->tensor());
}
// slice 2 level
for (size_t level = 0; level < 2UL; ++level) {
auto new_lod_tensor = lod_tensor->SliceShared(level, level + 2);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor->NumElements(level));
ASSERT_EQ(new_lod_tensor.NumElements(1),
lod_tensor->NumElements(level + 1));
ASSERT_EQ(new_lod_tensor.tensor(), lod_tensor->tensor());
}
}
TEST_F(LODTensorTester, SliceCopied_Level) {
// slice 1 level
for (size_t level = 0; level < 3UL; ++level) {
auto new_lod_tensor =
lod_tensor->SliceCopied<float>(level, level + 1, place);
ASSERT_EQ(new_lod_tensor.NumLevels(), 1UL);
ASSERT_EQ(new_lod_tensor.NumElements(0UL), lod_tensor->NumElements(level));
// ASSERT_EQ(new_lod_tensor.tensor(), lod_tensor->tensor());
// TODO(superjom) add tensor comparation here.
}
// slice 2 level
for (size_t level = 0; level < 2UL; ++level) {
auto new_lod_tensor =
lod_tensor->SliceCopied<float>(level, level + 2, place);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor->NumElements(level));
ASSERT_EQ(new_lod_tensor.NumElements(1),
lod_tensor->NumElements(level + 1));
// ASSERT_EQ(new_lod_tensor.tensor(), lod_tensor->tensor());
// TODO(superjom) add tensor comparation here.
}
}
TEST_F(LODTensorTester, SliceShared_Element) {
size_t level = 0;
auto new_lod_tensor = lod_tensor->SliceShared(level, 0, 2);
ASSERT_EQ(new_lod_tensor.NumLevels(), 3UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_EQ(new_lod_tensor.NumElements(2), 8UL);
ASSERT_EQ(new_lod_tensor.raw_tensor(), lod_tensor->raw_tensor());
level = 1;
new_lod_tensor = lod_tensor->SliceShared(level, 0, 2);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_EQ(new_lod_tensor.raw_tensor(), lod_tensor->raw_tensor());
}
TEST_F(LODTensorTester, SliceCopied_Element) {
size_t level = 0;
auto new_lod_tensor = lod_tensor->SliceCopied<float>(level, 0, 2, place);
ASSERT_EQ(new_lod_tensor.NumLevels(), 3UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_EQ(new_lod_tensor.NumElements(2), 8UL);
ASSERT_NE(new_lod_tensor.raw_tensor(), lod_tensor->raw_tensor());
level = 1;
new_lod_tensor = lod_tensor->SliceCopied<float>(level, 0, 2, place);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_NE(new_lod_tensor.raw_tensor(), lod_tensor->raw_tensor());
level = 1;
// LOD is
// 0 5 10
// 0 2 5 7 10
new_lod_tensor = lod_tensor->SliceCopied<float>(level, 1, 3, place);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_EQ(new_lod_tensor.lod_element(0, 0), 0UL);
ASSERT_EQ(new_lod_tensor.lod_element(0, 1), 5UL);
ASSERT_EQ(new_lod_tensor.lod_element(1, 0), 0UL);
ASSERT_EQ(new_lod_tensor.lod_element(1, 1), 2UL);
ASSERT_EQ(new_lod_tensor.lod_element(1, 2), 5UL);
ASSERT_EQ(new_lod_tensor.lod_element(1, 3), 7UL);
// TODO(superjom) compare the content of these tensors
}
TEST_F(LODTensorTester, ShareLOD) {
LODTensor new_lod_tensor;
new_lod_tensor.ShareLOD(*lod_tensor);
ASSERT_EQ(new_lod_tensor.lod(), lod_tensor->lod());
}
TEST_F(LODTensorTester, CopyLOD) {
LODTensor new_lod_tensor;
new_lod_tensor.CopyLOD(*lod_tensor);
ASSERT_NE(new_lod_tensor.lod(), lod_tensor->lod());
}
} // namespace framework
} // namespace paddle
...@@ -18,11 +18,8 @@ limitations under the License. */ ...@@ -18,11 +18,8 @@ limitations under the License. */
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor_py.h" #include "paddle/framework/tensor_py.h"
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
#include "paddle/operators/type_alias.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "pybind11/numpy.h" #include "pybind11/numpy.h"
...@@ -45,6 +42,9 @@ USE_OP_WITHOUT_KERNEL(recurrent_op); ...@@ -45,6 +42,9 @@ USE_OP_WITHOUT_KERNEL(recurrent_op);
USE_OP(uniform_random); USE_OP(uniform_random);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using Tensor = framework::Tensor;
template <typename ClassType> template <typename ClassType>
void ExposeOperator(ClassType &m) { void ExposeOperator(ClassType &m) {
m.def("infer_shape", &ClassType::type::InferShape) m.def("infer_shape", &ClassType::type::InferShape)
...@@ -150,8 +150,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -150,8 +150,8 @@ All parameter, weight, gradient are variables in Paddle.
[](Variable &self) -> Tensor * { return self.GetMutable<Tensor>(); }, [](Variable &self) -> Tensor * { return self.GetMutable<Tensor>(); },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("get_net", .def("get_net",
[](Variable &self) -> ops::NetOp * { [](Variable &self) -> operators::NetOp * {
return self.GetMutable<ops::NetOp>(); return self.GetMutable<operators::NetOp>();
}, },
py::return_value_policy::reference); py::return_value_policy::reference);
...@@ -230,23 +230,24 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -230,23 +230,24 @@ All parameter, weight, gradient are variables in Paddle.
ExposeOperator(operator_base); ExposeOperator(operator_base);
py::class_<ops::NetOp, std::shared_ptr<ops::NetOp>> net(m, "Net"); py::class_<operators::NetOp, std::shared_ptr<operators::NetOp>> net(m, "Net");
net.def_static("create", net.def_static("create",
[]() -> std::shared_ptr<ops::NetOp> { []() -> std::shared_ptr<operators::NetOp> {
auto retv = std::make_shared<ops::NetOp>(); auto retv = std::make_shared<operators::NetOp>();
retv->type_ = "plain_net"; retv->type_ = "plain_net";
return retv; return retv;
}) })
.def("add_op", &ops::NetOp::AddOp) .def("add_op", &operators::NetOp::AddOp)
.def( .def("add_op",
"add_op", [](operators::NetOp &self,
[](ops::NetOp &self, const std::shared_ptr<ops::NetOp> &net) -> void { const std::shared_ptr<operators::NetOp> &net) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(net)); self.AddOp(std::static_pointer_cast<OperatorBase>(net));
}) })
.def("complete_add_op", &ops::NetOp::CompleteAddOp) .def("complete_add_op", &operators::NetOp::CompleteAddOp)
.def("complete_add_op", .def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
[](std::shared_ptr<ops::NetOp> &self) { self->CompleteAddOp(); }); self->CompleteAddOp();
});
ExposeOperator(net); ExposeOperator(net);
......
...@@ -18,6 +18,8 @@ limitations under the License. */ ...@@ -18,6 +18,8 @@ limitations under the License. */
#include <cstring> #include <cstring>
#include <memory> #include <memory>
#include <typeindex> #include <typeindex>
#include <vector>
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
#include "paddle/memory/memory.h" #include "paddle/memory/memory.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
......
...@@ -19,7 +19,7 @@ TEST(Tensor, Dims) { ...@@ -19,7 +19,7 @@ TEST(Tensor, Dims) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
Tensor tt; Tensor tt;
tt.Resize(make_ddim({2, 3, 4})); tt.Resize({2, 3, 4});
DDim dims = tt.dims(); DDim dims = tt.dims();
ASSERT_EQ(arity(dims), 3); ASSERT_EQ(arity(dims), 3);
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
......
...@@ -93,8 +93,8 @@ TEST(Arguments, Matrix) { ...@@ -93,8 +93,8 @@ TEST(Arguments, Matrix) {
MatrixPtr matrix = Matrix::create(100, 200); MatrixPtr matrix = Matrix::create(100, 200);
CheckBufferArg check = [=](const BufferArg& arg) { CheckBufferArg check = [=](const BufferArg& arg) {
EXPECT_EQ(arg.shape().ndims(), 2U); EXPECT_EQ(arg.shape().ndims(), 2U);
EXPECT_EQ(arg.shape()[0], 100); EXPECT_EQ(arg.shape()[0], 100U);
EXPECT_EQ(arg.shape()[1], 200); EXPECT_EQ(arg.shape()[1], 200U);
EXPECT_EQ(arg.data(), matrix->getData()); EXPECT_EQ(arg.data(), matrix->getData());
EXPECT_EQ(arg.matrix<DEVICE_TYPE_CPU>().getHeight(), matrix->getHeight()); EXPECT_EQ(arg.matrix<DEVICE_TYPE_CPU>().getHeight(), matrix->getHeight());
...@@ -112,8 +112,8 @@ TEST(Arguments, Matrix) { ...@@ -112,8 +112,8 @@ TEST(Arguments, Matrix) {
TEST(Arguments, Vector) { TEST(Arguments, Vector) {
VectorPtr vector = Vector::create(100, false); VectorPtr vector = Vector::create(100, false);
CheckBufferArg check = [=](const BufferArg& arg) { CheckBufferArg check = [=](const BufferArg& arg) {
EXPECT_EQ(arg.shape().ndims(), 1); EXPECT_EQ(arg.shape().ndims(), 1U);
EXPECT_EQ(arg.shape()[0], 100); EXPECT_EQ(arg.shape()[0], 100U);
EXPECT_EQ(arg.data(), vector->getData()); EXPECT_EQ(arg.data(), vector->getData());
CpuVector inVector = arg.vector<real, DEVICE_TYPE_CPU>(); CpuVector inVector = arg.vector<real, DEVICE_TYPE_CPU>();
...@@ -131,9 +131,9 @@ TEST(Arguments, Vector) { ...@@ -131,9 +131,9 @@ TEST(Arguments, Vector) {
TEST(Arguments, CpuSparseMatrix) { TEST(Arguments, CpuSparseMatrix) {
CpuSparseMatrix sparse(200, 300, 50); CpuSparseMatrix sparse(200, 300, 50);
CheckBufferArg check = [=](const BufferArg& arg) { CheckBufferArg check = [=](const BufferArg& arg) {
EXPECT_EQ(arg.shape().ndims(), 2); EXPECT_EQ(arg.shape().ndims(), 2U);
EXPECT_EQ(arg.shape()[0], 200); EXPECT_EQ(arg.shape()[0], 200U);
EXPECT_EQ(arg.shape()[1], 300); EXPECT_EQ(arg.shape()[1], 300U);
EXPECT_EQ(arg.data(), sparse.getData()); EXPECT_EQ(arg.data(), sparse.getData());
// CHECK_EQ(arg.sparse().nnz(), 50); // CHECK_EQ(arg.sparse().nnz(), 50);
// CHECK_EQ(arg.sparse().dataFormat(), SPARSE_CSR_FORMAT); // CHECK_EQ(arg.sparse().dataFormat(), SPARSE_CSR_FORMAT);
...@@ -152,10 +152,10 @@ TEST(Arguments, CpuSparseMatrix) { ...@@ -152,10 +152,10 @@ TEST(Arguments, CpuSparseMatrix) {
TEST(Arguments, BufferArg) { TEST(Arguments, BufferArg) {
BufferArg arg(nullptr, VALUE_TYPE_FLOAT, {1, 2, 3}); BufferArg arg(nullptr, VALUE_TYPE_FLOAT, {1, 2, 3});
CheckBufferArg check = [=](const BufferArg& arg) { CheckBufferArg check = [=](const BufferArg& arg) {
EXPECT_EQ(arg.shape().ndims(), 3); EXPECT_EQ(arg.shape().ndims(), 3U);
EXPECT_EQ(arg.shape()[0], 1); EXPECT_EQ(arg.shape()[0], 1U);
EXPECT_EQ(arg.shape()[1], 2); EXPECT_EQ(arg.shape()[1], 2U);
EXPECT_EQ(arg.shape()[2], 3); EXPECT_EQ(arg.shape()[2], 3U);
}; };
BufferArgs argments; BufferArgs argments;
......
...@@ -44,7 +44,7 @@ TEST(TensorShape, GetAndSet) { ...@@ -44,7 +44,7 @@ TEST(TensorShape, GetAndSet) {
EXPECT_EQ(t.ndims(), 3U); EXPECT_EQ(t.ndims(), 3U);
EXPECT_EQ(t.getElements(), 6U); EXPECT_EQ(t.getElements(), 6U);
EXPECT_EQ(t[1], 2); EXPECT_EQ(t[1], 2U);
t.setDim(1, 100); t.setDim(1, 100);
EXPECT_EQ(t.getElements(), 300U); EXPECT_EQ(t.getElements(), 300U);
EXPECT_EQ(t[1], 100U); EXPECT_EQ(t[1], 100U);
......
...@@ -96,7 +96,7 @@ void SubNestedSequenceLayer::calSelectedCols( ...@@ -96,7 +96,7 @@ void SubNestedSequenceLayer::calSelectedCols(
for (size_t i = 0; i < seqNum; ++i) { for (size_t i = 0; i < seqNum; ++i) {
for (size_t j = 0; j < beamSize; ++j) { for (size_t j = 0; j < beamSize; ++j) {
if (selectedIndices->getElement(i, j) == -1.) break; if (selectedIndices->getElement(i, j) == -1.) break;
int selSubSeqIdx = selectedIndices->getElement(i, j); size_t selSubSeqIdx = selectedIndices->getElement(i, j);
CHECK_GT(inputSeqInfoVec_[i].size() - 1, selSubSeqIdx); CHECK_GT(inputSeqInfoVec_[i].size() - 1, selSubSeqIdx);
size_t subSeqLen = inputSeqInfoVec_[i][selSubSeqIdx + 1] - size_t subSeqLen = inputSeqInfoVec_[i][selSubSeqIdx + 1] -
...@@ -135,7 +135,7 @@ void SubNestedSequenceLayer::forward(PassType passType) { ...@@ -135,7 +135,7 @@ void SubNestedSequenceLayer::forward(PassType passType) {
CHECK(inputSeq.hasSubseq()) << "The first input of SubNestSequence layer " CHECK(inputSeq.hasSubseq()) << "The first input of SubNestSequence layer "
<< "must be a nested sequence."; << "must be a nested sequence.";
const MatrixPtr selectedIndices = getInputValue(1); const MatrixPtr selectedIndices = getInputValue(1);
CHECK_EQ(inputSeq.getNumSequences(), selectedIndices->getHeight()); CHECK_EQ(size_t(inputSeq.getNumSequences()), selectedIndices->getHeight());
if (dynamic_cast<GpuMatrix*>(selectedIndices.get())) { if (dynamic_cast<GpuMatrix*>(selectedIndices.get())) {
/* /*
......
...@@ -88,7 +88,7 @@ void checkLayerOut(vector<vector<int>> groundTruth, ...@@ -88,7 +88,7 @@ void checkLayerOut(vector<vector<int>> groundTruth,
TEST(Layer, kmaxSeqScoreLayer) { TEST(Layer, kmaxSeqScoreLayer) {
const size_t maxBeamSize = 100; const size_t maxBeamSize = 100;
int beamSize = 1 + (rand() % maxBeamSize); size_t beamSize = 1 + (rand() % maxBeamSize);
vector<int> seqStartPosition; vector<int> seqStartPosition;
vector<int> subSeqStartPosition; vector<int> subSeqStartPosition;
......
...@@ -59,6 +59,7 @@ op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) ...@@ -59,6 +59,7 @@ op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu)
op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu)
op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
cc_test(sgd_op_test SRCS sgd_op_test.cc DEPS sgd_op)
op_library(fc_op op_library(fc_op
SRCS fc_op.cc SRCS fc_op.cc
......
...@@ -17,9 +17,9 @@ limitations under the License. */ ...@@ -17,9 +17,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class AddOp : public OperatorWithKernel { class AddOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2); PADDLE_ENFORCE_EQ(ctx.InputSize(), 2);
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1); PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1);
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), "Inputs of AddOp must all be set"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), "Inputs of AddOp must all be set");
...@@ -31,9 +31,9 @@ class AddOp : public OperatorWithKernel { ...@@ -31,9 +31,9 @@ class AddOp : public OperatorWithKernel {
} }
}; };
class AddOpMaker : public OpProtoAndCheckerMaker { class AddOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) AddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of add op"); AddInput("X", "The first input of add op");
AddInput("Y", "The second input of add op"); AddInput("Y", "The second input of add op");
...@@ -46,14 +46,17 @@ The equation is: Out = X + Y ...@@ -46,14 +46,17 @@ The equation is: Out = X + Y
} }
}; };
class AddOpGrad : public OperatorWithKernel { class AddOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override {} void InferShape(const framework::InferShapeContext &ctx) const override {}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker); REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, add_two_grad, ops::AddOpGrad); REGISTER_GRADIENT_OP(add_two, add_two_grad, ops::AddOpGrad);
REGISTER_OP_CPU_KERNEL(add_two, ops::AddKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(add_two,
ops::AddKernel<paddle::platform::CPUPlace, float>);
...@@ -16,4 +16,6 @@ ...@@ -16,4 +16,6 @@
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/add_op.h" #include "paddle/operators/add_op.h"
REGISTER_OP_GPU_KERNEL(add_two, ops::AddKernel<ops::GPUPlace, float>); namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(add_two,
ops::AddKernel<paddle::platform::GPUPlace, float>);
...@@ -13,15 +13,21 @@ See the License for the specific language governing permissions and ...@@ -13,15 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/operators/type_alias.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class AddKernel : public OpKernel { class AddKernel : public framework::OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto input0 = context.Input<Tensor>(0); auto input0 = context.Input<Tensor>(0);
auto input1 = context.Input<Tensor>(1); auto input1 = context.Input<Tensor>(1);
auto output = context.Output<Tensor>(0); auto output = context.Output<Tensor>(0);
......
...@@ -14,9 +14,9 @@ limitations under the License. */ ...@@ -14,9 +14,9 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#define private public #define private public
#include <paddle/framework/op_registry.h> #include "paddle/framework/op_registry.h"
USE_OP(add_two); USE_OP(add_two);
// USE_OP(add_two_grad);
TEST(AddOp, GetOpProto) { TEST(AddOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos(); auto& protos = paddle::framework::OpRegistry::protos();
......
...@@ -17,9 +17,9 @@ limitations under the License. */ ...@@ -17,9 +17,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class OnehotCrossEntropyOp : public OperatorWithKernel { class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2, PADDLE_ENFORCE_EQ(ctx.InputSize(), 2,
"Input size of OnehotCrossEntropyOp must be two"); "Input size of OnehotCrossEntropyOp must be two");
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1, PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1,
...@@ -37,9 +37,9 @@ class OnehotCrossEntropyOp : public OperatorWithKernel { ...@@ -37,9 +37,9 @@ class OnehotCrossEntropyOp : public OperatorWithKernel {
} }
}; };
class OnehotCrossEntropyGradientOp : public OperatorWithKernel { class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X"); auto X = ctx.Input<Tensor>("X");
...@@ -48,9 +48,10 @@ class OnehotCrossEntropyGradientOp : public OperatorWithKernel { ...@@ -48,9 +48,10 @@ class OnehotCrossEntropyGradientOp : public OperatorWithKernel {
} }
}; };
class OnehotCrossEntropyOpMaker : public OpProtoAndCheckerMaker { class OnehotCrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
OnehotCrossEntropyOpMaker(OpProto *proto, OpAttrChecker *op_checker) OnehotCrossEntropyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of OnehotCrossEntropyOp"); AddInput("X", "The first input of OnehotCrossEntropyOp");
AddInput("label", "The second input of OnehotCrossEntropyOp"); AddInput("label", "The second input of OnehotCrossEntropyOp");
...@@ -66,12 +67,14 @@ OnehotCrossEntropy Operator. ...@@ -66,12 +67,14 @@ OnehotCrossEntropy Operator.
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker); ops::OnehotCrossEntropyOpMaker);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, REGISTER_OP_CPU_KERNEL(
ops::OnehotCrossEntropyOpKernel<ops::CPUPlace, float>); onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(onehot_cross_entropy, onehot_cross_entropy_grad, REGISTER_GRADIENT_OP(onehot_cross_entropy, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOp); ops::OnehotCrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy_grad, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOpKernel<ops::CPUPlace, float>); ops::OnehotCrossEntropyGradientOpKernel<paddle::platform::CPUPlace, float>);
...@@ -14,3 +14,8 @@ ...@@ -14,3 +14,8 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/cross_entropy_op.h" #include "paddle/operators/cross_entropy_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<paddle::platform::GPUPlace, float>);
...@@ -13,11 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,11 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/operators/type_alias.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T> template <typename T>
T tolerable_value(T x) { T tolerable_value(T x) {
static_assert(std::is_floating_point<T>::value, static_assert(std::is_floating_point<T>::value,
...@@ -38,9 +40,9 @@ T tolerable_value(T x) { ...@@ -38,9 +40,9 @@ T tolerable_value(T x) {
} }
template <typename Place, typename T> template <typename Place, typename T>
class OnehotCrossEntropyOpKernel : public OpKernel { class OnehotCrossEntropyOpKernel : public framework::OpKernel {
public: public:
void Compute(const ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto X = ctx.Input<Tensor>("X"); auto X = ctx.Input<Tensor>("X");
const T* Xdata = X->data<T>(); const T* Xdata = X->data<T>();
const int* label_data = ctx.Input<Tensor>(1)->data<int>(); const int* label_data = ctx.Input<Tensor>(1)->data<int>();
...@@ -61,9 +63,9 @@ class OnehotCrossEntropyOpKernel : public OpKernel { ...@@ -61,9 +63,9 @@ class OnehotCrossEntropyOpKernel : public OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class OnehotCrossEntropyGradientOpKernel : public OpKernel { class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel {
public: public:
void Compute(const ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto X = ctx.Input<Tensor>("X"); auto X = ctx.Input<Tensor>("X");
auto dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dY = ctx.Input<Tensor>(framework::GradVarName("Y")); auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));
......
...@@ -12,11 +12,16 @@ ...@@ -12,11 +12,16 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "type_alias.h" #include "paddle/operators/net_op.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using OpRegistry = framework::OpRegistry;
class FullyConnectedOp : public NetOp { class FullyConnectedOp : public NetOp {
public: public:
void Init() override { void Init() override {
...@@ -39,9 +44,10 @@ class FullyConnectedOp : public NetOp { ...@@ -39,9 +44,10 @@ class FullyConnectedOp : public NetOp {
} }
}; };
class FullyConnectedOpMaker : public OpProtoAndCheckerMaker { class FullyConnectedOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FullyConnectedOpMaker(OpProto *proto, OpAttrChecker *op_checker) FullyConnectedOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input of fc operator"); AddInput("X", "the input of fc operator");
AddInput("W", "the weight of fc operator"); AddInput("W", "the weight of fc operator");
...@@ -66,4 +72,5 @@ USE_OP(rowwise_add); ...@@ -66,4 +72,5 @@ USE_OP(rowwise_add);
USE_OP(sigmoid); USE_OP(sigmoid);
USE_OP(softmax); USE_OP(softmax);
namespace ops = paddle::operators;
REGISTER_OP(fc, ops::FullyConnectedOp, ops::FullyConnectedOpMaker); REGISTER_OP(fc, ops::FullyConnectedOp, ops::FullyConnectedOpMaker);
...@@ -50,8 +50,8 @@ The output will have the same size with input. ...@@ -50,8 +50,8 @@ The output will have the same size with input.
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(fill_zeros_like, paddle::operators::FillZerosLikeOp, namespace ops = paddle::operators;
paddle::operators::FillZerosLikeOpMaker); REGISTER_OP(fill_zeros_like, ops::FillZerosLikeOp, ops::FillZerosLikeOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_zeros_like, fill_zeros_like,
paddle::operators::FillZerosLikeKernel<paddle::platform::CPUPlace, float>); ops::FillZerosLikeKernel<paddle::platform::CPUPlace, float>);
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/fill_zeros_like_op.h" #include "paddle/operators/fill_zeros_like_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
fill_zeros_like, fill_zeros_like,
paddle::operators::FillZerosLikeKernel<paddle::platform::GPUPlace, float>); ops::FillZerosLikeKernel<paddle::platform::GPUPlace, float>);
...@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/operators/type_alias.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -17,9 +17,9 @@ limitations under the License. */ ...@@ -17,9 +17,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class MeanOp : public OperatorWithKernel { class MeanOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 1, "Input size of AddOp must be one"); PADDLE_ENFORCE_EQ(ctx.InputSize(), 1, "Input size of AddOp must be one");
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1, "Output size of AddOp must be one"); PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1, "Output size of AddOp must be one");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), "input should be set"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), "input should be set");
...@@ -28,9 +28,9 @@ class MeanOp : public OperatorWithKernel { ...@@ -28,9 +28,9 @@ class MeanOp : public OperatorWithKernel {
} }
}; };
class MeanOpMaker : public OpProtoAndCheckerMaker { class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker) MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op"); AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op").IgnoreGradient(); AddOutput("Out", "The output of mean op").IgnoreGradient();
...@@ -38,9 +38,9 @@ class MeanOpMaker : public OpProtoAndCheckerMaker { ...@@ -38,9 +38,9 @@ class MeanOpMaker : public OpProtoAndCheckerMaker {
} }
}; };
class MeanGradOp : public OperatorWithKernel { class MeanGradOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>("X" + framework::kGradVarSuffix) ctx.Output<Tensor>("X" + framework::kGradVarSuffix)
->Resize(ctx.Input<Tensor>("X")->dims()); ->Resize(ctx.Input<Tensor>("X")->dims());
} }
...@@ -49,7 +49,10 @@ class MeanGradOp : public OperatorWithKernel { ...@@ -49,7 +49,10 @@ class MeanGradOp : public OperatorWithKernel {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker); REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker);
REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel<ops::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(mean,
ops::MeanKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(mean, mean_grad, ops::MeanGradOp); REGISTER_GRADIENT_OP(mean, mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(mean_grad, ops::MeanGradKernel<ops::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(mean_grad,
ops::MeanGradKernel<paddle::platform::CPUPlace, float>);
...@@ -16,5 +16,8 @@ ...@@ -16,5 +16,8 @@
#include "paddle/operators/mean_op.h" #include "paddle/operators/mean_op.h"
REGISTER_OP_GPU_KERNEL(mean, ops::MeanKernel<ops::GPUPlace, float>); namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(mean_grad, ops::MeanGradKernel<ops::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(mean,
ops::MeanKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(mean_grad,
ops::MeanGradKernel<paddle::platform::GPUPlace, float>);
...@@ -13,15 +13,24 @@ See the License for the specific language governing permissions and ...@@ -13,15 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/operators/type_alias.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class MeanKernel : public OpKernel { class MeanKernel : public framework::OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto input = context.Input<Tensor>(0); auto input = context.Input<Tensor>(0);
auto output = context.Output<Tensor>(0); auto output = context.Output<Tensor>(0);
...@@ -36,9 +45,9 @@ class MeanKernel : public OpKernel { ...@@ -36,9 +45,9 @@ class MeanKernel : public OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class MeanGradKernel : public OpKernel { class MeanGradKernel : public framework::OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto OG = context.Input<Tensor>("Out" + framework::kGradVarSuffix); auto OG = context.Input<Tensor>("Out" + framework::kGradVarSuffix);
PADDLE_ENFORCE(framework::product(OG->dims()) == 1, PADDLE_ENFORCE(framework::product(OG->dims()) == 1,
"Mean Gradient should be scalar"); "Mean Gradient should be scalar");
......
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class MulOp : public OperatorWithKernel { class MulOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs"); PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs");
auto dim0 = ctx.Input<Tensor>(0)->dims(); auto dim0 = ctx.Input<Tensor>(0)->dims();
auto dim1 = ctx.Input<Tensor>(1)->dims(); auto dim1 = ctx.Input<Tensor>(1)->dims();
...@@ -37,9 +37,9 @@ class MulOp : public OperatorWithKernel { ...@@ -37,9 +37,9 @@ class MulOp : public OperatorWithKernel {
} }
}; };
class MulOpMaker : public OpProtoAndCheckerMaker { class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker) MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of mul op"); AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op"); AddInput("Y", "The second input of mul op");
...@@ -52,9 +52,9 @@ The equation is: Out = X * Y ...@@ -52,9 +52,9 @@ The equation is: Out = X * Y
} }
}; };
class MulOpGrad : public OperatorWithKernel { class MulOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override {} void InferShape(const framework::InferShapeContext &ctx) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "MulGrad"; LOG(INFO) << "MulGrad";
return ""; return "";
...@@ -64,7 +64,8 @@ class MulOpGrad : public OperatorWithKernel { ...@@ -64,7 +64,8 @@ class MulOpGrad : public OperatorWithKernel {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker); REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, ops::MulOpGrad); REGISTER_GRADIENT_OP(mul, mul_grad, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<ops::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>);
...@@ -15,4 +15,6 @@ ...@@ -15,4 +15,6 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/mul_op.h" #include "paddle/operators/mul_op.h"
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<ops::GPUPlace, float>); namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>);
...@@ -13,16 +13,21 @@ ...@@ -13,16 +13,21 @@
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h"
#include "paddle/operators/type_alias.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class MulKernel : public OpKernel { class MulKernel : public framework::OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = { Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}}; {Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
...@@ -40,5 +45,6 @@ class MulKernel : public OpKernel { ...@@ -40,5 +45,6 @@ class MulKernel : public OpKernel {
Z.device(place) = X.contract(Y, dim_pair); Z.device(place) = X.contract(Y, dim_pair);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
*/ */
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,13 +14,7 @@ limitations under the License. */ ...@@ -14,13 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/operators/type_alias.h"
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -2,31 +2,27 @@ ...@@ -2,31 +2,27 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Scope = framework::Scope;
using DeviceContext = platform::DeviceContext;
static int infer_shape_cnt = 0; static int infer_shape_cnt = 0;
static int run_cnt = 0; static int run_cnt = 0;
class TestOp : public OperatorBase { class TestOp : public framework::OperatorBase {
public: public:
void InferShape(const framework::Scope& scope) const override { void InferShape(const Scope& scope) const override { ++infer_shape_cnt; }
++infer_shape_cnt; void Run(const Scope& scope,
} const platform::DeviceContext& dev_ctx) const override {
void Run(const framework::Scope& scope,
const paddle::platform::DeviceContext& dev_ctx) const override {
++run_cnt; ++run_cnt;
} }
}; };
class EmptyOp : public OperatorBase { class EmptyOp : public framework::OperatorBase {
public: public:
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {}
const platform::DeviceContext& dev_ctx) const override {}
}; };
template <typename T> template <typename T>
...@@ -72,7 +68,7 @@ TEST(OpKernel, all) { ...@@ -72,7 +68,7 @@ TEST(OpKernel, all) {
net->Run(scope, dev_ctx); net->Run(scope, dev_ctx);
ASSERT_EQ(2, infer_shape_cnt); ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt); ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet); ASSERT_THROW(net->AddOp(op2), platform::EnforceNotMet);
} }
TEST(NetOp, insert_op) { TEST(NetOp, insert_op) {
......
...@@ -14,17 +14,19 @@ ...@@ -14,17 +14,19 @@
#include "paddle/operators/recurrent_op.h" #include "paddle/operators/recurrent_op.h"
#include <glog/logging.h>
#include <cstring> #include <cstring>
#include <sstream> #include <sstream>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
#include "paddle/platform/enforce.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Scope = framework::Scope;
using Variable = framework::Variable;
using Tensor = framework::Tensor;
void RecurrentAlgorithm::InferShape(const Scope& scope) const { void RecurrentAlgorithm::InferShape(const Scope& scope) const {
seq_len_ = scope.FindVar((arg_->inlinks[0]).external) seq_len_ = scope.FindVar((arg_->inlinks[0]).external)
->GetMutable<Tensor>() ->GetMutable<Tensor>()
...@@ -135,10 +137,11 @@ void RecurrentOp::Init() { ...@@ -135,10 +137,11 @@ void RecurrentOp::Init() {
alg_.Init(std::move(arg)); alg_.Init(std::move(arg));
} }
class RecurrentAlgorithmProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class RecurrentAlgorithmProtoAndCheckerMaker
: public framework::OpProtoAndCheckerMaker {
public: public:
RecurrentAlgorithmProtoAndCheckerMaker(OpProto* proto, RecurrentAlgorithmProtoAndCheckerMaker(framework::OpProto* proto,
OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
const auto& name = RecurrentOp::kArgName; const auto& name = RecurrentOp::kArgName;
// inputs and outputs stored in proto // inputs and outputs stored in proto
......
...@@ -27,6 +27,10 @@ namespace operators { ...@@ -27,6 +27,10 @@ namespace operators {
using framework::make_ddim; using framework::make_ddim;
using framework::DDim; using framework::DDim;
using framework::Tensor;
using framework::Variable;
using framework::Scope;
using framework::OpRegistry;
class RecurrentOpTest : public ::testing::Test { class RecurrentOpTest : public ::testing::Test {
protected: protected:
...@@ -164,7 +168,7 @@ class RecurrentOpTest : public ::testing::Test { ...@@ -164,7 +168,7 @@ class RecurrentOpTest : public ::testing::Test {
// father scope // father scope
Scope scope_; Scope scope_;
std::shared_ptr<OperatorBase> rnn_op_; std::shared_ptr<framework::OperatorBase> rnn_op_;
}; };
TEST_F(RecurrentOpTest, Run) { TEST_F(RecurrentOpTest, Run) {
......
...@@ -18,7 +18,9 @@ namespace paddle { ...@@ -18,7 +18,9 @@ namespace paddle {
namespace operators { namespace operators {
namespace rnn { namespace rnn {
namespace fmw = paddle::framework; namespace f = paddle::framework;
using Tensor = framework::Tensor;
void SegmentInputs(const std::vector<Scope*>& step_scopes, void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& inlinks, const size_t seq_len, const std::vector<Link>& inlinks, const size_t seq_len,
...@@ -30,10 +32,10 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes, ...@@ -30,10 +32,10 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
inlinks[i].external); inlinks[i].external);
Tensor* input = input_var->GetMutable<Tensor>(); Tensor* input = input_var->GetMutable<Tensor>();
fmw::DDim dims = input->dims(); f::DDim dims = input->dims();
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len, PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
"all the inlinks must have same length"); "all the inlinks must have same length");
fmw::DDim step_dims = slice_ddim(dims, 1, dims.size()); f::DDim step_dims = slice_ddim(dims, 1, dims.size());
for (size_t j = 0; j < seq_len; j++) { for (size_t j = 0; j < seq_len; j++) {
Tensor* step_input = Tensor* step_input =
step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>(); step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>();
...@@ -58,11 +60,10 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes, ...@@ -58,11 +60,10 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal); auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal);
PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope", PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope",
outlinks[i].internal); outlinks[i].internal);
fmw::DDim step_dims = f::DDim step_dims = step_scope_var->template GetMutable<Tensor>()->dims();
step_scope_var->template GetMutable<Tensor>()->dims();
std::vector<int> dims_vec = vectorize(step_dims); std::vector<int> dims_vec = vectorize(step_dims);
dims_vec.insert(dims_vec.begin(), seq_len); dims_vec.insert(dims_vec.begin(), seq_len);
output->Resize(fmw::make_ddim(dims_vec)); output->Resize(f::make_ddim(dims_vec));
} else { } else {
output->mutable_data<float>(platform::CPUPlace()); output->mutable_data<float>(platform::CPUPlace());
for (size_t j = 0; j < seq_len; j++) { for (size_t j = 0; j < seq_len; j++) {
...@@ -104,7 +105,7 @@ void LinkMemories(const std::vector<Scope*>& scopes, ...@@ -104,7 +105,7 @@ void LinkMemories(const std::vector<Scope*>& scopes,
} }
void InitArgument(const ArgumentName& name, Argument* arg, void InitArgument(const ArgumentName& name, Argument* arg,
const OperatorBase& op) { const framework::OperatorBase& op) {
arg->step_net = op.Input(name.step_net); arg->step_net = op.Input(name.step_net);
arg->step_scopes = op.Output(name.step_scopes); arg->step_scopes = op.Output(name.step_scopes);
......
...@@ -17,12 +17,13 @@ ...@@ -17,12 +17,13 @@
#include <string> #include <string>
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/operators/type_alias.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace rnn { namespace rnn {
using Scope = framework::Scope;
/** /**
* Memory of a RNN (same as the role of `Momory` in PaddlePaddle). * Memory of a RNN (same as the role of `Momory` in PaddlePaddle).
* *
...@@ -86,7 +87,7 @@ void LinkMemories(const std::vector<Scope*>& step_scopes, ...@@ -86,7 +87,7 @@ void LinkMemories(const std::vector<Scope*>& step_scopes,
const int offset, bool infer_shape_mode); const int offset, bool infer_shape_mode);
void InitArgument(const ArgumentName& name, Argument* arg, void InitArgument(const ArgumentName& name, Argument* arg,
const OperatorBase& op); const framework::OperatorBase& op);
} // namespace rnn } // namespace rnn
} // namespace operators } // namespace operators
......
...@@ -13,12 +13,13 @@ ...@@ -13,12 +13,13 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/rowwise_add_op.h" #include "paddle/operators/rowwise_add_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class RowwiseAddOp : public OperatorWithKernel { class RowwiseAddOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2UL, PADDLE_ENFORCE(ctx.InputSize() == 2UL,
"Two inputs is needed by rowwise add"); "Two inputs is needed by rowwise add");
auto dim0 = ctx.Input<Tensor>(0)->dims(); auto dim0 = ctx.Input<Tensor>(0)->dims();
...@@ -32,9 +33,10 @@ class RowwiseAddOp : public OperatorWithKernel { ...@@ -32,9 +33,10 @@ class RowwiseAddOp : public OperatorWithKernel {
} }
}; };
class RowwiseAddOpMaker : public OpProtoAndCheckerMaker { class RowwiseAddOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
RowwiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) RowWiseAddOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The left input of row-wise add op, must be matrix"); AddInput("X", "The left input of row-wise add op, must be matrix");
AddInput("b", "The right input of row-wise add op, must be vector"); AddInput("b", "The right input of row-wise add op, must be vector");
...@@ -61,10 +63,12 @@ class RowwiseAddGradOp : public OperatorWithKernel { ...@@ -61,10 +63,12 @@ class RowwiseAddGradOp : public OperatorWithKernel {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker); REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker);
REGISTER_OP_CPU_KERNEL(rowwise_add, REGISTER_OP_CPU_KERNEL(
ops::RowwiseAddKernel<ops::CPUPlace, float>); rowwise_add, ops::RowwiseAddKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, ops::RowwiseAddGradOp); REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, ops::RowwiseAddGradOp);
REGISTER_OP_CPU_KERNEL(rowwise_add_grad, REGISTER_OP_CPU_KERNEL(
ops::RowwiseAddGradKernel<ops::CPUPlace, float>); rowwise_add_grad,
ops::RowwiseAddGradKernel<paddle::platform::CPUPlace, float>);
...@@ -15,5 +15,6 @@ ...@@ -15,5 +15,6 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/rowwise_add_op.h" #include "paddle/operators/rowwise_add_op.h"
REGISTER_OP_GPU_KERNEL(rowwise_add, namespace ops = paddle::operators;
ops::RowwiseAddKernel<ops::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(
rowwise_add, ops::RowwiseAddKernel<paddle::platform::GPUPlace, float>);
...@@ -13,15 +13,24 @@ ...@@ -13,15 +13,24 @@
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/operators/type_alias.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class RowwiseAddKernel : public OpKernel { class RowwiseAddKernel : public OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto out = context.Output<Tensor>(0); auto out = context.Output<Tensor>(0);
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
...@@ -39,9 +48,9 @@ class RowwiseAddKernel : public OpKernel { ...@@ -39,9 +48,9 @@ class RowwiseAddKernel : public OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class RowwiseAddGradKernel : public OpKernel { class RowwiseAddGradKernel : public framework::OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* XGrad = context.Output<Tensor>(0); auto* XGrad = context.Output<Tensor>(0);
auto* bGrad = context.Output<Tensor>(1); auto* bGrad = context.Output<Tensor>(1);
XGrad->mutable_data<T>(context.GetPlace()); XGrad->mutable_data<T>(context.GetPlace());
......
...@@ -17,9 +17,9 @@ limitations under the License. */ ...@@ -17,9 +17,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class SGDOp : public OperatorWithKernel { class SGDOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2, "Input size of SGDOp must be two"); PADDLE_ENFORCE_EQ(ctx.InputSize(), 2, "Input size of SGDOp must be two");
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1, "Output size of SGDOp must be one"); PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1, "Output size of SGDOp must be one");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), "inputs[0] mast be set"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), "inputs[0] mast be set");
...@@ -31,9 +31,9 @@ class SGDOp : public OperatorWithKernel { ...@@ -31,9 +31,9 @@ class SGDOp : public OperatorWithKernel {
} }
}; };
class SGDOpMaker : public OpProtoAndCheckerMaker { class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SGDOpMaker(OpProto *proto, OpAttrChecker *op_checker) SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("param", "input parameter"); AddInput("param", "input parameter");
AddInput("grad", "input gradient"); AddInput("grad", "input gradient");
...@@ -51,5 +51,7 @@ param_out = param - learning_rate * grad; ...@@ -51,5 +51,7 @@ param_out = param - learning_rate * grad;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sgd, ops::SGDOp, ops::SGDOpMaker); REGISTER_OP(sgd, ops::SGDOp, ops::SGDOpMaker);
REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel<ops::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(sgd,
ops::SGDOpKernel<paddle::platform::CPUPlace, float>);
...@@ -15,4 +15,6 @@ ...@@ -15,4 +15,6 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/sgd_op.h" #include "paddle/operators/sgd_op.h"
REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel<ops::GPUPlace, float>); namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sgd,
ops::SGDOpKernel<paddle::platform::GPUPlace, float>);
...@@ -13,15 +13,21 @@ See the License for the specific language governing permissions and ...@@ -13,15 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/operators/type_alias.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class SGDOpKernel : public OpKernel { class SGDOpKernel : public framework::OpKernel {
public: public:
void Compute(const ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto param = ctx.Input<Tensor>("param"); auto param = ctx.Input<Tensor>("param");
auto grad = ctx.Input<Tensor>("grad"); auto grad = ctx.Input<Tensor>("grad");
auto param_out = ctx.Output<Tensor>(0); auto param_out = ctx.Output<Tensor>(0);
......
...@@ -13,21 +13,23 @@ ...@@ -13,21 +13,23 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/sigmoid_op.h" #include "paddle/operators/sigmoid_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class SigmoidOp : public OperatorWithKernel { class SigmoidOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1, "Sigmoid Op only have one input"); PADDLE_ENFORCE(ctx.InputSize() == 1, "Sigmoid Op only have one input");
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Sigmoid Op only have one output"); PADDLE_ENFORCE(ctx.OutputSize() == 1, "Sigmoid Op only have one output");
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims()); ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
} }
}; };
class SigmoidOpMaker : public OpProtoAndCheckerMaker { class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) SigmoidOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "sigmoid input"); AddInput("X", "sigmoid input");
AddOutput("Y", "sigmoid output"); AddOutput("Y", "sigmoid output");
...@@ -35,9 +37,9 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker { ...@@ -35,9 +37,9 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker {
} }
}; };
class SigmoidOpGrad : public OperatorWithKernel { class SigmoidOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims()); ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
} }
}; };
...@@ -45,9 +47,11 @@ class SigmoidOpGrad : public OperatorWithKernel { ...@@ -45,9 +47,11 @@ class SigmoidOpGrad : public OperatorWithKernel {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker); REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, ops::SigmoidOpGrad); REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, ops::SigmoidOpGrad);
REGISTER_OP_CPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(sigmoid,
REGISTER_OP_CPU_KERNEL(sigmoid_grad, ops::SigmoidKernel<paddle::platform::CPUPlace, float>);
ops::SigmoidGradKernel<ops::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(
sigmoid_grad, ops::SigmoidGradKernel<paddle::platform::CPUPlace, float>);
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/sigmoid_op.h" #include "paddle/operators/sigmoid_op.h"
REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::GPUPlace, float>); namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sigmoid_grad,
ops::SigmoidGradKernel<ops::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(sigmoid,
ops::SigmoidKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
sigmoid_grad, ops::SigmoidGradKernel<paddle::platform::GPUPlace, float>);
...@@ -13,16 +13,21 @@ ...@@ -13,16 +13,21 @@
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h"
#include "paddle/operators/type_alias.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class SigmoidKernel : public OpKernel { class SigmoidKernel : public framework::OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto input = context.Input<Tensor>(0); auto input = context.Input<Tensor>(0);
auto output = context.Output<Tensor>(0); auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
...@@ -37,9 +42,9 @@ class SigmoidKernel : public OpKernel { ...@@ -37,9 +42,9 @@ class SigmoidKernel : public OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class SigmoidGradKernel : public OpKernel { class SigmoidGradKernel : public framework::OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto Y_t = context.Input<Tensor>("Y"); auto Y_t = context.Input<Tensor>("Y");
auto dY_t = context.Input<Tensor>(framework::GradVarName("Y")); auto dY_t = context.Input<Tensor>(framework::GradVarName("Y"));
auto dX_t = context.Output<Tensor>(framework::GradVarName("X")); auto dX_t = context.Output<Tensor>(framework::GradVarName("X"));
......
...@@ -17,9 +17,9 @@ limitations under the License. */ ...@@ -17,9 +17,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class SoftmaxOp : public OperatorWithKernel { class SoftmaxOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL, PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL,
"Only one input is need for softmax"); "Only one input is need for softmax");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims().size(), 2UL, PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims().size(), 2UL,
...@@ -30,9 +30,10 @@ class SoftmaxOp : public OperatorWithKernel { ...@@ -30,9 +30,10 @@ class SoftmaxOp : public OperatorWithKernel {
} }
}; };
class SoftmaxOpMaker : public OpProtoAndCheckerMaker { class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SoftmaxOpMaker(OpProto *proto, OpAttrChecker *op_checker) SoftmaxOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "input of softmax"); AddInput("X", "input of softmax");
AddOutput("Y", "output of softmax"); AddOutput("Y", "output of softmax");
...@@ -40,9 +41,9 @@ class SoftmaxOpMaker : public OpProtoAndCheckerMaker { ...@@ -40,9 +41,9 @@ class SoftmaxOpMaker : public OpProtoAndCheckerMaker {
} }
}; };
class SoftmaxOpGrad : public OperatorWithKernel { class SoftmaxOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL, PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL,
"Input of SoftmaxOpGrad should be 3, X, Y, YG"); "Input of SoftmaxOpGrad should be 3, X, Y, YG");
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1UL, PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1UL,
...@@ -61,8 +62,11 @@ class SoftmaxOpGrad : public OperatorWithKernel { ...@@ -61,8 +62,11 @@ class SoftmaxOpGrad : public OperatorWithKernel {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker); REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker);
REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel<ops::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(softmax,
ops::SoftmaxKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(softmax, softmax_grad, ops::SoftmaxOpGrad); REGISTER_GRADIENT_OP(softmax, softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax_grad, REGISTER_OP_CPU_KERNEL(
ops::SoftmaxGradKernel<ops::CPUPlace, float>); softmax_grad, ops::SoftmaxGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -13,9 +13,11 @@ ...@@ -13,9 +13,11 @@
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h"
#include "paddle/operators/softmax_op.h" #include "paddle/operators/softmax_op.h"
REGISTER_OP_GPU_KERNEL(softmax, ops::SoftmaxKernel<ops::GPUPlace, float>); namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(softmax_grad,
ops::SoftmaxGradKernel<ops::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(softmax,
ops::SoftmaxKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
softmax_grad, ops::SoftmaxGradKernel<paddle::platform::GPUPlace, float>);
...@@ -13,19 +13,21 @@ See the License for the specific language governing permissions and ...@@ -13,19 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/ddim.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/type_alias.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class SoftmaxKernel : public OpKernel { class SoftmaxKernel : public framework::OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto input = context.Input<Tensor>("X"); auto input = context.Input<Tensor>("X");
auto output = context.Output<Tensor>("Y"); auto output = context.Output<Tensor>("Y");
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
...@@ -62,9 +64,9 @@ class SoftmaxKernel : public OpKernel { ...@@ -62,9 +64,9 @@ class SoftmaxKernel : public OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class SoftmaxGradKernel : public OpKernel { class SoftmaxGradKernel : public framework::OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
std::shared_ptr<Tensor> scale_ = std::make_shared<Tensor>(); std::shared_ptr<Tensor> scale_ = std::make_shared<Tensor>();
auto Y = context.Input<Tensor>("Y"); auto Y = context.Input<Tensor>("Y");
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
using OpKernel = framework::OpKernel;
using OperatorBase = framework::OperatorBase;
using InferShapeContext = framework::InferShapeContext;
using ExecutionContext = framework::ExecutionContext;
using Variable = framework::Variable;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor;
using Scope = framework::Scope;
using OperatorWithKernel = framework::OperatorWithKernel;
using OperatorBase = framework::OperatorBase;
using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker;
using OpProto = framework::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
using CPUPlace = platform::CPUPlace;
using GPUPlace = platform::GPUPlace;
using OpRegistry = framework::OpRegistry;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
...@@ -8,7 +8,7 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags) ...@@ -8,7 +8,7 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
add_subdirectory(dynload) add_subdirectory(dynload)
cc_test(enforce_test SRCS enforce_test.cc) cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece)
IF(WITH_GPU) IF(WITH_GPU)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader) set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
......
...@@ -13,6 +13,10 @@ limitations under the License. */ ...@@ -13,6 +13,10 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/string/piece.h"
using StringPiece = paddle::string::Piece;
using paddle::string::HasPrefix;
TEST(ENFORCE, OK) { TEST(ENFORCE, OK) {
PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345); PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345);
...@@ -22,19 +26,15 @@ TEST(ENFORCE, OK) { ...@@ -22,19 +26,15 @@ TEST(ENFORCE, OK) {
} }
TEST(ENFORCE, FAILED) { TEST(ENFORCE, FAILED) {
bool in_catch = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123); PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123);
} catch (paddle::platform::EnforceNotMet error) { } catch (paddle::platform::EnforceNotMet error) {
// your error handling code here caught_exception = true;
in_catch = true; EXPECT_TRUE(
std::string msg = "Enforce is not ok 123 at all"; HasPrefix(StringPiece(error.what()), "Enforce is not ok 123 at all"));
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
} }
ASSERT_TRUE(in_catch); EXPECT_TRUE(caught_exception);
} }
TEST(ENFORCE, NO_ARG_OK) { TEST(ENFORCE, NO_ARG_OK) {
...@@ -47,41 +47,27 @@ TEST(ENFORCE, NO_ARG_OK) { ...@@ -47,41 +47,27 @@ TEST(ENFORCE, NO_ARG_OK) {
TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) { TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) {
int a = 2; int a = 2;
bool in_catch = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE_EQ(a, 1 + 3); PADDLE_ENFORCE_EQ(a, 1 + 3);
} catch (paddle::platform::EnforceNotMet error) { } catch (paddle::platform::EnforceNotMet error) {
in_catch = true; caught_exception = true;
const std::string msg = "enforce a == 1 + 3 failed, 2 != 4"; HasPrefix(StringPiece(error.what()), "enforce a == 1 + 3 failed, 2 != 4");
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
} }
EXPECT_TRUE(caught_exception);
ASSERT_TRUE(in_catch);
} }
TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) { TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) {
int a = 2; int a = 2;
bool in_catch = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE_EQ(a, 1 + 3, "%s size not match", "their"); PADDLE_ENFORCE_EQ(a, 1 + 3, "%s size not match", "their");
} catch (paddle::platform::EnforceNotMet error) { } catch (paddle::platform::EnforceNotMet error) {
in_catch = true; caught_exception = true;
const std::string msg = HasPrefix(StringPiece(error.what()),
"enforce a == 1 + 3 failed, 2 != 4\ntheir size not match"; "enforce a == 1 + 3 failed, 2 != 4\ntheir size not match");
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
} }
EXPECT_TRUE(caught_exception);
ASSERT_TRUE(in_catch);
} }
TEST(ENFORCE_NE, OK) { TEST(ENFORCE_NE, OK) {
...@@ -89,42 +75,32 @@ TEST(ENFORCE_NE, OK) { ...@@ -89,42 +75,32 @@ TEST(ENFORCE_NE, OK) {
PADDLE_ENFORCE_NE(1.0, 2UL); PADDLE_ENFORCE_NE(1.0, 2UL);
} }
TEST(ENFORCE_NE, FAIL) { TEST(ENFORCE_NE, FAIL) {
bool in_catch = false; bool caught_exception = false;
try { try {
// 2UL here to check data type compatible // 2UL here to check data type compatible
PADDLE_ENFORCE_NE(1.0, 1UL); PADDLE_ENFORCE_NE(1.0, 1UL);
} catch (paddle::platform::EnforceNotMet error) { } catch (paddle::platform::EnforceNotMet error) {
in_catch = true; caught_exception = true;
const std::string msg = "enforce 1.0 != 1UL failed, 1.000000 == 1"; EXPECT_TRUE(HasPrefix(StringPiece(error.what()),
const char* what = error.what(); "enforce 1.0 != 1UL failed, 1.000000 == 1"))
for (size_t i = 0; i < msg.length(); ++i) { << error.what() << " does not have expected prefix";
ASSERT_EQ(what[i], msg[i]);
}
} }
EXPECT_TRUE(caught_exception);
ASSERT_TRUE(in_catch);
} }
TEST(ENFORCE_GT, OK) { PADDLE_ENFORCE_GT(2, 1); } TEST(ENFORCE_GT, OK) { PADDLE_ENFORCE_GT(2, 1); }
TEST(ENFORCE_GT, FAIL) { TEST(ENFORCE_GT, FAIL) {
bool in_catch = false; bool caught_exception = false;
try { try {
// 2UL here to check data type compatible
PADDLE_ENFORCE_GT(1, 2UL); PADDLE_ENFORCE_GT(1, 2UL);
} catch (paddle::platform::EnforceNotMet error) { } catch (paddle::platform::EnforceNotMet error) {
in_catch = true; caught_exception = true;
const std::string msg = "enforce 1 > 2UL failed, 1 <= 2"; EXPECT_TRUE(
const char* what = error.what(); HasPrefix(StringPiece(error.what()), "enforce 1 > 2UL failed, 1 <= 2"));
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
} }
EXPECT_TRUE(caught_exception);
ASSERT_TRUE(in_catch);
} }
TEST(ENFORCE_GE, OK) { TEST(ENFORCE_GE, OK) {
...@@ -134,21 +110,16 @@ TEST(ENFORCE_GE, OK) { ...@@ -134,21 +110,16 @@ TEST(ENFORCE_GE, OK) {
PADDLE_ENFORCE_GE(3.21, 2UL); PADDLE_ENFORCE_GE(3.21, 2UL);
} }
TEST(ENFORCE_GE, FAIL) { TEST(ENFORCE_GE, FAIL) {
bool in_catch = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE_GE(1, 2UL); PADDLE_ENFORCE_GE(1, 2UL);
} catch (paddle::platform::EnforceNotMet error) { } catch (paddle::platform::EnforceNotMet error) {
in_catch = true; caught_exception = true;
const std::string msg = "enforce 1 >= 2UL failed, 1 < 2"; EXPECT_TRUE(
const char* what = error.what(); HasPrefix(StringPiece(error.what()), "enforce 1 >= 2UL failed, 1 < 2"));
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
} }
EXPECT_TRUE(caught_exception);
ASSERT_TRUE(in_catch);
} }
TEST(ENFORCE_LE, OK) { TEST(ENFORCE_LE, OK) {
...@@ -159,21 +130,16 @@ TEST(ENFORCE_LE, OK) { ...@@ -159,21 +130,16 @@ TEST(ENFORCE_LE, OK) {
PADDLE_ENFORCE_LE(2UL, 3.2); PADDLE_ENFORCE_LE(2UL, 3.2);
} }
TEST(ENFORCE_LE, FAIL) { TEST(ENFORCE_LE, FAIL) {
bool in_catch = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE_GT(1, 2UL); PADDLE_ENFORCE_GT(1, 2UL);
} catch (paddle::platform::EnforceNotMet error) { } catch (paddle::platform::EnforceNotMet error) {
in_catch = true; caught_exception = true;
const std::string msg = "enforce 1 > 2UL failed, 1 <= 2"; EXPECT_TRUE(
const char* what = error.what(); HasPrefix(StringPiece(error.what()), "enforce 1 > 2UL failed, 1 <= 2"));
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
} }
EXPECT_TRUE(caught_exception);
ASSERT_TRUE(in_catch);
} }
TEST(ENFORCE_LT, OK) { TEST(ENFORCE_LT, OK) {
...@@ -182,21 +148,15 @@ TEST(ENFORCE_LT, OK) { ...@@ -182,21 +148,15 @@ TEST(ENFORCE_LT, OK) {
PADDLE_ENFORCE_LT(2UL, 3); PADDLE_ENFORCE_LT(2UL, 3);
} }
TEST(ENFORCE_LT, FAIL) { TEST(ENFORCE_LT, FAIL) {
bool in_catch = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE_LT(1UL, 0.12); PADDLE_ENFORCE_LT(1UL, 0.12);
} catch (paddle::platform::EnforceNotMet error) { } catch (paddle::platform::EnforceNotMet error) {
in_catch = true; caught_exception = true;
const std::string msg = "enforce 1UL < 0.12 failed, 1 >= 0.12"; EXPECT_TRUE(HasPrefix(StringPiece(error.what()),
const char* what = error.what(); "enforce 1UL < 0.12 failed, 1 >= 0.12"));
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
} }
EXPECT_TRUE(caught_exception);
ASSERT_TRUE(in_catch);
} }
TEST(ENFORCE_NOT_NULL, OK) { TEST(ENFORCE_NOT_NULL, OK) {
...@@ -205,20 +165,14 @@ TEST(ENFORCE_NOT_NULL, OK) { ...@@ -205,20 +165,14 @@ TEST(ENFORCE_NOT_NULL, OK) {
delete a; delete a;
} }
TEST(ENFORCE_NOT_NULL, FAIL) { TEST(ENFORCE_NOT_NULL, FAIL) {
bool in_catch = false; bool caught_exception = false;
int* a{nullptr};
try { try {
int* a = nullptr;
PADDLE_ENFORCE_NOT_NULL(a); PADDLE_ENFORCE_NOT_NULL(a);
} catch (paddle::platform::EnforceNotMet error) { } catch (paddle::platform::EnforceNotMet error) {
in_catch = true; caught_exception = true;
const std::string msg = "a should not be null"; EXPECT_TRUE(HasPrefix(StringPiece(error.what()), "a should not be null"));
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
} }
EXPECT_TRUE(caught_exception);
ASSERT_TRUE(in_catch);
} }
cc_library(paddle_pybind SHARED
SRCS pybind.cc
DEPS pybind python backward
fc_op
sgd_op
add_op
mean_op
cross_entropy_op
recurrent_op
fill_zeros_like_op)
...@@ -50,8 +50,8 @@ void NewRemoteParameterUpdater::init( ...@@ -50,8 +50,8 @@ void NewRemoteParameterUpdater::init(
// create parameter server client. // create parameter server client.
if (useEtcd_) { if (useEtcd_) {
parameterClient_ = paddle_new_etcd_pserver_client( parameterClient_ =
(char *)pserverSpec_.c_str(), FLAGS_trainer_id == 0); paddle_new_etcd_pserver_client((char *)pserverSpec_.c_str());
} else { } else {
parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(), parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(),
FLAGS_trainer_id == 0); FLAGS_trainer_id == 0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册