Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
9f9058ac
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9f9058ac
编写于
6月 23, 2017
作者:
G
gongweibao
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'upstream/develop' into develop
上级
8d476901
8b86624b
变更
24
显示空白变更内容
内联
并排
Showing
24 changed file
with
808 addition
and
201 deletion
+808
-201
.travis.yml
.travis.yml
+4
-20
Dockerfile
Dockerfile
+1
-1
doc/api/v2/config/evaluators.rst
doc/api/v2/config/evaluators.rst
+9
-0
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+21
-2
go/pserver/client_test.go
go/pserver/client_test.go
+6
-2
go/pserver/service.go
go/pserver/service.go
+122
-3
go/pserver/service_test.go
go/pserver/service_test.go
+17
-8
go/utils/networkhelper/helper.go
go/utils/networkhelper/helper.go
+45
-0
go/utils/networkhelper/helper_test.go
go/utils/networkhelper/helper_test.go
+10
-0
paddle/gserver/evaluators/DetectionMAPEvaluator.cpp
paddle/gserver/evaluators/DetectionMAPEvaluator.cpp
+308
-0
paddle/gserver/tests/test_Evaluator.cpp
paddle/gserver/tests/test_Evaluator.cpp
+17
-0
paddle/parameter/ParameterUpdaterHook.cpp
paddle/parameter/ParameterUpdaterHook.cpp
+56
-84
paddle/scripts/travis/build_and_test.sh
paddle/scripts/travis/build_and_test.sh
+0
-12
paddle/scripts/travis/build_doc.sh
paddle/scripts/travis/build_doc.sh
+8
-5
paddle/scripts/travis/check_style.sh
paddle/scripts/travis/check_style.sh
+4
-4
paddle/scripts/travis/common.sh
paddle/scripts/travis/common.sh
+0
-6
paddle/scripts/travis/main.sh
paddle/scripts/travis/main.sh
+0
-13
proto/ModelConfig.proto
proto/ModelConfig.proto
+9
-0
proto/ParameterConfig.proto
proto/ParameterConfig.proto
+3
-1
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+34
-19
python/paddle/trainer_config_helpers/attrs.py
python/paddle/trainer_config_helpers/attrs.py
+41
-1
python/paddle/trainer_config_helpers/evaluators.py
python/paddle/trainer_config_helpers/evaluators.py
+86
-19
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+5
-1
python/paddle/v2/attr.py
python/paddle/v2/attr.py
+2
-0
未找到文件。
.travis.yml
浏览文件 @
9f9058ac
...
@@ -2,7 +2,6 @@ group: deprecated-2017Q2
...
@@ -2,7 +2,6 @@ group: deprecated-2017Q2
language
:
cpp
language
:
cpp
cache
:
cache
:
directories
:
directories
:
-
$HOME/third_party
-
$HOME/.ccache
-
$HOME/.ccache
-
$HOME/.cache/pip
-
$HOME/.cache/pip
sudo
:
required
sudo
:
required
...
@@ -10,15 +9,13 @@ dist: trusty
...
@@ -10,15 +9,13 @@ dist: trusty
os
:
os
:
-
linux
-
linux
env
:
env
:
-
JOB=DOCS
-
JOB=build_doc
-
JOB=BUILD_AND_TEST
-
JOB=check_style
-
JOB=PRE_COMMIT
addons
:
addons
:
apt
:
apt
:
packages
:
packages
:
-
gcc-4.8
-
gcc-4.8
-
g++-4.8
-
g++-4.8
-
gfortran-4.8
-
git
-
git
-
build-essential
-
build-essential
-
python
-
python
...
@@ -35,18 +32,7 @@ addons:
...
@@ -35,18 +32,7 @@ addons:
-
libtool
-
libtool
-
ccache
-
ccache
before_install
:
before_install
:
-
|
-
if [[ "$JOB" == "check_style" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi
if [ ${JOB} == "BUILD_AND_TEST" ]; then
local change_list=`git diff --name-only $TRAVIS_COMMIT_RANGE`
if [ $? -eq 0 ]; then # if git diff return no zero, then rerun unit test.
if ! echo ${change_list} | grep -qvE '(\.md$)|(\.rst$)|(\.jpg$)|(\.png$)'
then
echo "Only markdown docs were updated, stopping build process."
exit
fi
fi
fi
-
if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi
# Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python
# Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python
# protobuf version.
# protobuf version.
-
pip install numpy wheel 'protobuf==3.1' sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit requests==2.9.2 LinkChecker
-
pip install numpy wheel 'protobuf==3.1' sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit requests==2.9.2 LinkChecker
...
@@ -55,9 +41,7 @@ before_install:
...
@@ -55,9 +41,7 @@ before_install:
-
|
-
|
function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; }
function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; }
script
:
script
:
-
|
-
paddle/scripts/travis/$JOB.sh
timeout 2580 paddle/scripts/travis/main.sh # 43min timeout
RESULT=$?; if [ $RESULT -eq 0 ] || [ $RESULT -eq 142 ]; then true; else false; fi;
notifications
:
notifications
:
email
:
email
:
on_success
:
change
on_success
:
change
...
...
Dockerfile
浏览文件 @
9f9058ac
...
@@ -25,7 +25,7 @@ COPY ./paddle/scripts/docker/root/ /root/
...
@@ -25,7 +25,7 @@ COPY ./paddle/scripts/docker/root/ /root/
RUN
apt-get update
&&
\
RUN
apt-get update
&&
\
apt-get
install
-y
\
apt-get
install
-y
\
git python-pip python-dev openssh-server bison
\
git python-pip python-dev openssh-server bison
\
wget unzip
tar
xz-utils bzip2
gzip
coreutils
\
wget unzip
tar
xz-utils bzip2
gzip
coreutils
ntp
\
curl
sed grep
graphviz libjpeg-dev zlib1g-dev
\
curl
sed grep
graphviz libjpeg-dev zlib1g-dev
\
python-numpy python-matplotlib gcc g++
\
python-numpy python-matplotlib gcc g++
\
automake locales clang-format-3.8 swig doxygen cmake
\
automake locales clang-format-3.8 swig doxygen cmake
\
...
...
doc/api/v2/config/evaluators.rst
浏览文件 @
9f9058ac
...
@@ -99,3 +99,12 @@ value_printer
...
@@ -99,3 +99,12 @@ value_printer
.. automodule:: paddle.v2.evaluator
.. automodule:: paddle.v2.evaluator
:members: value_printer
:members: value_printer
:noindex:
:noindex:
Detection
=====
detection_map
-------------
.. automodule:: paddle.v2.evaluator
:members: detection_map
:noindex:
go/cmd/pserver/pserver.go
浏览文件 @
9f9058ac
...
@@ -5,18 +5,35 @@ import (
...
@@ -5,18 +5,35 @@ import (
"net/http"
"net/http"
"net/rpc"
"net/rpc"
"strconv"
"strconv"
"time"
"github.com/namsral/flag"
"github.com/namsral/flag"
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver"
log
"github.com/sirupsen/logrus"
)
)
func
main
()
{
func
main
()
{
port
:=
flag
.
Int
(
"port"
,
0
,
"port of the pserver"
)
port
:=
flag
.
Int
(
"port"
,
0
,
"port of the pserver"
)
etcdEndpoint
:=
flag
.
String
(
"etcd-endpoint"
,
"http://127.0.0.1:2379"
,
"comma separated endpoint string for pserver to connect to etcd"
)
etcdTimeout
:=
flag
.
Int
(
"etcd-timeout"
,
5
,
"timeout for etcd calls"
)
logLevel
:=
flag
.
String
(
"log-level"
,
"info"
,
"log level, possible values: debug, info, warning, error, fatal, panic"
)
flag
.
Parse
()
flag
.
Parse
()
s
:=
pserver
.
NewService
()
level
,
err
:=
log
.
ParseLevel
(
*
logLevel
)
err
:=
rpc
.
Register
(
s
)
if
err
!=
nil
{
panic
(
err
)
}
log
.
SetLevel
(
level
)
timeout
:=
time
.
Second
*
time
.
Duration
((
*
etcdTimeout
))
s
,
err
:=
pserver
.
NewService
(
*
etcdEndpoint
,
timeout
)
if
err
!=
nil
{
panic
(
err
)
}
err
=
rpc
.
Register
(
s
)
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
panic
(
err
)
}
}
...
@@ -27,7 +44,9 @@ func main() {
...
@@ -27,7 +44,9 @@ func main() {
panic
(
err
)
panic
(
err
)
}
}
log
.
Infof
(
"start pserver at port %d"
,
*
port
)
err
=
http
.
Serve
(
l
,
nil
)
err
=
http
.
Serve
(
l
,
nil
)
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
panic
(
err
)
}
}
...
...
go/pserver/client_test.go
浏览文件 @
9f9058ac
...
@@ -7,6 +7,7 @@ import (
...
@@ -7,6 +7,7 @@ import (
"strconv"
"strconv"
"strings"
"strings"
"testing"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver"
)
)
...
@@ -30,9 +31,12 @@ func init() {
...
@@ -30,9 +31,12 @@ func init() {
port
[
i
]
=
p
port
[
i
]
=
p
go
func
(
l
net
.
Listener
)
{
go
func
(
l
net
.
Listener
)
{
s
:=
pserver
.
NewService
()
s
,
err
:=
pserver
.
NewService
(
""
,
time
.
Second
*
5
)
if
err
!=
nil
{
panic
(
err
)
}
server
:=
rpc
.
NewServer
()
server
:=
rpc
.
NewServer
()
err
:
=
server
.
Register
(
s
)
err
=
server
.
Register
(
s
)
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
panic
(
err
)
}
}
...
...
go/pserver/service.go
浏览文件 @
9f9058ac
package
pserver
package
pserver
import
(
import
(
"context"
"errors"
"errors"
"fmt"
"fmt"
"strconv"
"strings"
"sync"
"sync"
"time"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log
"github.com/sirupsen/logrus"
)
)
// ElementType is the type of elements of a Parameter.
// ElementType is the type of elements of a Parameter.
...
@@ -24,6 +33,9 @@ const (
...
@@ -24,6 +33,9 @@ const (
Float64
Float64
)
)
// PsDesired is etcd path for store desired pserver count
const
PsDesired
=
"/ps_desired"
// Parameter is a piece of data to sync with the parameter server.
// Parameter is a piece of data to sync with the parameter server.
type
Parameter
struct
{
type
Parameter
struct
{
Name
string
Name
string
...
@@ -47,14 +59,121 @@ type Service struct {
...
@@ -47,14 +59,121 @@ type Service struct {
mu
sync
.
Mutex
mu
sync
.
Mutex
opt
*
optimizer
opt
*
optimizer
paramMap
map
[
string
]
Parameter
paramMap
map
[
string
]
Parameter
etcdEndpoints
string
etcdClient
*
clientv3
.
Client
// etcdTimeout is also used as retry intervals.
etcdTimeout
time
.
Duration
// desired number of pservers in the job.
// assume desired will not change during one training job.
desired
int
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP
string
}
}
// NewService creates a new service.
// NewService creates a new service, will bypass etcd registration if no
func
NewService
()
*
Service
{
// endpoints specified.
func
NewService
(
endpoints
string
,
timeout
time
.
Duration
)
(
*
Service
,
error
)
{
s
:=
&
Service
{
opt
:
newOptimizer
(
sgd
,
0.005
)}
s
:=
&
Service
{
opt
:
newOptimizer
(
sgd
,
0.005
)}
s
.
paramMap
=
make
(
map
[
string
]
Parameter
)
s
.
paramMap
=
make
(
map
[
string
]
Parameter
)
s
.
initialized
=
make
(
chan
struct
{})
s
.
initialized
=
make
(
chan
struct
{})
return
s
s
.
etcdEndpoints
=
endpoints
s
.
etcdTimeout
=
timeout
var
err
error
s
.
externalIP
,
err
=
networkhelper
.
GetExternalIP
()
if
err
!=
nil
{
return
nil
,
err
}
if
endpoints
!=
""
{
// initialize connection to etcd, try
ep
:=
strings
.
Split
(
s
.
etcdEndpoints
,
","
)
for
{
cli
,
err
:=
clientv3
.
New
(
clientv3
.
Config
{
Endpoints
:
ep
,
DialTimeout
:
s
.
etcdTimeout
,
})
if
err
!=
nil
{
log
.
Errorf
(
"connect to etcd error: %v"
,
err
)
time
.
Sleep
(
s
.
etcdTimeout
)
continue
}
s
.
etcdClient
=
cli
log
.
Debugf
(
"inited client to %s"
,
s
.
etcdEndpoints
)
break
}
// wait and set s.desired init value
for
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
)
resp
,
err
:=
s
.
etcdClient
.
Get
(
ctx
,
PsDesired
)
cancel
()
if
err
!=
nil
{
log
.
Errorf
(
"getting %s error: %v"
,
PsDesired
,
err
)
time
.
Sleep
(
s
.
etcdTimeout
)
continue
}
if
len
(
resp
.
Kvs
)
!=
0
{
s
.
desired
,
err
=
strconv
.
Atoi
(
string
(
resp
.
Kvs
[
0
]
.
Value
))
if
err
!=
nil
{
log
.
Errorf
(
"value of %s invalid %v
\n
"
,
PsDesired
,
err
)
time
.
Sleep
(
s
.
etcdTimeout
)
// NOTE: wait util ps_desired value change
continue
}
break
}
}
// try register pserver node on etcd
for
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
)
_
,
err
:=
s
.
registerPserverEtcd
(
ctx
)
cancel
()
if
err
!=
nil
{
log
.
Warn
(
err
)
time
.
Sleep
(
s
.
etcdTimeout
)
continue
}
break
}
}
// if endpoints != ""
// Bypass etcd registration if no endpoints specified
return
s
,
nil
}
// registerPserverEtcd registers pserver node on etcd using transaction.
func
(
s
*
Service
)
registerPserverEtcd
(
ctx
context
.
Context
)
(
*
clientv3
.
TxnResponse
,
error
)
{
return
concurrency
.
NewSTM
(
s
.
etcdClient
,
func
(
c
concurrency
.
STM
)
error
{
registered
:=
false
for
i
:=
0
;
i
<
s
.
desired
;
i
++
{
psKey
:=
"/ps/"
+
strconv
.
Itoa
(
i
)
log
.
Debugf
(
"checking %s"
,
psKey
)
ps
:=
c
.
Get
(
psKey
)
log
.
Debugf
(
"got value (%s) for key: %s"
,
ps
,
psKey
)
if
ps
==
""
{
resp
,
err
:=
s
.
etcdClient
.
Grant
(
context
.
TODO
(),
5
)
if
err
!=
nil
{
log
.
Fatal
(
err
)
}
// find the first id and write info
c
.
Put
(
psKey
,
s
.
externalIP
,
clientv3
.
WithLease
(
resp
.
ID
))
log
.
Debugf
(
"set pserver node %s with value %s"
,
psKey
,
s
.
externalIP
)
_
,
kaerr
:=
s
.
etcdClient
.
KeepAlive
(
context
.
TODO
(),
resp
.
ID
)
if
kaerr
!=
nil
{
log
.
Errorf
(
"keepalive etcd node error: %v"
,
kaerr
)
return
kaerr
}
log
.
Debug
(
"register finished"
)
registered
=
true
break
}
}
if
registered
==
true
{
return
nil
}
return
errors
.
New
(
"not registerd, may due to already have enough pservers"
)
},
concurrency
.
WithAbortContext
(
ctx
),
concurrency
.
WithIsolation
(
concurrency
.
RepeatableReads
))
}
}
// InitParam initializes a parameter.
// InitParam initializes a parameter.
...
...
go/pserver/service_test.go
浏览文件 @
9f9058ac
...
@@ -10,12 +10,15 @@ import (
...
@@ -10,12 +10,15 @@ import (
)
)
func
TestFull
(
t
*
testing
.
T
)
{
func
TestFull
(
t
*
testing
.
T
)
{
s
:=
pserver
.
NewService
()
s
,
err
:=
pserver
.
NewService
(
""
,
time
.
Second
*
5
)
if
err
!=
nil
{
t
.
Error
(
err
)
}
var
p
pserver
.
Parameter
var
p
pserver
.
Parameter
p
.
Name
=
"param_a"
p
.
Name
=
"param_a"
p
.
Content
=
[]
byte
{
1
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
3
,
0
,
0
,
0
}
p
.
Content
=
[]
byte
{
1
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
3
,
0
,
0
,
0
}
p
.
ElementType
=
pserver
.
Int32
p
.
ElementType
=
pserver
.
Int32
err
:
=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
Param
:
p
,
Config
:
nil
},
nil
)
err
=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
Param
:
p
,
Config
:
nil
},
nil
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
FailNow
()
t
.
FailNow
()
}
}
...
@@ -72,8 +75,11 @@ func TestFull(t *testing.T) {
...
@@ -72,8 +75,11 @@ func TestFull(t *testing.T) {
}
}
func
TestMultipleInit
(
t
*
testing
.
T
)
{
func
TestMultipleInit
(
t
*
testing
.
T
)
{
s
:=
pserver
.
NewService
()
s
,
err
:=
pserver
.
NewService
(
""
,
time
.
Second
*
5
)
err
:=
s
.
FinishInitParams
(
0
,
nil
)
if
err
!=
nil
{
t
.
Error
(
err
)
}
err
=
s
.
FinishInitParams
(
0
,
nil
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
FailNow
()
t
.
FailNow
()
}
}
...
@@ -85,15 +91,18 @@ func TestMultipleInit(t *testing.T) {
...
@@ -85,15 +91,18 @@ func TestMultipleInit(t *testing.T) {
}
}
func
TestUninitialized
(
t
*
testing
.
T
)
{
func
TestUninitialized
(
t
*
testing
.
T
)
{
s
:=
pserver
.
NewService
(
)
s
,
err
:=
pserver
.
NewService
(
""
,
time
.
Second
*
5
)
err
:
=
s
.
SendGrad
(
pserver
.
Gradient
{},
nil
)
err
=
s
.
SendGrad
(
pserver
.
Gradient
{},
nil
)
if
err
.
Error
()
!=
pserver
.
Uninitialized
{
if
err
.
Error
()
!=
pserver
.
Uninitialized
{
t
.
FailNow
()
t
.
FailNow
()
}
}
}
}
func
TestBlockUntilInitialized
(
t
*
testing
.
T
)
{
func
TestBlockUntilInitialized
(
t
*
testing
.
T
)
{
s
:=
pserver
.
NewService
()
s
,
err
:=
pserver
.
NewService
(
""
,
time
.
Second
*
5
)
if
err
!=
nil
{
t
.
Error
(
err
)
}
ch
:=
make
(
chan
struct
{},
2
)
ch
:=
make
(
chan
struct
{},
2
)
errCh
:=
make
(
chan
error
,
2
)
errCh
:=
make
(
chan
error
,
2
)
var
wg
sync
.
WaitGroup
var
wg
sync
.
WaitGroup
...
@@ -133,7 +142,7 @@ func TestBlockUntilInitialized(t *testing.T) {
...
@@ -133,7 +142,7 @@ func TestBlockUntilInitialized(t *testing.T) {
p
.
Name
=
"param_a"
p
.
Name
=
"param_a"
p
.
Content
=
[]
byte
{
1
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
3
,
0
,
0
,
0
}
p
.
Content
=
[]
byte
{
1
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
3
,
0
,
0
,
0
}
p
.
ElementType
=
pserver
.
Int32
p
.
ElementType
=
pserver
.
Int32
err
:
=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
Param
:
p
,
Config
:
nil
},
nil
)
err
=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
Param
:
p
,
Config
:
nil
},
nil
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
FailNow
()
t
.
FailNow
()
}
}
...
...
go/utils/networkhelper/helper.go
0 → 100644
浏览文件 @
9f9058ac
package
networkhelper
import
(
"errors"
"net"
)
// GetExternalIP returns the ip address of local network interface, not the
// loopback device.
func
GetExternalIP
()
(
string
,
error
)
{
ifaces
,
err
:=
net
.
Interfaces
()
if
err
!=
nil
{
return
""
,
err
}
for
_
,
iface
:=
range
ifaces
{
if
iface
.
Flags
&
net
.
FlagUp
==
0
{
continue
// interface down
}
if
iface
.
Flags
&
net
.
FlagLoopback
!=
0
{
continue
// loopback interface
}
addrs
,
err
:=
iface
.
Addrs
()
if
err
!=
nil
{
return
""
,
err
}
for
_
,
addr
:=
range
addrs
{
var
ip
net
.
IP
switch
v
:=
addr
.
(
type
)
{
case
*
net
.
IPNet
:
ip
=
v
.
IP
case
*
net
.
IPAddr
:
ip
=
v
.
IP
}
if
ip
==
nil
||
ip
.
IsLoopback
()
{
continue
}
ip
=
ip
.
To4
()
if
ip
==
nil
{
continue
// not an ipv4 address
}
return
ip
.
String
(),
nil
}
}
return
""
,
errors
.
New
(
"are you connected to the network?"
)
}
go/utils/networkhelper/helper_test.go
0 → 100644
浏览文件 @
9f9058ac
package
networkhelper
import
"testing"
func
TestGetIP
(
t
*
testing
.
T
)
{
_
,
err
:=
GetExternalIP
()
if
err
!=
nil
{
t
.
Errorf
(
"GetExternalIP returns error : %v
\n
"
,
err
)
}
}
paddle/gserver/evaluators/DetectionMAPEvaluator.cpp
0 → 100644
浏览文件 @
9f9058ac
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Evaluator.h"
#include "paddle/gserver/layers/DetectionUtil.h"
using
std
::
map
;
using
std
::
vector
;
using
std
::
pair
;
using
std
::
make_pair
;
namespace
paddle
{
/**
* @brief detection map Evaluator
*
* The config file api is detection_map_evaluator.
*/
class
DetectionMAPEvaluator
:
public
Evaluator
{
public:
DetectionMAPEvaluator
()
:
evaluateDifficult_
(
false
),
cpuOutput_
(
nullptr
),
cpuLabel_
(
nullptr
)
{}
virtual
void
start
()
{
Evaluator
::
start
();
allTruePos_
.
clear
();
allFalsePos_
.
clear
();
numPos_
.
clear
();
}
virtual
real
evalImp
(
std
::
vector
<
Argument
>&
arguments
)
{
overlapThreshold_
=
config_
.
overlap_threshold
();
backgroundId_
=
config_
.
background_id
();
evaluateDifficult_
=
config_
.
evaluate_difficult
();
apType_
=
config_
.
ap_type
();
MatrixPtr
detectTmpValue
=
arguments
[
0
].
value
;
Matrix
::
resizeOrCreate
(
cpuOutput_
,
detectTmpValue
->
getHeight
(),
detectTmpValue
->
getWidth
(),
false
,
false
);
MatrixPtr
labelTmpValue
=
arguments
[
1
].
value
;
Matrix
::
resizeOrCreate
(
cpuLabel_
,
labelTmpValue
->
getHeight
(),
labelTmpValue
->
getWidth
(),
false
,
false
);
cpuOutput_
->
copyFrom
(
*
detectTmpValue
);
cpuLabel_
->
copyFrom
(
*
labelTmpValue
);
Argument
label
=
arguments
[
1
];
const
int
*
labelIndex
=
label
.
sequenceStartPositions
->
getData
(
false
);
size_t
batchSize
=
label
.
getNumSequences
();
vector
<
map
<
size_t
,
vector
<
NormalizedBBox
>>>
allGTBBoxes
;
vector
<
map
<
size_t
,
vector
<
pair
<
real
,
NormalizedBBox
>>>>
allDetectBBoxes
;
for
(
size_t
n
=
0
;
n
<
batchSize
;
++
n
)
{
map
<
size_t
,
vector
<
NormalizedBBox
>>
bboxes
;
for
(
int
i
=
labelIndex
[
n
];
i
<
labelIndex
[
n
+
1
];
++
i
)
{
vector
<
NormalizedBBox
>
bbox
;
getBBoxFromLabelData
(
cpuLabel_
->
getData
()
+
i
*
6
,
1
,
bbox
);
int
c
=
cpuLabel_
->
getData
()[
i
*
6
];
bboxes
[
c
].
push_back
(
bbox
[
0
]);
}
allGTBBoxes
.
push_back
(
bboxes
);
}
size_t
n
=
0
;
const
real
*
cpuOutputData
=
cpuOutput_
->
getData
();
for
(
size_t
imgId
=
0
;
imgId
<
batchSize
;
++
imgId
)
{
map
<
size_t
,
vector
<
pair
<
real
,
NormalizedBBox
>>>
bboxes
;
size_t
curImgId
=
static_cast
<
size_t
>
((
cpuOutputData
+
n
*
7
)[
0
]);
while
(
curImgId
==
imgId
&&
n
<
cpuOutput_
->
getHeight
())
{
vector
<
real
>
label
;
vector
<
real
>
score
;
vector
<
NormalizedBBox
>
bbox
;
getBBoxFromDetectData
(
cpuOutputData
+
n
*
7
,
1
,
label
,
score
,
bbox
);
bboxes
[
label
[
0
]].
push_back
(
make_pair
(
score
[
0
],
bbox
[
0
]));
++
n
;
curImgId
=
static_cast
<
size_t
>
((
cpuOutputData
+
n
*
7
)[
0
]);
}
allDetectBBoxes
.
push_back
(
bboxes
);
}
for
(
size_t
n
=
0
;
n
<
batchSize
;
++
n
)
{
for
(
map
<
size_t
,
vector
<
NormalizedBBox
>>::
iterator
it
=
allGTBBoxes
[
n
].
begin
();
it
!=
allGTBBoxes
[
n
].
end
();
++
it
)
{
size_t
count
=
0
;
if
(
evaluateDifficult_
)
{
count
=
it
->
second
.
size
();
}
else
{
for
(
size_t
i
=
0
;
i
<
it
->
second
.
size
();
++
i
)
if
(
!
(
it
->
second
[
i
].
isDifficult
))
++
count
;
}
if
(
numPos_
.
find
(
it
->
first
)
==
numPos_
.
end
()
&&
count
!=
0
)
{
numPos_
[
it
->
first
]
=
count
;
}
else
{
numPos_
[
it
->
first
]
+=
count
;
}
}
}
// calcTFPos
calcTFPos
(
batchSize
,
allGTBBoxes
,
allDetectBBoxes
);
return
0
;
}
virtual
void
printStats
(
std
::
ostream
&
os
)
const
{
real
mAP
=
calcMAP
();
os
<<
"Detection mAP="
<<
mAP
;
}
virtual
void
distributeEval
(
ParameterClient2
*
client
)
{
LOG
(
FATAL
)
<<
"Distribute detection evaluation not implemented."
;
}
protected:
void
calcTFPos
(
const
size_t
batchSize
,
const
vector
<
map
<
size_t
,
vector
<
NormalizedBBox
>>>&
allGTBBoxes
,
const
vector
<
map
<
size_t
,
vector
<
pair
<
real
,
NormalizedBBox
>>>>&
allDetectBBoxes
)
{
for
(
size_t
n
=
0
;
n
<
allDetectBBoxes
.
size
();
++
n
)
{
if
(
allGTBBoxes
[
n
].
size
()
==
0
)
{
for
(
map
<
size_t
,
vector
<
pair
<
real
,
NormalizedBBox
>>>::
const_iterator
it
=
allDetectBBoxes
[
n
].
begin
();
it
!=
allDetectBBoxes
[
n
].
end
();
++
it
)
{
size_t
label
=
it
->
first
;
for
(
size_t
i
=
0
;
i
<
it
->
second
.
size
();
++
i
)
{
allTruePos_
[
label
].
push_back
(
make_pair
(
it
->
second
[
i
].
first
,
0
));
allFalsePos_
[
label
].
push_back
(
make_pair
(
it
->
second
[
i
].
first
,
1
));
}
}
}
else
{
for
(
map
<
size_t
,
vector
<
pair
<
real
,
NormalizedBBox
>>>::
const_iterator
it
=
allDetectBBoxes
[
n
].
begin
();
it
!=
allDetectBBoxes
[
n
].
end
();
++
it
)
{
size_t
label
=
it
->
first
;
vector
<
pair
<
real
,
NormalizedBBox
>>
predBBoxes
=
it
->
second
;
if
(
allGTBBoxes
[
n
].
find
(
label
)
==
allGTBBoxes
[
n
].
end
())
{
for
(
size_t
i
=
0
;
i
<
predBBoxes
.
size
();
++
i
)
{
allTruePos_
[
label
].
push_back
(
make_pair
(
predBBoxes
[
i
].
first
,
0
));
allFalsePos_
[
label
].
push_back
(
make_pair
(
predBBoxes
[
i
].
first
,
1
));
}
}
else
{
vector
<
NormalizedBBox
>
gtBBoxes
=
allGTBBoxes
[
n
].
find
(
label
)
->
second
;
vector
<
bool
>
visited
(
gtBBoxes
.
size
(),
false
);
// Sort detections in descend order based on scores
std
::
sort
(
predBBoxes
.
begin
(),
predBBoxes
.
end
(),
sortScorePairDescend
<
NormalizedBBox
>
);
for
(
size_t
i
=
0
;
i
<
predBBoxes
.
size
();
++
i
)
{
real
maxOverlap
=
-
1.0
;
size_t
maxIdx
=
0
;
for
(
size_t
j
=
0
;
j
<
gtBBoxes
.
size
();
++
j
)
{
real
overlap
=
jaccardOverlap
(
predBBoxes
[
i
].
second
,
gtBBoxes
[
j
]);
if
(
overlap
>
maxOverlap
)
{
maxOverlap
=
overlap
;
maxIdx
=
j
;
}
}
if
(
maxOverlap
>
overlapThreshold_
)
{
if
(
evaluateDifficult_
||
(
!
evaluateDifficult_
&&
!
gtBBoxes
[
maxIdx
].
isDifficult
))
{
if
(
!
visited
[
maxIdx
])
{
allTruePos_
[
label
].
push_back
(
make_pair
(
predBBoxes
[
i
].
first
,
1
));
allFalsePos_
[
label
].
push_back
(
make_pair
(
predBBoxes
[
i
].
first
,
0
));
visited
[
maxIdx
]
=
true
;
}
else
{
allTruePos_
[
label
].
push_back
(
make_pair
(
predBBoxes
[
i
].
first
,
0
));
allFalsePos_
[
label
].
push_back
(
make_pair
(
predBBoxes
[
i
].
first
,
1
));
}
}
}
else
{
allTruePos_
[
label
].
push_back
(
make_pair
(
predBBoxes
[
i
].
first
,
0
));
allFalsePos_
[
label
].
push_back
(
make_pair
(
predBBoxes
[
i
].
first
,
1
));
}
}
}
}
}
}
}
real
calcMAP
()
const
{
real
mAP
=
0.0
;
size_t
count
=
0
;
for
(
map
<
size_t
,
size_t
>::
const_iterator
it
=
numPos_
.
begin
();
it
!=
numPos_
.
end
();
++
it
)
{
size_t
label
=
it
->
first
;
size_t
labelNumPos
=
it
->
second
;
if
(
labelNumPos
==
0
||
allTruePos_
.
find
(
label
)
==
allTruePos_
.
end
())
continue
;
vector
<
pair
<
real
,
size_t
>>
labelTruePos
=
allTruePos_
.
find
(
label
)
->
second
;
vector
<
pair
<
real
,
size_t
>>
labelFalsePos
=
allFalsePos_
.
find
(
label
)
->
second
;
// Compute average precision.
vector
<
size_t
>
tpCumSum
;
getAccumulation
(
labelTruePos
,
&
tpCumSum
);
vector
<
size_t
>
fpCumSum
;
getAccumulation
(
labelFalsePos
,
&
fpCumSum
);
std
::
vector
<
real
>
precision
,
recall
;
size_t
num
=
tpCumSum
.
size
();
// Compute Precision.
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
CHECK_LE
(
tpCumSum
[
i
],
labelNumPos
);
precision
.
push_back
(
static_cast
<
real
>
(
tpCumSum
[
i
])
/
static_cast
<
real
>
(
tpCumSum
[
i
]
+
fpCumSum
[
i
]));
recall
.
push_back
(
static_cast
<
real
>
(
tpCumSum
[
i
])
/
labelNumPos
);
}
// VOC2007 style
if
(
apType_
==
"11point"
)
{
vector
<
real
>
maxPrecisions
(
11
,
0.0
);
int
startIdx
=
num
-
1
;
for
(
int
j
=
10
;
j
>=
0
;
--
j
)
for
(
int
i
=
startIdx
;
i
>=
0
;
--
i
)
{
if
(
recall
[
i
]
<
j
/
10.
)
{
startIdx
=
i
;
if
(
j
>
0
)
maxPrecisions
[
j
-
1
]
=
maxPrecisions
[
j
];
break
;
}
else
{
if
(
maxPrecisions
[
j
]
<
precision
[
i
])
maxPrecisions
[
j
]
=
precision
[
i
];
}
}
for
(
int
j
=
10
;
j
>=
0
;
--
j
)
mAP
+=
maxPrecisions
[
j
]
/
11
;
++
count
;
}
else
if
(
apType_
==
"Integral"
)
{
// Nature integral
real
averagePrecisions
=
0.
;
real
prevRecall
=
0.
;
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
if
(
fabs
(
recall
[
i
]
-
prevRecall
)
>
1e-6
)
averagePrecisions
+=
precision
[
i
]
*
fabs
(
recall
[
i
]
-
prevRecall
);
prevRecall
=
recall
[
i
];
}
mAP
+=
averagePrecisions
;
++
count
;
}
else
{
LOG
(
FATAL
)
<<
"Unkown ap version: "
<<
apType_
;
}
}
if
(
count
!=
0
)
mAP
/=
count
;
return
mAP
*
100
;
}
void
getAccumulation
(
vector
<
pair
<
real
,
size_t
>>
inPairs
,
vector
<
size_t
>*
accuVec
)
const
{
std
::
stable_sort
(
inPairs
.
begin
(),
inPairs
.
end
(),
sortScorePairDescend
<
size_t
>
);
accuVec
->
clear
();
size_t
sum
=
0
;
for
(
size_t
i
=
0
;
i
<
inPairs
.
size
();
++
i
)
{
sum
+=
inPairs
[
i
].
second
;
accuVec
->
push_back
(
sum
);
}
}
std
::
string
getTypeImpl
()
const
{
return
"detection_map"
;
}
real
getValueImpl
()
const
{
return
calcMAP
();
}
private:
real
overlapThreshold_
;
// overlap threshold when determining whether matched
bool
evaluateDifficult_
;
// whether evaluate difficult ground truth
size_t
backgroundId_
;
// class index of background
std
::
string
apType_
;
// how to calculate mAP (Integral or 11point)
MatrixPtr
cpuOutput_
;
MatrixPtr
cpuLabel_
;
map
<
size_t
,
size_t
>
numPos_
;
// counts of true objects each classification
map
<
size_t
,
vector
<
pair
<
real
,
size_t
>>>
allTruePos_
;
// true positive prediction
map
<
size_t
,
vector
<
pair
<
real
,
size_t
>>>
allFalsePos_
;
// false positive prediction
};
REGISTER_EVALUATOR
(
detection_map
,
DetectionMAPEvaluator
);
}
// namespace paddle
paddle/gserver/tests/test_Evaluator.cpp
浏览文件 @
9f9058ac
...
@@ -138,6 +138,23 @@ void testEvaluatorAll(TestConfig testConf,
...
@@ -138,6 +138,23 @@ void testEvaluatorAll(TestConfig testConf,
testEvaluator
(
testConf
,
testEvaluatorName
,
batchSize
,
false
);
testEvaluator
(
testConf
,
testEvaluatorName
,
batchSize
,
false
);
}
}
TEST
(
Evaluator
,
detection_map
)
{
TestConfig
config
;
config
.
evaluatorConfig
.
set_type
(
"detection_map"
);
config
.
evaluatorConfig
.
set_overlap_threshold
(
0.5
);
config
.
evaluatorConfig
.
set_background_id
(
0
);
config
.
evaluatorConfig
.
set_ap_type
(
"Integral"
);
config
.
evaluatorConfig
.
set_evaluate_difficult
(
0
);
config
.
inputDefs
.
push_back
({
INPUT_DATA
,
"output"
,
7
});
config
.
inputDefs
.
push_back
({
INPUT_SEQUENCE_DATA
,
"label"
,
6
});
config
.
evaluatorConfig
.
set_evaluate_difficult
(
false
);
testEvaluatorAll
(
config
,
"detection_map"
,
100
);
config
.
evaluatorConfig
.
set_evaluate_difficult
(
true
);
testEvaluatorAll
(
config
,
"detection_map"
,
100
);
}
TEST
(
Evaluator
,
classification_error
)
{
TEST
(
Evaluator
,
classification_error
)
{
TestConfig
config
;
TestConfig
config
;
config
.
evaluatorConfig
.
set_type
(
"classification_error"
);
config
.
evaluatorConfig
.
set_type
(
"classification_error"
);
...
...
paddle/parameter/ParameterUpdaterHook.cpp
浏览文件 @
9f9058ac
...
@@ -14,11 +14,13 @@ limitations under the License. */
...
@@ -14,11 +14,13 @@ limitations under the License. */
#include "ParameterUpdaterHook.h"
#include "ParameterUpdaterHook.h"
#include <algorithm>
#include <atomic>
#include <atomic>
#include <fstream>
#include <fstream>
#include <mutex>
#include <mutex>
#include <thread>
#include <thread>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include "paddle/math/Vector.h"
#include "paddle/math/Vector.h"
#include "paddle/parameter/Parameter.h"
#include "paddle/parameter/Parameter.h"
...
@@ -29,106 +31,76 @@ namespace paddle {
...
@@ -29,106 +31,76 @@ namespace paddle {
/**
/**
* The static pruning hook
* The static pruning hook
*
*
Static means user specify a sparsity_ratio before training started, and the
*
Static means user load a mask map before training started. This map will
*
network will prune the parameters based on the sparsity_ratio. More details
*
define which link/weight between neural is disabled
.
*
can be found https://arxiv.org/pdf/1506.02626.pdf
.
*/
*/
class
StaticPruningHook
:
public
IParameterUpdaterHook
{
class
StaticPruningHook
:
public
IParameterUpdaterHook
{
public:
public:
/**
explicit
StaticPruningHook
(
const
ParameterUpdaterHookConfig
&
hookConfig
)
* The Mask Map Header.
:
initCount_
(
0
)
{
* The map file started with this header.
sparsityRatio_
=
hookConfig
.
sparsity_ratio
();
*
* In Version 0, reset file will be:
* contains header.size bit, each bit means such weight is enabled or not.
* if bit is 1, then such weight is enabled.
* at end, the file will round to byte, and the low bits of end byte will be
* filled by zero.
*
*/
struct
StaticMaskHeader
{
uint32_t
version
;
size_t
size
;
}
__attribute__
((
__packed__
));
explicit
StaticPruningHook
(
const
std
::
string
&
mask_filename
)
:
initCount_
(
0
)
{
bool
ok
=
this
->
loadMaskFile
(
mask_filename
);
if
(
!
ok
)
{
LOG
(
WARNING
)
<<
"Fail to load mask file "
<<
mask_filename
<<
" in current directory, searching in init_model_path"
;
std
::
string
combineMaskFilename
=
path
::
join
(
FLAGS_init_model_path
,
mask_filename
);
CHECK
(
this
->
loadMaskFile
(
combineMaskFilename
))
<<
"Cannot load "
<<
mask_filename
<<
" in ./"
<<
mask_filename
<<
" and "
<<
combineMaskFilename
;
}
}
VLOG
(
3
)
<<
mask_filename
<<
" mask size = "
<<
this
->
mask_
.
size
();
static
bool
sortPairAscend
(
const
std
::
pair
<
real
,
size_t
>
&
pair1
,
const
std
::
pair
<
real
,
size_t
>
&
pair2
)
{
return
pair1
.
first
>
pair2
.
first
;
}
}
void
update
(
Parameter
*
para
)
{
void
update
(
Parameter
*
para
)
{
updateThreadChecker_
.
check
();
updateThreadChecker_
.
check
();
auto
&
vec
=
para
->
getBuf
(
PARAMETER_GRADIENT
);
auto
&
vec
=
para
->
getBuf
(
PARAMETER_GRADIENT
);
if
(
vec
)
{
if
(
vec
)
{
vec
->
dotMul
(
*
maskVec_
);
vec
->
dotMul
(
*
maskVec_
);
}
}
}
}
void
init
(
Parameter
*
para
)
{
void
generateMask
(
Parameter
*
para
)
{
size_t
initCount
=
this
->
initCount_
.
fetch_add
(
1
);
VectorPtr
maskTemp
=
Vector
::
create
(
para
->
getSize
(),
false
);
CHECK_EQ
(
initCount
,
0UL
)
<<
"Currently the StaticPruningHook must invoke "
maskTemp
->
zeroMem
();
"in same ParamterUpdater"
;
real
*
maskTempData
=
maskTemp
->
getData
();
VLOG
(
3
)
<<
"Initialize Parameter "
<<
para
;
size_t
nonZeroNum
=
para
->
getSize
()
*
(
1
-
sparsityRatio_
);
SetDevice
device
(
para
->
getDeviceId
());
auto
maskVec
=
Vector
::
create
(
this
->
mask_
.
size
(),
false
);
VectorPtr
paraVec
=
para
->
getBuf
(
PARAMETER_VALUE
);
{
// Initialize maskVec with float mask vector
VectorPtr
paraCpuCopy
=
Vector
::
create
(
para
->
getSize
(),
false
);
real
*
dataPtr
=
maskVec
->
getData
();
size_t
i
=
0
;
paraCpuCopy
->
copyFrom
(
*
paraVec
);
for
(
bool
m
:
mask_
)
{
std
::
vector
<
std
::
pair
<
real
,
size_t
>>
param
;
dataPtr
[
i
++
]
=
m
?
1.0
:
0.0
;
}
for
(
size_t
i
=
0
;
i
<
para
->
getSize
();
i
++
)
}
param
.
push_back
(
std
::
make_pair
(
fabs
(
paraCpuCopy
->
getData
()[
i
]),
i
));
std
::
partial_sort
(
param
.
begin
(),
param
.
begin
()
+
nonZeroNum
,
param
.
end
(),
sortPairAscend
);
for
(
size_t
i
=
0
;
i
<
nonZeroNum
;
i
++
)
maskTempData
[
param
[
i
].
second
]
=
1.0
;
// Currently just use a mask vector for hack.
// Currently just use a mask vector for hack.
// @TODO(yuyang18): Implemented the mask operation in vector.
if
(
para
->
useGpu
())
{
if
(
para
->
useGpu
())
{
maskVec_
=
Vector
::
create
(
this
->
mask_
.
s
ize
(),
para
->
useGpu
());
maskVec_
=
Vector
::
create
(
para
->
getS
ize
(),
para
->
useGpu
());
maskVec_
->
copyFrom
(
*
mask
Vec
);
maskVec_
->
copyFrom
(
*
mask
Temp
);
}
else
{
}
else
{
maskVec_
=
mask
Vec
;
maskVec_
=
mask
Temp
;
}
}
auto
&
vec
=
para
->
getBuf
(
PARAMETER_VALUE
);
vec
->
dotMul
(
*
maskVec_
);
}
}
private:
void
init
(
Parameter
*
para
)
{
bool
loadMaskFile
(
const
std
::
string
&
mask_filename
)
{
generateMask
(
para
);
std
::
ifstream
fin
;
size_t
initCount
=
this
->
initCount_
.
fetch_add
(
1
);
fin
.
open
(
mask_filename
);
CHECK_EQ
(
initCount
,
0UL
)
<<
"Currently the StaticPruningHook must invoke "
if
(
fin
.
is_open
())
{
"in same ParamterUpdater"
;
StaticMaskHeader
header
;
VLOG
(
3
)
<<
"Initialize Parameter "
<<
para
;
fin
.
read
(
reinterpret_cast
<
char
*>
(
&
header
),
sizeof
(
StaticMaskHeader
));
SetDevice
device
(
para
->
getDeviceId
());
CHECK_EQ
(
header
.
version
,
0UL
);
mask_
.
resize
(
header
.
size
);
auto
&
paraVec
=
para
->
getBuf
(
PARAMETER_VALUE
);
uint8_t
buf
;
paraVec
->
dotMul
(
*
maskVec_
);
for
(
size_t
i
=
0
;
i
<
header
.
size
;
++
i
,
buf
<<=
1
)
{
if
(
i
%
8
==
0
)
{
fin
.
read
(
reinterpret_cast
<
char
*>
(
&
buf
),
sizeof
(
uint8_t
));
}
mask_
[
i
]
=
buf
&
0x80
;
}
fin
.
close
();
return
true
;
}
else
{
return
false
;
}
}
}
private:
SameThreadChecker
updateThreadChecker_
;
SameThreadChecker
updateThreadChecker_
;
std
::
atomic
<
size_t
>
initCount_
;
std
::
atomic
<
size_t
>
initCount_
;
VectorPtr
maskVec_
;
VectorPtr
maskVec_
;
std
::
vector
<
bool
>
mask
_
;
real
sparsityRatio
_
;
};
};
IParameterUpdaterHook
::
IParameterUpdaterHook
()
{}
IParameterUpdaterHook
::
IParameterUpdaterHook
()
{}
...
@@ -145,7 +117,7 @@ IParameterUpdaterHook::~IParameterUpdaterHook() {}
...
@@ -145,7 +117,7 @@ IParameterUpdaterHook::~IParameterUpdaterHook() {}
*/
*/
class
StringIntPairHasher
{
class
StringIntPairHasher
{
public:
public:
size_t
operator
()(
const
std
::
pair
<
std
::
string
,
int
>
&
k
)
const
{
size_t
operator
()(
const
std
::
pair
<
std
::
string
,
int
>
&
k
)
const
{
return
intHasher_
(
strHasher_
(
k
.
first
)
+
k
.
second
);
return
intHasher_
(
strHasher_
(
k
.
first
)
+
k
.
second
);
}
}
...
@@ -162,19 +134,19 @@ static WeakKVCache<std::pair<std::string, int>,
...
@@ -162,19 +134,19 @@ static WeakKVCache<std::pair<std::string, int>,
/**
/**
* ParameterUpdaterHook actually factory method.
* ParameterUpdaterHook actually factory method.
*/
*/
static
IParameterUpdaterHook
*
createImpl
(
static
IParameterUpdaterHook
*
createImpl
(
const
ParameterUpdaterHookConfig
&
config
)
{
const
ParameterUpdaterHookConfig
&
config
)
{
auto
&
type
=
config
.
type
();
auto
&
type
=
config
.
type
();
if
(
type
==
"pruning"
)
{
if
(
type
==
"pruning"
)
{
if
(
config
.
has_purning_mask_filename
())
{
return
new
StaticPruningHook
(
config
);
return
new
StaticPruningHook
(
config
.
purning_mask_filename
());
}
}
}
LOG
(
FATAL
)
<<
"Unknown Hook type: "
<<
type
;
return
nullptr
;
return
nullptr
;
}
}
std
::
shared_ptr
<
IParameterUpdaterHook
>
IParameterUpdaterHook
::
create
(
std
::
shared_ptr
<
IParameterUpdaterHook
>
IParameterUpdaterHook
::
create
(
const
ParameterConfig
&
paramConfig
,
int
idx
)
{
const
ParameterConfig
&
paramConfig
,
int
idx
)
{
std
::
pair
<
std
::
string
,
int
>
key
=
{
paramConfig
.
name
(),
idx
};
std
::
pair
<
std
::
string
,
int
>
key
=
{
paramConfig
.
name
(),
idx
};
return
g_hookCache_
.
get
(
return
g_hookCache_
.
get
(
key
,
[
&
]
{
return
createImpl
(
paramConfig
.
update_hooks
(
idx
));
});
key
,
[
&
]
{
return
createImpl
(
paramConfig
.
update_hooks
(
idx
));
});
...
...
paddle/scripts/travis/build_and_test.sh
已删除
100755 → 0
浏览文件 @
8d476901
#!/bin/bash
source
./common.sh
NPROC
=
1
export
PYTHONPATH
=
/opt/python/2.7.12/lib/python2.7/site-packages
export
PYTHONHOME
=
/opt/python/2.7.12
export
PATH
=
/opt/python/2.7.12/bin:
${
PATH
}
cmake ..
-DCMAKE_Fortran_COMPILER
=
/usr/bin/gfortran-4.8
-DON_TRAVIS
=
ON
-DWITH_COVERAGE
=
ON
-DCOVERALLS_UPLOAD
=
ON
${
EXTRA_CMAKE_OPTS
}
NRPOC
=
`
nproc
`
make
-j
$NPROC
make coveralls
sudo
make
install
paddle/scripts/travis/
docs
.sh
→
paddle/scripts/travis/
build_doc
.sh
浏览文件 @
9f9058ac
#!/bin/bash
#!/bin/bash
set
-e
# Create the build directory for CMake.
mkdir
-p
$TRAVIS_BUILD_DIR
/build
cd
$TRAVIS_BUILD_DIR
/build
# Add set -e, cd to directory.
source
./common.sh
# Compile Documentation only.
# Compile Documentation only.
cmake ..
-DCMAKE_BUILD_TYPE
=
Debug
-D
CMAKE_Fortran_COMPILER
=
/usr/bin/gfortran-4.8
-DWITH_GPU
=
OFF
-DWITH_DOC
=
OFF
-DWITH_STYLE_CHECK
=
OFF
${
EXTRA_CMAKE_OPTS
}
cmake ..
-DCMAKE_BUILD_TYPE
=
Debug
-D
WITH_GPU
=
OFF
-DWITH_DOC
=
OFF
-DWITH_STYLE_CHECK
=
OFF
mkdir
output
mkdir
output
make
-j
`
nproc
`
make
-j
`
nproc
`
find ..
-name
'*whl'
| xargs pip
install
# install all wheels.
find ..
-name
'*whl'
| xargs pip
install
# install all wheels.
rm
-rf
*
rm
-rf
*
cmake ..
-DCMAKE_BUILD_TYPE
=
Debug
-D
CMAKE_Fortran_COMPILER
=
/usr/bin/gfortran-4.8
-DWITH_GPU
=
OFF
-DWITH_DOC
=
ON
${
EXTRA_CMAKE_OPTS
}
cmake ..
-DCMAKE_BUILD_TYPE
=
Debug
-D
WITH_GPU
=
OFF
-DWITH_DOC
=
ON
make paddle_docs paddle_docs_cn
make
-j
`
nproc
`
paddle_docs paddle_docs_cn
# check websites for broken links
# check websites for broken links
linkchecker doc/en/html/index.html
linkchecker doc/en/html/index.html
...
...
paddle/scripts/travis/
precommit
.sh
→
paddle/scripts/travis/
check_style
.sh
浏览文件 @
9f9058ac
#!/bin/bash
#!/bin/bash
function
abort
(){
function
abort
(){
echo
"Your c
ommit not fit PaddlePaddle code style
"
1>&2
echo
"Your c
hange doesn't follow PaddlePaddle's code style.
"
1>&2
echo
"Please use pre-commit
scripts to auto-format your code
"
1>&2
echo
"Please use pre-commit
to reformat your code and git push again.
"
1>&2
exit
1
exit
1
}
}
trap
'abort'
0
trap
'abort'
0
set
-e
set
-e
source
common.sh
cd
..
cd
$TRAVIS_BUILD_DIR
export
PATH
=
/usr/bin:
$PATH
export
PATH
=
/usr/bin:
$PATH
pre-commit
install
pre-commit
install
clang-format
--version
clang-format
--version
...
...
paddle/scripts/travis/common.sh
已删除
100755 → 0
浏览文件 @
8d476901
#!/bin/bash
set
-e
mkdir
-p
../../../build
cd
../../../build
mkdir
-p
$HOME
/third_party
EXTRA_CMAKE_OPTS
=
"-DTHIRD_PARTY_PATH=
${
HOME
}
/third_party"
paddle/scripts/travis/main.sh
已删除
100755 → 0
浏览文件 @
8d476901
#!/bin/bash
cd
`
dirname
$0
`
if
[
${
JOB
}
==
"BUILD_AND_TEST"
]
;
then
./build_and_test.sh
elif
[
${
JOB
}
==
"DOCS"
]
;
then
./docs.sh
elif
[
${
JOB
}
==
"PRE_COMMIT"
]
;
then
./precommit.sh
else
echo
Unknown job
${
JOB
}
exit
1
fi
proto/ModelConfig.proto
浏览文件 @
9f9058ac
...
@@ -489,6 +489,15 @@ message EvaluatorConfig {
...
@@ -489,6 +489,15 @@ message EvaluatorConfig {
// Used by ClassificationErrorEvaluator
// Used by ClassificationErrorEvaluator
// top # classification error
// top # classification error
optional
int32
top_k
=
13
[
default
=
1
];
optional
int32
top_k
=
13
[
default
=
1
];
// Used by DetectionMAPEvaluator
optional
double
overlap_threshold
=
14
[
default
=
0.5
];
optional
int32
background_id
=
15
[
default
=
0
];
optional
bool
evaluate_difficult
=
16
[
default
=
false
];
optional
string
ap_type
=
17
[
default
=
"11point"
];
}
}
message
LinkConfig
{
message
LinkConfig
{
...
...
proto/ParameterConfig.proto
浏览文件 @
9f9058ac
...
@@ -25,8 +25,10 @@ enum ParameterInitStrategy {
...
@@ -25,8 +25,10 @@ enum ParameterInitStrategy {
}
}
message
ParameterUpdaterHookConfig
{
message
ParameterUpdaterHookConfig
{
// hook type such as 'pruning'
required
string
type
=
1
;
required
string
type
=
1
;
optional
string
purning_mask_filename
=
2
;
// this represents the ratio of zero element to be set by the Parameter
optional
double
sparsity_ratio
=
2
[
default
=
0.6
];
}
}
message
ParameterConfig
{
message
ParameterConfig
{
...
...
python/paddle/trainer/config_parser.py
浏览文件 @
9f9058ac
...
@@ -1280,8 +1280,7 @@ def parse_maxout(maxout, input_layer_name, maxout_conf):
...
@@ -1280,8 +1280,7 @@ def parse_maxout(maxout, input_layer_name, maxout_conf):
# Define an evaluator
# Define an evaluator
@
config_func
@
config_func
def
Evaluator
(
def
Evaluator
(
name
,
name
,
type
,
type
,
inputs
,
inputs
,
chunk_scheme
=
None
,
chunk_scheme
=
None
,
...
@@ -1293,7 +1292,11 @@ def Evaluator(
...
@@ -1293,7 +1292,11 @@ def Evaluator(
num_results
=
None
,
num_results
=
None
,
top_k
=
None
,
top_k
=
None
,
delimited
=
None
,
delimited
=
None
,
excluded_chunk_types
=
None
,
):
excluded_chunk_types
=
None
,
overlap_threshold
=
None
,
background_id
=
None
,
evaluate_difficult
=
None
,
ap_type
=
None
):
evaluator
=
g_config
.
model_config
.
evaluators
.
add
()
evaluator
=
g_config
.
model_config
.
evaluators
.
add
()
evaluator
.
type
=
type
evaluator
.
type
=
type
evaluator
.
name
=
MakeLayerNameInSubmodel
(
name
)
evaluator
.
name
=
MakeLayerNameInSubmodel
(
name
)
...
@@ -1327,6 +1330,18 @@ def Evaluator(
...
@@ -1327,6 +1330,18 @@ def Evaluator(
if
excluded_chunk_types
:
if
excluded_chunk_types
:
evaluator
.
excluded_chunk_types
.
extend
(
excluded_chunk_types
)
evaluator
.
excluded_chunk_types
.
extend
(
excluded_chunk_types
)
if
overlap_threshold
is
not
None
:
evaluator
.
overlap_threshold
=
overlap_threshold
if
background_id
is
not
None
:
evaluator
.
background_id
=
background_id
if
evaluate_difficult
is
not
None
:
evaluator
.
evaluate_difficult
=
evaluate_difficult
if
ap_type
is
not
None
:
evaluator
.
ap_type
=
ap_type
class
LayerBase
(
object
):
class
LayerBase
(
object
):
def
__init__
(
def
__init__
(
...
@@ -3124,11 +3139,11 @@ def Layer(name, type, **xargs):
...
@@ -3124,11 +3139,11 @@ def Layer(name, type, **xargs):
@
config_func
@
config_func
def
ParameterHook
(
type
,
**
kwargs
):
def
ParameterHook
(
type
,
**
kwargs
):
if
type
==
'pruning'
:
if
type
==
'pruning'
:
mask_filename
=
kwargs
.
get
(
'mask_filename'
,
None
)
assert
mask_filename
is
not
None
hook
=
ParameterUpdaterHookConfig
()
hook
=
ParameterUpdaterHookConfig
()
hook
.
type
=
type
hook
.
type
=
type
hook
.
purning_mask_filename
=
mask_filename
sparsity_ratio
=
kwargs
.
get
(
'sparsity_ratio'
,
None
)
if
sparsity_ratio
is
not
None
:
hook
.
sparsity_ratio
=
sparsity_ratio
return
hook
return
hook
else
:
else
:
return
None
return
None
...
@@ -3236,13 +3251,13 @@ def Parameter(name,
...
@@ -3236,13 +3251,13 @@ def Parameter(name,
if
update_hooks
is
not
None
:
if
update_hooks
is
not
None
:
if
hasattr
(
update_hooks
,
'__call__'
):
if
hasattr
(
update_hooks
,
'__call__'
):
update_hooks
=
update_hooks
(
para
.
name
)
update_hooks
=
update_hooks
()
if
isinstance
(
update_hooks
,
list
):
if
isinstance
(
update_hooks
,
list
):
for
hook
in
update_hooks
:
for
hook
in
update_hooks
:
para
.
update_hooks
.
extend
([
hook
])
para
.
update_hooks
.
extend
([
hook
])
else
:
else
:
para
.
update_hooks
.
extend
(
update_hooks
)
para
.
update_hooks
.
extend
(
[
update_hooks
]
)
g_parameter_map
[
name
]
=
para
g_parameter_map
[
name
]
=
para
if
initializer
is
not
None
:
if
initializer
is
not
None
:
...
...
python/paddle/trainer_config_helpers/attrs.py
浏览文件 @
9f9058ac
...
@@ -14,7 +14,8 @@
...
@@ -14,7 +14,8 @@
from
paddle.trainer.config_parser
import
*
from
paddle.trainer.config_parser
import
*
__all__
=
[
__all__
=
[
'ParamAttr'
,
'ExtraAttr'
,
'ParameterAttribute'
,
'ExtraLayerAttribute'
'HookAttr'
,
'ParamAttr'
,
'ExtraAttr'
,
'ParameterAttribute'
,
'ExtraLayerAttribute'
]
]
...
@@ -55,6 +56,40 @@ def is_compatible_with(x, Type):
...
@@ -55,6 +56,40 @@ def is_compatible_with(x, Type):
return
False
return
False
class
HookAttribute
(
object
):
"""
Hook Attribute object. As a member of ParameterAttribute class, the hook is an auxiliary operation that occurs
during training process of a layer with parameters, such as img_conv layer, fc layer.
:param type: Hook type, currently supported types:
'pruning' : user specify a sparsity_ratio before training started, and the
network will prune the parameters based on the sparsity_ratio.
eg: The definition of Hook object can be hk = HookAttribute('pruning', 0.6)
The specific usage can be paddle.layer.img_conv(input=img, filter_size=3,
num_channels=3, num_filters=64,
param_attr=ParameterAttribute(update_hooks=hk) )
The pruning details can be found https://arxiv.org/pdf/1506.02626.pdf
:type type: string
:param sparsity_ratio: Must be specified if hook type is 'pruning',
it represents the ratio of the zero elements to be set by the Parameter.
:type sparsity_ratio: float or None
"""
def
__init__
(
self
,
type
,
sparsity_ratio
=
None
):
self
.
type
=
type
self
.
sparsity_ratio
=
sparsity_ratio
if
self
.
sparsity_ratio
is
not
None
:
assert
is_compatible_with
(
self
.
sparsity_ratio
,
float
),
'sparisity_ratio must be float type'
assert
self
.
sparsity_ratio
<=
1
and
self
.
sparsity_ratio
>=
0
,
'sparsity_ratio must be a float between [0, 1] '
def
__call__
(
self
):
return
ParameterHook
(
self
.
type
,
sparsity_ratio
=
self
.
sparsity_ratio
)
class
ParameterAttribute
(
object
):
class
ParameterAttribute
(
object
):
"""
"""
Parameter Attributes object. To fine-tuning network training process, user
Parameter Attributes object. To fine-tuning network training process, user
...
@@ -114,6 +149,7 @@ class ParameterAttribute(object):
...
@@ -114,6 +149,7 @@ class ParameterAttribute(object):
momentum
=
None
,
momentum
=
None
,
gradient_clipping_threshold
=
None
,
gradient_clipping_threshold
=
None
,
sparse_update
=
False
,
sparse_update
=
False
,
update_hooks
=
None
,
initializer
=
None
):
initializer
=
None
):
self
.
attr
=
{}
self
.
attr
=
{}
...
@@ -169,6 +205,9 @@ class ParameterAttribute(object):
...
@@ -169,6 +205,9 @@ class ParameterAttribute(object):
if
initializer
is
not
None
:
if
initializer
is
not
None
:
self
.
attr
[
'initializer'
]
=
initializer
self
.
attr
[
'initializer'
]
=
initializer
if
update_hooks
:
self
.
attr
[
'update_hooks'
]
=
update_hooks
def
set_default_parameter_name
(
self
,
name
):
def
set_default_parameter_name
(
self
,
name
):
"""
"""
Set default parameter name. If parameter not set, then will use default
Set default parameter name. If parameter not set, then will use default
...
@@ -244,5 +283,6 @@ class ExtraLayerAttribute(object):
...
@@ -244,5 +283,6 @@ class ExtraLayerAttribute(object):
return
attr
.
attr
return
attr
.
attr
HookAttr
=
HookAttribute
ParamAttr
=
ParameterAttribute
ParamAttr
=
ParameterAttribute
ExtraAttr
=
ExtraLayerAttribute
ExtraAttr
=
ExtraLayerAttribute
python/paddle/trainer_config_helpers/evaluators.py
浏览文件 @
9f9058ac
...
@@ -21,7 +21,8 @@ __all__ = [
...
@@ -21,7 +21,8 @@ __all__ = [
"chunk_evaluator"
,
"sum_evaluator"
,
"column_sum_evaluator"
,
"chunk_evaluator"
,
"sum_evaluator"
,
"column_sum_evaluator"
,
"value_printer_evaluator"
,
"gradient_printer_evaluator"
,
"value_printer_evaluator"
,
"gradient_printer_evaluator"
,
"maxid_printer_evaluator"
,
"maxframe_printer_evaluator"
,
"maxid_printer_evaluator"
,
"maxframe_printer_evaluator"
,
"seqtext_printer_evaluator"
,
"classification_error_printer_evaluator"
"seqtext_printer_evaluator"
,
"classification_error_printer_evaluator"
,
"detection_map_evaluator"
]
]
...
@@ -31,10 +32,11 @@ class EvaluatorAttribute(object):
...
@@ -31,10 +32,11 @@ class EvaluatorAttribute(object):
FOR_RANK
=
1
<<
2
FOR_RANK
=
1
<<
2
FOR_PRINT
=
1
<<
3
FOR_PRINT
=
1
<<
3
FOR_UTILS
=
1
<<
4
FOR_UTILS
=
1
<<
4
FOR_DETECTION
=
1
<<
5
KEYS
=
[
KEYS
=
[
"for_classification"
,
"for_regression"
,
"for_rank"
,
"for_print"
,
"for_classification"
,
"for_regression"
,
"for_rank"
,
"for_print"
,
"for_utils"
"for_utils"
,
"for_detection"
]
]
@
staticmethod
@
staticmethod
...
@@ -57,8 +59,7 @@ def evaluator(*attrs):
...
@@ -57,8 +59,7 @@ def evaluator(*attrs):
return
impl
return
impl
def
evaluator_base
(
def
evaluator_base
(
input
,
input
,
type
,
type
,
label
=
None
,
label
=
None
,
weight
=
None
,
weight
=
None
,
...
@@ -72,7 +73,11 @@ def evaluator_base(
...
@@ -72,7 +73,11 @@ def evaluator_base(
num_results
=
None
,
num_results
=
None
,
delimited
=
None
,
delimited
=
None
,
top_k
=
None
,
top_k
=
None
,
excluded_chunk_types
=
None
,
):
excluded_chunk_types
=
None
,
overlap_threshold
=
None
,
background_id
=
None
,
evaluate_difficult
=
None
,
ap_type
=
None
):
"""
"""
Evaluator will evaluate the network status while training/testing.
Evaluator will evaluate the network status while training/testing.
...
@@ -107,6 +112,14 @@ def evaluator_base(
...
@@ -107,6 +112,14 @@ def evaluator_base(
:type weight: LayerOutput.
:type weight: LayerOutput.
:param top_k: number k in top-k error rate
:param top_k: number k in top-k error rate
:type top_k: int
:type top_k: int
:param overlap_threshold: In detection tasks to filter detection results
:type overlap_threshold: float
:param background_id: Identifier of background class
:type background_id: int
:param evaluate_difficult: Whether to evaluate difficult objects
:type evaluate_difficult: bool
:param ap_type: How to calculate average persicion
:type ap_type: str
"""
"""
# inputs type assertions.
# inputs type assertions.
assert
classification_threshold
is
None
or
isinstance
(
assert
classification_threshold
is
None
or
isinstance
(
...
@@ -136,7 +149,61 @@ def evaluator_base(
...
@@ -136,7 +149,61 @@ def evaluator_base(
delimited
=
delimited
,
delimited
=
delimited
,
num_results
=
num_results
,
num_results
=
num_results
,
top_k
=
top_k
,
top_k
=
top_k
,
excluded_chunk_types
=
excluded_chunk_types
,
)
excluded_chunk_types
=
excluded_chunk_types
,
overlap_threshold
=
overlap_threshold
,
background_id
=
background_id
,
evaluate_difficult
=
evaluate_difficult
,
ap_type
=
ap_type
)
@
evaluator
(
EvaluatorAttribute
.
FOR_DETECTION
)
@
wrap_name_default
()
def
detection_map_evaluator
(
input
,
label
,
overlap_threshold
=
0.5
,
background_id
=
0
,
evaluate_difficult
=
False
,
ap_type
=
"11point"
,
name
=
None
):
"""
Detection mAP Evaluator. It will print mean Average Precision (mAP) for detection.
The detection mAP Evaluator based on the output of detection_output layer counts
the true positive and the false positive bbox and integral them to get the
mAP.
The simple usage is:
.. code-block:: python
eval = detection_map_evaluator(input=det_output,label=lbl)
:param input: Input layer.
:type input: LayerOutput
:param label: Label layer.
:type label: LayerOutput
:param overlap_threshold: The bbox overlap threshold of a true positive.
:type overlap_threshold: float
:param background_id: The background class index.
:type background_id: int
:param evaluate_difficult: Whether evaluate a difficult ground truth.
:type evaluate_difficult: bool
"""
if
not
isinstance
(
input
,
list
):
input
=
[
input
]
if
label
:
input
.
append
(
label
)
evaluator_base
(
name
=
name
,
type
=
"detection_map"
,
input
=
input
,
label
=
label
,
overlap_threshold
=
overlap_threshold
,
background_id
=
background_id
,
evaluate_difficult
=
evaluate_difficult
,
ap_type
=
ap_type
)
@
evaluator
(
EvaluatorAttribute
.
FOR_CLASSIFICATION
)
@
evaluator
(
EvaluatorAttribute
.
FOR_CLASSIFICATION
)
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
9f9058ac
...
@@ -3839,7 +3839,8 @@ def classification_cost(input,
...
@@ -3839,7 +3839,8 @@ def classification_cost(input,
weight
=
None
,
weight
=
None
,
name
=
None
,
name
=
None
,
evaluator
=
classification_error_evaluator
,
evaluator
=
classification_error_evaluator
,
layer_attr
=
None
):
layer_attr
=
None
,
coeff
=
1.
):
"""
"""
classification cost Layer.
classification cost Layer.
...
@@ -3855,6 +3856,8 @@ def classification_cost(input,
...
@@ -3855,6 +3856,8 @@ def classification_cost(input,
:param evaluator: Evaluator method.
:param evaluator: Evaluator method.
:param layer_attr: layer's extra attribute.
:param layer_attr: layer's extra attribute.
:type layer_attr: ExtraLayerAttribute
:type layer_attr: ExtraLayerAttribute
:param coeff: The coefficient affects the gradient in the backward.
:type coeff: float
:return: LayerOutput object.
:return: LayerOutput object.
:rtype: LayerOutput
:rtype: LayerOutput
"""
"""
...
@@ -3868,6 +3871,7 @@ def classification_cost(input,
...
@@ -3868,6 +3871,7 @@ def classification_cost(input,
name
=
name
,
name
=
name
,
type
=
"multi-class-cross-entropy"
,
type
=
"multi-class-cross-entropy"
,
inputs
=
ipts
,
inputs
=
ipts
,
coeff
=
coeff
,
**
ExtraLayerAttribute
.
to_kwargs
(
layer_attr
))
**
ExtraLayerAttribute
.
to_kwargs
(
layer_attr
))
def
__add_evaluator__
(
e
):
def
__add_evaluator__
(
e
):
...
...
python/paddle/v2/attr.py
浏览文件 @
9f9058ac
...
@@ -17,10 +17,12 @@ import paddle.trainer_config_helpers.attrs
...
@@ -17,10 +17,12 @@ import paddle.trainer_config_helpers.attrs
__all__
=
[
__all__
=
[
"Param"
,
"Param"
,
"Extra"
,
"Extra"
,
"Hook"
,
]
]
Param
=
paddle
.
trainer_config_helpers
.
attrs
.
ParameterAttribute
Param
=
paddle
.
trainer_config_helpers
.
attrs
.
ParameterAttribute
Extra
=
paddle
.
trainer_config_helpers
.
attrs
.
ExtraLayerAttribute
Extra
=
paddle
.
trainer_config_helpers
.
attrs
.
ExtraLayerAttribute
Hook
=
paddle
.
trainer_config_helpers
.
attrs
.
HookAttribute
for
each
in
paddle
.
trainer_config_helpers
.
attrs
.
__all__
:
for
each
in
paddle
.
trainer_config_helpers
.
attrs
.
__all__
:
globals
()[
each
]
=
getattr
(
paddle
.
trainer_config_helpers
.
attrs
,
each
)
globals
()[
each
]
=
getattr
(
paddle
.
trainer_config_helpers
.
attrs
,
each
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录