Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
57213340
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
57213340
编写于
7月 30, 2017
作者:
D
dongzhihong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"update the compute kernel"
上级
a22567eb
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
77 addition
and
79 deletion
+77
-79
paddle/framework/operator.h
paddle/framework/operator.h
+4
-4
paddle/operators/random_op.cc
paddle/operators/random_op.cc
+35
-12
paddle/operators/random_op.cu
paddle/operators/random_op.cu
+24
-1
paddle/operators/random_op.h
paddle/operators/random_op.h
+2
-55
paddle/platform/device_context.h
paddle/platform/device_context.h
+12
-7
未找到文件。
paddle/framework/operator.h
浏览文件 @
57213340
...
...
@@ -88,7 +88,7 @@ class OperatorBase {
/// Net will call this function to Run an op.
virtual
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
platform
::
DeviceContext
&
dev_ctx
)
const
=
0
;
const
platform
::
DeviceContext
&
dev_ctx
)
const
=
0
;
// Get a input with argument's name described in `op_proto`
const
std
::
string
&
Input
(
const
std
::
string
&
name
)
const
;
...
...
@@ -113,7 +113,7 @@ class OperatorBase {
class
KernelContext
{
public:
KernelContext
(
const
OperatorBase
*
op
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
platform
::
DeviceContext
&
device_context
)
const
platform
::
DeviceContext
&
device_context
)
:
op_
(
*
op
),
scope_
(
scope
),
device_context_
(
&
device_context
)
{}
const
Variable
*
Input
(
int
index
)
const
{
...
...
@@ -159,7 +159,7 @@ class KernelContext {
const
OperatorBase
&
op_
;
const
std
::
shared_ptr
<
Scope
>
scope_
;
platform
::
DeviceContext
*
device_context_
;
const
platform
::
DeviceContext
*
device_context_
;
};
class
OpKernel
{
...
...
@@ -213,7 +213,7 @@ class OperatorWithKernel : public OperatorBase {
std
::
unordered_map
<
OpKernelKey
,
std
::
unique_ptr
<
OpKernel
>
,
OpKernelHash
>
;
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
platform
::
DeviceContext
&
dev_ctx
)
const
final
{
const
platform
::
DeviceContext
&
dev_ctx
)
const
final
{
auto
&
opKernel
=
AllOpKernels
().
at
(
type_
).
at
(
OpKernelKey
(
dev_ctx
));
opKernel
->
Compute
(
KernelContext
(
this
,
scope
,
dev_ctx
));
}
...
...
paddle/operators/random_op.cc
浏览文件 @
57213340
...
...
@@ -19,7 +19,28 @@
namespace
paddle
{
namespace
operators
{
class
RandomOp
:
public
framework
::
OperatorWithKernel
{
template
<
typename
T
>
class
GaussianRandomOpKernel
<
platform
::
CPUPlace
,
T
>
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
KernelContext
&
context
)
const
override
{
auto
mean
=
context
.
op_
.
GetAttr
<
T
>
(
"mean"
);
auto
std
=
context
.
op_
.
GetAttr
<
T
>
(
"std"
);
// auto seed = context.op_.GetAttr<T>("seed");
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
framework
::
Tensor
>
();
T
*
r
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
ctx
=
static_cast
<
const
platform
::
CPUDeviceContext
*>
(
context
.
device_context_
);
// generator need to modify context
auto
g
=
const_cast
<
platform
::
CPUDeviceContext
*>
(
ctx
)
->
RandGenerator
();
std
::
normal_distribution
<
T
>
distribution
(
mean
,
std
);
for
(
int
i
=
0
;
i
<
framework
::
product
(
output
->
dims
());
++
i
)
{
r
[
i
]
=
distribution
(
g
);
}
}
};
class
GaussianRandomOp
:
public
framework
::
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
framework
::
Tensor
*>&
inputs
,
...
...
@@ -33,20 +54,21 @@ protected:
}
};
class
RandomOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
Gaussian
RandomOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
RandomOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
GaussianRandomOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddAttr
<
std
::
vector
<
int
>>
(
"shape"
,
"The shape of matrix to be randomized"
);
AddAttr
<
float
>
(
"seed"
,
"random seed generator."
).
SetDefault
(
1337
);
//
AddAttr<float>("seed", "random seed generator.").SetDefault(1337);
AddAttr
<
float
>
(
"mean"
,
"mean value of random."
).
SetDefault
(
.0
);
AddAttr
<
float
>
(
"std"
,
"minimum value of random value"
)
.
SetDefault
(
1.0
)
.
LargerThan
(
.0
);
AddOutput
(
"Out"
,
"output matrix of random op"
);
AddComment
(
R"DOC(
Random Operator fill a matrix in normal distribution.
The eqution : Out = Random(Shape=(d0, d1, ...), Dtype, mean, std)
Gaussian
Random Operator fill a matrix in normal distribution.
The eqution : Out =
Gaussian
Random(Shape=(d0, d1, ...), Dtype, mean, std)
)DOC"
);
}
};
...
...
@@ -54,10 +76,11 @@ The eqution : Out = Random(Shape=(d0, d1, ...), Dtype, mean, std)
}
// namespace operators
}
// namespace paddle
REGISTER_OP
(
random
,
paddle
::
operators
::
RandomOp
,
paddle
::
operators
::
RandomOpMaker
);
REGISTER_OP
(
gaussian_
random
,
paddle
::
operators
::
Gaussian
RandomOp
,
paddle
::
operators
::
Gaussian
RandomOpMaker
);
typedef
paddle
::
operators
::
RandomOpKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
RandomOpKernel_CPU_float
;
REGISTER_OP_CPU_KERNEL
(
random
,
RandomOpKernel_CPU_float
);
typedef
paddle
::
operators
::
GaussianRandomOpKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
GaussianRandomOpKernel_CPU_float
;
REGISTER_OP_CPU_KERNEL
(
gaussian_random
,
GaussianRandomOpKernel_CPU_float
);
paddle/operators/random_op.cu
浏览文件 @
57213340
#include "paddle/operators/random_op.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
GaussianRandomOpKernel
<
platform
::
GPUPlace
,
T
>
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
KernelContext
&
context
)
const
override
{
auto
mean
=
context
.
op_
.
GetAttr
<
T
>
(
"mean"
);
auto
std
=
context
.
op_
.
GetAttr
<
T
>
(
"std"
);
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
framework
::
Tensor
>
();
T
*
r
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
ctx
=
static_cast
<
const
platform
::
GPUDeviceContext
*>
(
context
.
device_context_
);
// generator need to modify context
auto
g
=
const_cast
<
platform
::
GPUDeviceContext
*>
(
ctx
)
->
RandGenerator
();
curandGenerateNormal
(
g
,
r
,
framework
::
product
(
output
->
dims
()),
mean
,
std
);
typedef
paddle
::
operators
::
RandomOpKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
}
};
}
// namespace operators
}
// namespace paddle
typedef
paddle
::
operators
::
GaussianRandomOpKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
RandomOpKernel_GPU_float
;
REGISTER_OP_GPU_KERNEL
(
random
,
RandomOpKernel_GPU_float
);
\ No newline at end of file
paddle/operators/random_op.h
浏览文件 @
57213340
...
...
@@ -7,63 +7,10 @@
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
bool
Gaussian
(
platform
::
CPUDeviceContext
*
ctx
,
T
*
output
,
const
int
size
,
const
T
&
mean
,
const
T
&
std
,
const
T
&
seed
)
{
auto
g
=
ctx
->
RandGenerator
(
seed
);
std
::
normal_distribution
<
T
>
distribution
(
mean
,
std
);
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
output
[
i
]
=
distribution
(
g
);
}
return
true
;
}
#ifndef PADDLE_ONLY_CPU
template
<
typename
T
>
bool
Gaussian
(
platform
::
CUDADeviceContext
*
ctx
,
T
*
output
,
const
int
size
,
const
T
&
mean
,
const
T
&
std
,
const
T
&
seed
)
{
auto
g
=
ctx
->
RandGenerator
(
seed
);
return
curandGenerateNormal
(
g
,
output
,
size
,
mean
,
std
);
}
#endif
template
<
typename
Place
,
typename
T
>
class
RandomOpKernel
:
public
framework
::
OpKernel
{
class
Gaussian
RandomOpKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
KernelContext
&
context
)
const
override
{
auto
mean
=
context
.
op_
.
GetAttr
<
T
>
(
"mean"
);
auto
std
=
context
.
op_
.
GetAttr
<
T
>
(
"std"
);
auto
seed
=
context
.
op_
.
GetAttr
<
T
>
(
"seed"
);
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
framework
::
Tensor
>
();
auto
place
=
context
.
GetPlace
();
if
(
platform
::
is_cpu_place
(
place
))
{
Gaussian
(
dynamic_cast
<
platform
::
CPUDeviceContext
*>
(
context
.
device_context_
),
output
->
mutable_data
<
T
>
(
context
.
GetPlace
()),
framework
::
product
(
output
->
dims
()),
mean
,
std
,
seed
);
}
else
{
#ifndef PADDLE_ONLY_CPU
Gaussian
(
dynamic_cast
<
platform
::
CUDADeviceContext
*>
(
context
.
device_context_
),
output
->
mutable_data
<
T
>
(
context
.
GetPlace
()),
framework
::
product
(
output
->
dims
()),
mean
,
std
,
seed
);
#endif
}
}
void
Compute
(
const
framework
::
KernelContext
&
context
)
const
override
{}
};
}
// namespace operators
...
...
paddle/platform/device_context.h
浏览文件 @
57213340
...
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU
#endif
#include <chrono>
#include <memory>
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
...
...
@@ -40,7 +41,10 @@ class DeviceContext {
class
CPUDeviceContext
:
public
DeviceContext
{
public:
typedef
std
::
mt19937
random_generator_type
;
CPUDeviceContext
()
{
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
());
}
CPUDeviceContext
()
{
random_seed_
=
std
::
chrono
::
system_clock
::
now
().
time_since_epoch
().
count
();
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
());
}
Eigen
::
DefaultDevice
*
eigen_device
()
const
{
return
eigen_device_
.
get
();
}
...
...
@@ -49,16 +53,15 @@ class CPUDeviceContext : public DeviceContext {
return
retv
;
}
random_generator_type
&
RandGenerator
(
const
int
seed
)
{
random_generator_type
&
RandGenerator
()
{
if
(
!
rand_generator_
)
{
random_seed_
=
seed
;
rand_generator_
.
reset
(
new
random_generator_type
(
random_seed_
));
}
return
*
rand_generator_
.
get
();
}
private:
int
random_seed_
;
unsigned
random_seed_
;
std
::
unique_ptr
<
random_generator_type
>
rand_generator_
;
std
::
unique_ptr
<
Eigen
::
DefaultDevice
>
eigen_device_
;
};
...
...
@@ -81,6 +84,9 @@ class GPUPlaceGuard {
class
CUDADeviceContext
:
public
DeviceContext
{
public:
CUDADeviceContext
()
{
random_seed_
=
std
::
chrono
::
system_clock
::
now
().
time_since_epoch
().
count
();
}
explicit
CUDADeviceContext
(
const
GPUPlace
gpu_place
)
:
gpu_place_
(
gpu_place
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
),
"cudaStreamCreate failed"
);
...
...
@@ -98,9 +104,8 @@ class CUDADeviceContext : public DeviceContext {
"cudaStreamSynchronize failed"
);
}
curandGenerator_t
RandGenerator
(
const
int
seed
)
{
curandGenerator_t
RandGenerator
()
{
if
(
!
rand_generator_
)
{
random_seed_
=
seed
;
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandCreateGenerator
(
&
rand_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
),
...
...
@@ -177,7 +182,7 @@ class CUDADeviceContext : public DeviceContext {
cudnnHandle_t
dnn_handle_
{
nullptr
};
int
random_seed_
;
unsigned
random_seed_
;
curandGenerator_t
rand_generator_
{
nullptr
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录