Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
36e8e725
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
36e8e725
编写于
8月 22, 2017
作者:
Q
qijun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
expose random seed to users
上级
b054392e
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
196 addition
and
228 deletion
+196
-228
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+2
-2
paddle/operators/gaussian_random_op.cc
paddle/operators/gaussian_random_op.cc
+33
-9
paddle/operators/gaussian_random_op.cu
paddle/operators/gaussian_random_op.cu
+53
-8
paddle/operators/gaussian_random_op.h
paddle/operators/gaussian_random_op.h
+0
-38
paddle/operators/math/math_function.cc
paddle/operators/math/math_function.cc
+0
-22
paddle/operators/math/math_function.cu
paddle/operators/math/math_function.cu
+0
-48
paddle/operators/math/math_function.h
paddle/operators/math/math_function.h
+0
-8
paddle/operators/uniform_random_op.cc
paddle/operators/uniform_random_op.cc
+36
-8
paddle/operators/uniform_random_op.cu
paddle/operators/uniform_random_op.cu
+56
-7
paddle/operators/uniform_random_op.h
paddle/operators/uniform_random_op.h
+0
-38
paddle/platform/device_context.cc
paddle/platform/device_context.cc
+2
-25
paddle/platform/device_context.h
paddle/platform/device_context.h
+2
-13
python/paddle/v2/framework/tests/test_gaussian_random_op.py
python/paddle/v2/framework/tests/test_gaussian_random_op.py
+6
-1
python/paddle/v2/framework/tests/test_uniform_random_op.py
python/paddle/v2/framework/tests/test_uniform_random_op.py
+6
-1
未找到文件。
paddle/operators/CMakeLists.txt
浏览文件 @
36e8e725
...
...
@@ -58,7 +58,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
op_library
(
sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu
)
op_library
(
softmax_op SRCS softmax_op.cc softmax_op.cu
)
op_library
(
gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu
DEPS math_function
)
op_library
(
gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu
)
op_library
(
cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu
)
op_library
(
fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu
)
...
...
@@ -67,4 +67,4 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library
(
recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor op_registry operator net_op
)
op_library
(
uniform_random_op
SRCS uniform_random_op.cc uniform_random_op.cu
DEPS math_function
)
SRCS uniform_random_op.cc uniform_random_op.cu
)
paddle/operators/gaussian_random_op.cc
浏览文件 @
36e8e725
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/gaussian_random_op.h"
#include <random>
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
CPUGaussianRandomKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
float
mean
=
context
.
op_
.
GetAttr
<
float
>
(
"mean"
);
float
std
=
context
.
op_
.
GetAttr
<
float
>
(
"std"
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
op_
.
GetAttr
<
int
>
(
"seed"
));
std
::
minstd_rand
engine
;
if
(
seed
==
0
)
{
seed
=
std
::
random_device
()();
}
engine
.
seed
(
seed
);
std
::
normal_distribution
<
T
>
dist
(
mean
,
std
);
ssize_t
size
=
framework
::
product
(
tensor
->
dims
());
for
(
ssize_t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
dist
(
engine
);
}
}
};
class
GaussianRandomOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
...
...
@@ -43,8 +65,12 @@ Use to initialize tensor with gaussian random generator.
)DOC"
);
AddAttr
<
std
::
vector
<
int
>>
(
"dims"
,
"The dimension of random tensor."
);
AddAttr
<
float
>
(
"mean"
,
"mean value of random."
).
SetDefault
(
.0
f
);
AddAttr
<
float
>
(
"std"
,
"minimum value of random value."
).
SetDefault
(
1.0
f
);
AddAttr
<
float
>
(
"mean"
,
"mean of random tensor."
).
SetDefault
(
.0
f
);
AddAttr
<
float
>
(
"std"
,
"std of random tensor."
).
SetDefault
(
1.0
f
);
AddAttr
<
int
>
(
"seed"
,
"Random seed of generator."
"0 means use system wide seed"
)
.
SetDefault
(
0
);
}
};
...
...
@@ -54,6 +80,4 @@ Use to initialize tensor with gaussian random generator.
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
gaussian_random
,
ops
::
GaussianRandomOp
,
ops
::
GaussianRandomOpMaker
);
REGISTER_OP_CPU_KERNEL
(
gaussian_random
,
ops
::
GaussianRandomKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
gaussian_random
,
ops
::
CPUGaussianRandomKernel
<
float
>
);
\ No newline at end of file
paddle/operators/gaussian_random_op.cu
浏览文件 @
36e8e725
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/gaussian_random_op.h"
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
GaussianGenerator
{
T
mean_
,
std_
;
unsigned
int
seed_
;
__host__
__device__
GaussianGenerator
(
T
mean
,
T
std
,
int
seed
)
:
mean_
(
mean
),
std_
(
std
),
seed_
(
seed
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed_
);
thrust
::
normal_distribution
<
T
>
dist
(
min_
,
max_
);
rng
.
discard
(
n
);
return
dist
(
rng
);
}
};
template
<
typename
T
>
class
GPUGaussianRandomKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
op_
.
GetAttr
<
int
>
(
"seed"
));
if
(
seed
==
0
)
{
std
::
random_device
rd
;
seed
=
rd
();
}
T
mean
=
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"mean"
));
T
std
=
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"std"
));
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
ssize_t
N
=
framework
::
product
(
tensor
->
dims
());
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
N
,
thrust
::
device_ptr
<
T
>
(
data
),
GaussianGenerator
<
T
>
(
mean
,
std
,
seed
));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
gaussian_random
,
ops
::
GaussianRandomKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
gaussian_random
,
paddle
::
operators
::
GPUGaussianRandomKernel
<
float
>
);
\ No newline at end of file
paddle/operators/gaussian_random_op.h
已删除
100644 → 0
浏览文件 @
b054392e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
Place
,
typename
T
>
class
GaussianRandomKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
mean
=
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"mean"
));
T
std
=
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"std"
));
auto
n
=
framework
::
product
(
tensor
->
dims
());
auto
*
device_context
=
const_cast
<
platform
::
DeviceContext
*>
(
context
.
device_context_
);
math
::
RandGaussian
<
Place
,
T
>
(
n
,
mean
,
std
,
data
,
device_context
);
}
};
}
// namespace operators
}
// namespace paddle
paddle/operators/math/math_function.cc
浏览文件 @
36e8e725
...
...
@@ -118,28 +118,6 @@ void Set<platform::CPUPlace, float>(const int n, const float alpha,
out
.
device
(
*
(
cpu_context
->
eigen_device
()))
=
out
.
constant
(
float
(
alpha
));
}
template
<
>
void
RandUniform
<
platform
::
CPUPlace
,
float
>
(
const
int
n
,
const
float
min
,
const
float
max
,
float
*
output
,
platform
::
DeviceContext
*
context
)
{
auto
*
cpu_context
=
reinterpret_cast
<
platform
::
CPUDeviceContext
*>
(
context
);
std
::
uniform_real_distribution
<
float
>
distribution
(
min
,
max
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
output
[
i
]
=
distribution
(
cpu_context
->
rand_engine
());
}
}
template
<
>
void
RandGaussian
<
platform
::
CPUPlace
,
float
>
(
const
int
n
,
const
float
mean
,
const
float
std
,
float
*
output
,
platform
::
DeviceContext
*
context
)
{
auto
*
cpu_context
=
reinterpret_cast
<
platform
::
CPUDeviceContext
*>
(
context
);
std
::
normal_distribution
<
float
>
distribution
(
mean
,
std
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
output
[
i
]
=
distribution
(
cpu_context
->
rand_engine
());
}
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/math_function.cu
浏览文件 @
36e8e725
...
...
@@ -135,54 +135,6 @@ void Set<platform::GPUPlace, float>(const int n, const float alpha,
out
.
device
(
*
(
cuda_context
->
eigen_device
()))
=
out
.
constant
(
float
(
alpha
));
}
template
<
typename
T
>
__global__
void
UniformShift
(
const
int
n
,
const
T
min
,
const
T
max
,
T
*
x
)
{
float
scale
=
max
-
min
;
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
x
[
i
]
=
x
[
i
]
*
scale
+
min
;
}
}
template
<
>
void
RandUniform
<
platform
::
GPUPlace
,
float
>
(
const
int
n
,
const
float
min
,
const
float
max
,
float
*
output
,
platform
::
DeviceContext
*
context
)
{
auto
*
cuda_context
=
reinterpret_cast
<
platform
::
CUDADeviceContext
*>
(
context
);
PADDLE_ENFORCE
(
platform
::
dynload
::
curandGenerateUniform
(
cuda_context
->
curand_generator
(),
output
,
n
));
int
block
=
512
;
int
grid
=
(
n
+
block
-
1
)
/
block
;
UniformShift
<
float
><<<
grid
,
block
,
0
,
cuda_context
->
stream
()
>>>
(
n
,
min
,
max
,
output
);
}
template
<
typename
T
>
int
HandleOddLengthRandGaussian
(
const
int
n
,
const
T
mean
,
const
T
std
,
T
*
output
,
platform
::
CUDADeviceContext
*
context
)
{
if
(
n
%
2
==
1
)
{
std
::
default_random_engine
generator
;
std
::
normal_distribution
<
T
>
distribution
(
mean
,
std
);
const
T
random_value
=
distribution
(
generator
);
Set
<
platform
::
GPUPlace
,
T
>
(
1
,
random_value
,
output
+
(
n
-
1
),
context
);
return
n
-
1
;
}
return
n
;
}
template
<
>
void
RandGaussian
<
platform
::
GPUPlace
,
float
>
(
const
int
n
,
const
float
mean
,
const
float
std
,
float
*
output
,
platform
::
DeviceContext
*
context
)
{
auto
*
cuda_context
=
reinterpret_cast
<
platform
::
CUDADeviceContext
*>
(
context
);
const
int
even_n
=
HandleOddLengthRandGaussian
<
float
>
(
n
,
mean
,
std
,
output
,
cuda_context
);
PADDLE_ENFORCE
(
platform
::
dynload
::
curandGenerateNormal
(
cuda_context
->
curand_generator
(),
output
,
even_n
,
mean
,
std
));
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/math_function.h
浏览文件 @
36e8e725
...
...
@@ -82,14 +82,6 @@ template <typename Place, typename T>
void
Set
(
const
int
n
,
const
T
alpha
,
T
*
output
,
platform
::
DeviceContext
*
context
);
template
<
typename
Place
,
typename
T
>
void
RandUniform
(
const
int
n
,
const
T
min
,
const
T
max
,
T
*
output
,
platform
::
DeviceContext
*
context
);
template
<
typename
Place
,
typename
T
>
void
RandGaussian
(
const
int
n
,
const
T
mean
,
const
T
std
,
T
*
output
,
platform
::
DeviceContext
*
context
);
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/uniform_random_op.cc
浏览文件 @
36e8e725
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/uniform_random_op.h"
#include <random>
#include <type_traits>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
operators
{
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template
<
typename
T
>
class
CPUUniformRandomKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
op_
.
GetAttr
<
int
>
(
"seed"
));
std
::
minstd_rand
engine
;
if
(
seed
==
0
)
{
seed
=
std
::
random_device
()();
}
engine
.
seed
(
seed
);
std
::
uniform_real_distribution
<
T
>
dist
(
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"min"
)),
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"max"
)));
ssize_t
size
=
framework
::
product
(
tensor
->
dims
());
for
(
ssize_t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
dist
(
engine
);
}
}
};
class
UniformRandomOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
...
...
@@ -38,12 +64,15 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddOutput
(
"Out"
,
"The output tensor of uniform random op"
);
AddComment
(
R"DOC(Uniform random operator.
Used to initialize tensor with uniform random generator.
)DOC"
);
AddAttr
<
std
::
vector
<
int
>>
(
"dims"
,
"the dimension of random tensor"
);
AddAttr
<
float
>
(
"min"
,
"Minimum value of uniform random"
).
SetDefault
(
-
1.0
f
);
AddAttr
<
float
>
(
"max"
,
"Maximun value of uniform random"
).
SetDefault
(
1.0
f
);
AddAttr
<
int
>
(
"seed"
,
"Random seed of uniform random. "
"0 means generate a seed by system"
)
.
SetDefault
(
0
);
}
};
}
// namespace operators
...
...
@@ -51,6 +80,5 @@ Used to initialize tensor with uniform random generator.
REGISTER_OP_WITHOUT_GRADIENT
(
uniform_random
,
paddle
::
operators
::
UniformRandomOp
,
paddle
::
operators
::
UniformRandomOpMaker
);
REGISTER_OP_CPU_KERNEL
(
uniform_random
,
paddle
::
operators
::
UniformRandomKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
uniform_random
,
paddle
::
operators
::
CPUUniformRandomKernel
<
float
>
);
\ No newline at end of file
paddle/operators/uniform_random_op.cu
浏览文件 @
36e8e725
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/uniform_random_op.h"
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
UniformGenerator
{
T
min_
,
max_
;
unsigned
int
seed_
;
__host__
__device__
UniformGenerator
(
T
min
,
T
max
,
int
seed
)
:
min_
(
min
),
max_
(
max
),
seed_
(
seed
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed_
);
thrust
::
uniform_real_distribution
<
T
>
dist
(
min_
,
max_
);
rng
.
discard
(
n
);
return
dist
(
rng
);
}
};
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template
<
typename
T
>
class
GPUUniformRandomKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
op_
.
GetAttr
<
int
>
(
"seed"
));
if
(
seed
==
0
)
{
std
::
random_device
rd
;
seed
=
rd
();
}
T
min
=
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"min"
));
T
max
=
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"max"
));
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
ssize_t
N
=
framework
::
product
(
tensor
->
dims
());
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
N
,
thrust
::
device_ptr
<
T
>
(
data
),
UniformGenerator
<
T
>
(
min
,
max
,
seed
));
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP_GPU_KERNEL
(
uniform_random
,
paddle
::
operators
::
UniformRandomKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
uniform_random
,
paddle
::
operators
::
GPUUniformRandomKernel
<
float
>
);
\ No newline at end of file
paddle/operators/uniform_random_op.h
已删除
100644 → 0
浏览文件 @
b054392e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
Place
,
typename
T
>
class
UniformRandomKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
min
=
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"min"
));
T
max
=
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"max"
));
auto
n
=
framework
::
product
(
tensor
->
dims
());
auto
*
device_context
=
const_cast
<
platform
::
DeviceContext
*>
(
context
.
device_context_
);
math
::
RandUniform
<
Place
,
T
>
(
n
,
min
,
max
,
data
,
device_context
);
}
};
}
// namespace operators
}
// namespace paddle
paddle/platform/device_context.cc
浏览文件 @
36e8e725
...
...
@@ -25,17 +25,8 @@ CPUDeviceContext::CPUDeviceContext() {
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
());
}
CPUDeviceContext
::
CPUDeviceContext
(
CPUPlace
place
,
int
seed
)
{
CPUDeviceContext
::
CPUDeviceContext
(
CPUPlace
place
)
{
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
());
rand_seed_
=
seed
;
}
std
::
minstd_rand
&
CPUDeviceContext
::
rand_engine
()
{
if
(
!
rand_engine_
)
{
rand_engine_
.
reset
(
new
std
::
minstd_rand
());
rand_engine_
->
seed
(
rand_seed_
);
}
return
*
(
rand_engine_
.
get
());
}
Eigen
::
DefaultDevice
*
CPUDeviceContext
::
eigen_device
()
const
{
...
...
@@ -104,8 +95,7 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
return
reinterpret_cast
<
const
CUDADeviceContext
*>
(
this
)
->
eigen_device
();
}
CUDADeviceContext
::
CUDADeviceContext
(
GPUPlace
place
,
uint64_t
seed
)
:
place_
(
place
),
rand_seed_
(
seed
)
{
CUDADeviceContext
::
CUDADeviceContext
(
GPUPlace
place
)
:
place_
(
place
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
));
eigen_stream_
.
reset
(
new
EigenCudaStreamDevice
());
...
...
@@ -157,19 +147,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() {
return
cudnn_handle_
;
}
curandGenerator_t
CUDADeviceContext
::
curand_generator
()
{
if
(
!
curand_generator_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
curandCreateGenerator
(
&
curand_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
));
PADDLE_ENFORCE
(
dynload
::
curandSetPseudoRandomGeneratorSeed
(
curand_generator_
,
rand_seed_
));
PADDLE_ENFORCE
(
dynload
::
curandSetStream
(
curand_generator_
,
stream_
));
}
return
curand_generator_
;
}
cudaStream_t
CUDADeviceContext
::
stream
()
{
return
stream_
;
}
#endif // PADDLE_ONLY_CPU
...
...
paddle/platform/device_context.h
浏览文件 @
36e8e725
...
...
@@ -17,7 +17,6 @@ limitations under the License. */
#ifndef PADDLE_ONLY_CPU
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU
#endif
...
...
@@ -40,18 +39,14 @@ class DeviceContext {
class
CPUDeviceContext
:
public
DeviceContext
{
public:
CPUDeviceContext
();
explicit
CPUDeviceContext
(
CPUPlace
place
,
int
seed
=
0
);
explicit
CPUDeviceContext
(
CPUPlace
place
);
virtual
~
CPUDeviceContext
()
{}
Eigen
::
DefaultDevice
*
eigen_device
()
const
;
std
::
minstd_rand
&
rand_engine
();
Place
GetPlace
()
const
override
;
private:
int
rand_seed_
;
std
::
unique_ptr
<
std
::
minstd_rand
>
rand_engine_
;
std
::
unique_ptr
<
Eigen
::
DefaultDevice
>
eigen_device_
;
};
...
...
@@ -60,7 +55,7 @@ class EigenCudaStreamDevice;
class
CUDADeviceContext
:
public
DeviceContext
{
public:
explicit
CUDADeviceContext
(
GPUPlace
place
,
uint64_t
seed
=
0
);
explicit
CUDADeviceContext
(
GPUPlace
place
);
virtual
~
CUDADeviceContext
();
/*! \brief Wait for all operations completion in the stream. */
...
...
@@ -79,9 +74,6 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
();
/*! \brief Return curand handle in the device context. */
curandGenerator_t
curand_generator
();
/*! \brief Return cuda stream in the device context. */
cudaStream_t
stream
();
// clang-format on
...
...
@@ -92,13 +84,10 @@ class CUDADeviceContext : public DeviceContext {
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
uint64_t
rand_seed_
;
// clang-format off
cudaStream_t
stream_
{
nullptr
};
cudnnHandle_t
cudnn_handle_
{
nullptr
};
cublasHandle_t
cublas_handle_
{
nullptr
};
curandGenerator_t
curand_generator_
{
nullptr
};
// clang-format on
};
...
...
python/paddle/v2/framework/tests/test_gaussian_random_op.py
浏览文件 @
36e8e725
...
...
@@ -17,7 +17,12 @@ class GaussianRandomTest(unittest.TestCase):
scope
.
new_var
(
"Out"
).
get_tensor
()
op
=
Operator
(
"gaussian_random"
,
Out
=
"Out"
,
dims
=
[
1000
,
784
],
mean
=
.
0
,
std
=
1.
)
"gaussian_random"
,
Out
=
"Out"
,
dims
=
[
1000
,
784
],
mean
=
.
0
,
std
=
1.
,
seed
=
10
)
op
.
infer_shape
(
scope
)
context
=
core
.
DeviceContext
.
create
(
place
)
...
...
python/paddle/v2/framework/tests/test_uniform_random_op.py
浏览文件 @
36e8e725
...
...
@@ -17,7 +17,12 @@ class UniformRandomTest(unittest.TestCase):
scope
.
new_var
(
"X"
).
get_tensor
()
op
=
Operator
(
"uniform_random"
,
Out
=
"X"
,
dims
=
[
1000
,
784
],
min
=-
5.0
,
max
=
10.0
)
"uniform_random"
,
Out
=
"X"
,
dims
=
[
1000
,
784
],
min
=-
5.0
,
max
=
10.0
,
seed
=
10
)
op
.
infer_shape
(
scope
)
ctx
=
core
.
DeviceContext
.
create
(
place
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录