Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
475dd708
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 2 年 前同步成功
通知
708
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
475dd708
编写于
7月 11, 2017
作者:
W
wanghaoshuang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into pixel_softmax_layer
上级
0152d97e
15f021a9
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
344 addition
and
35 deletion
+344
-35
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+6
-2
go/pserver/etcd_client.go
go/pserver/etcd_client.go
+13
-0
go/pserver/optimizer.go
go/pserver/optimizer.go
+15
-3
go/pserver/optimizer_test.go
go/pserver/optimizer_test.go
+1
-1
go/pserver/service.go
go/pserver/service.go
+99
-11
go/pserver/service_test.go
go/pserver/service_test.go
+12
-14
paddle/platform/CMakeLists.txt
paddle/platform/CMakeLists.txt
+2
-0
paddle/platform/device_context.h
paddle/platform/device_context.h
+159
-0
paddle/platform/device_context_test.cc
paddle/platform/device_context_test.cc
+33
-0
python/paddle/trainer_config_helpers/networks.py
python/paddle/trainer_config_helpers/networks.py
+2
-2
python/paddle/v2/dataset/wmt14.py
python/paddle/v2/dataset/wmt14.py
+2
-2
未找到文件。
go/cmd/pserver/pserver.go
浏览文件 @
475dd708
...
...
@@ -20,6 +20,8 @@ func main() {
"comma separated endpoint string for pserver to connect to etcd"
)
etcdTimeout
:=
flag
.
Int
(
"etcd-timeout"
,
5
,
"timeout for etcd calls"
)
numPservers
:=
flag
.
Int
(
"num-pservers"
,
1
,
"total pserver count in a training job"
)
checkpointPath
:=
flag
.
String
(
"checkpoint-path"
,
"/checkpoints/"
,
"save checkpoint path"
)
checkpointInterval
:=
flag
.
Int
(
"checkpoint-interval"
,
600
,
"save checkpoint per interval seconds"
)
logLevel
:=
flag
.
String
(
"log-level"
,
"info"
,
"log level, possible values: debug, info, warning, error, fatal, panic"
)
flag
.
Parse
()
...
...
@@ -31,18 +33,20 @@ func main() {
log
.
SetLevel
(
level
)
var
idx
int
var
cp
pserver
.
Checkpoint
var
e
*
pserver
.
EtcdClient
if
*
index
>=
0
{
idx
=
*
index
}
else
{
timeout
:=
time
.
Second
*
time
.
Duration
((
*
etcdTimeout
))
e
:
=
pserver
.
NewEtcdClient
(
*
etcdEndpoint
,
*
numPservers
,
timeout
)
e
=
pserver
.
NewEtcdClient
(
*
etcdEndpoint
,
*
numPservers
,
timeout
)
idx
,
err
=
e
.
Register
()
if
err
!=
nil
{
panic
(
err
)
}
}
s
,
err
:=
pserver
.
NewService
(
idx
)
s
,
err
:=
pserver
.
NewService
(
idx
,
*
checkpointInterval
,
*
checkpointPath
,
e
,
cp
)
if
err
!=
nil
{
panic
(
err
)
}
...
...
go/pserver/etcd_client.go
浏览文件 @
475dd708
...
...
@@ -18,6 +18,8 @@ const (
PsDesired
=
"/ps_desired"
// PsAddr is the base dir for pserver to store their addr
PsPath
=
"/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint
=
"/checkpoints/"
)
// EtcdClient is the etcd client that the pserver uses for fault
...
...
@@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
return
idx
,
nil
}
// PutKey put into etcd with value by key specified
func
(
e
*
EtcdClient
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
int
)
error
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
*
time
.
Duration
(
timeout
))
_
,
err
:=
e
.
etcdClient
.
Put
(
ctx
,
key
,
string
(
value
))
cancel
()
if
err
!=
nil
{
return
err
}
return
nil
}
go/pserver/optimizer.go
浏览文件 @
475dd708
...
...
@@ -35,22 +35,28 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
return
(
*
[
1
<<
30
]
byte
)(
p
)[
:
len
:
len
]
}
func
newOptimizer
(
paramWithConfigs
ParameterWithConfig
)
*
optimizer
{
func
newOptimizer
(
paramWithConfigs
ParameterWithConfig
,
State
[]
byte
)
*
optimizer
{
o
:=
&
optimizer
{}
o
.
elementType
=
paramWithConfigs
.
Param
.
ElementType
p
:=
paramWithConfigs
.
Param
c
:=
paramWithConfigs
.
Config
s
:=
State
log
.
WithFields
(
log
.
Fields
{
"ElementType"
:
p
.
ElementType
,
"ParamSize"
:
len
(
p
.
Content
),
"ConfigSize"
:
len
(
c
),
"StateSize"
:
len
(
s
),
})
.
Info
(
"New Optimizer Created with config:"
)
var
cbuffer
unsafe
.
Pointer
cbuffer
=
C
.
malloc
(
C
.
size_t
(
len
(
p
.
Content
)))
C
.
memcpy
(
cbuffer
,
unsafe
.
Pointer
(
&
p
.
Content
[
0
]),
C
.
size_t
(
len
(
p
.
Content
)))
var
cstate
unsafe
.
Pointer
if
len
(
s
)
!=
0
{
cstate
=
unsafe
.
Pointer
(
&
s
[
0
])
}
o
.
opt
=
C
.
paddle_create_optimizer
((
*
C
.
uchar
)(
&
c
[
0
]),
C
.
int
(
len
(
c
)),
C
.
paddle_element_type
(
p
.
ElementType
),
cbuffer
,
C
.
int
(
len
(
p
.
Content
)
/
C
.
sizeof_float
),
(
*
C
.
char
)(
nullPtr
),
0
)
C
.
paddle_element_type
(
p
.
ElementType
),
cbuffer
,
C
.
int
(
len
(
p
.
Content
)
/
C
.
sizeof_float
),
(
*
C
.
char
)(
cstate
),
C
.
int
(
len
(
s
)))
return
o
}
...
...
@@ -60,6 +66,12 @@ func (o *optimizer) GetWeights() []byte {
return
cArrayToSlice
(
buffer
,
int
(
bufferLen
)
*
C
.
sizeof_float
)
}
func
(
o
*
optimizer
)
GetStates
()
[]
byte
{
var
cbuffer
*
C
.
char
cbuffer_len
:=
C
.
paddle_optimizer_get_state
(
o
.
opt
,
&
cbuffer
)
return
cArrayToSlice
(
unsafe
.
Pointer
(
cbuffer
),
int
(
cbuffer_len
))
}
func
(
o
*
optimizer
)
UpdateParameter
(
g
Gradient
)
error
{
if
o
.
elementType
!=
g
.
ElementType
{
return
fmt
.
Errorf
(
"Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v"
,
g
.
Name
,
o
.
elementType
,
g
.
ElementType
)
...
...
go/pserver/optimizer_test.go
浏览文件 @
475dd708
...
...
@@ -19,6 +19,6 @@ func TestOptimizerCreateRelease(t *testing.T) {
Param
:
p
,
Config
:
config
,
}
o
:=
newOptimizer
(
param
)
o
:=
newOptimizer
(
param
,
nil
)
o
.
Cleanup
()
}
go/pserver/service.go
浏览文件 @
475dd708
package
pserver
import
(
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"sync"
"time"
log
"github.com/sirupsen/logrus"
)
// ElementType is the type of elements of a Parameter.
...
...
@@ -39,6 +51,22 @@ type ParameterWithConfig struct {
Config
[]
byte
// parameter configuration in Proto Buffer format
}
// ParameterCheckpoint is Parameter and State checkpoint
type
ParameterCheckpoint
struct
{
ParamConfig
ParameterWithConfig
State
[]
byte
}
// checkpoint signature
type
checkpointMeta
struct
{
UUID
string
`json:"uuid"`
Md5sum
string
`json:"md5sum"`
Timestamp
string
`json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file
type
Checkpoint
[]
ParameterCheckpoint
// Gradient is the gradient of the parameter.
type
Gradient
Parameter
...
...
@@ -46,19 +74,32 @@ type Gradient Parameter
type
Service
struct
{
initialized
chan
struct
{}
idx
int
checkpointInterval
time
.
Duration
checkpointPath
string
client
*
EtcdClient
mu
sync
.
Mutex
optMap
map
[
string
]
*
optimizer
}
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
func
NewService
(
idx
int
)
(
*
Service
,
error
)
{
func
NewService
(
idx
int
,
seconds
int
,
path
string
,
client
*
EtcdClient
,
cp
Checkpoint
)
(
*
Service
,
error
)
{
s
:=
&
Service
{
idx
:
idx
,
checkpointInterval
:
time
.
Second
*
time
.
Duration
(
seconds
),
checkpointPath
:
path
,
client
:
client
,
}
s
.
optMap
=
make
(
map
[
string
]
*
optimizer
)
s
.
initialized
=
make
(
chan
struct
{})
if
cp
!=
nil
{
for
_
,
item
:=
range
cp
{
p
:=
item
.
ParamConfig
st
:=
item
.
State
s
.
optMap
[
p
.
Param
.
Name
]
=
newOptimizer
(
p
,
st
)
}
}
return
s
,
nil
}
...
...
@@ -78,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// TODO(helin): check if paramWithConfigs.Param.Content is
// properly memory aligned, if not, make copy to a memory
// aligned region.
s
.
optMap
[
paramWithConfigs
.
Param
.
Name
]
=
newOptimizer
(
paramWithConfigs
)
s
.
optMap
[
paramWithConfigs
.
Param
.
Name
]
=
newOptimizer
(
paramWithConfigs
,
nil
)
return
nil
}
...
...
@@ -139,10 +180,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
return
nil
}
//
Save tells the parameter server to save parameters.
func
(
s
*
Service
)
Save
(
path
string
,
dummy
*
int
)
error
{
//
pserver save checkpoint
func
(
s
*
Service
)
doCheckpoint
(
)
error
{
<-
s
.
initialized
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
cp
:=
make
([]
ParameterCheckpoint
,
0
,
len
(
s
.
optMap
))
index
:=
0
for
name
,
opt
:=
range
s
.
optMap
{
var
pc
ParameterCheckpoint
pc
.
ParamConfig
.
Param
.
Name
=
name
pc
.
ParamConfig
.
Param
.
ElementType
=
opt
.
elementType
pc
.
ParamConfig
.
Param
.
Content
=
opt
.
GetWeights
()
pc
.
State
=
opt
.
GetStates
()
cp
[
index
]
=
pc
index
++
}
var
buf
bytes
.
Buffer
encoder
:=
gob
.
NewEncoder
(
&
buf
)
err
:=
encoder
.
Encode
(
cp
)
if
err
!=
nil
{
return
err
}
// TODO
cpMeta
:=
checkpointMeta
{}
cpMeta
.
UUID
=
s
.
checkpointPath
+
strconv
.
Itoa
(
s
.
idx
)
cpMeta
.
Timestamp
=
time
.
Now
()
.
String
()
h
:=
md5
.
New
()
cpMeta
.
Md5sum
=
hex
.
EncodeToString
(
h
.
Sum
(
buf
.
Bytes
()))
cpMetajson
,
_
:=
json
.
Marshal
(
cpMeta
)
err
=
s
.
client
.
PutKey
(
filepath
.
Join
(
PsCheckpoint
,
strconv
.
Itoa
(
s
.
idx
)),
cpMetajson
,
3
)
if
err
!=
nil
{
return
err
}
if
_
,
err
=
os
.
Stat
(
cpMeta
.
UUID
);
os
.
IsNotExist
(
err
)
{
log
.
Info
(
"checkpoint does not exists."
)
}
else
{
err
=
os
.
Remove
(
cpMeta
.
UUID
)
log
.
Infof
(
"checkpoint %s already exsits, removing "
,
cpMeta
.
UUID
)
}
f
,
err
:=
os
.
Create
(
cpMeta
.
UUID
)
defer
f
.
Close
()
if
err
!=
nil
{
return
err
}
writer
:=
bufio
.
NewWriter
(
f
)
_
,
err
=
writer
.
Write
(
buf
.
Bytes
())
writer
.
Flush
()
if
err
!=
nil
{
return
err
}
return
nil
}
go/pserver/service_test.go
浏览文件 @
475dd708
...
...
@@ -15,7 +15,8 @@ const (
)
func
TestServiceFull
(
t
*
testing
.
T
)
{
s
,
err
:=
pserver
.
NewService
(
0
)
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
if
err
!=
nil
{
t
.
Error
(
err
)
}
...
...
@@ -86,7 +87,8 @@ func TestServiceFull(t *testing.T) {
}
func
TestMultipleInit
(
t
*
testing
.
T
)
{
s
,
err
:=
pserver
.
NewService
(
0
)
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
if
err
!=
nil
{
t
.
Error
(
err
)
}
...
...
@@ -102,7 +104,8 @@ func TestMultipleInit(t *testing.T) {
}
func
TestUninitialized
(
t
*
testing
.
T
)
{
s
,
err
:=
pserver
.
NewService
(
0
)
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
err
=
s
.
SendGrad
(
pserver
.
Gradient
{},
nil
)
if
err
.
Error
()
!=
pserver
.
Uninitialized
{
t
.
FailNow
()
...
...
@@ -110,7 +113,8 @@ func TestUninitialized(t *testing.T) {
}
func
TestBlockUntilInitialized
(
t
*
testing
.
T
)
{
s
,
err
:=
pserver
.
NewService
(
0
)
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
if
err
!=
nil
{
t
.
Error
(
err
)
}
...
...
@@ -128,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) {
ch
<-
struct
{}{}
}()
wg
.
Add
(
1
)
go
func
()
{
err
:=
s
.
Save
(
""
,
nil
)
if
err
!=
nil
{
errCh
<-
err
}
wg
.
Done
()
ch
<-
struct
{}{}
}()
time
.
Sleep
(
50
*
time
.
Millisecond
)
select
{
...
...
@@ -170,3 +164,7 @@ func TestBlockUntilInitialized(t *testing.T) {
wg
.
Wait
()
}
func
TestCheckpointSpeed
(
t
*
testing
.
T
)
{
//TODO(zhihong): test speed
}
paddle/platform/CMakeLists.txt
浏览文件 @
475dd708
...
...
@@ -4,3 +4,5 @@ nv_test(cuda_test SRCS cuda_test.cu)
cc_library
(
place SRCS place.cc
)
cc_test
(
place_test SRCS place_test.cc DEPS place glog gflags
)
nv_test
(
device_context_test SRCS device_context_test.cc DEPS dynamic_loader place eigen3 glog gflags
)
paddle/platform/device_context.h
0 → 100644
浏览文件 @
475dd708
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/enforce.h"
#ifndef PADDLE_ONLY_CPU
#include "paddle/platform/cuda.h"
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#define EIGEN_USE_GPU
#endif
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace
paddle
{
namespace
platform
{
class
DeviceContext
{
public:
virtual
~
DeviceContext
()
{}
};
class
CPUDeviceContext
:
public
DeviceContext
{};
#ifndef PADDLE_ONLY_CPU
class
GPUPlaceGuard
{
public:
explicit
GPUPlaceGuard
(
GPUPlace
new_place
)
:
previous_
(
GetCurrentDeviceId
())
{
if
(
previous_
!=
new_place
)
{
paddle
::
platform
::
SetDeviceId
(
new_place
.
device
);
}
}
~
GPUPlaceGuard
()
{
paddle
::
platform
::
SetDeviceId
(
previous_
.
device
);
}
private:
GPUPlace
previous_
;
};
class
CUDADeviceContext
:
public
DeviceContext
{
public:
explicit
CUDADeviceContext
(
const
GPUPlace
gpu_place
)
:
gpu_place_
(
gpu_place
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
paddle
::
platform
::
throw_on_error
(
cudaStreamCreate
(
&
stream_
),
"cudaStreamCreate failed"
);
eigen_stream_
=
new
Eigen
::
CudaStreamDevice
(
&
stream_
);
eigen_device_
=
new
Eigen
::
GpuDevice
(
eigen_stream_
);
}
void
Wait
()
{
paddle
::
platform
::
throw_on_error
(
cudaStreamSynchronize
(
stream_
),
"cudaStreamSynchronize failed"
);
}
cudaStream_t
stream
()
{
return
stream_
;
}
Eigen
::
GpuDevice
eigen_device
()
{
return
*
eigen_device_
;
}
cublasHandle_t
cublas_handle
()
{
if
(
!
blas_handle_
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cublasCreate
(
&
blas_handle_
)
==
CUBLAS_STATUS_SUCCESS
,
"cublasCreate failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cublasSetStream
(
blas_handle_
,
stream_
)
==
CUBLAS_STATUS_SUCCESS
,
"cublasSetStream failed"
);
}
return
blas_handle_
;
}
cudnnHandle_t
cudnn_handle
()
{
if
(
!
dnn_handle_
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cudnnCreate
(
&
dnn_handle_
)
==
CUDNN_STATUS_SUCCESS
,
"cudnnCreate failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cudnnSetStream
(
dnn_handle_
,
stream_
)
==
CUDNN_STATUS_SUCCESS
,
"cudnnSetStream failed"
);
}
return
dnn_handle_
;
}
curandGenerator_t
curand_generator
()
{
if
(
!
rand_generator_
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandCreateGenerator
(
&
rand_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
)
==
CURAND_STATUS_SUCCESS
,
"curandCreateGenerator failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandSetPseudoRandomGeneratorSeed
(
rand_generator_
,
random_seed_
)
==
CURAND_STATUS_SUCCESS
,
"curandSetPseudoRandomGeneratorSeed failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandSetStream
(
rand_generator_
,
stream_
)
==
CURAND_STATUS_SUCCESS
,
"curandSetStream failed"
);
}
return
rand_generator_
;
}
~
CUDADeviceContext
()
{
Wait
();
if
(
blas_handle_
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cublasDestroy
(
blas_handle_
)
==
CUBLAS_STATUS_SUCCESS
,
"cublasDestroy failed"
);
}
if
(
dnn_handle_
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cudnnDestroy
(
dnn_handle_
)
==
CUDNN_STATUS_SUCCESS
,
"cudnnDestroy failed"
);
}
if
(
rand_generator_
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandDestroyGenerator
(
rand_generator_
)
==
CURAND_STATUS_SUCCESS
,
"curandDestroyGenerator failed"
);
}
delete
eigen_stream_
;
delete
eigen_device_
;
paddle
::
platform
::
throw_on_error
(
cudaStreamDestroy
(
stream_
),
"cudaStreamDestroy failed"
);
}
private:
GPUPlace
gpu_place_
;
cudaStream_t
stream_
;
Eigen
::
CudaStreamDevice
*
eigen_stream_
;
Eigen
::
GpuDevice
*
eigen_device_
;
cublasHandle_t
blas_handle_
{
nullptr
};
cudnnHandle_t
dnn_handle_
{
nullptr
};
int
random_seed_
;
curandGenerator_t
rand_generator_
{
nullptr
};
};
#endif
}
// namespace platform
}
// namespace paddle
paddle/platform/device_context_test.cc
0 → 100644
浏览文件 @
475dd708
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/device_context.h"
#include "gtest/gtest.h"
TEST
(
CUDADeviceContext
,
Init
)
{
int
count
=
paddle
::
platform
::
GetDeviceCount
();
for
(
int
i
=
0
;
i
<
count
;
i
++
)
{
paddle
::
platform
::
CUDADeviceContext
*
device_context
=
new
paddle
::
platform
::
CUDADeviceContext
(
i
);
Eigen
::
GpuDevice
gpu_device
=
device_context
->
eigen_device
();
ASSERT_NE
(
nullptr
,
gpu_device
.
stream
());
cudnnHandle_t
cudnn_handle
=
device_context
->
cudnn_handle
();
ASSERT_NE
(
nullptr
,
cudnn_handle
);
cublasHandle_t
cublas_handle
=
device_context
->
cublas_handle
();
ASSERT_NE
(
nullptr
,
cublas_handle
);
curandGenerator_t
curand_handle
=
device_context
->
curand_generator
();
ASSERT_NE
(
nullptr
,
curand_handle
);
delete
device_context
;
}
}
python/paddle/trainer_config_helpers/networks.py
浏览文件 @
475dd708
...
...
@@ -1395,7 +1395,7 @@ def inputs(layers, *args):
if
len
(
args
)
!=
0
:
layers
.
extend
(
args
)
Inputs
(
*
[
l
.
name
for
l
in
layers
])
Inputs
(
*
[
l
.
name
for
l
in
layers
])
def
outputs
(
layers
,
*
args
):
...
...
@@ -1438,7 +1438,7 @@ def outputs(layers, *args):
assert
len
(
layers
)
>
0
if
HasInputsSet
():
# input already set
Outputs
(
*
[
l
.
name
for
l
in
layers
])
Outputs
(
*
[
l
.
name
for
l
in
layers
])
return
# just return outputs.
if
len
(
layers
)
!=
1
:
...
...
python/paddle/v2/dataset/wmt14.py
浏览文件 @
475dd708
...
...
@@ -32,9 +32,9 @@ MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
# this is a small set of data for test. The original data is too large and will be add later.
URL_TRAIN
=
'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
MD5_TRAIN
=
'0791583d57d5beb693b9414c5b36798c'
#
this is the pretrained model, whose bleu =
26.92
#
BLEU of this trained model is
26.92
URL_MODEL
=
'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
MD5_MODEL
=
'
4ce14a26607fb8a1cc23bcdedb1895e4
'
MD5_MODEL
=
'
0cb4a5366189b6acba876491c8724fa3
'
START
=
"<s>"
END
=
"<e>"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录