提交 c7e8c1aa 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #6 from reyoung/feature/refactorize_framework_proto

Feature/refactorize framework proto
......@@ -28,7 +28,7 @@ RUN apt-get update && \
wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \
curl sed grep graphviz libjpeg-dev zlib1g-dev \
python-matplotlib gcc-4.8 g++-4.8 \
automake locales clang-format-3.8 swig doxygen cmake \
automake locales clang-format swig doxygen cmake \
liblapack-dev liblapacke-dev libboost-dev \
clang-3.8 llvm-3.8 libclang-3.8-dev \
net-tools && \
......
hash: 1b9b07408ca7fac27a374dc2ccd2433e4bff090484008a037df967284949a582
updated: 2017-08-03T21:46:51.744995189Z
updated: 2017-08-07T23:37:48.867469328Z
imports:
- name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
......@@ -10,7 +10,7 @@ imports:
- name: github.com/cockroachdb/cmux
version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92
- name: github.com/coreos/etcd
version: c31bec0f29facff13f7c3e3d948e55dd6689ed42
version: d0d1a87aa96ae14914751d42264262cb69eda170
subpackages:
- alarm
- auth
......@@ -24,6 +24,7 @@ imports:
- error
- etcdserver
- etcdserver/api
- etcdserver/api/etcdhttp
- etcdserver/api/v2http
- etcdserver/api/v2http/httptypes
- etcdserver/api/v3client
......@@ -210,11 +211,6 @@ testImports:
version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9
subpackages:
- spew
- name: github.com/docker/docker
version: b6d164e6c46d8115b146e4c3ac93784e9ef8b49e
subpackages:
- pkg/ioutils
- pkg/longpath
- name: github.com/pmezard/go-difflib
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
subpackages:
......
package master_test
import (
"io/ioutil"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/embed"
"github.com/docker/docker/pkg/ioutils"
"github.com/stretchr/testify/assert"
)
func TestNewServiceWithEtcd(t *testing.T) {
// setup an embed etcd server
etcdDir, err := ioutils.TempDir("", "")
etcdDir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
cfg := embed.NewConfig()
lpurl, _ := url.Parse("http://localhost:0")
lcurl, _ := url.Parse("http://localhost:0")
cfg.LPUrls = []url.URL{*lpurl}
cfg.LCUrls = []url.URL{*lcurl}
cfg.Dir = etcdDir
e, err := embed.StartEtcd(cfg)
if err != nil {
......@@ -30,15 +36,13 @@ func TestNewServiceWithEtcd(t *testing.T) {
t.Fatal(err)
}
}()
select {
case <-e.Server.ReadyNotify():
t.Log("Server is ready!")
case <-time.After(60 * time.Second):
e.Server.Stop() // trigger a shutdown
t.Fatal("Server took too long to start!")
}
ep := []string{"127.0.0.1:2379"}
<-e.Server.ReadyNotify()
port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
endpoint := "127.0.0.1:" + port
ep := []string{endpoint}
masterAddr := "127.0.0.1:3306"
store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30)
if err != nil {
......
......@@ -90,8 +90,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
type selector bool
func (s selector) Select() bool {
return bool(s)
func (s selector) Select() (bool, error) {
return bool(s), nil
}
func (s selector) Done() error {
return nil
}
type lister []client.Server
......@@ -114,11 +118,10 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_cli
}
//export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcdEndpoints *C.char, selected int) C.paddle_pserver_client {
// TODO(Longfei: use etcd lock to decide which trainer to initialize the parameters)
func paddle_new_etcd_pserver_client(etcdEndpoints *C.char) C.paddle_pserver_client {
addr := C.GoString(etcdEndpoints)
etcdClient := client.NewEtcd(addr)
c := client.NewClient(etcdClient, etcdClient.Desired(), selector(selected != 0))
c := client.NewClient(etcdClient, etcdClient.Desired(), etcdClient)
return add(c)
}
......@@ -136,7 +139,12 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) {
//export paddle_begin_init_params
func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
c := get(client)
if selected := c.BeginInitParams(); selected {
selected, err := c.BeginInitParams()
if err != nil {
panic(err)
}
if selected {
return 1
}
return 0
......
......@@ -27,9 +27,13 @@ import (
// TODO(helin): add RPC call retry logic
// Selector selects if the client should initialize parameter servers.
// Selector selects if the client should initialize parameters and
// reports the initialization process done.
type Selector interface {
Select() bool
// Select selects if the client should initialize parameter servers.
Select() (bool, error)
// Done indicates the initialization process is done.
Done() error
}
// Server is the identification of a parameter Server.
......@@ -115,7 +119,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
// servers. Other trainers will be blocked until the initialization is
// done, and they need to get the initialized parameters from
// parameter servers using GetParams.
func (c *Client) BeginInitParams() bool {
func (c *Client) BeginInitParams() (bool, error) {
return c.sel.Select()
}
......
......@@ -124,8 +124,12 @@ func initEtcdClient() {
type selector bool
func (s selector) Select() bool {
return bool(s)
func (s selector) Select() (bool, error) {
return bool(s), nil
}
func (s selector) Done() error {
return nil
}
type lister []client.Server
......@@ -135,7 +139,11 @@ func (l lister) List() []client.Server {
}
func testClient(t *testing.T, c *client.Client) {
selected := c.BeginInitParams()
selected, err := c.BeginInitParams()
if err != nil {
t.Fatal(err)
}
if !selected {
t.Fatal("should be selected.")
}
......
......@@ -16,53 +16,60 @@ package client
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus"
)
const (
defaultEtcdTimeout time.Duration = 5 * time.Second
initLockPath = "/init_ps/lock"
initDonePath = "/init_ps/done"
initDoneVal = "1"
)
// EtcdClient is used by pserver client that is a part of trainer process.
// Etcd is used by pserver client that is a part of trainer process.
// TODO:
// 1. add watcher to watch the change state of pservers)
// 1. add etcd lock)
type EtcdClient struct {
// 1. add watcher to watch the change state of pservers.
type Etcd struct {
client *clientv3.Client
timeout time.Duration
endpoints []string
lock *concurrency.Mutex
}
// Desired read ps desired number from etcd.
func (p *EtcdClient) Desired() int {
func (e *Etcd) Desired() int {
var psDesired int
for {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
resp, err := p.client.Get(ctx, pserver.PsDesired)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
resp, err := e.client.Get(ctx, pserver.PsDesired)
cancel()
if err != nil {
log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err)
time.Sleep(p.timeout)
time.Sleep(e.timeout)
continue
}
kvs := resp.Kvs
if len(kvs) == 0 {
log.Infoln("Waiting for ps desired registered ...")
time.Sleep(p.timeout)
time.Sleep(e.timeout)
continue
}
psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("psDesired %d invalid %v", psDesired, err)
time.Sleep(p.timeout)
time.Sleep(e.timeout)
continue
}
......@@ -73,26 +80,26 @@ func (p *EtcdClient) Desired() int {
}
// List return the pserver list read from etcd.
func (p *EtcdClient) List() []Server {
psDesired := p.Desired()
func (e *Etcd) List() []Server {
psDesired := e.Desired()
servers := make([]Server, psDesired)
for {
for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
psKey := pserver.PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
resp, err := p.client.Get(ctx, psKey)
resp, err := e.client.Get(ctx, psKey)
cancel()
if err != nil {
log.Infof("Get psKey= %s error, %v", psKey, err)
time.Sleep(p.timeout)
time.Sleep(e.timeout)
continue
}
kvs := resp.Kvs
if len(kvs) == 0 {
log.Infof("Waiting for ps addr registered ...")
time.Sleep(p.timeout)
time.Sleep(e.timeout)
continue
}
......@@ -100,7 +107,7 @@ func (p *EtcdClient) List() []Server {
// TODO(Longfei) check the ps address
if psAddr == "" {
log.Infof("Get psKey = %s, psAddr is empty", psKey)
time.Sleep(p.timeout)
time.Sleep(e.timeout)
continue
}
log.Debugf("got value (%s) for key: %s", psAddr, psKey)
......@@ -113,7 +120,7 @@ func (p *EtcdClient) List() []Server {
}
// NewEtcd create a etcd client to return the state of pserver on etcd.
func NewEtcd(endpoints string) *EtcdClient {
func NewEtcd(endpoints string) *Etcd {
ep := strings.Split(endpoints, ",")
var cli *clientv3.Client
var err error
......@@ -130,10 +137,118 @@ func NewEtcd(endpoints string) *EtcdClient {
break
}
log.Infof("Connected to etcd: %s\n", endpoints)
client := &EtcdClient{
client := &Etcd{
client: cli,
timeout: defaultEtcdTimeout,
endpoints: ep,
}
return client
}
// Select indicates if the current trainer is selected to initialize
// the pserver parameters.
func (e *Etcd) Select() (bool, error) {
sess, err := concurrency.NewSession(e.client, concurrency.WithTTL(5))
if err != nil {
return false, err
}
lock := concurrency.NewMutex(sess, initLockPath)
log.Infof("Trying to acquire lock at %s.", initLockPath)
// Do not use timeout context here, since we don't know how
// long does it take for other trainers to initialize the
// parameters.
err = lock.Lock(context.Background())
if err != nil {
return false, err
}
log.Infof("Successfully acquired lock at %s.", initLockPath)
get := clientv3.OpGet(initDonePath)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
tresp, err := e.client.Txn(ctx).If(lock.IsOwner()).Then(get).Commit()
cancel()
if err != nil {
return false, err
}
if !tresp.Succeeded {
return false, errors.New("no longer the owner of the lock")
}
resp := tresp.Responses[0].GetResponseRange()
if len(resp.Kvs) == 0 {
// Key value not set, select current trainer.
e.lock = lock
log.Infoln("Trainer selected.")
return true, nil
}
if string(resp.Kvs[0].Value) == initDoneVal {
log.Infoln("Initialization is already done.")
ctx, cancel = context.WithTimeout(context.Background(), e.timeout)
err = lock.Unlock(ctx)
cancel()
if err != nil {
log.Errorln(err)
}
return false, nil
}
return false, fmt.Errorf("key %s have unexpected value: %v", initDonePath, resp.Kvs[0].Value)
}
// Done indicates the parameter initialization process is done.
func (e *Etcd) Done() error {
if e.lock == nil {
return errors.New("lock is nil, Done called unexpectedly")
}
put := clientv3.OpPut(initDonePath, initDoneVal)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
tresp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit()
cancel()
if err != nil {
return err
}
if !tresp.Succeeded {
return errors.New("no longer the owner of the lock")
}
ctx, cancel = context.WithTimeout(context.Background(), e.timeout)
err = e.lock.Unlock(ctx)
cancel()
if err != nil {
log.Errorln(err)
} else {
e.lock = nil
}
return nil
}
// Close closes the etcd client.
func (e *Etcd) Close() error {
var err error
if e.lock != nil {
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
err = e.lock.Unlock(ctx)
cancel()
if err == nil {
e.lock = nil
}
}
cErr := e.client.Close()
if cErr != nil {
if err != nil {
log.Errorln(cErr)
return err
}
return cErr
}
return err
}
package client_test
import (
"io/ioutil"
"net/url"
"os"
"strings"
"sync"
"testing"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
"github.com/coreos/etcd/embed"
)
func TestSelector(t *testing.T) {
etcdDir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
cfg := embed.NewConfig()
lpurl, _ := url.Parse("http://localhost:0")
lcurl, _ := url.Parse("http://localhost:0")
cfg.LPUrls = []url.URL{*lpurl}
cfg.LCUrls = []url.URL{*lcurl}
cfg.Dir = etcdDir
e, err := embed.StartEtcd(cfg)
if err != nil {
t.Fatal(err)
}
defer func() {
e.Close()
if err := os.RemoveAll(etcdDir); err != nil {
t.Fatal(err)
}
}()
<-e.Server.ReadyNotify()
port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
endpoint := "127.0.0.1:" + port
var mu sync.Mutex
selectedCount := 0
var wg sync.WaitGroup
selectAndDone := func(c *client.Etcd) {
defer wg.Done()
selected, err := c.Select()
if err != nil {
panic(err)
}
if selected {
mu.Lock()
selectedCount++
mu.Unlock()
err = c.Done()
if err != nil {
t.Fatal(err)
}
}
}
c0 := client.NewEtcd(endpoint)
c1 := client.NewEtcd(endpoint)
c2 := client.NewEtcd(endpoint)
c3 := client.NewEtcd(endpoint)
wg.Add(3)
go selectAndDone(c0)
go selectAndDone(c1)
go selectAndDone(c2)
wg.Wait()
// simulate trainer crashed and restarted after the
// initialization process.
wg.Add(1)
go selectAndDone(c3)
wg.Wait()
mu.Lock()
if selectedCount != 1 {
t.Fatal("selected count wrong:", selectedCount)
}
mu.Unlock()
err = c0.Close()
if err != nil {
t.Fatal(err)
}
err = c1.Close()
if err != nil {
t.Fatal(err)
}
err = c2.Close()
if err != nil {
t.Fatal(err)
}
err = c3.Close()
if err != nil {
t.Fatal(err)
}
}
......@@ -7,6 +7,9 @@ cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context)
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
cc_library(lod_tensor SRCS lod_tensor.cc details/lod_tensor.cc DEPS ddim place tensor)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor)
cc_test(variable_test SRCS variable_test.cc)
cc_library(scope SRCS scope.cc)
......
......@@ -147,8 +147,9 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
ForEachVarName(grad_op->inputs_, [&no_grad_names,
&net](std::string& grad_input) {
if (no_grad_names.count(grad_input)) {
std::string prefix =
grad_input.substr(0, grad_input.size() - kGradVarSuffix.size());
// +1 for \0
std::string prefix = grad_input.substr(
0, grad_input.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
grad_input = prefix + kZeroVarSuffix;
// If part of input gradient of that operator is not calculated, fill
......@@ -184,7 +185,7 @@ std::shared_ptr<OperatorBase> Backward(
std::unordered_set<std::string> no_grad_names;
no_grad_names.reserve(no_grad_vars.size());
no_grad_names.insert(kEmptyVarName + kGradVarSuffix);
no_grad_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
for (auto& name : no_grad_vars) {
no_grad_names.insert(name + kGradVarSuffix);
......
......@@ -166,21 +166,17 @@ REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
TEST(Backward, need_to_be_removed) {}
//
// TEST(Backward, simple_op_grad) {
// auto fwd = f::OpRegistry::CreateOp(
// "rowwise_add", {{"X", {"X"}}, {"b", {"b"}}}, {{"Out", {"Out"}}}, {});
// auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
// ASSERT_NE(fwd, nullptr);
// auto gop = f::OpRegistry::CreateGradOp(*fwd);
// ASSERT_EQ(4UL, gop->inputs_.size());
// ASSERT_EQ(f::kEmptyVarName, gop->inputs_[0]);
// ASSERT_EQ("rowwise_add_grad", gop->type_);
// ASSERT_EQ("X" + f::kGradVarSuffix, gop->outputs_[0]);
// ASSERT_EQ("b" + f::kGradVarSuffix, gop->outputs_[1]);
// ASSERT_EQ(f::GradVarName("X"), gop->outputs_[0]);
// ASSERT_EQ(f::GradVarName("b"), gop->outputs_[1]);
//
// ASSERT_EQ("X" + f::kGradVarSuffix, gop->Output("X" + f::kGradVarSuffix));
// ASSERT_EQ(f::GradVarName("X"), gop->Output(f::GradVarName("X")));
//}
//
// TEST(Backward, simple_op_not_need_grad) {
......@@ -188,7 +184,7 @@ TEST(Backward, need_to_be_removed) {}
// ASSERT_NE(fwd, nullptr);
// auto gop = f::Backward(*fwd, {"X"});
// ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(),
// "X" + f::kGradVarSuffix),
// f::GradVarName("X")),
// gop->outputs_.end());
//
// auto no_input_gop = f::Backward(*fwd, {"X", "b"});
......@@ -259,18 +255,18 @@ TEST(Backward, need_to_be_removed) {}
// all_output.erase(f::kEmptyVarName);
//
// for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) {
// ASSERT_NE(all_output.find(out + f::kGradVarSuffix), all_output.end());
// ASSERT_NE(all_output.find(f::GradVarName(out)), all_output.end());
// }
//
// // Not Generated X
// ASSERT_EQ(all_output.find("X" + f::kGradVarSuffix), all_output.end());
// ASSERT_EQ(all_output.find(f::GradVarName("X")), all_output.end());
//
// ASSERT_EQ(2UL, bwd_net->ops_.size());
// ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
// auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get());
// ASSERT_EQ(3UL, first_fc_grad->ops_.size());
// ASSERT_EQ(f::kEmptyVarName,
// first_fc_grad->ops_[2]->Output("A" + f::kGradVarSuffix));
// first_fc_grad->ops_[2]->Output(f::GradVarName("A")));
//}
//
// TEST(Backward, net_shared_weight) {
......@@ -322,17 +318,15 @@ TEST(Backward, need_to_be_removed) {}
// ASSERT_EQ(1UL, fill_zero.inputs_.size());
// ASSERT_EQ("Z", fill_zero.inputs_[0]);
// ASSERT_EQ(1UL, fill_zero.outputs_.size());
// ASSERT_EQ("Z" + f::kZeroVarSuffix, fill_zero.outputs_[0]);
// ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, fill_zero.outputs_[0]);
//
// auto &d_many_out = *net->ops_[1];
// ASSERT_EQ("many_output_op_grad", d_many_out.type_);
// ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG
// ASSERT_EQ("Z" + f::kZeroVarSuffix, d_many_out.Input("z" +
// f::kGradVarSuffix));
// ASSERT_EQ("Y" + f::kGradVarSuffix, d_many_out.Input("y" +
// f::kGradVarSuffix));
// ASSERT_EQ("X" + f::kGradVarSuffix,
// d_many_out.Output("x" + f::kGradVarSuffix));
// ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix,
// d_many_out.Input(f::GradVarName("z")));
// ASSERT_EQ(f::GradVarName("Y"), d_many_out.Input(f::GradVarName("y")));
// ASSERT_EQ(f::GradVarName("X"), d_many_out.Output(f::GradVarName("x")));
//}
//
// TEST(Backward, op_part_of_input_are_not_need) {
......@@ -342,11 +336,9 @@ TEST(Backward, need_to_be_removed) {}
// ASSERT_EQ(grad_mul.type_, "mul_grad");
// ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
// ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
// ASSERT_EQ(grad_mul.Output("A" + f::kGradVarSuffix), f::kEmptyVarName);
// ASSERT_EQ(grad_mul.Output("B" + f::kGradVarSuffix), "b" +
// f::kGradVarSuffix);
// ASSERT_EQ(grad_mul.Input("Out" + f::kGradVarSuffix),
// "out" + f::kGradVarSuffix);
// ASSERT_EQ(grad_mul.Output(f::GradVarName("A")), f::kEmptyVarName);
// ASSERT_EQ(grad_mul.Output(f::GradVarName("B")), f::GradVarName("b"));
// ASSERT_EQ(grad_mul.Input(f::GradVarName("Out")), f::GradVarName("out"));
// ASSERT_EQ(grad_mul.Input("A"), "a");
// ASSERT_EQ(grad_mul.Input("B"), "b");
// ASSERT_EQ(grad_mul.Input("Out"), "out");
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include <memory>
namespace paddle {
namespace framework {
namespace details {
using LOD = LODTensor::LOD;
std::shared_ptr<LOD> SliceLOD(const LOD &lod, size_t level_begin,
size_t level_end) {
auto new_lod = std::make_shared<LOD>();
new_lod->reserve(level_end - level_begin);
for (size_t i = level_begin; i < level_end; i++) {
new_lod->emplace_back(lod[i]);
}
return new_lod;
}
std::shared_ptr<LOD> SliceLOD(const LOD &lod, size_t level, size_t elem_begin,
size_t elem_end, bool tensor_shared) {
// slice the lod.
auto new_lod = std::make_shared<LOD>();
new_lod->reserve(lod.size() - level);
auto start = lod.at(level)[elem_begin];
auto end = lod.at(level)[elem_end];
for (auto it = lod.begin() + level; it != lod.end(); it++) {
auto it_begin = std::find(it->begin(), it->end(), start);
auto it_end = std::find(it_begin, it->end(), end);
PADDLE_ENFORCE(it_begin != it->end(), "error in parsing lod info");
PADDLE_ENFORCE(it_end != it->end(), "error in parsing lod info");
new_lod->emplace_back(it_begin, it_end + 1);
if (!tensor_shared) {
// reset offset if tensor is copyed and sliced.
std::transform(new_lod->back().begin(), new_lod->back().end(),
new_lod->back().begin(),
[start](int v) { return v - start; });
PADDLE_ENFORCE(new_lod->back().front() == 0, "error in slice LOD");
}
}
return new_lod;
}
} // namespace details
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
namespace paddle {
namespace framework {
namespace details {
/*
* Slice levels from LOD.
*
* @lod: LOD to slice.
* @level_begin: level to begin slice.
* @level_end: level to end slice.
*/
std::shared_ptr<LODTensor::LOD> SliceLOD(const LODTensor::LOD &lod,
size_t level_begin, size_t level_end);
/*
* Slice elements from a level of LOD.
*
* @lod: LOD to slice.
* @level: which level to slice.
* @elem_begin: element's index to begin slice.
* @elem_end: element's index to end slice.
*/
std::shared_ptr<LODTensor::LOD> SliceLOD(const LODTensor::LOD &lod,
size_t level, size_t elem_begin,
size_t elem_end, bool tensor_shared);
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -40,8 +40,8 @@ message OpDesc {
};
message Var {
required string op_proto_name = 1;
repeated string var_names = 2;
required string parameter = 1;
repeated string arguments = 2;
};
required string type = 3;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include <glog/logging.h>
namespace paddle {
namespace framework {
LODTensor LODTensor::SliceShared(size_t level_begin, size_t level_end) const {
PADDLE_ENFORCE(HasLOD(), "has no LOD info, can't be sliced.");
auto new_lod = details::SliceLOD(*lod_start_pos_, level_begin, level_end);
// slice levels just need to update LOD info, each level will contains the
// whole tensor_, so no need to modify tensor_.
return LODTensor(tensor_, new_lod);
}
LODTensor LODTensor::SliceShared(size_t level, size_t elem_begin,
size_t elem_end) const {
PADDLE_ENFORCE(HasLOD(), "has no LOD info, can't be sliced.");
PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level,
NumLevels());
PADDLE_ENFORCE(elem_begin < NumElements(level),
"element begin [%d] out of range [%d]", elem_begin,
NumElements(level));
PADDLE_ENFORCE(elem_end < NumElements(level) + 1,
"element end [%d] out of range [%d]", elem_end,
NumElements(level));
auto new_lod = details::SliceLOD(*lod_start_pos_, level, elem_begin, elem_end,
true /*tensor_shared*/);
// slice elements just need to update LOD info, because offsets are not
// changed, so the original tensor_ can be reused.
return LODTensor(tensor_, new_lod);
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#if (!PADDLE_ONLY_CPU)
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#endif
#include "paddle/framework/ddim.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace framework {
/*
* LODTensor (Level of details Tensor)
* see https://en.wikipedia.org/wiki/Level_of_details for reference.
*/
class LODTensor {
public:
// Level save offsets of each unit.
#ifdef PADDLE_ONLY_CPU
using Level = std::vector<size_t>;
#else
using Level = thrust::device_vector<size_t>;
#endif
// LOD stores offsets of each level of units, the largest units level first,
// then the smaller units level. Each Level stores the offsets of units in
// Tesor.
typedef std::vector<Level> LOD;
LODTensor() {}
LODTensor(const std::shared_ptr<Tensor> &tensor,
const std::shared_ptr<LOD> &lod) {
Reset(tensor, lod);
}
void Reset(const std::shared_ptr<Tensor> &tensor,
const std::shared_ptr<LOD> &lod) {
tensor_ = tensor;
lod_start_pos_ = lod;
}
/*
* Get a element from LOD.
*/
size_t lod_element(size_t level, size_t elem) const {
PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level,
NumLevels());
PADDLE_ENFORCE(elem < NumElements(level),
"element begin [%d] out of range [%d]", elem,
NumElements(level));
return (*lod_start_pos_)[level][elem];
}
/*
* Number of LODTensor's levels, each level has units of data, for example,
* in the sentence's view, article, paragraph, sentence are 3 levels.
*/
size_t NumLevels() const {
return lod_start_pos_ ? lod_start_pos_->size() : 0UL;
}
/*
* Number of elements in a level.
*/
size_t NumElements(size_t level = 0) const {
PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level,
NumLevels());
// the last offset is the end of last element
return lod_start_pos_->at(level).size() - 1;
}
/*
* Slice of levels[level_begin:level_end], with tensor copied.
*/
template <typename T>
LODTensor SliceCopied(size_t level_begin, size_t level_end,
const platform::Place &dst_place) const;
/*
* Slice of levels[level_begin:level_end], with tensor shared.
*/
LODTensor SliceShared(size_t level_begin, size_t level_end) const;
/*
* Slice of elements of a level, [elem_begin: elem_end], with tensor copied.
* @note: low performance in slice lod_start_pos_.
*/
template <typename T>
LODTensor SliceCopied(size_t level, size_t elem_begin, size_t elem_end,
const platform::Place &dst_place) const;
/*
* Slice of elements of a level, [elem_begin: elem_end], with tensor shared.
* @note: low performance in slice lod_start_pos_.
*/
LODTensor SliceShared(size_t level, size_t elem_begin, size_t elem_end) const;
/*
* Copy other's lod_start_pos_, to share LOD info.
* @note: the LOD info should not be changed.
*/
void ShareLOD(const LODTensor &other) {
lod_start_pos_ = other.lod_start_pos_;
}
/*
* Copy other's lod_start_pos_'s content, free to mutate.
*/
void CopyLOD(const LODTensor &other) {
lod_start_pos_ = std::make_shared<LOD>(*other.lod_start_pos_);
}
/*
* Determine whether LODTensor has a valid LOD info.
*/
bool HasLOD() const { return bool(lod_start_pos_); }
LOD *lod() const { return lod_start_pos_.get(); }
std::shared_ptr<Tensor> &tensor() { return tensor_; }
Tensor *raw_tensor() { return tensor_.get(); }
private:
std::shared_ptr<LOD> lod_start_pos_;
std::shared_ptr<Tensor> tensor_;
};
} // namespace framework
} // namespace paddle
#include "paddle/framework/lod_tensor_impl.h"
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/details/lod_tensor.h"
namespace paddle {
namespace framework {
template <typename T>
LODTensor LODTensor::SliceCopied(size_t level_begin, size_t level_end,
const platform::Place &dst_place) const {
PADDLE_ENFORCE(HasLOD(), "has no LOD info, can't be sliced.");
auto new_lod = details::SliceLOD(*lod_start_pos_, level_begin, level_end);
auto new_tensor = std::make_shared<Tensor>();
new_tensor->CopyFrom<T>(*tensor_, dst_place);
return LODTensor(new_tensor, new_lod);
}
template <typename T>
LODTensor LODTensor::SliceCopied(size_t level, size_t elem_begin,
size_t elem_end,
const platform::Place &dst_place) const {
PADDLE_ENFORCE(HasLOD(), "has no LOD info, can't be sliced.");
PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level,
NumLevels());
PADDLE_ENFORCE(elem_begin < NumElements(level),
"element begin [%d] out of range [%d]", elem_begin,
NumElements(level));
PADDLE_ENFORCE(elem_end < NumElements(level) + 1,
"element end [%d] out of range [%d]", elem_end,
NumElements(level));
auto new_lod = details::SliceLOD(*lod_start_pos_, level, elem_begin, elem_end,
false /*tensor_shared*/);
auto start_idx = new_lod->front().front();
auto end_idx = new_lod->front().back() - 1 /*the next element's start*/;
auto sliced_tensor = tensor_->Slice<T>(start_idx, end_idx);
auto new_tensor = std::make_shared<Tensor>();
new_tensor->CopyFrom<T>(sliced_tensor, dst_place);
return LODTensor(new_tensor, new_lod);
}
} // namespace framework
} // namespace paddle
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "paddle/framework/lod_tensor.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <memory>
namespace paddle {
namespace framework {
class LODTensorTester : public ::testing::Test {
public:
virtual void SetUp() override {
lod_tensor.reset(new LODTensor);
// tensor's batch_size: 30
// 3 levels
// 0 10 20
// 0 5 10 15 20
// 0 2 5 7 10 12 15 20
auto lod = std::make_shared<LODTensor::LOD>();
lod->push_back(std::vector<size_t>{0, 10, 20});
lod->push_back(std::vector<size_t>{0, 5, 10, 15, 20});
lod->push_back(std::vector<size_t>{0, 2, 5, 7, 10, 12, 15, 17, 20});
auto tensor = std::make_shared<Tensor>();
tensor->Resize({20 /*batch size*/, 128 /*dim*/});
// malloc memory
tensor->mutable_data<float>(place);
lod_tensor->Reset(tensor, lod);
}
protected:
std::unique_ptr<LODTensor> lod_tensor;
platform::CPUPlace place;
};
TEST_F(LODTensorTester, NumLevels) { ASSERT_EQ(lod_tensor->NumLevels(), 3UL); }
TEST_F(LODTensorTester, NumElements) {
ASSERT_EQ(lod_tensor->NumElements(0), 2UL);
ASSERT_EQ(lod_tensor->NumElements(1), 4UL);
ASSERT_EQ(lod_tensor->NumElements(2), 8UL);
}
TEST_F(LODTensorTester, SliceShared_Level) {
// slice 1 level
for (size_t level = 0; level < 3UL; ++level) {
auto new_lod_tensor = lod_tensor->SliceShared(level, level + 1);
ASSERT_EQ(new_lod_tensor.NumLevels(), 1UL);
ASSERT_EQ(new_lod_tensor.NumElements(0UL), lod_tensor->NumElements(level));
ASSERT_EQ(new_lod_tensor.tensor(), lod_tensor->tensor());
}
// slice 2 level
for (size_t level = 0; level < 2UL; ++level) {
auto new_lod_tensor = lod_tensor->SliceShared(level, level + 2);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor->NumElements(level));
ASSERT_EQ(new_lod_tensor.NumElements(1),
lod_tensor->NumElements(level + 1));
ASSERT_EQ(new_lod_tensor.tensor(), lod_tensor->tensor());
}
}
TEST_F(LODTensorTester, SliceCopied_Level) {
// slice 1 level
for (size_t level = 0; level < 3UL; ++level) {
auto new_lod_tensor =
lod_tensor->SliceCopied<float>(level, level + 1, place);
ASSERT_EQ(new_lod_tensor.NumLevels(), 1UL);
ASSERT_EQ(new_lod_tensor.NumElements(0UL), lod_tensor->NumElements(level));
// ASSERT_EQ(new_lod_tensor.tensor(), lod_tensor->tensor());
// TODO(superjom) add tensor comparation here.
}
// slice 2 level
for (size_t level = 0; level < 2UL; ++level) {
auto new_lod_tensor =
lod_tensor->SliceCopied<float>(level, level + 2, place);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor->NumElements(level));
ASSERT_EQ(new_lod_tensor.NumElements(1),
lod_tensor->NumElements(level + 1));
// ASSERT_EQ(new_lod_tensor.tensor(), lod_tensor->tensor());
// TODO(superjom) add tensor comparation here.
}
}
TEST_F(LODTensorTester, SliceShared_Element) {
size_t level = 0;
auto new_lod_tensor = lod_tensor->SliceShared(level, 0, 2);
ASSERT_EQ(new_lod_tensor.NumLevels(), 3UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_EQ(new_lod_tensor.NumElements(2), 8UL);
ASSERT_EQ(new_lod_tensor.raw_tensor(), lod_tensor->raw_tensor());
level = 1;
new_lod_tensor = lod_tensor->SliceShared(level, 0, 2);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_EQ(new_lod_tensor.raw_tensor(), lod_tensor->raw_tensor());
}
TEST_F(LODTensorTester, SliceCopied_Element) {
size_t level = 0;
auto new_lod_tensor = lod_tensor->SliceCopied<float>(level, 0, 2, place);
ASSERT_EQ(new_lod_tensor.NumLevels(), 3UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_EQ(new_lod_tensor.NumElements(2), 8UL);
ASSERT_NE(new_lod_tensor.raw_tensor(), lod_tensor->raw_tensor());
level = 1;
new_lod_tensor = lod_tensor->SliceCopied<float>(level, 0, 2, place);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_NE(new_lod_tensor.raw_tensor(), lod_tensor->raw_tensor());
level = 1;
// LOD is
// 0 5 10
// 0 2 5 7 10
new_lod_tensor = lod_tensor->SliceCopied<float>(level, 1, 3, place);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_EQ(new_lod_tensor.lod_element(0, 0), 0UL);
ASSERT_EQ(new_lod_tensor.lod_element(0, 1), 5UL);
ASSERT_EQ(new_lod_tensor.lod_element(1, 0), 0UL);
ASSERT_EQ(new_lod_tensor.lod_element(1, 1), 2UL);
ASSERT_EQ(new_lod_tensor.lod_element(1, 2), 5UL);
ASSERT_EQ(new_lod_tensor.lod_element(1, 3), 7UL);
// TODO(superjom) compare the content of these tensors
}
TEST_F(LODTensorTester, ShareLOD) {
LODTensor new_lod_tensor;
new_lod_tensor.ShareLOD(*lod_tensor);
ASSERT_EQ(new_lod_tensor.lod(), lod_tensor->lod());
}
TEST_F(LODTensorTester, CopyLOD) {
LODTensor new_lod_tensor;
new_lod_tensor.CopyLOD(*lod_tensor);
ASSERT_NE(new_lod_tensor.lod(), lod_tensor->lod());
}
} // namespace framework
} // namespace paddle
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
namespace paddle {
......@@ -127,7 +128,7 @@ class OpRegistry {
static void RegisterOp(const std::string& op_type) {
op_creators()[op_type] = [] { return new OpType; };
OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = protos()[op_type];
OpProto& op_proto = OpProtos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker);
maker.Validate();
*op_proto.mutable_type() = op_type;
......@@ -135,17 +136,6 @@ class OpRegistry {
op_proto.IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_proto.InitializationErrorString());
VarIndexMaps()[op_type].reset(new VarIndexMap());
auto& varmap = *VarIndexMaps()[op_type];
int idx = 0;
for (auto& var : op_proto.inputs()) {
varmap[var.name()] = idx++;
}
idx = 0;
for (auto& var : op_proto.outputs()) {
varmap[var.name()] = idx++;
}
}
template <typename GradOpType>
......@@ -180,8 +170,8 @@ class OpRegistry {
static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) {
VarNameMap inputs;
for (auto& input : op_desc.inputs()) {
auto& var_names = inputs[input.op_proto_name()];
auto& var_names_in_proto = input.var_names();
auto& var_names = inputs[input.parameter()];
auto& var_names_in_proto = input.arguments();
var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
std::back_inserter(var_names));
......@@ -189,8 +179,8 @@ class OpRegistry {
VarNameMap outputs;
for (auto& output : op_desc.outputs()) {
auto& var_names = outputs[output.op_proto_name()];
auto& var_names_in_proto = output.var_names();
auto& var_names = outputs[output.parameter()];
auto& var_names_in_proto = output.arguments();
var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
std::back_inserter(var_names));
......@@ -212,22 +202,11 @@ class OpRegistry {
return grad_op;
}
static std::unordered_map<std::string, OpProto>& protos() {
static std::unordered_map<std::string, OpProto> protos_;
return protos_;
}
static std::unordered_map<std::string, std::string>& grad_ops() {
static std::unordered_map<std::string, std::string> grad_ops_;
return grad_ops_;
}
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
VarIndexMaps() {
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>> maps_;
return maps_;
}
static std::unordered_map<std::string, OpCreator>& op_creators() {
static std::unordered_map<std::string, OpCreator> op_creators_;
return op_creators_;
......
......@@ -58,12 +58,12 @@ TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
auto input = op_desc.add_inputs();
input->set_op_proto_name("input");
*input->mutable_var_names()->Add() = "aa";
input->set_parameter("input");
*input->mutable_arguments()->Add() = "aa";
auto output = op_desc.add_outputs();
output->set_op_proto_name("output");
*output->mutable_var_names()->Add() = "bb";
output->set_parameter("output");
*output->mutable_arguments()->Add() = "bb";
float scale = 3.3;
auto attr = op_desc.mutable_attrs()->Add();
......@@ -84,12 +84,12 @@ TEST(OpRegistry, IllegalAttr) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
auto input = op_desc.add_inputs();
input->set_op_proto_name("input");
*input->mutable_var_names()->Add() = "aa";
input->set_parameter("input");
*input->mutable_arguments()->Add() = "aa";
auto output = op_desc.add_outputs();
output->set_op_proto_name("output");
*output->mutable_var_names()->Add() = "bb";
output->set_parameter("output");
*output->mutable_arguments()->Add() = "bb";
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
......@@ -114,12 +114,12 @@ TEST(OpRegistry, DefaultValue) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
auto input = op_desc.add_inputs();
input->set_op_proto_name("input");
*input->mutable_var_names()->Add() = "aa";
input->set_parameter("input");
*input->mutable_arguments()->Add() = "aa";
auto output = op_desc.add_outputs();
output->set_op_proto_name("output");
*output->mutable_var_names()->Add() = "bb";
output->set_parameter("output");
*output->mutable_arguments()->Add() = "bb";
ASSERT_TRUE(op_desc.IsInitialized());
......@@ -135,12 +135,12 @@ TEST(OpRegistry, CustomChecker) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("my_test_op");
auto input = op_desc.add_inputs();
input->set_op_proto_name("input");
*input->mutable_var_names()->Add() = "ii";
input->set_parameter("input");
*input->mutable_arguments()->Add() = "ii";
auto output = op_desc.add_outputs();
output->set_op_proto_name("output");
*output->mutable_var_names()->Add() = "oo";
output->set_parameter("output");
*output->mutable_arguments()->Add() = "oo";
// attr 'test_attr' is not set
bool caught = false;
......
......@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/framework/operator.h"
#include <algorithm>
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
......@@ -33,6 +33,14 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
}
#endif
static std::unordered_map<std::string, OpProto>* g_op_protos = nullptr;
std::unordered_map<std::string, OpProto>& OpProtos() {
if (g_op_protos == nullptr) {
g_op_protos = new std::unordered_map<std::string, OpProto>();
}
return *g_op_protos;
}
const std::string& OperatorBase::Input(const std::string& name) const {
auto it = inputs_.find(name);
PADDLE_ENFORCE(it != inputs_.end(), "Op %s does not have output %s", type_,
......
......@@ -32,24 +32,26 @@ namespace paddle {
namespace framework {
/// If a variable is a empty variable, that name will be used.
const std::string kEmptyVarName = "@EMPTY@";
constexpr char kEmptyVarName[] = "@EMPTY@";
/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
const std::string kTempVarName = "@TEMP@";
constexpr char kTempVarName[] = "@TEMP@";
/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
const std::string kGradVarSuffix = "@GRAD";
constexpr char kGradVarSuffix[] = "@GRAD";
/// Variables with this suffix are supposed to be filled up with zeros.
const std::string kZeroVarSuffix = "@ZERO";
constexpr char kZeroVarSuffix[] = "@ZERO";
inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix;
}
extern std::unordered_map<std::string, OpProto>& OpProtos();
class OperatorBase;
class InferShapeContext;
class ExecutionContext;
......@@ -103,6 +105,35 @@ class OperatorBase {
//! TODO add a vector_view to prevent memory copy.
const std::vector<std::string>& Outputs(const std::string& name) const;
virtual std::vector<std::string> OutputVars(bool has_intermediate) const {
std::vector<std::string> ret_val;
if (has_intermediate) {
// push all outputs into ret_val
for (auto& o : outputs_) {
ret_val.reserve(ret_val.size() + o.second.size());
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
}
return ret_val;
}
auto it = OpProtos().find(type_);
PADDLE_ENFORCE(
it != OpProtos().end(),
"Operator %s not registered, cannot figure out intermediate outputs",
type_);
// get all OpProto::Var for outputs
for (auto& o : it->second.outputs()) {
// ignore all intermediate output
if (o.intermediate()) continue;
auto out = outputs_.find(o.name());
if (out != outputs_.end()) {
ret_val.reserve(ret_val.size() + out->second.size());
ret_val.insert(ret_val.end(), out->second.begin(), out->second.end());
}
}
return ret_val;
}
public:
std::string type_;
// NOTE: in case of OpGrad, inputs_ contains:
......@@ -117,10 +148,10 @@ class OperatorBase {
AttributeMap attrs_;
};
class OperatorContext {
class InferShapeContext {
public:
OperatorContext(const OperatorBase* op, const Scope& scope)
: op_(*op), scope_(scope) {}
InferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}
size_t InputSize(const std::string& name) const {
return op_.inputs_.at(name).size();
......@@ -209,12 +240,6 @@ class OperatorContext {
const Scope& scope_;
};
class InferShapeContext : public OperatorContext {
public:
InferShapeContext(const OperatorBase* op, const Scope& scope)
: OperatorContext(op, scope) {}
};
template <typename T>
struct EigenDeviceConverter;
......@@ -230,11 +255,11 @@ struct EigenDeviceConverter<platform::GPUPlace> {
};
#endif
class ExecutionContext : public OperatorContext {
class ExecutionContext : public InferShapeContext {
public:
ExecutionContext(const OperatorBase* op, const Scope& scope,
ExecutionContext(const OperatorBase& op, const Scope& scope,
const platform::DeviceContext* device_context)
: OperatorContext(op, scope), device_context_(device_context) {}
: InferShapeContext(op, scope), device_context_(device_context) {}
template <typename PlaceType,
typename DeviceType =
......@@ -286,13 +311,13 @@ class OperatorWithKernel : public OperatorBase {
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void InferShape(const Scope& scope) const override {
InferShape(InferShapeContext(this, scope));
InferShape(InferShapeContext(*this, scope));
}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(ExecutionContext(this, scope, &dev_ctx));
opKernel->Compute(ExecutionContext(*this, scope, &dev_ctx));
}
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
......
......@@ -61,12 +61,12 @@ TEST(OperatorBase, all) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("test_operator");
auto* ipt = op_desc.mutable_inputs()->Add();
*ipt->mutable_var_names()->Add() = "IN1";
ipt->set_op_proto_name("input");
*ipt->mutable_arguments()->Add() = "IN1";
ipt->set_parameter("input");
auto* output = op_desc.mutable_outputs()->Add();
*output->mutable_var_names()->Add() = "OUT1";
output->set_op_proto_name("output");
*output->mutable_arguments()->Add() = "OUT1";
output->set_parameter("output");
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
......@@ -184,12 +184,12 @@ TEST(OpKernel, all) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("op_with_kernel");
auto* ipt = op_desc.mutable_inputs()->Add();
*ipt->mutable_var_names()->Add() = "IN1";
ipt->set_op_proto_name("x");
*ipt->mutable_arguments()->Add() = "IN1";
ipt->set_parameter("x");
auto* output = op_desc.mutable_outputs()->Add();
*output->mutable_var_names()->Add() = "OUT1";
output->set_op_proto_name("y");
*output->mutable_arguments()->Add() = "OUT1";
output->set_parameter("y");
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
......@@ -217,17 +217,17 @@ TEST(OpKernel, multi_inputs) {
OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel");
auto x = op_desc.mutable_inputs()->Add();
x->set_op_proto_name("xs");
*x->mutable_var_names()->Add() = "x0";
*x->mutable_var_names()->Add() = "x1";
*x->mutable_var_names()->Add() = "x2";
x->set_parameter("xs");
*x->mutable_arguments()->Add() = "x0";
*x->mutable_arguments()->Add() = "x1";
*x->mutable_arguments()->Add() = "x2";
auto k = op_desc.mutable_inputs()->Add();
k->set_op_proto_name("k");
*k->mutable_var_names()->Add() = "k0";
k->set_parameter("k");
*k->mutable_arguments()->Add() = "k0";
auto y = op_desc.mutable_outputs()->Add();
y->set_op_proto_name("ys");
*y->mutable_var_names()->Add() = "y0";
*y->mutable_var_names()->Add() = "y1";
y->set_parameter("ys");
*y->mutable_arguments()->Add() = "y0";
*y->mutable_arguments()->Add() = "y1";
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/operators/net_op.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "paddle/string/to_string.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
......@@ -186,9 +187,13 @@ All parameter, weight, gradient are variables in Paddle.
});
// clang-format on
py::class_<paddle::platform::GPUPlace>(m, "GPUPlace").def(py::init<int>());
py::class_<platform::GPUPlace>(m, "GPUPlace")
.def(py::init<int>())
.def("__str__", string::to_string<const platform::GPUPlace &>);
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace").def(py::init<>());
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace")
.def(py::init<>())
.def("__str__", string::to_string<const platform::CPUPlace &>);
py::class_<OperatorBase, std::shared_ptr<OperatorBase>> operator_base(
m, "Operator");
......
......@@ -18,6 +18,8 @@ limitations under the License. */
#include <cstring>
#include <memory>
#include <typeindex>
#include <vector>
#include "paddle/framework/ddim.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/device_context.h"
......
......@@ -19,7 +19,7 @@ TEST(Tensor, Dims) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor tt;
tt.Resize(make_ddim({2, 3, 4}));
tt.Resize({2, 3, 4});
DDim dims = tt.dims();
ASSERT_EQ(arity(dims), 3);
for (int i = 0; i < 3; ++i) {
......
......@@ -93,8 +93,8 @@ TEST(Arguments, Matrix) {
MatrixPtr matrix = Matrix::create(100, 200);
CheckBufferArg check = [=](const BufferArg& arg) {
EXPECT_EQ(arg.shape().ndims(), 2U);
EXPECT_EQ(arg.shape()[0], 100);
EXPECT_EQ(arg.shape()[1], 200);
EXPECT_EQ(arg.shape()[0], 100U);
EXPECT_EQ(arg.shape()[1], 200U);
EXPECT_EQ(arg.data(), matrix->getData());
EXPECT_EQ(arg.matrix<DEVICE_TYPE_CPU>().getHeight(), matrix->getHeight());
......@@ -112,8 +112,8 @@ TEST(Arguments, Matrix) {
TEST(Arguments, Vector) {
VectorPtr vector = Vector::create(100, false);
CheckBufferArg check = [=](const BufferArg& arg) {
EXPECT_EQ(arg.shape().ndims(), 1);
EXPECT_EQ(arg.shape()[0], 100);
EXPECT_EQ(arg.shape().ndims(), 1U);
EXPECT_EQ(arg.shape()[0], 100U);
EXPECT_EQ(arg.data(), vector->getData());
CpuVector inVector = arg.vector<real, DEVICE_TYPE_CPU>();
......@@ -131,9 +131,9 @@ TEST(Arguments, Vector) {
TEST(Arguments, CpuSparseMatrix) {
CpuSparseMatrix sparse(200, 300, 50);
CheckBufferArg check = [=](const BufferArg& arg) {
EXPECT_EQ(arg.shape().ndims(), 2);
EXPECT_EQ(arg.shape()[0], 200);
EXPECT_EQ(arg.shape()[1], 300);
EXPECT_EQ(arg.shape().ndims(), 2U);
EXPECT_EQ(arg.shape()[0], 200U);
EXPECT_EQ(arg.shape()[1], 300U);
EXPECT_EQ(arg.data(), sparse.getData());
// CHECK_EQ(arg.sparse().nnz(), 50);
// CHECK_EQ(arg.sparse().dataFormat(), SPARSE_CSR_FORMAT);
......@@ -152,10 +152,10 @@ TEST(Arguments, CpuSparseMatrix) {
TEST(Arguments, BufferArg) {
BufferArg arg(nullptr, VALUE_TYPE_FLOAT, {1, 2, 3});
CheckBufferArg check = [=](const BufferArg& arg) {
EXPECT_EQ(arg.shape().ndims(), 3);
EXPECT_EQ(arg.shape()[0], 1);
EXPECT_EQ(arg.shape()[1], 2);
EXPECT_EQ(arg.shape()[2], 3);
EXPECT_EQ(arg.shape().ndims(), 3U);
EXPECT_EQ(arg.shape()[0], 1U);
EXPECT_EQ(arg.shape()[1], 2U);
EXPECT_EQ(arg.shape()[2], 3U);
};
BufferArgs argments;
......
......@@ -44,7 +44,7 @@ TEST(TensorShape, GetAndSet) {
EXPECT_EQ(t.ndims(), 3U);
EXPECT_EQ(t.getElements(), 6U);
EXPECT_EQ(t[1], 2);
EXPECT_EQ(t[1], 2U);
t.setDim(1, 100);
EXPECT_EQ(t.getElements(), 300U);
EXPECT_EQ(t[1], 100U);
......
......@@ -96,7 +96,7 @@ void SubNestedSequenceLayer::calSelectedCols(
for (size_t i = 0; i < seqNum; ++i) {
for (size_t j = 0; j < beamSize; ++j) {
if (selectedIndices->getElement(i, j) == -1.) break;
int selSubSeqIdx = selectedIndices->getElement(i, j);
size_t selSubSeqIdx = selectedIndices->getElement(i, j);
CHECK_GT(inputSeqInfoVec_[i].size() - 1, selSubSeqIdx);
size_t subSeqLen = inputSeqInfoVec_[i][selSubSeqIdx + 1] -
......@@ -135,7 +135,7 @@ void SubNestedSequenceLayer::forward(PassType passType) {
CHECK(inputSeq.hasSubseq()) << "The first input of SubNestSequence layer "
<< "must be a nested sequence.";
const MatrixPtr selectedIndices = getInputValue(1);
CHECK_EQ(inputSeq.getNumSequences(), selectedIndices->getHeight());
CHECK_EQ(size_t(inputSeq.getNumSequences()), selectedIndices->getHeight());
if (dynamic_cast<GpuMatrix*>(selectedIndices.get())) {
/*
......
......@@ -88,7 +88,7 @@ void checkLayerOut(vector<vector<int>> groundTruth,
TEST(Layer, kmaxSeqScoreLayer) {
const size_t maxBeamSize = 100;
int beamSize = 1 + (rand() % maxBeamSize);
size_t beamSize = 1 + (rand() % maxBeamSize);
vector<int> seqStartPosition;
vector<int> subSeqStartPosition;
......
......@@ -45,10 +45,8 @@ cc_library(net_op SRCS net_op.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
op_library(add_op SRCS add_op.cc add_op.cu)
cc_test(add_op_test SRCS add_op_test.cc DEPS add_op)
op_library(mean_op SRCS mean_op.cc mean_op.cu)
cc_test(mean_op_test SRCS mean_op_test.cc DEPS mean_op)
op_library(mul_op SRCS mul_op.cc mul_op.cu)
op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
......@@ -59,7 +57,6 @@ op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu)
op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu)
op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
cc_test(sgd_op_test SRCS sgd_op_test.cc DEPS sgd_op)
op_library(fc_op
SRCS fc_op.cc
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#define private public
#include "paddle/framework/op_registry.h"
USE_OP(add_two);
TEST(AddOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos();
auto it = protos.find("add_two");
ASSERT_NE(it, protos.end());
auto& op_creators = paddle::framework::OpRegistry::op_creators();
auto it1 = op_creators.find("add_two_grad");
ASSERT_NE(it1, op_creators.end());
}
......@@ -39,7 +39,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
class MeanGradOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>("X" + framework::kGradVarSuffix)
ctx.Output<Tensor>(framework::GradVarName("X"))
->Resize(ctx.Input<Tensor>("X")->dims());
}
};
......
......@@ -48,10 +48,10 @@ template <typename Place, typename T>
class MeanGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto OG = context.Input<Tensor>("Out" + framework::kGradVarSuffix);
auto OG = context.Input<Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(framework::product(OG->dims()) == 1,
"Mean Gradient should be scalar");
auto IG = context.Output<Tensor>("X" + framework::kGradVarSuffix);
auto IG = context.Output<Tensor>(framework::GradVarName("X"));
IG->mutable_data<T>(context.GetPlace());
T ig_size = (T)framework::product(IG->dims());
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <paddle/framework/op_registry.h>
USE_OP(mean);
TEST(MeanOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos();
auto it = protos.find("mean");
ASSERT_NE(it, protos.end());
}
......@@ -21,19 +21,20 @@
namespace paddle {
namespace operators {
const char NetOp::kAll[] = "all";
void NetOp::CompleteAddOp(bool calc) {
add_op_done_ = true;
if (!calc) return;
std::set<std::string> input_set;
std::set<std::string> output_set;
std::set<std::string> temp_output;
for (auto& op : ops_) {
for (auto& ipt : op->inputs_) {
for (auto& var_name : ipt.second) {
if (!Contains(output_set, var_name)) { // Not other op's output
input_set.insert(var_name);
} else {
temp_output.insert(var_name);
intermediate_outputs_.insert(var_name);
}
}
}
......@@ -44,24 +45,12 @@ void NetOp::CompleteAddOp(bool calc) {
}
}
}
auto& inputs = inputs_["all"];
auto& inputs = inputs_[kAll];
inputs.reserve(input_set.size());
std::copy(input_set.begin(), input_set.end(), std::back_inserter(inputs));
auto& outputs = outputs_["all"];
auto& outputs = outputs_[kAll];
outputs.reserve(output_set.size());
std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs));
//! TODO figure out how to generate temporary_index in Network.
std::vector<int> tmp_index;
tmp_index.reserve(temp_output.size());
int output_len = static_cast<int>(outputs.size());
for (int i = 0; i < output_len; ++i) {
if (Contains(temp_output, outputs[i])) {
tmp_index.push_back(i);
}
}
attrs_["temporary_index"] = tmp_index;
}
std::string NetOp::DebugString() const {
......@@ -78,5 +67,19 @@ std::string NetOp::DebugString() const {
bool NetOp::IsNetOp() const { return true; }
std::vector<std::string> NetOp::OutputVars(bool has_intermediate) const {
if (has_intermediate) {
return this->outputs_.at(kAll);
}
auto& all = this->outputs_.at(kAll);
std::vector<std::string> ret_val;
for (auto& each : all) {
if (!Contains(intermediate_outputs_, each)) {
ret_val.push_back(each);
}
}
return ret_val;
}
} // namespace operators
} // namespace paddle
......@@ -36,6 +36,8 @@ namespace operators {
*/
class NetOp : public framework::OperatorBase {
public:
static const char kAll[];
/**
* Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch
......@@ -91,11 +93,13 @@ class NetOp : public framework::OperatorBase {
std::string DebugString() const override;
bool IsNetOp() const override;
std::vector<std::string> OutputVars(bool has_intermediate) const override;
std::vector<std::shared_ptr<OperatorBase>> ops_;
private:
bool add_op_done_{false};
std::set<std::string> intermediate_outputs_;
template <typename T, typename KeyType>
static bool Contains(T container, KeyType key) {
......
......@@ -54,22 +54,13 @@ TEST(OpKernel, all) {
net->CompleteAddOp();
AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"},
net->inputs_.at("__all__"));
AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_.at("__all__"));
auto tmp_idx_iter = net->attrs_.find("temporary_index");
ASSERT_NE(net->attrs_.end(), tmp_idx_iter);
auto& tmp_idx = boost::get<std::vector<int>>(tmp_idx_iter->second);
ASSERT_EQ(1UL, tmp_idx.size());
ASSERT_EQ("y", net->outputs_.at("__all__")[tmp_idx[0]]);
net->inputs_.at(NetOp::kAll));
AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_.at(NetOp::kAll));
Scope scope;
platform::CPUDeviceContext dev_ctx;
auto final_outs = net->OutputVars(false);
net->InferShape(scope);
net->Run(scope, dev_ctx);
ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), platform::EnforceNotMet);
ASSERT_EQ(final_outs.size(), 1UL);
ASSERT_EQ(final_outs[0], "z");
}
TEST(NetOp, insert_op) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <paddle/framework/op_registry.h>
USE_OP(sgd);
TEST(SGDOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos();
auto it = protos.find("sgd");
ASSERT_NE(it, protos.end());
}
......@@ -8,7 +8,7 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
add_subdirectory(dynload)
cc_test(enforce_test SRCS enforce_test.cc)
cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece)
IF(WITH_GPU)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
......
......@@ -15,11 +15,12 @@ limitations under the License. */
#pragma once
#include <execinfo.h>
#include <paddle/string/printf.h>
#include <iomanip>
#include <sstream>
#include <stdexcept>
#include <string>
#include "paddle/string/printf.h"
#include "paddle/string/to_string.h"
#ifndef PADDLE_ONLY_CPU
......@@ -191,27 +192,11 @@ inline void throw_on_error(T e) {
PADDLE_ENFORCE(nullptr != (__VAL), #__VAL " should not be null\n%s", \
paddle::string::Sprintf("" __VA_ARGS__));
template <typename T>
inline std::string enforce_to_string(const T& val) {
std::ostringstream sout;
sout << val;
return sout.str();
}
template <>
inline std::string enforce_to_string(const std::string& val) {
return val;
}
template <>
inline std::string enforce_to_string(const char* const& val) {
return std::string(val);
}
#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \
PADDLE_ENFORCE(__VAL0 __CMP __VAL1, \
"enforce %s " #__CMP " %s failed, %s " #__INV_CMP " %s\n%s", \
#__VAL0, #__VAL1, \
paddle::platform::enforce_to_string(__VAL0), \
paddle::platform::enforce_to_string(__VAL1), \
#__VAL0, #__VAL1, paddle::string::to_string(__VAL0), \
paddle::string::to_string(__VAL1), \
paddle::string::Sprintf("" __VA_ARGS__));
} // namespace platform
......
......@@ -9,10 +9,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <array>
#include <iostream>
#include <memory>
#include "gtest/gtest.h"
#include "paddle/platform/enforce.h"
#include "paddle/string/piece.h"
using StringPiece = paddle::string::Piece;
using paddle::string::HasPrefix;
TEST(ENFORCE, OK) {
PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345);
......@@ -22,19 +28,15 @@ TEST(ENFORCE, OK) {
}
TEST(ENFORCE, FAILED) {
bool in_catch = false;
bool caught_exception = false;
try {
PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123);
} catch (paddle::platform::EnforceNotMet error) {
// your error handling code here
in_catch = true;
std::string msg = "Enforce is not ok 123 at all";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
caught_exception = true;
EXPECT_TRUE(
HasPrefix(StringPiece(error.what()), "Enforce is not ok 123 at all"));
}
ASSERT_TRUE(in_catch);
EXPECT_TRUE(caught_exception);
}
TEST(ENFORCE, NO_ARG_OK) {
......@@ -47,41 +49,27 @@ TEST(ENFORCE, NO_ARG_OK) {
TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) {
int a = 2;
bool in_catch = false;
bool caught_exception = false;
try {
PADDLE_ENFORCE_EQ(a, 1 + 3);
} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg = "enforce a == 1 + 3 failed, 2 != 4";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
caught_exception = true;
HasPrefix(StringPiece(error.what()), "enforce a == 1 + 3 failed, 2 != 4");
}
ASSERT_TRUE(in_catch);
EXPECT_TRUE(caught_exception);
}
TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) {
int a = 2;
bool in_catch = false;
bool caught_exception = false;
try {
PADDLE_ENFORCE_EQ(a, 1 + 3, "%s size not match", "their");
} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg =
"enforce a == 1 + 3 failed, 2 != 4\ntheir size not match";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
caught_exception = true;
HasPrefix(StringPiece(error.what()),
"enforce a == 1 + 3 failed, 2 != 4\ntheir size not match");
}
ASSERT_TRUE(in_catch);
EXPECT_TRUE(caught_exception);
}
TEST(ENFORCE_NE, OK) {
......@@ -89,42 +77,32 @@ TEST(ENFORCE_NE, OK) {
PADDLE_ENFORCE_NE(1.0, 2UL);
}
TEST(ENFORCE_NE, FAIL) {
bool in_catch = false;
bool caught_exception = false;
try {
// 2UL here to check data type compatible
PADDLE_ENFORCE_NE(1.0, 1UL);
} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg = "enforce 1.0 != 1UL failed, 1.000000 == 1";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
caught_exception = true;
EXPECT_TRUE(HasPrefix(StringPiece(error.what()),
"enforce 1.0 != 1UL failed, 1 == 1"))
<< error.what() << " does not have expected prefix";
}
ASSERT_TRUE(in_catch);
EXPECT_TRUE(caught_exception);
}
TEST(ENFORCE_GT, OK) { PADDLE_ENFORCE_GT(2, 1); }
TEST(ENFORCE_GT, FAIL) {
bool in_catch = false;
bool caught_exception = false;
try {
// 2UL here to check data type compatible
PADDLE_ENFORCE_GT(1, 2UL);
} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg = "enforce 1 > 2UL failed, 1 <= 2";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
caught_exception = true;
EXPECT_TRUE(
HasPrefix(StringPiece(error.what()), "enforce 1 > 2UL failed, 1 <= 2"));
}
ASSERT_TRUE(in_catch);
EXPECT_TRUE(caught_exception);
}
TEST(ENFORCE_GE, OK) {
......@@ -134,21 +112,16 @@ TEST(ENFORCE_GE, OK) {
PADDLE_ENFORCE_GE(3.21, 2UL);
}
TEST(ENFORCE_GE, FAIL) {
bool in_catch = false;
bool caught_exception = false;
try {
PADDLE_ENFORCE_GE(1, 2UL);
} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg = "enforce 1 >= 2UL failed, 1 < 2";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
caught_exception = true;
EXPECT_TRUE(
HasPrefix(StringPiece(error.what()), "enforce 1 >= 2UL failed, 1 < 2"));
}
ASSERT_TRUE(in_catch);
EXPECT_TRUE(caught_exception);
}
TEST(ENFORCE_LE, OK) {
......@@ -159,21 +132,16 @@ TEST(ENFORCE_LE, OK) {
PADDLE_ENFORCE_LE(2UL, 3.2);
}
TEST(ENFORCE_LE, FAIL) {
bool in_catch = false;
bool caught_exception = false;
try {
PADDLE_ENFORCE_GT(1, 2UL);
} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg = "enforce 1 > 2UL failed, 1 <= 2";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
caught_exception = true;
EXPECT_TRUE(
HasPrefix(StringPiece(error.what()), "enforce 1 > 2UL failed, 1 <= 2"));
}
ASSERT_TRUE(in_catch);
EXPECT_TRUE(caught_exception);
}
TEST(ENFORCE_LT, OK) {
......@@ -182,21 +150,15 @@ TEST(ENFORCE_LT, OK) {
PADDLE_ENFORCE_LT(2UL, 3);
}
TEST(ENFORCE_LT, FAIL) {
bool in_catch = false;
bool caught_exception = false;
try {
PADDLE_ENFORCE_LT(1UL, 0.12);
} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg = "enforce 1UL < 0.12 failed, 1 >= 0.12";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
caught_exception = true;
EXPECT_TRUE(HasPrefix(StringPiece(error.what()),
"enforce 1UL < 0.12 failed, 1 >= 0.12"));
}
ASSERT_TRUE(in_catch);
EXPECT_TRUE(caught_exception);
}
TEST(ENFORCE_NOT_NULL, OK) {
......@@ -205,20 +167,50 @@ TEST(ENFORCE_NOT_NULL, OK) {
delete a;
}
TEST(ENFORCE_NOT_NULL, FAIL) {
bool in_catch = false;
int* a{nullptr};
bool caught_exception = false;
try {
int* a = nullptr;
PADDLE_ENFORCE_NOT_NULL(a);
} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg = "a should not be null";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
caught_exception = true;
EXPECT_TRUE(HasPrefix(StringPiece(error.what()), "a should not be null"));
}
EXPECT_TRUE(caught_exception);
}
struct Dims {
size_t dims_[4];
bool operator==(const Dims& o) const {
for (size_t i = 0; i < 4; ++i) {
if (dims_[i] != o.dims_[i]) return false;
}
return true;
}
};
ASSERT_TRUE(in_catch);
std::ostream& operator<<(std::ostream& os, const Dims& d) {
for (size_t i = 0; i < 4; ++i) {
if (i == 0) {
os << "[";
}
os << d.dims_[i];
if (i == 4 - 1) {
os << "]";
} else {
os << ", ";
}
}
return os;
}
TEST(ENFORCE_USER_DEFINED_CLASS, EQ) {
Dims a{{1, 2, 3, 4}}, b{{1, 2, 3, 4}};
PADDLE_ENFORCE_EQ(a, b);
}
TEST(ENFORCE_USER_DEFINED_CLASS, NE) {
Dims a{{1, 2, 3, 4}}, b{{5, 6, 7, 8}};
ASSERT_THROW(PADDLE_ENFORCE_EQ(a, b), paddle::platform::EnforceNotMet);
}
\ No newline at end of file
......@@ -2,3 +2,4 @@ cc_library(stringpiece SRCS piece.cc)
cc_test(stringpiece_test SRCS piece_test.cc DEPS stringpiece glog gflags)
cc_test(stringprintf_test SRCS printf_test.cc DEPS glog gflags)
cc_test(to_string_test SRCS to_string_test.cc)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sstream>
#include <string>
namespace paddle {
namespace string {
template <typename T>
inline std::string to_string(T v) {
std::ostringstream sout;
sout << v;
return sout.str();
}
// Faster std::string/const char* type
template <>
inline std::string to_string(std::string v) {
return v;
}
template <>
inline std::string to_string(const char* v) {
return std::string(v);
}
} // namespace string
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/string/to_string.h"
#include <gtest/gtest.h>
constexpr char kOutputString[] = "User Defined Output";
class UserDefinedClass {
public:
};
std::ostream& operator<<(std::ostream& s, const UserDefinedClass& ins) {
s << kOutputString;
return s;
}
TEST(to_string, normal) {
using namespace paddle::string;
ASSERT_EQ("10", to_string(10));
ASSERT_EQ("abc", to_string("abc"));
ASSERT_EQ("1.2", to_string(1.2));
}
TEST(to_string, user_defined) {
using namespace paddle::string;
UserDefinedClass instance;
ASSERT_EQ(kOutputString, to_string(instance));
}
\ No newline at end of file
......@@ -50,8 +50,8 @@ void NewRemoteParameterUpdater::init(
// create parameter server client.
if (useEtcd_) {
parameterClient_ = paddle_new_etcd_pserver_client(
(char *)pserverSpec_.c_str(), FLAGS_trainer_id == 0);
parameterClient_ =
paddle_new_etcd_pserver_client((char *)pserverSpec_.c_str());
} else {
parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(),
FLAGS_trainer_id == 0);
......
......@@ -92,15 +92,27 @@ def get_numeric_gradient(op,
class GradientChecker(unittest.TestCase):
def __is_close(self, numeric_grads, scope, max_relative_error):
def assert_is_close(self, numeric_grads, scope, max_relative_error,
msg_prefix):
for name in numeric_grads:
op_grad = numpy.array(
scope.find_var(grad_var_name(name)).get_tensor())
is_close = numpy.allclose(
numeric_grads[name], op_grad, rtol=max_relative_error, atol=100)
if not is_close:
return False
return True
b = numpy.array(scope.find_var(grad_var_name(name)).get_tensor())
a = numeric_grads[name]
abs_a = numpy.abs(a)
# if abs_a is nearly zero, then use abs error for a, not relative
# error.
abs_a[abs_a < 1e-3] = 1
diff_mat = numpy.abs(a - b) / abs_a
max_diff = numpy.max(diff_mat)
def err_msg():
offset = numpy.argmax(diff_mat > max_relative_error)
return "%s Variable %s max gradient diff %f over limit %f, the first " \
"error element is %d" % (
msg_prefix, name, max_diff, max_relative_error, offset)
self.assertLessEqual(max_diff, max_relative_error, err_msg())
def check_grad(self,
forward_op,
......@@ -145,7 +157,8 @@ class GradientChecker(unittest.TestCase):
# get numeric gradient
for check_name in inputs_to_check:
numeric_grad[check_name] = \
get_numeric_gradient(forward_op, input_vars, output_name, check_name)
get_numeric_gradient(forward_op, input_vars, output_name,
check_name)
# get operator gradient according to different device
for place in places:
......@@ -187,15 +200,8 @@ class GradientChecker(unittest.TestCase):
backward_op.infer_shape(scope)
backward_op.run(scope, ctx)
if isinstance(place, core.CPUPlace):
msg = "CPU kernel gradient is not close to numeric gradient"
else:
if isinstance(place, core.GPUPlace):
msg = "GPU kernel gradient is not close to numeric gradient"
else:
raise ValueError("unknown place " + type(place))
self.assertTrue(
self.__is_close(numeric_grad, scope, max_relative_error), msg)
self.assert_is_close(numeric_grad, scope, max_relative_error,
"Gradient Check On %s" % str(place))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册