Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
6dc567a5
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6dc567a5
编写于
7月 17, 2017
作者:
Q
qijun
浏览文件
操作
浏览文件
下载
差异文件
merge baidu/develop
上级
5017b154
a4eaf2d3
变更
31
显示空白变更内容
内联
并排
Showing
31 changed file
with
428 addition
and
154 deletion
+428
-154
Dockerfile.android
Dockerfile.android
+11
-0
cmake/cross_compiling/android.cmake
cmake/cross_compiling/android.cmake
+8
-3
go/cmd/master/master.go
go/cmd/master/master.go
+11
-3
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+1
-1
go/master/client.go
go/master/client.go
+3
-2
go/master/service.go
go/master/service.go
+1
-0
go/pserver/client/c/test/test_train.py
go/pserver/client/c/test/test_train.py
+23
-5
go/pserver/client/etcd_client.go
go/pserver/client/etcd_client.go
+3
-2
go/pserver/etcd_client.go
go/pserver/etcd_client.go
+6
-5
paddle/api/PaddleAPI.h
paddle/api/PaddleAPI.h
+2
-1
paddle/api/ParameterUpdater.cpp
paddle/api/ParameterUpdater.cpp
+3
-2
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+7
-4
paddle/framework/ddim.cc
paddle/framework/ddim.cc
+49
-5
paddle/framework/ddim.h
paddle/framework/ddim.h
+9
-0
paddle/framework/ddim_test.cc
paddle/framework/ddim_test.cc
+20
-0
paddle/framework/dim_test.cu
paddle/framework/dim_test.cu
+82
-81
paddle/framework/enforce.cc
paddle/framework/enforce.cc
+15
-0
paddle/framework/enforce.h
paddle/framework/enforce.h
+6
-0
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+23
-8
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+34
-2
paddle/framework/tensor.cc
paddle/framework/tensor.cc
+19
-0
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+43
-5
paddle/scripts/docker/build.sh
paddle/scripts/docker/build.sh
+2
-1
paddle/scripts/docker/build_android.sh
paddle/scripts/docker/build_android.sh
+3
-6
paddle/trainer/NewRemoteParameterUpdater.cpp
paddle/trainer/NewRemoteParameterUpdater.cpp
+18
-2
paddle/trainer/NewRemoteParameterUpdater.h
paddle/trainer/NewRemoteParameterUpdater.h
+5
-0
python/paddle/v2/dataset/common.py
python/paddle/v2/dataset/common.py
+6
-4
python/paddle/v2/dataset/mq2007.py
python/paddle/v2/dataset/mq2007.py
+3
-3
python/paddle/v2/master/client.py
python/paddle/v2/master/client.py
+3
-2
python/paddle/v2/optimizer.py
python/paddle/v2/optimizer.py
+5
-5
python/paddle/v2/trainer.py
python/paddle/v2/trainer.py
+4
-2
未找到文件。
Dockerfile.android
浏览文件 @
6dc567a5
...
@@ -14,6 +14,17 @@ RUN apt-get update && \
...
@@ -14,6 +14,17 @@ RUN apt-get update && \
wget curl tar unzip gcc g++ locales clang-format-3.8 swig cmake && \
wget curl tar unzip gcc g++ locales clang-format-3.8 swig cmake && \
apt-get clean -y
apt-get clean -y
# Install Go and glide
RUN wget -O go.tgz https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz && \
tar -C /usr/local -xzf go.tgz && \
mkdir /root/gopath && \
mkdir /root/gopath/bin && \
mkdir /root/gopath/src && \
rm go.tgz
ENV GOROOT=/usr/local/go GOPATH=/root/gopath
# should not be in the same line with GOROOT definition, otherwise docker build could not find GOROOT.
ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin
# git credential to skip password typing
# git credential to skip password typing
RUN git config --global credential.helper store
RUN git config --global credential.helper store
...
...
cmake/cross_compiling/android.cmake
浏览文件 @
6dc567a5
...
@@ -108,6 +108,7 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
...
@@ -108,6 +108,7 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
ENDIF
()
ENDIF
()
IF
(
ANDROID_ABI STREQUAL
"arm64-v8a"
)
IF
(
ANDROID_ABI STREQUAL
"arm64-v8a"
)
SET
(
ANDROID_TOOLCHAIN_NAME aarch64-linux-android
)
SET
(
ANDROID_TOOLCHAIN_NAME aarch64-linux-android
)
SET
(
CMAKE_SYSTEM_PROCESSOR aarch64
)
ENDIF
()
ENDIF
()
SET
(
ANDROID_TOOLCHAIN_PREFIX
"
${
ANDROID_TOOLCHAIN_ROOT
}
/bin/
${
ANDROID_TOOLCHAIN_NAME
}
-"
)
SET
(
ANDROID_TOOLCHAIN_PREFIX
"
${
ANDROID_TOOLCHAIN_ROOT
}
/bin/
${
ANDROID_TOOLCHAIN_NAME
}
-"
)
ENDIF
()
ENDIF
()
...
@@ -193,6 +194,10 @@ ELSE()
...
@@ -193,6 +194,10 @@ ELSE()
SET
(
CMAKE_ANDROID_STANDALONE_TOOLCHAIN
${
ANDROID_STANDALONE_TOOLCHAIN
}
)
SET
(
CMAKE_ANDROID_STANDALONE_TOOLCHAIN
${
ANDROID_STANDALONE_TOOLCHAIN
}
)
ENDIF
()
ENDIF
()
SET
(
CMAKE_ANDROID_ARCH_ABI
${
ANDROID_ABI
}
)
SET
(
CMAKE_ANDROID_ARCH_ABI
${
ANDROID_ABI
}
)
IF
(
ANDROID_ABI MATCHES
"^armeabi(-v7a)?$"
)
SET
(
CMAKE_ANDROID_ARM_MODE
${
ANDROID_ARM_MODE
}
)
SET
(
CMAKE_ANDROID_ARM_MODE
${
ANDROID_ARM_MODE
}
)
IF
(
ANDROID_ABI STREQUAL
"armeabi-v7a"
)
SET
(
CMAKE_ANDROID_ARM_NEON
${
ANDROID_ARM_NEON
}
)
SET
(
CMAKE_ANDROID_ARM_NEON
${
ANDROID_ARM_NEON
}
)
ENDIF
()
ENDIF
()
ENDIF
()
ENDIF
()
go/cmd/master/master.go
浏览文件 @
6dc567a5
...
@@ -11,6 +11,7 @@ import (
...
@@ -11,6 +11,7 @@ import (
"github.com/namsral/flag"
"github.com/namsral/flag"
log
"github.com/sirupsen/logrus"
log
"github.com/sirupsen/logrus"
"github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
...
@@ -20,11 +21,18 @@ func main() {
...
@@ -20,11 +21,18 @@ func main() {
port
:=
flag
.
Int
(
"port"
,
8080
,
"port of the master server."
)
port
:=
flag
.
Int
(
"port"
,
8080
,
"port of the master server."
)
ttlSec
:=
flag
.
Int
(
"ttl"
,
60
,
"etcd lease TTL in seconds."
)
ttlSec
:=
flag
.
Int
(
"ttl"
,
60
,
"etcd lease TTL in seconds."
)
endpoints
:=
flag
.
String
(
"endpoints"
,
"http://127.0.0.1:2379"
,
"comma separated etcd endpoints. If empty, fault tolerance will not be enabled."
)
endpoints
:=
flag
.
String
(
"endpoints"
,
"http://127.0.0.1:2379"
,
"comma separated etcd endpoints. If empty, fault tolerance will not be enabled."
)
taskTimeoutDur
:=
flag
.
Duration
(
"task_timout_dur"
,
20
*
time
.
Minute
,
"task timout duration."
)
taskTimeoutDur
:=
flag
.
Duration
(
"task-timout-dur"
,
20
*
time
.
Minute
,
"task timout duration."
)
taskTimeoutMax
:=
flag
.
Int
(
"task_timeout_max"
,
3
,
"max timtout count for each task before it being declared failed task."
)
taskTimeoutMax
:=
flag
.
Int
(
"task-timeout-max"
,
3
,
"max timtout count for each task before it being declared failed task."
)
chunkPerTask
:=
flag
.
Int
(
"chunk_per_task"
,
10
,
"chunk per task."
)
chunkPerTask
:=
flag
.
Int
(
"chunk-per-task"
,
10
,
"chunk per task."
)
logLevel
:=
flag
.
String
(
"log-level"
,
"info"
,
"log level, possible values: debug, info, warning, error, fatal, panic"
)
flag
.
Parse
()
flag
.
Parse
()
level
,
e
:=
log
.
ParseLevel
(
*
logLevel
)
candy
.
Must
(
e
)
log
.
SetLevel
(
level
)
if
*
endpoints
==
""
{
if
*
endpoints
==
""
{
log
.
Warningln
(
"-endpoints not set, fault tolerance not be enabled."
)
log
.
Warningln
(
"-endpoints not set, fault tolerance not be enabled."
)
}
}
...
...
go/cmd/pserver/pserver.go
浏览文件 @
6dc567a5
...
@@ -40,7 +40,7 @@ func main() {
...
@@ -40,7 +40,7 @@ func main() {
idx
=
*
index
idx
=
*
index
}
else
{
}
else
{
e
=
pserver
.
NewEtcdClient
(
*
etcdEndpoint
,
*
numPservers
,
*
etcdTimeout
)
e
=
pserver
.
NewEtcdClient
(
*
etcdEndpoint
,
*
numPservers
,
*
etcdTimeout
)
idx
,
err
=
e
.
Register
()
idx
,
err
=
e
.
Register
(
*
port
)
candy
.
Must
(
err
)
candy
.
Must
(
err
)
cp
,
err
=
pserver
.
NewCheckpointFromFile
(
*
checkpointPath
,
idx
,
e
)
cp
,
err
=
pserver
.
NewCheckpointFromFile
(
*
checkpointPath
,
idx
,
e
)
...
...
go/master/client.go
浏览文件 @
6dc567a5
...
@@ -2,6 +2,7 @@ package master
...
@@ -2,6 +2,7 @@ package master
import
(
import
(
"os"
"os"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
"github.com/PaddlePaddle/recordio"
...
@@ -36,9 +37,9 @@ func (c *Client) getRecords() {
...
@@ -36,9 +37,9 @@ func (c *Client) getRecords() {
for
{
for
{
t
,
err
:=
c
.
getTask
()
t
,
err
:=
c
.
getTask
()
if
err
!=
nil
{
if
err
!=
nil
{
// TODO(helin): wait before move on with next
// getTask call.
// getTask call.
log
.
Errorln
(
err
)
log
.
Errorf
(
"Get task failed, sleep 3 seconds and continue, %s"
,
err
)
time
.
Sleep
(
3
*
time
.
Second
)
continue
continue
}
}
...
...
go/master/service.go
浏览文件 @
6dc567a5
...
@@ -215,6 +215,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
...
@@ -215,6 +215,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
}
}
count
:=
index
.
NumChunks
()
count
:=
index
.
NumChunks
()
log
.
Infof
(
"readChunks: file %s has %d chunks"
,
path
,
count
)
for
i
:=
0
;
i
<
count
;
i
++
{
for
i
:=
0
;
i
<
count
;
i
++
{
chunk
:=
Chunk
{
chunk
:=
Chunk
{
Path
:
path
,
Path
:
path
,
...
...
go/pserver/client/c/test/test_train.py
浏览文件 @
6dc567a5
import
paddle.v2
as
paddle
import
paddle.v2
as
paddle
import
paddle.v2.dataset.uci_housing
as
uci_housing
import
paddle.v2.dataset.uci_housing
as
uci_housing
import
paddle.v2.master
as
master
import
os
import
cPickle
as
pickle
etcd_ip
=
os
.
getenv
(
"MASTER_IP"
,
"127.0.0.1"
)
etcd_endpoint
=
"http://"
+
etcd_ip
+
":2379"
def
cloud_reader
():
print
"connecting to master, etcd endpoints: "
,
etcd_endpoint
master_client
=
master
.
client
(
etcd_endpoint
,
5
,
64
)
master_client
.
set_dataset
(
[
"/pfs/dlnel/public/dataset/uci_housing/uci_housing-*-of-*"
])
while
1
:
r
,
e
=
master_client
.
next_record
()
if
not
r
:
break
yield
pickle
.
loads
(
r
)
def
main
():
def
main
():
...
@@ -22,13 +40,13 @@ def main():
...
@@ -22,13 +40,13 @@ def main():
# create optimizer of new remote updater to pserver
# create optimizer of new remote updater to pserver
optimizer
=
paddle
.
optimizer
.
Momentum
(
momentum
=
0
)
optimizer
=
paddle
.
optimizer
.
Momentum
(
momentum
=
0
)
#TODO(zhihong) : replace optimizer with new OptimizerConfig
print
"etcd endoint: "
,
etcd_endpoint
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
cost
,
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
cost
,
parameters
=
parameters
,
parameters
=
parameters
,
update_equation
=
optimizer
,
update_equation
=
optimizer
,
is_local
=
False
,
is_local
=
False
,
pserver_spec
=
"localhost:3000"
)
pserver_spec
=
etcd_endpoint
,
use_etcd
=
True
)
# event_handler to print training and testing info
# event_handler to print training and testing info
def
event_handler
(
event
):
def
event_handler
(
event
):
...
@@ -47,11 +65,11 @@ def main():
...
@@ -47,11 +65,11 @@ def main():
print
"Test %d, %.2f"
%
(
event
.
pass_id
,
result
.
cost
)
print
"Test %d, %.2f"
%
(
event
.
pass_id
,
result
.
cost
)
# training
# training
# NOTE: use uci_housing.train() as reader for non-paddlecloud training
trainer
.
train
(
trainer
.
train
(
reader
=
paddle
.
batch
(
reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
reader
.
shuffle
(
uci_housing
.
train
(),
buf_size
=
500
),
cloud_reader
,
buf_size
=
500
),
batch_size
=
2
),
batch_size
=
2
),
feeding
=
{
'x'
:
0
,
feeding
=
{
'x'
:
0
,
'y'
:
1
},
'y'
:
1
},
event_handler
=
event_handler
,
event_handler
=
event_handler
,
...
...
go/pserver/client/etcd_client.go
浏览文件 @
6dc567a5
...
@@ -12,6 +12,7 @@ import (
...
@@ -12,6 +12,7 @@ import (
)
)
const
(
const
(
// DefaultEtcdTimeout is the default etcd timeout
DefaultEtcdTimeout
time
.
Duration
=
5
*
time
.
Second
DefaultEtcdTimeout
time
.
Duration
=
5
*
time
.
Second
)
)
...
@@ -66,12 +67,12 @@ func (p *EtcdClient) List() []Server {
...
@@ -66,12 +67,12 @@ func (p *EtcdClient) List() []Server {
for
{
for
{
for
i
:=
0
;
i
<
psDesired
;
i
++
{
for
i
:=
0
;
i
<
psDesired
;
i
++
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
p
.
timeout
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
p
.
timeout
)
cancel
()
psKey
:=
pserver
.
PsPath
+
strconv
.
Itoa
(
i
)
psKey
:=
pserver
.
PsPath
+
strconv
.
Itoa
(
i
)
log
.
Debugf
(
"checking %s"
,
psKey
)
log
.
Debugf
(
"checking %s"
,
psKey
)
resp
,
err
:=
p
.
client
.
Get
(
ctx
,
psKey
)
resp
,
err
:=
p
.
client
.
Get
(
ctx
,
psKey
)
cancel
()
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Infof
(
"Get psKey=
%s error, %v"
,
psKey
,
err
)
log
.
Infof
(
"Get psKey=%s error, %v"
,
psKey
,
err
)
time
.
Sleep
(
p
.
timeout
)
time
.
Sleep
(
p
.
timeout
)
continue
continue
}
}
...
...
go/pserver/etcd_client.go
浏览文件 @
6dc567a5
...
@@ -49,7 +49,7 @@ func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *Et
...
@@ -49,7 +49,7 @@ func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *Et
// Register registers the pserver on etcd
// Register registers the pserver on etcd
//
//
// Register returns the index of the current pserver.
// Register returns the index of the current pserver.
func
(
e
*
EtcdClient
)
Register
()
(
int
,
error
)
{
func
(
e
*
EtcdClient
)
Register
(
port
int
)
(
int
,
error
)
{
var
err
error
var
err
error
e
.
externalIP
,
err
=
networkhelper
.
GetExternalIP
()
e
.
externalIP
,
err
=
networkhelper
.
GetExternalIP
()
...
@@ -116,7 +116,7 @@ func (e *EtcdClient) Register() (int, error) {
...
@@ -116,7 +116,7 @@ func (e *EtcdClient) Register() (int, error) {
for
{
for
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
)
var
err
error
var
err
error
pserverIdx
,
err
=
e
.
registerPserverEtcd
(
ctx
)
pserverIdx
,
err
=
e
.
registerPserverEtcd
(
ctx
,
port
)
cancel
()
cancel
()
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Warn
(
err
)
log
.
Warn
(
err
)
...
@@ -140,7 +140,7 @@ func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (
...
@@ -140,7 +140,7 @@ func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (
}
}
// registerPserverEtcd registers pserver node on etcd using transaction.
// registerPserverEtcd registers pserver node on etcd using transaction.
func
(
e
*
EtcdClient
)
registerPserverEtcd
(
ctx
context
.
Context
)
(
int
,
error
)
{
func
(
e
*
EtcdClient
)
registerPserverEtcd
(
ctx
context
.
Context
,
port
int
)
(
int
,
error
)
{
var
idx
int
var
idx
int
_
,
err
:=
concurrency
.
NewSTM
(
e
.
etcdClient
,
func
(
c
concurrency
.
STM
)
error
{
_
,
err
:=
concurrency
.
NewSTM
(
e
.
etcdClient
,
func
(
c
concurrency
.
STM
)
error
{
registered
:=
false
registered
:=
false
...
@@ -156,8 +156,9 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
...
@@ -156,8 +156,9 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
log
.
Fatal
(
err
)
log
.
Fatal
(
err
)
}
}
// find the first id and write info
// find the first id and write info
c
.
Put
(
psKey
,
e
.
externalIP
,
clientv3
.
WithLease
(
resp
.
ID
))
pserverAddr
:=
e
.
externalIP
+
":"
+
strconv
.
Itoa
(
port
)
log
.
Debugf
(
"set pserver node %s with value %s"
,
psKey
,
e
.
externalIP
)
c
.
Put
(
psKey
,
pserverAddr
,
clientv3
.
WithLease
(
resp
.
ID
))
log
.
Debugf
(
"set pserver node %s with value %s"
,
psKey
,
pserverAddr
)
ch
,
kaerr
:=
e
.
etcdClient
.
KeepAlive
(
context
.
TODO
(),
resp
.
ID
)
ch
,
kaerr
:=
e
.
etcdClient
.
KeepAlive
(
context
.
TODO
(),
resp
.
ID
)
if
kaerr
!=
nil
{
if
kaerr
!=
nil
{
log
.
Errorf
(
"keepalive etcd node error: %v"
,
kaerr
)
log
.
Errorf
(
"keepalive etcd node error: %v"
,
kaerr
)
...
...
paddle/api/PaddleAPI.h
浏览文件 @
6dc567a5
...
@@ -843,7 +843,8 @@ public:
...
@@ -843,7 +843,8 @@ public:
bool
useSparseUpdater
);
bool
useSparseUpdater
);
static
ParameterUpdater
*
createNewRemoteUpdater
(
static
ParameterUpdater
*
createNewRemoteUpdater
(
OptimizationConfig
*
config
,
OptimizationConfig
*
config
,
const
std
::
string
pserverSpec
)
throw
(
UnsupportError
);
const
std
::
string
pserverSpec
,
const
bool
useEtcd
)
throw
(
UnsupportError
);
~
ParameterUpdater
();
~
ParameterUpdater
();
/**
/**
...
...
paddle/api/ParameterUpdater.cpp
浏览文件 @
6dc567a5
...
@@ -33,11 +33,12 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater(
...
@@ -33,11 +33,12 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater(
ParameterUpdater
*
ParameterUpdater
::
createNewRemoteUpdater
(
ParameterUpdater
*
ParameterUpdater
::
createNewRemoteUpdater
(
OptimizationConfig
*
config
,
OptimizationConfig
*
config
,
const
std
::
string
pserverSpec
)
throw
(
UnsupportError
)
{
const
std
::
string
pserverSpec
,
const
bool
useEtcd
)
throw
(
UnsupportError
)
{
#ifndef PADDLE_WITHOUT_GOLANG
#ifndef PADDLE_WITHOUT_GOLANG
auto
updater
=
new
ParameterUpdater
();
auto
updater
=
new
ParameterUpdater
();
updater
->
m
->
updater
.
reset
(
new
paddle
::
NewRemoteParameterUpdater
(
updater
->
m
->
updater
.
reset
(
new
paddle
::
NewRemoteParameterUpdater
(
config
->
m
->
getConfig
(),
pserverSpec
));
config
->
m
->
getConfig
(),
pserverSpec
,
useEtcd
));
return
updater
;
return
updater
;
#else
#else
throw
UnsupportError
();
throw
UnsupportError
();
...
...
paddle/framework/CMakeLists.txt
浏览文件 @
6dc567a5
# ddim lib
# ddim lib
cc_library
(
enforce SRCS enforce.cc DEPS glog
)
cc_test
(
enforce_test SRCS enforce_test.cc DEPS enforce
)
cc_library
(
ddim SRCS ddim.cc DEPS eigen3
)
cc_library
(
ddim SRCS ddim.cc DEPS eigen3
)
cc_test
(
ddim_test SRCS ddim_test.cc DEPS ddim
)
cc_test
(
ddim_test SRCS ddim_test.cc DEPS ddim
)
nv_test
(
dim_test SRCS dim_test.cu DEPS ddim
)
nv_test
(
dim_test SRCS dim_test.cu DEPS ddim
)
cc_test
(
tensor_test SRCS tensor_test.cc DEPS ddim
)
cc_library
(
tensor SRCS tensor.cc DEPS ddim place enforce paddle_memory
)
cc_test
(
tensor_test SRCS tensor_test.cc DEPS tensor
)
cc_test
(
variable_test SRCS variable_test.cc
)
cc_test
(
variable_test SRCS variable_test.cc
)
cc_test
(
scope_test SRCS scope_test.cc
)
cc_test
(
scope_test SRCS scope_test.cc
)
cc_test
(
enforce_test SRCS enforce_test.cc
)
proto_library
(
attr_type SRCS attr_type.proto
)
proto_library
(
attr_type SRCS attr_type.proto
)
proto_library
(
op_proto SRCS op_proto.proto DEPS attr_type
)
proto_library
(
op_proto SRCS op_proto.proto DEPS attr_type
)
cc_test
(
op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf
)
cc_test
(
op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf
)
proto_library
(
op_desc SRCS op_desc.proto DEPS attr_type
)
proto_library
(
op_desc SRCS op_desc.proto DEPS attr_type
)
cc_test
(
op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf
)
cc_test
(
op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf
)
cc_library
(
operator SRCS operator.cc DEPS op_desc device_context
)
cc_library
(
operator SRCS operator.cc DEPS op_desc device_context
tensor
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto op_desc
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto op_desc
enforce
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry operator
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry operator
)
py_proto_compile
(
framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto
)
py_proto_compile
(
framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto
)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target
(
framework_py_proto_init ALL COMMAND
${
CMAKE_COMMAND
}
-E touch __init__.py
)
add_custom_target
(
framework_py_proto_init ALL COMMAND
${
CMAKE_COMMAND
}
-E touch __init__.py
)
...
...
paddle/framework/ddim.cc
浏览文件 @
6dc567a5
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/framework/ddim.h"
#include "paddle/framework/ddim.h"
#include "paddle/framework/enforce.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -192,13 +193,56 @@ std::vector<int> vectorize(const DDim& ddim) {
...
@@ -192,13 +193,56 @@ std::vector<int> vectorize(const DDim& ddim) {
return
result
;
return
result
;
}
}
struct
ProductVisitor
:
public
boost
::
static_visitor
<
ssize_t
>
{
template
<
int
D
>
ssize_t
operator
()(
const
Dim
<
D
>&
dim
)
{
return
product
(
dim
);
}
};
ssize_t
product
(
const
DDim
&
ddim
)
{
ssize_t
product
(
const
DDim
&
ddim
)
{
ssize_t
result
=
1
;
ProductVisitor
visitor
;
std
::
vector
<
int
>
v
=
vectorize
(
ddim
);
return
boost
::
apply_visitor
(
visitor
,
ddim
);
for
(
auto
i
:
v
)
{
}
result
*=
i
;
struct
SliceVectorizeVisitor
:
public
boost
::
static_visitor
<>
{
std
::
vector
<
int
>&
vector
;
int
begin
;
int
end
;
SliceVectorizeVisitor
(
std
::
vector
<
int
>&
v
,
int
b
,
int
e
)
:
vector
(
v
),
begin
(
b
),
end
(
e
)
{
PADDLE_ENFORCE
(
begin
<
end
,
"Begin index must be less than end index in ddim slice."
);
PADDLE_ENFORCE
(
begin
>=
0
,
"Begin index can't be less than zero in ddim slice."
);
}
}
return
result
;
template
<
int
S
>
void
operator
()(
const
Dim
<
S
>&
dim
)
{
if
(
begin
==
0
)
{
vector
.
push_back
(
dim
.
head
);
}
else
{
--
begin
;
}
--
end
;
if
(
end
>
0
)
{
this
->
operator
()(
dim
.
tail
);
}
}
void
operator
()(
const
Dim
<
1
>&
dim
)
{
PADDLE_ENFORCE
(
end
==
1
,
"End index in ddim slice is out of bound."
);
vector
.
push_back
(
dim
.
head
);
}
};
DDim
slice_ddim
(
const
DDim
&
dim
,
int
begin
,
int
end
)
{
std
::
vector
<
int
>
vec
;
vec
.
reserve
(
end
-
begin
);
SliceVectorizeVisitor
visitor
(
vec
,
begin
,
end
);
boost
::
apply_visitor
(
visitor
,
dim
);
return
make_ddim
(
vec
);
}
}
/// \cond HIDDEN
/// \cond HIDDEN
...
...
paddle/framework/ddim.h
浏览文件 @
6dc567a5
...
@@ -96,6 +96,15 @@ std::vector<int> vectorize(const DDim& ddim);
...
@@ -96,6 +96,15 @@ std::vector<int> vectorize(const DDim& ddim);
ssize_t
product
(
const
DDim
&
ddim
);
ssize_t
product
(
const
DDim
&
ddim
);
/**
* \brief Slice a ddim
*
* Slice dim with [begin, end).
* e.g. DDim d = make_ddim({1,2,3,4,5});
* slice_ddim(d, 1, 3); ====> {2,3}
*/
DDim
slice_ddim
(
const
DDim
&
dim
,
int
begin
,
int
end
);
/**
/**
* \brief What is the length of this dimension?
* \brief What is the length of this dimension?
*
*
...
...
paddle/framework/ddim_test.cc
浏览文件 @
6dc567a5
...
@@ -52,6 +52,26 @@ TEST(DDim, Equality) {
...
@@ -52,6 +52,26 @@ TEST(DDim, Equality) {
// product of a DDim
// product of a DDim
EXPECT_EQ
(
paddle
::
framework
::
product
(
vddim
),
45
);
EXPECT_EQ
(
paddle
::
framework
::
product
(
vddim
),
45
);
EXPECT_EQ
(
paddle
::
framework
::
product
(
paddle
::
framework
::
make_ddim
({
3
,
2
,
5
,
3
})),
90
);
// slice a DDim
paddle
::
framework
::
DDim
ddim2
=
paddle
::
framework
::
make_ddim
({
1
,
2
,
3
,
4
,
5
,
6
});
paddle
::
framework
::
DDim
ss
=
paddle
::
framework
::
slice_ddim
(
ddim2
,
2
,
5
);
EXPECT_EQ
(
arity
(
ss
),
3
);
EXPECT_EQ
(
ss
[
0
],
3
);
EXPECT_EQ
(
ss
[
1
],
4
);
EXPECT_EQ
(
ss
[
2
],
5
);
paddle
::
framework
::
DDim
ss2
=
paddle
::
framework
::
slice_ddim
(
ddim2
,
0
,
6
);
EXPECT_EQ
(
arity
(
ss2
),
6
);
EXPECT_EQ
(
ss2
[
0
],
1
);
EXPECT_EQ
(
ss2
[
1
],
2
);
EXPECT_EQ
(
ss2
[
2
],
3
);
EXPECT_EQ
(
ss2
[
3
],
4
);
EXPECT_EQ
(
ss2
[
4
],
5
);
EXPECT_EQ
(
ss2
[
5
],
6
);
}
}
TEST
(
DDim
,
Print
)
{
TEST
(
DDim
,
Print
)
{
...
...
paddle/framework/dim_test.cu
浏览文件 @
6dc567a5
#include <thrust/device_vector.h>
#include <thrust/device_vector.h>
#include <sstream>
#include <sstream>
#include "paddle/framework/dim.h"
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/framework/dim.h"
__global__
void
test
(
paddle
::
framework
::
Dim
<
2
>*
o
)
{
__global__
void
test
(
paddle
::
framework
::
Dim
<
2
>*
o
)
{
o
[
0
]
=
paddle
::
framework
::
make_dim
(
5
,
6
);
o
[
0
]
=
paddle
::
framework
::
make_dim
(
5
,
6
);
...
@@ -21,7 +21,7 @@ TEST(Dim, Equality) {
...
@@ -21,7 +21,7 @@ TEST(Dim, Equality) {
// construct a Dim on the GPU
// construct a Dim on the GPU
thrust
::
device_vector
<
paddle
::
framework
::
Dim
<
2
>>
t
(
2
);
thrust
::
device_vector
<
paddle
::
framework
::
Dim
<
2
>>
t
(
2
);
test
<<<
1
,
1
>>>
(
thrust
::
raw_pointer_cast
(
t
.
data
()));
test
<<<
1
,
1
>>>
(
thrust
::
raw_pointer_cast
(
t
.
data
()));
a
=
t
[
0
];
a
=
t
[
0
];
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
a
),
5
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
a
),
5
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
a
),
6
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
a
),
6
);
...
@@ -48,12 +48,13 @@ TEST(Dim, Equality) {
...
@@ -48,12 +48,13 @@ TEST(Dim, Equality) {
// dynamic access on GPU
// dynamic access on GPU
thrust
::
device_vector
<
int
>
r
(
1
);
thrust
::
device_vector
<
int
>
r
(
1
);
dyn_idx_gpu
<<<
1
,
1
>>>
(
thrust
::
raw_pointer_cast
(
r
.
data
()));
dyn_idx_gpu
<<<
1
,
1
>>>
(
thrust
::
raw_pointer_cast
(
r
.
data
()));
int
res
=
r
[
0
];
int
res
=
r
[
0
];
EXPECT_EQ
(
res
,
6
);
EXPECT_EQ
(
res
,
6
);
// ex_prefix_mul
// ex_prefix_mul
paddle
::
framework
::
Dim
<
3
>
c
=
paddle
::
framework
::
ex_prefix_mul
(
paddle
::
framework
::
Dim
<
3
>
(
3
,
4
,
5
));
paddle
::
framework
::
Dim
<
3
>
c
=
paddle
::
framework
::
ex_prefix_mul
(
paddle
::
framework
::
Dim
<
3
>
(
3
,
4
,
5
));
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
c
),
1
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
c
),
1
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
3
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
3
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
12
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
12
);
...
...
paddle/framework/enforce.cc
0 → 100644
浏览文件 @
6dc567a5
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/enforce.h"
paddle/framework/enforce.h
浏览文件 @
6dc567a5
...
@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
...
@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <glog/logging.h>
#include <paddle/string/printf.h>
#include <paddle/string/printf.h>
#include <exception>
#include <exception>
#include <sstream>
#include <sstream>
...
@@ -58,12 +59,17 @@ class EnforceNotMet : public std::exception {
...
@@ -58,12 +59,17 @@ class EnforceNotMet : public std::exception {
/**
/**
* @brief Enforce a condition, otherwise throw an EnforceNotMet
* @brief Enforce a condition, otherwise throw an EnforceNotMet
*/
*/
#ifdef NDEBUG
#define PADDLE_ENFORCE(condition, ...) \
#define PADDLE_ENFORCE(condition, ...) \
do { \
do { \
if (UNLIKELY(!(condition))) { \
if (UNLIKELY(!(condition))) { \
PADDLE_THROW(__VA_ARGS__); \
PADDLE_THROW(__VA_ARGS__); \
} \
} \
} while (0)
} while (0)
#else
#define PADDLE_ENFORCE(condition, ...) \
CHECK(condition) << ::paddle::string::Sprintf(__VA_ARGS__);
#endif
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/op_registry.h
浏览文件 @
6dc567a5
...
@@ -61,7 +61,14 @@ class OpProtoAndCheckerMaker {
...
@@ -61,7 +61,14 @@ class OpProtoAndCheckerMaker {
OpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
OpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
proto_
(
proto
),
op_checker_
(
op_checker
)
{}
:
proto_
(
proto
),
op_checker_
(
op_checker
)
{}
~
OpProtoAndCheckerMaker
()
{
CheckNoDuplicatedAttrs
();
}
~
OpProtoAndCheckerMaker
()
{
PADDLE_ENFORCE
(
validated_
,
"should call Validate after build"
);
}
void
Validate
()
{
validated_
=
true
;
CheckNoDuplicatedInOutAttrs
();
}
protected:
protected:
void
AddInput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
void
AddInput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
...
@@ -163,19 +170,26 @@ Add a mark to which output is temporary is helpful for future optimization.
...
@@ -163,19 +170,26 @@ Add a mark to which output is temporary is helpful for future optimization.
}
}
}
}
void
CheckNoDuplicatedAttrs
()
{
void
CheckNoDuplicated
InOut
Attrs
()
{
std
::
unordered_set
<
std
::
string
>
names
;
std
::
unordered_set
<
std
::
string
>
names
;
size_t
cnt
=
0
;
auto
checker
=
[
&
](
const
std
::
string
&
name
)
{
PADDLE_ENFORCE
(
!
names
.
count
(
name
),
"[%s] is duplicated"
,
name
);
names
.
insert
(
name
);
};
for
(
auto
&
attr
:
proto_
->
attrs
())
{
for
(
auto
&
attr
:
proto_
->
attrs
())
{
names
.
insert
(
attr
.
name
());
checker
(
attr
.
name
());
++
cnt
;
}
for
(
auto
&
input
:
proto_
->
inputs
())
{
checker
(
input
.
name
());
}
for
(
auto
&
output
:
proto_
->
outputs
())
{
checker
(
output
.
name
());
}
}
PADDLE_ENFORCE
(
names
.
size
()
==
cnt
,
"Cannot register two attribute in same name!"
);
}
}
OpProto
*
proto_
;
OpProto
*
proto_
;
OpAttrChecker
*
op_checker_
;
OpAttrChecker
*
op_checker_
;
bool
validated_
{
false
};
bool
has_multiple_input_
{
false
};
bool
has_multiple_input_
{
false
};
bool
has_multiple_output_
{
false
};
bool
has_multiple_output_
{
false
};
bool
has_temporary_output_
{
false
};
bool
has_temporary_output_
{
false
};
...
@@ -190,7 +204,8 @@ class OpRegistry {
...
@@ -190,7 +204,8 @@ class OpRegistry {
creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
OpProto
&
op_proto
=
protos
()[
op_type
];
OpProto
&
op_proto
=
protos
()[
op_type
];
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
auto
maker
=
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
maker
.
Validate
();
*
op_proto
.
mutable_type
()
=
op_type
;
*
op_proto
.
mutable_type
()
=
op_type
;
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
op_proto
.
IsInitialized
(),
op_proto
.
IsInitialized
(),
...
...
paddle/framework/op_registry_test.cc
浏览文件 @
6dc567a5
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
namespace
pd
=
paddle
::
framework
;
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
CosineOp
:
public
OperatorBase
{
class
CosineOp
:
public
OperatorBase
{
...
@@ -28,8 +30,6 @@ class MyTestOp : public OperatorBase {
...
@@ -28,8 +30,6 @@ class MyTestOp : public OperatorBase {
void
InferShape
(
const
ScopePtr
&
scope
)
const
override
{}
void
InferShape
(
const
ScopePtr
&
scope
)
const
override
{}
void
Run
(
const
ScopePtr
&
scope
,
void
Run
(
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
public:
};
};
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
...
@@ -182,3 +182,35 @@ TEST(OpRegistry, CustomChecker) {
...
@@ -182,3 +182,35 @@ TEST(OpRegistry, CustomChecker) {
int
test_attr
=
op
->
GetAttr
<
int
>
(
"test_attr"
);
int
test_attr
=
op
->
GetAttr
<
int
>
(
"test_attr"
);
ASSERT_EQ
(
test_attr
,
4
);
ASSERT_EQ
(
test_attr
,
4
);
}
}
class
TestAttrProtoMaker
:
public
pd
::
OpProtoAndCheckerMaker
{
public:
TestAttrProtoMaker
(
pd
::
OpProto
*
proto
,
pd
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddAttr
<
float
>
(
"scale"
,
"scale of test op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of test op"
);
}
};
TEST
(
ProtoMaker
,
DuplicatedAttr
)
{
pd
::
OpProto
op_proto
;
pd
::
OpAttrChecker
op_checker
;
auto
proto_maker
=
TestAttrProtoMaker
(
&
op_proto
,
&
op_checker
);
ASSERT_THROW
(
proto_maker
.
Validate
(),
paddle
::
framework
::
EnforceNotMet
);
}
class
TestInOutProtoMaker
:
public
pd
::
OpProtoAndCheckerMaker
{
public:
TestInOutProtoMaker
(
pd
::
OpProto
*
proto
,
pd
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"input"
,
"input of test op"
);
AddInput
(
"input"
,
"input of test op"
);
}
};
TEST
(
ProtoMaker
,
DuplicatedInOut
)
{
pd
::
OpProto
op_proto
;
pd
::
OpAttrChecker
op_checker
;
auto
proto_maker
=
TestInOutProtoMaker
(
&
op_proto
,
&
op_checker
);
ASSERT_THROW
(
proto_maker
.
Validate
(),
paddle
::
framework
::
EnforceNotMet
);
}
paddle/framework/tensor.cc
0 → 100644
浏览文件 @
6dc567a5
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/framework/tensor.h>
namespace
paddle
{
namespace
framework
{}
}
// namespace paddle
paddle/operators/CMakeLists.txt
浏览文件 @
6dc567a5
if
(
WITH_GPU
)
function
(
op_library TARGET
)
nv_library
(
add_op SRCS add_op.cc add_op.cu DEPS operator op_registry ddim glog paddle_memory
)
# op_library is a function to create op library. The interface is same as
else
()
# cc_library. But it handle split GPU/CPU code and link some common library
cc_library
(
add_op SRCS add_op.cc DEPS operator op_registry ddim glog paddle_memory
)
# for ops.
endif
()
set
(
cc_srcs
)
set
(
cu_srcs
)
set
(
op_common_deps operator op_registry
)
set
(
options
""
)
set
(
oneValueArgs
""
)
set
(
multiValueArgs SRCS DEPS
)
cmake_parse_arguments
(
op_library
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
foreach
(
src
${
op_library_SRCS
}
)
if
(
${
src
}
MATCHES
".*
\\
.cu$"
)
list
(
APPEND cu_srcs
${
src
}
)
elseif
(
${
src
}
MATCHES
".*
\\
.cc$"
)
list
(
APPEND cc_srcs
${
src
}
)
else
()
message
(
FATAL_ERROR
"
${
TARGET
}
Source file
${
src
}
should only be .cc or .cu"
)
endif
()
endforeach
()
list
(
LENGTH cc_srcs cc_srcs_len
)
if
(
${
cc_srcs_len
}
EQUAL 0
)
message
(
FATAL_ERROR
"The op library
${
TARGET
}
should contains at least one .cc file"
)
endif
()
list
(
LENGTH cu_srcs cu_srcs_len
)
if
(
${
cu_srcs_len
}
EQUAL 0
)
message
(
WARNING
"The op library
${
TARGET
}
not support GPU!"
)
endif
()
if
(
WITH_GPU
)
nv_library
(
${
TARGET
}
SRCS
${
cc_srcs
}
${
cu_srcs
}
DEPS
${
op_library_DEPS
}
${
op_common_deps
}
)
else
()
cc_library
(
${
TARGET
}
SRCS
${
cc_srcs
}
DEPS
${
op_library_DEPS
}
${
op_common_deps
}
)
endif
()
endfunction
()
op_library
(
add_op SRCS add_op.cc add_op.cu
)
cc_test
(
add_op_test SRCS add_op_test.cc DEPS add_op
)
cc_test
(
add_op_test SRCS add_op_test.cc DEPS add_op
)
paddle/scripts/docker/build.sh
浏览文件 @
6dc567a5
...
@@ -155,7 +155,8 @@ RUN apt-get update &&\
...
@@ -155,7 +155,8 @@ RUN apt-get update &&\
paddle version
paddle version
${
DOCKERFILE_CUDNN_DSO
}
${
DOCKERFILE_CUDNN_DSO
}
${
DOCKERFILE_GPU_ENV
}
${
DOCKERFILE_GPU_ENV
}
ADD go/cmd/pserver/pserver /usr/bin/
ADD go/cmd/master/master /usr/bin/
# default command shows the paddle version and exit
# default command shows the paddle version and exit
CMD ["paddle", "version"]
CMD ["paddle", "version"]
EOF
EOF
paddle/scripts/docker/build_android.sh
浏览文件 @
6dc567a5
...
@@ -2,9 +2,9 @@
...
@@ -2,9 +2,9 @@
set
-xe
set
-xe
mkdir
-p
/paddle/build
mkdir
-p
/paddle/build
_android
cd
/paddle/build
cd
/paddle/build
_android
rm
-f
/paddle/install 2>/dev/null
||
true
rm
-
r
f
/paddle/install 2>/dev/null
||
true
cmake
-DCMAKE_SYSTEM_NAME
=
Android
\
cmake
-DCMAKE_SYSTEM_NAME
=
Android
\
-DANDROID_STANDALONE_TOOLCHAIN
=
$ANDROID_STANDALONE_TOOLCHAIN
\
-DANDROID_STANDALONE_TOOLCHAIN
=
$ANDROID_STANDALONE_TOOLCHAIN
\
-DANDROID_ABI
=
armeabi-v7a
\
-DANDROID_ABI
=
armeabi-v7a
\
...
@@ -21,6 +21,3 @@ cmake -DCMAKE_SYSTEM_NAME=Android \
...
@@ -21,6 +21,3 @@ cmake -DCMAKE_SYSTEM_NAME=Android \
..
..
make
-j
`
nproc
`
make
-j
`
nproc
`
make
install
make
install
export
PATH
=
/paddle/install/bin:/paddle/install/opt/paddle/bin:
$PATH
paddle version
paddle/trainer/NewRemoteParameterUpdater.cpp
浏览文件 @
6dc567a5
...
@@ -28,6 +28,17 @@ NewRemoteParameterUpdater::NewRemoteParameterUpdater(
...
@@ -28,6 +28,17 @@ NewRemoteParameterUpdater::NewRemoteParameterUpdater(
newGradients_
(
nullptr
),
newGradients_
(
nullptr
),
pserverSpec_
(
pserverSpec
)
{}
pserverSpec_
(
pserverSpec
)
{}
NewRemoteParameterUpdater
::
NewRemoteParameterUpdater
(
const
OptimizationConfig
&
config
,
const
std
::
string
pserverSpec
,
const
bool
useEtcd
)
:
trainerConfig_
(
config
),
parameterClient_
(
-
1
),
newParameters_
(
nullptr
),
newGradients_
(
nullptr
),
pserverSpec_
(
pserverSpec
),
useEtcd_
(
useEtcd
)
{}
void
NewRemoteParameterUpdater
::
init
(
void
NewRemoteParameterUpdater
::
init
(
const
std
::
vector
<
ParameterPtr
>
&
parameters
)
{
const
std
::
vector
<
ParameterPtr
>
&
parameters
)
{
ParameterUpdater
::
init
(
parameters
);
ParameterUpdater
::
init
(
parameters
);
...
@@ -38,8 +49,13 @@ void NewRemoteParameterUpdater::init(
...
@@ -38,8 +49,13 @@ void NewRemoteParameterUpdater::init(
}
}
// create parameter server client.
// create parameter server client.
if
(
useEtcd_
)
{
parameterClient_
=
paddle_new_etcd_pserver_client
(
(
char
*
)
pserverSpec_
.
c_str
(),
FLAGS_trainer_id
==
0
);
}
else
{
parameterClient_
=
paddle_new_pserver_client
((
char
*
)
pserverSpec_
.
c_str
(),
parameterClient_
=
paddle_new_pserver_client
((
char
*
)
pserverSpec_
.
c_str
(),
FLAGS_trainer_id
==
0
);
FLAGS_trainer_id
==
0
);
}
// init new parameter and gradient.
// init new parameter and gradient.
newParameters_
=
initNewParameter
(
PARAMETER_VALUE
);
newParameters_
=
initNewParameter
(
PARAMETER_VALUE
);
...
...
paddle/trainer/NewRemoteParameterUpdater.h
浏览文件 @
6dc567a5
...
@@ -32,6 +32,9 @@ class NewRemoteParameterUpdater : public ParameterUpdater {
...
@@ -32,6 +32,9 @@ class NewRemoteParameterUpdater : public ParameterUpdater {
public:
public:
NewRemoteParameterUpdater
(
const
OptimizationConfig
&
config
,
NewRemoteParameterUpdater
(
const
OptimizationConfig
&
config
,
const
std
::
string
pserverSpec
);
const
std
::
string
pserverSpec
);
NewRemoteParameterUpdater
(
const
OptimizationConfig
&
config
,
const
std
::
string
pserverSpec
,
const
bool
useEtcd
);
~
NewRemoteParameterUpdater
()
{
~
NewRemoteParameterUpdater
()
{
releaseNewParameter
(
newParameters_
);
releaseNewParameter
(
newParameters_
);
releaseNewParameter
(
newGradients_
);
releaseNewParameter
(
newGradients_
);
...
@@ -111,6 +114,8 @@ protected:
...
@@ -111,6 +114,8 @@ protected:
paddle_parameter
**
newGradients_
;
paddle_parameter
**
newGradients_
;
/// the specification of parameter server "host1:port,host1:port"
/// the specification of parameter server "host1:port,host1:port"
std
::
string
pserverSpec_
;
std
::
string
pserverSpec_
;
/// true if pserverSpec_ is etcd endpoint, else pserverSpec_ is pserver addr
bool
useEtcd_
;
};
};
}
// namespace paddle
}
// namespace paddle
python/paddle/v2/dataset/common.py
浏览文件 @
6dc567a5
...
@@ -22,6 +22,8 @@ import importlib
...
@@ -22,6 +22,8 @@ import importlib
import
paddle.v2.dataset
import
paddle.v2.dataset
import
cPickle
import
cPickle
import
glob
import
glob
import
cPickle
as
pickle
import
random
__all__
=
[
__all__
=
[
'DATA_HOME'
,
'download'
,
'md5file'
,
'split'
,
'cluster_files_reader'
,
'DATA_HOME'
,
'download'
,
'md5file'
,
'split'
,
'cluster_files_reader'
,
...
@@ -170,8 +172,6 @@ def convert(output_path,
...
@@ -170,8 +172,6 @@ def convert(output_path,
name_prefix
,
name_prefix
,
max_lines_to_shuffle
=
1000
):
max_lines_to_shuffle
=
1000
):
import
recordio
import
recordio
import
cPickle
as
pickle
import
random
"""
"""
Convert data from reader to recordio format files.
Convert data from reader to recordio format files.
...
@@ -201,8 +201,10 @@ def convert(output_path,
...
@@ -201,8 +201,10 @@ def convert(output_path,
def
write_data
(
w
,
lines
):
def
write_data
(
w
,
lines
):
random
.
shuffle
(
lines
)
random
.
shuffle
(
lines
)
for
i
,
d
in
enumerate
(
lines
):
for
i
,
d
in
enumerate
(
lines
):
d
=
pickle
.
dumps
(
d
,
pickle
.
HIGHEST_PROTOCOL
)
# FIXME(Yancey1989):
w
[
i
%
num_shards
].
write
(
d
)
# dumps with protocol: pickle.HIGHEST_PROTOCOL
o
=
pickle
.
dumps
(
d
)
w
[
i
%
num_shards
].
write
(
o
)
w
=
open_writers
()
w
=
open_writers
()
lines
=
[]
lines
=
[]
...
...
python/paddle/v2/dataset/mq2007.py
浏览文件 @
6dc567a5
...
@@ -212,19 +212,19 @@ def gen_pair(querylist, partial_order="full"):
...
@@ -212,19 +212,19 @@ def gen_pair(querylist, partial_order="full"):
for
j
in
range
(
i
+
1
,
len
(
querylist
)):
for
j
in
range
(
i
+
1
,
len
(
querylist
)):
query_right
=
querylist
[
j
]
query_right
=
querylist
[
j
]
if
query_left
.
relevance_score
>
query_right
.
relevance_score
:
if
query_left
.
relevance_score
>
query_right
.
relevance_score
:
labels
.
append
(
1
)
labels
.
append
(
[
1
]
)
docpairs
.
append
([
docpairs
.
append
([
np
.
array
(
query_left
.
feature_vector
),
np
.
array
(
query_left
.
feature_vector
),
np
.
array
(
query_right
.
feature_vector
)
np
.
array
(
query_right
.
feature_vector
)
])
])
elif
query_left
.
relevance_score
<
query_right
.
relevance_score
:
elif
query_left
.
relevance_score
<
query_right
.
relevance_score
:
labels
.
append
(
1
)
labels
.
append
(
[
1
]
)
docpairs
.
append
([
docpairs
.
append
([
np
.
array
(
query_right
.
feature_vector
),
np
.
array
(
query_right
.
feature_vector
),
np
.
array
(
query_left
.
feature_vector
)
np
.
array
(
query_left
.
feature_vector
)
])
])
for
label
,
pair
in
zip
(
labels
,
docpairs
):
for
label
,
pair
in
zip
(
labels
,
docpairs
):
yield
label
,
pair
[
0
],
pair
[
1
]
yield
np
.
array
(
label
)
,
pair
[
0
],
pair
[
1
]
def
gen_list
(
querylist
):
def
gen_list
(
querylist
):
...
...
python/paddle/v2/master/client.py
浏览文件 @
6dc567a5
...
@@ -10,8 +10,9 @@ class client(object):
...
@@ -10,8 +10,9 @@ class client(object):
client is a client to the master server.
client is a client to the master server.
"""
"""
def
__init__
(
self
,
addr
,
buf_size
):
def
__init__
(
self
,
etcd_endpoints
,
timeout
,
buf_size
):
self
.
c
=
lib
.
paddle_new_master_client
(
addr
,
buf_size
)
self
.
c
=
lib
.
paddle_new_etcd_master_client
(
etcd_endpoints
,
timeout
,
buf_size
)
def
close
(
self
):
def
close
(
self
):
lib
.
paddle_release_master_client
(
self
.
c
)
lib
.
paddle_release_master_client
(
self
.
c
)
...
...
python/paddle/v2/optimizer.py
浏览文件 @
6dc567a5
import
py_paddle.swig_paddle
as
swig_api
import
paddle.trainer_config_helpers.config_parser_utils
as
config_parser_utils
import
paddle.trainer_config_helpers.config_parser_utils
as
config_parser_utils
import
paddle.trainer_config_helpers.optimizers
as
v1_optimizers
import
paddle.trainer_config_helpers.optimizers
as
v1_optimizers
"""
"""
...
@@ -16,7 +17,6 @@ __all__ = [
...
@@ -16,7 +17,6 @@ __all__ = [
class
Optimizer
(
object
):
class
Optimizer
(
object
):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
import
py_paddle.swig_paddle
as
swig_api
if
'batch_size'
in
kwargs
:
if
'batch_size'
in
kwargs
:
del
kwargs
[
'batch_size'
]
# not important for python library.
del
kwargs
[
'batch_size'
]
# not important for python library.
...
@@ -46,12 +46,12 @@ class Optimizer(object):
...
@@ -46,12 +46,12 @@ class Optimizer(object):
return
swig_api
.
ParameterUpdater
.
createRemoteUpdater
(
return
swig_api
.
ParameterUpdater
.
createRemoteUpdater
(
self
.
__opt_conf__
,
pass_num
,
use_sparse_updater
)
self
.
__opt_conf__
,
pass_num
,
use_sparse_updater
)
def
__create_new_remote_updater__
(
self
,
pserver_spec
):
def
__create_new_remote_updater__
(
self
,
pserver_spec
,
use_etcd
):
return
swig_api
.
ParameterUpdater
.
createNewRemoteUpdater
(
return
swig_api
.
ParameterUpdater
.
createNewRemoteUpdater
(
self
.
__opt_conf__
,
pserver_spec
)
self
.
__opt_conf__
,
pserver_spec
,
use_etcd
)
def
create_updater
(
self
,
is_local
,
num_passes
,
use_sparse_updater
,
def
create_updater
(
self
,
is_local
,
num_passes
,
use_sparse_updater
,
pserver_spec
):
pserver_spec
,
use_etcd
):
"""
"""
create proper parameter_updater by configuration.
create proper parameter_updater by configuration.
:param is_local: create local or remote parameter updater
:param is_local: create local or remote parameter updater
...
@@ -77,7 +77,7 @@ class Optimizer(object):
...
@@ -77,7 +77,7 @@ class Optimizer(object):
num_passes
,
use_sparse_updater
)
num_passes
,
use_sparse_updater
)
else
:
else
:
parameter_updater
=
self
.
__create_new_remote_updater__
(
parameter_updater
=
self
.
__create_new_remote_updater__
(
pserver_spec
)
pserver_spec
,
use_etcd
)
return
parameter_updater
return
parameter_updater
...
...
python/paddle/v2/trainer.py
浏览文件 @
6dc567a5
...
@@ -45,7 +45,8 @@ class SGD(object):
...
@@ -45,7 +45,8 @@ class SGD(object):
update_equation
,
update_equation
,
extra_layers
=
None
,
extra_layers
=
None
,
is_local
=
True
,
is_local
=
True
,
pserver_spec
=
None
):
pserver_spec
=
None
,
use_etcd
=
True
):
if
not
isinstance
(
parameters
,
v2_parameters
.
Parameters
):
if
not
isinstance
(
parameters
,
v2_parameters
.
Parameters
):
raise
TypeError
(
'parameters should be parameters'
)
raise
TypeError
(
'parameters should be parameters'
)
...
@@ -61,6 +62,7 @@ class SGD(object):
...
@@ -61,6 +62,7 @@ class SGD(object):
self
.
__topology_in_proto__
=
topology
.
proto
()
self
.
__topology_in_proto__
=
topology
.
proto
()
self
.
__is_local__
=
is_local
self
.
__is_local__
=
is_local
self
.
__pserver_spec__
=
pserver_spec
self
.
__pserver_spec__
=
pserver_spec
self
.
__use_etcd__
=
use_etcd
self
.
__use_sparse_updater__
=
self
.
__topology__
.
use_sparse_updater
()
self
.
__use_sparse_updater__
=
self
.
__topology__
.
use_sparse_updater
()
# # In local mode, disable sparse_remote_update.
# # In local mode, disable sparse_remote_update.
...
@@ -127,7 +129,7 @@ class SGD(object):
...
@@ -127,7 +129,7 @@ class SGD(object):
self
.
__parameter_updater__
=
self
.
__optimizer__
.
create_updater
(
self
.
__parameter_updater__
=
self
.
__optimizer__
.
create_updater
(
self
.
__is_local__
,
num_passes
,
self
.
__use_sparse_updater__
,
self
.
__is_local__
,
num_passes
,
self
.
__use_sparse_updater__
,
self
.
__pserver_spec__
)
self
.
__pserver_spec__
,
self
.
__use_etcd__
)
self
.
__parameter_updater__
.
init
(
self
.
__gradient_machine__
)
self
.
__parameter_updater__
.
init
(
self
.
__gradient_machine__
)
self
.
__gradient_machine__
.
start
()
self
.
__gradient_machine__
.
start
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录