Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7443b2e4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7443b2e4
编写于
8月 22, 2017
作者:
Q
QI JUN
提交者:
GitHub
8月 22, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3596 from QiJune/implement_random_function
refine random related operators
上级
6eab5638
aff90d8e
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
55 addition
and
80 deletion
+55
-80
paddle/operators/gaussian_random_op.cc
paddle/operators/gaussian_random_op.cc
+13
-16
paddle/operators/gaussian_random_op.cu
paddle/operators/gaussian_random_op.cu
+35
-23
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+0
-1
paddle/operators/uniform_random_op.cc
paddle/operators/uniform_random_op.cc
+3
-6
paddle/operators/uniform_random_op.cu
paddle/operators/uniform_random_op.cu
+1
-4
paddle/platform/device_context.cc
paddle/platform/device_context.cc
+0
-16
paddle/platform/device_context.h
paddle/platform/device_context.h
+2
-11
paddle/platform/device_context_test.cc
paddle/platform/device_context_test.cc
+0
-2
python/paddle/v2/framework/tests/CMakeLists.txt
python/paddle/v2/framework/tests/CMakeLists.txt
+1
-1
未找到文件。
paddle/operators/gaussian_random_op.cc
浏览文件 @
7443b2e4
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -19,25 +16,25 @@ namespace paddle {
...
@@ -19,25 +16,25 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
template
<
typename
T
>
class
GaussianRandomKernel
:
public
framework
::
OpKernel
{
class
CPU
GaussianRandomKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
float
mean
=
context
.
op_
.
GetAttr
<
float
>
(
"mean"
);
float
mean
=
context
.
op_
.
GetAttr
<
float
>
(
"mean"
);
float
std
=
context
.
op_
.
GetAttr
<
float
>
(
"std"
);
float
std
=
context
.
op_
.
GetAttr
<
float
>
(
"std"
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
// TODO(dzh): attribute does not support unsigned int.
unsigned
int
seed
=
// And we need a global random seed configuration.
static_cast
<
unsigned
int
>
(
context
.
op_
.
GetAttr
<
int
>
(
"seed"
));
int
seed
=
context
.
op_
.
GetAttr
<
int
>
(
"seed"
)
;
std
::
minstd_rand
engine
;
if
(
seed
==
0
)
{
if
(
seed
==
0
)
{
seed
=
std
::
random_device
()();
seed
=
std
::
random_device
()();
}
}
std
::
mt19937
g
(
seed
);
engine
.
seed
(
seed
);
std
::
normal_distribution
<
T
>
dist
ribution
(
mean
,
std
);
std
::
normal_distribution
<
T
>
dist
(
mean
,
std
);
ssize_t
size
=
framework
::
product
(
tensor
->
dims
());
ssize_t
size
=
framework
::
product
(
tensor
->
dims
());
for
(
in
t
i
=
0
;
i
<
size
;
++
i
)
{
for
(
ssize_
t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
dist
ribution
(
g
);
data
[
i
]
=
dist
(
engine
);
}
}
}
}
};
};
...
@@ -48,7 +45,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
...
@@ -48,7 +45,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
context
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
context
)
const
override
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
dims
=
GetAttr
<
std
::
vector
<
int
>>
(
"dims"
);
auto
dims
=
GetAttr
<
std
::
vector
<
int
>>
(
"dims"
);
PADDLE_ENFORCE
(
dims
.
size
()
>
0UL
,
PADDLE_ENFORCE
(
dims
.
size
()
>
0UL
,
"dims can be one int or array. dims must be set."
);
"dims can be one int or array. dims must be set."
);
...
@@ -68,8 +65,8 @@ Use to initialize tensor with gaussian random generator.
...
@@ -68,8 +65,8 @@ Use to initialize tensor with gaussian random generator.
)DOC"
);
)DOC"
);
AddAttr
<
std
::
vector
<
int
>>
(
"dims"
,
"The dimension of random tensor."
);
AddAttr
<
std
::
vector
<
int
>>
(
"dims"
,
"The dimension of random tensor."
);
AddAttr
<
float
>
(
"mean"
,
"mean
value of random
."
).
SetDefault
(
.0
f
);
AddAttr
<
float
>
(
"mean"
,
"mean
of random tensor
."
).
SetDefault
(
.0
f
);
AddAttr
<
float
>
(
"std"
,
"
minimum value of random value
."
).
SetDefault
(
1.0
f
);
AddAttr
<
float
>
(
"std"
,
"
std of random tensor
."
).
SetDefault
(
1.0
f
);
AddAttr
<
int
>
(
"seed"
,
AddAttr
<
int
>
(
"seed"
,
"Random seed of generator."
"Random seed of generator."
"0 means use system wide seed"
)
"0 means use system wide seed"
)
...
@@ -83,4 +80,4 @@ Use to initialize tensor with gaussian random generator.
...
@@ -83,4 +80,4 @@ Use to initialize tensor with gaussian random generator.
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
gaussian_random
,
ops
::
GaussianRandomOp
,
REGISTER_OP_WITHOUT_GRADIENT
(
gaussian_random
,
ops
::
GaussianRandomOp
,
ops
::
GaussianRandomOpMaker
);
ops
::
GaussianRandomOpMaker
);
REGISTER_OP_CPU_KERNEL
(
gaussian_random
,
ops
::
GaussianRandomKernel
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
gaussian_random
,
ops
::
CPU
GaussianRandomKernel
<
float
>
);
\ No newline at end of file
paddle/operators/gaussian_random_op.cu
浏览文件 @
7443b2e4
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <memory>
#include <thrust/device_ptr.h>
#include <random>
#include <thrust/iterator/counting_iterator.h>
#include "paddle/platform/dynload/curand.h"
#include <thrust/random.h>
#include "paddle/platform/gpu_info.h"
#include <thrust/transform.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
template
<
typename
T
>
class
GaussianRandomKernel
:
public
framework
::
OpKernel
{
struct
GaussianGenerator
{
T
mean_
,
std_
;
unsigned
int
seed_
;
__host__
__device__
GaussianGenerator
(
T
mean
,
T
std
,
int
seed
)
:
mean_
(
mean
),
std_
(
std
),
seed_
(
seed
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed_
);
thrust
::
normal_distribution
<
T
>
dist
(
mean_
,
std_
);
rng
.
discard
(
n
);
return
dist
(
rng
);
}
};
template
<
typename
T
>
class
GPUGaussianRandomKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
float
mean
=
context
.
op_
.
GetAttr
<
float
>
(
"mean"
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
float
std
=
context
.
op_
.
GetAttr
<
float
>
(
"std"
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
int
seed
=
context
.
op_
.
GetAttr
<
int
>
(
"seed"
);
static_cast
<
unsigned
int
>
(
context
.
op_
.
GetAttr
<
int
>
(
"seed"
)
);
if
(
seed
==
0
)
{
if
(
seed
==
0
)
{
std
::
random_device
rd
;
std
::
random_device
rd
;
seed
=
rd
();
seed
=
rd
();
}
}
curandGenerator_t
g
;
T
mean
=
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"mean"
))
;
PADDLE_ENFORCE
(
platform
::
dynload
::
curandCreateGenerator
(
T
std
=
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"std"
));
&
g
,
CURAND_RNG_PSEUDO_DEFAULT
)
);
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
PADDLE_ENFORCE
(
ssize_t
N
=
framework
::
product
(
tensor
->
dims
());
platform
::
dynload
::
curandSetPseudoRandomGeneratorSeed
(
g
,
seed
));
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
N
,
platform
::
dynload
::
curandGenerateNormal
(
thrust
::
device_ptr
<
T
>
(
data
),
g
,
data
,
framework
::
product
(
tensor
->
dims
()),
mean
,
std
);
GaussianGenerator
<
T
>
(
mean
,
std
,
seed
)
);
}
}
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
gaussian_random
,
REGISTER_OP_GPU_KERNEL
(
gaussian_random
,
ops
::
GaussianRandomKernel
<
float
>
);
paddle
::
operators
::
GPU
GaussianRandomKernel
<
float
>
);
paddle/operators/mul_op.cc
浏览文件 @
7443b2e4
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
limitations under the License. */
limitations under the License. */
#include "paddle/operators/mul_op.h"
#include "paddle/operators/mul_op.h"
#include "paddle/operators/math/math_function.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
...
paddle/operators/uniform_random_op.cc
浏览文件 @
7443b2e4
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -39,7 +36,8 @@ class CPUUniformRandomKernel : public framework::OpKernel {
...
@@ -39,7 +36,8 @@ class CPUUniformRandomKernel : public framework::OpKernel {
std
::
uniform_real_distribution
<
T
>
dist
(
std
::
uniform_real_distribution
<
T
>
dist
(
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"min"
)),
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"min"
)),
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"max"
)));
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"max"
)));
for
(
ssize_t
i
=
0
;
i
<
framework
::
product
(
tensor
->
dims
());
++
i
)
{
ssize_t
size
=
framework
::
product
(
tensor
->
dims
());
for
(
ssize_t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
dist
(
engine
);
data
[
i
]
=
dist
(
engine
);
}
}
}
}
...
@@ -66,7 +64,6 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -66,7 +64,6 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddOutput
(
"Out"
,
"The output tensor of uniform random op"
);
AddOutput
(
"Out"
,
"The output tensor of uniform random op"
);
AddComment
(
R"DOC(Uniform random operator.
AddComment
(
R"DOC(Uniform random operator.
Used to initialize tensor with uniform random generator.
Used to initialize tensor with uniform random generator.
)DOC"
);
)DOC"
);
AddAttr
<
std
::
vector
<
int
>>
(
"dims"
,
"the dimension of random tensor"
);
AddAttr
<
std
::
vector
<
int
>>
(
"dims"
,
"the dimension of random tensor"
);
...
@@ -84,4 +81,4 @@ Used to initialize tensor with uniform random generator.
...
@@ -84,4 +81,4 @@ Used to initialize tensor with uniform random generator.
REGISTER_OP_WITHOUT_GRADIENT
(
uniform_random
,
paddle
::
operators
::
UniformRandomOp
,
REGISTER_OP_WITHOUT_GRADIENT
(
uniform_random
,
paddle
::
operators
::
UniformRandomOp
,
paddle
::
operators
::
UniformRandomOpMaker
);
paddle
::
operators
::
UniformRandomOpMaker
);
REGISTER_OP_CPU_KERNEL
(
uniform_random
,
REGISTER_OP_CPU_KERNEL
(
uniform_random
,
paddle
::
operators
::
CPUUniformRandomKernel
<
float
>
);
paddle
::
operators
::
CPUUniformRandomKernel
<
float
>
);
\ No newline at end of file
paddle/operators/uniform_random_op.cu
浏览文件 @
7443b2e4
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -68,4 +65,4 @@ class GPUUniformRandomKernel : public framework::OpKernel {
...
@@ -68,4 +65,4 @@ class GPUUniformRandomKernel : public framework::OpKernel {
}
// namespace paddle
}
// namespace paddle
REGISTER_OP_GPU_KERNEL
(
uniform_random
,
REGISTER_OP_GPU_KERNEL
(
uniform_random
,
paddle
::
operators
::
GPUUniformRandomKernel
<
float
>
);
paddle
::
operators
::
GPUUniformRandomKernel
<
float
>
);
\ No newline at end of file
paddle/platform/device_context.cc
浏览文件 @
7443b2e4
...
@@ -114,9 +114,6 @@ CUDADeviceContext::~CUDADeviceContext() {
...
@@ -114,9 +114,6 @@ CUDADeviceContext::~CUDADeviceContext() {
PADDLE_ENFORCE
(
dynload
::
cudnnDestroy
(
cudnn_handle_
));
PADDLE_ENFORCE
(
dynload
::
cudnnDestroy
(
cudnn_handle_
));
}
}
if
(
curand_generator_
)
{
PADDLE_ENFORCE
(
dynload
::
curandDestroyGenerator
(
curand_generator_
));
}
eigen_stream_
.
reset
();
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
...
@@ -152,19 +149,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() {
...
@@ -152,19 +149,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() {
cudaStream_t
CUDADeviceContext
::
stream
()
{
return
stream_
;
}
cudaStream_t
CUDADeviceContext
::
stream
()
{
return
stream_
;
}
curandGenerator_t
CUDADeviceContext
::
curand_generator
()
{
if
(
!
curand_generator_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
curandCreateGenerator
(
&
curand_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
));
PADDLE_ENFORCE
(
dynload
::
curandSetPseudoRandomGeneratorSeed
(
curand_generator_
,
seed_
));
PADDLE_ENFORCE
(
dynload
::
curandSetStream
(
curand_generator_
,
stream_
));
}
return
curand_generator_
;
}
#endif // PADDLE_ONLY_CPU
#endif // PADDLE_ONLY_CPU
}
// namespace platform
}
// namespace platform
...
...
paddle/platform/device_context.h
浏览文件 @
7443b2e4
...
@@ -17,7 +17,6 @@ limitations under the License. */
...
@@ -17,7 +17,6 @@ limitations under the License. */
#ifndef PADDLE_ONLY_CPU
#ifndef PADDLE_ONLY_CPU
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/gpu_info.h"
#include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU
#define EIGEN_USE_GPU
#endif
#endif
...
@@ -40,7 +39,7 @@ class DeviceContext {
...
@@ -40,7 +39,7 @@ class DeviceContext {
class
CPUDeviceContext
:
public
DeviceContext
{
class
CPUDeviceContext
:
public
DeviceContext
{
public:
public:
CPUDeviceContext
();
CPUDeviceContext
();
explicit
CPUDeviceContext
(
CPUPlace
);
explicit
CPUDeviceContext
(
CPUPlace
place
);
virtual
~
CPUDeviceContext
()
{}
virtual
~
CPUDeviceContext
()
{}
Eigen
::
DefaultDevice
*
eigen_device
()
const
;
Eigen
::
DefaultDevice
*
eigen_device
()
const
;
...
@@ -56,7 +55,7 @@ class EigenCudaStreamDevice;
...
@@ -56,7 +55,7 @@ class EigenCudaStreamDevice;
class
CUDADeviceContext
:
public
DeviceContext
{
class
CUDADeviceContext
:
public
DeviceContext
{
public:
public:
explicit
CUDADeviceContext
(
GPUPlace
);
explicit
CUDADeviceContext
(
GPUPlace
place
);
virtual
~
CUDADeviceContext
();
virtual
~
CUDADeviceContext
();
/*! \brief Wait for all operations completion in the stream. */
/*! \brief Wait for all operations completion in the stream. */
...
@@ -75,9 +74,6 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -75,9 +74,6 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
();
cudnnHandle_t
cudnn_handle
();
/*! \brief Return curand handle in the device context. */
curandGenerator_t
curand_generator
();
/*! \brief Return cuda stream in the device context. */
/*! \brief Return cuda stream in the device context. */
cudaStream_t
stream
();
cudaStream_t
stream
();
// clang-format on
// clang-format on
...
@@ -85,18 +81,13 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -85,18 +81,13 @@ class CUDADeviceContext : public DeviceContext {
private:
private:
GPUPlace
place_
;
GPUPlace
place_
;
private:
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
private:
uint64_t
seed_
;
// clang-format off
// clang-format off
cudaStream_t
stream_
{
nullptr
};
cudaStream_t
stream_
{
nullptr
};
cudnnHandle_t
cudnn_handle_
{
nullptr
};
cudnnHandle_t
cudnn_handle_
{
nullptr
};
cublasHandle_t
cublas_handle_
{
nullptr
};
cublasHandle_t
cublas_handle_
{
nullptr
};
curandGenerator_t
curand_generator_
{
nullptr
};
// clang-format on
// clang-format on
};
};
...
...
paddle/platform/device_context_test.cc
浏览文件 @
7443b2e4
...
@@ -43,8 +43,6 @@ TEST(Device, CUDADeviceContext) {
...
@@ -43,8 +43,6 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE
(
nullptr
,
cudnn_handle
);
ASSERT_NE
(
nullptr
,
cudnn_handle
);
cublasHandle_t
cublas_handle
=
device_context
->
cublas_handle
();
cublasHandle_t
cublas_handle
=
device_context
->
cublas_handle
();
ASSERT_NE
(
nullptr
,
cublas_handle
);
ASSERT_NE
(
nullptr
,
cublas_handle
);
curandGenerator_t
curand_handle
=
device_context
->
curand_generator
();
ASSERT_NE
(
nullptr
,
curand_handle
);
ASSERT_NE
(
nullptr
,
device_context
->
stream
());
ASSERT_NE
(
nullptr
,
device_context
->
stream
());
delete
device_context
;
delete
device_context
;
}
}
...
...
python/paddle/v2/framework/tests/CMakeLists.txt
浏览文件 @
7443b2e4
...
@@ -22,7 +22,7 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py)
...
@@ -22,7 +22,7 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py)
py_test
(
test_default_scope_funcs SRCS test_default_scope_funcs.py
)
py_test
(
test_default_scope_funcs SRCS test_default_scope_funcs.py
)
py_test
(
test_operator SRCS test_operator.py
)
py_test
(
test_operator SRCS test_operator.py
)
#
py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py)
py_test
(
test_gaussian_random_op SRCS test_gaussian_random_op.py
)
py_test
(
test_uniform_random_op SRCS test_uniform_random_op.py
)
py_test
(
test_uniform_random_op SRCS test_uniform_random_op.py
)
py_test
(
test_recurrent_op SRCS test_recurrent_op.py
)
py_test
(
test_recurrent_op SRCS test_recurrent_op.py
)
py_test
(
test_sgd_op SRCS test_sgd_op.py
)
py_test
(
test_sgd_op SRCS test_sgd_op.py
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录