Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
2cde56c5
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2cde56c5
编写于
9月 20, 2017
作者:
W
wanghaoshuang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use Transform instead of eigen
上级
743dfd82
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
58 addition
and
115 deletion
+58
-115
paddle/operators/clip_op.cc
paddle/operators/clip_op.cc
+1
-2
paddle/operators/clip_op.cu
paddle/operators/clip_op.cu
+2
-56
paddle/operators/clip_op.h
paddle/operators/clip_op.h
+55
-57
未找到文件。
paddle/operators/clip_op.cc
浏览文件 @
2cde56c5
...
...
@@ -80,6 +80,5 @@ class ClipOpGrad : public framework::OperatorWithKernel {
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
clip
,
ops
::
ClipOp
,
ops
::
ClipOpMaker
<
float
>
,
clip_grad
,
ops
::
ClipOpGrad
);
REGISTER_OP_CPU_KERNEL
(
clip
,
ops
::
ClipKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
clip
,
ops
::
ClipKernel
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
clip_grad
,
ops
::
ClipGradKernel
<
float
>
);
paddle/operators/clip_op.cu
浏览文件 @
2cde56c5
...
...
@@ -14,60 +14,6 @@
#include "paddle/operators/clip_op.h"
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
namespace
paddle
{
namespace
operators
{
using
framework
::
LoDTensor
;
template
<
typename
T
>
__global__
void
ClipGradientKernel
(
const
int
N
,
const
T
min
,
const
T
max
,
const
T
*
Y
,
const
T
*
dY
,
T
*
dX
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
N
)
{
if
(
Y
[
i
]
>
min
&&
Y
[
i
]
<
max
)
{
dX
[
i
]
=
dY
[
i
];
}
else
{
dX
[
i
]
=
0
;
}
}
}
template
<
typename
T
>
class
ClipGradientOpCUDAKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
max
=
context
.
Attr
<
float
>
(
"max"
);
auto
min
=
context
.
Attr
<
float
>
(
"min"
);
auto
*
d_out
=
context
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
context
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
if
(
d_x
!=
nullptr
)
{
auto
*
x
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
dims
=
d_x
->
dims
();
int64_t
count
=
d_out
->
numel
();
auto
d_x_data
=
d_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
d_out_data
=
d_out
->
data
<
T
>
();
auto
x_data
=
x
->
data
<
T
>
();
int
N
=
d_x
->
dims
()[
0
];
int
D
=
d_x
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
N
*
D
+
block
-
1
)
/
block
;
ClipGradientKernel
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
.
device_context
())
.
stream
()
>>>
(
count
,
min
,
max
,
x_data
,
d_out_data
,
d_x_data
);
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
clip
,
ops
::
ClipKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
clip_grad
,
ops
::
ClipGradientOpCUDAKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
clip
,
ops
::
ClipKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
clip_grad
,
ops
::
ClipGradKernel
<
float
>
);
paddle/operators/clip_op.h
浏览文件 @
2cde56c5
...
...
@@ -16,57 +16,61 @@
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/transform.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
LoDTensor
;
using
framework
::
Tensor
;
using
platform
::
Transform
;
template
<
typename
T
,
size_t
D
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenTensor
=
framework
::
EigenTensor
<
T
,
D
,
MajorType
,
IndexType
>
;
template
<
typename
T
>
class
ClipFunctor
{
public:
explicit
ClipFunctor
(
const
T
min
,
const
T
max
)
:
min_
(
min
),
max_
(
max
)
{}
HOSTDEVICE
T
operator
()(
const
T
&
x
)
const
{
if
(
x
<
min_
)
return
min_
;
else
if
(
x
>
max_
)
return
max_
;
else
return
x
;
}
private:
T
min_
;
T
max_
;
};
template
<
typename
T
>
class
ClipGradFunctor
{
public:
explicit
ClipGradFunctor
(
const
T
min
,
const
T
max
)
:
min_
(
min
),
max_
(
max
)
{}
HOSTDEVICE
T
operator
()(
const
T
&
x
,
const
T
&
y
)
const
{
if
(
y
>
min_
&&
y
<
max_
)
return
x
;
else
return
0
;
}
template
<
typename
Place
,
typename
T
,
size_t
D
>
void
ClipFunction
(
const
framework
::
ExecutionContext
&
context
)
{
auto
max
=
context
.
op
().
Attr
<
float
>
(
"max"
);
auto
min
=
context
.
op
().
Attr
<
float
>
(
"min"
);
auto
*
x
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
x_tensor
=
EigenTensor
<
T
,
D
>::
From
(
*
x
);
auto
out_tensor
=
EigenTensor
<
T
,
D
>::
From
(
*
out
);
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
out_tensor
.
device
(
place
)
=
x_tensor
.
cwiseMin
(
max
).
cwiseMax
(
min
);
}
private:
T
min_
;
T
max_
;
};
template
<
typename
Place
,
typename
T
>
template
<
typename
T
>
class
ClipKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
int
rank
=
context
.
Input
<
LoDTensor
>
(
"X"
)
->
dims
().
size
();
switch
(
rank
)
{
case
1
:
ClipFunction
<
Place
,
T
,
1
>
(
context
);
break
;
case
2
:
ClipFunction
<
Place
,
T
,
2
>
(
context
);
break
;
case
3
:
ClipFunction
<
Place
,
T
,
3
>
(
context
);
break
;
case
4
:
ClipFunction
<
Place
,
T
,
4
>
(
context
);
break
;
case
5
:
ClipFunction
<
Place
,
T
,
5
>
(
context
);
break
;
case
6
:
ClipFunction
<
Place
,
T
,
6
>
(
context
);
break
;
default:
PADDLE_THROW
(
"PadOp only support tensors with no more than 6 dimensions."
);
}
auto
max
=
context
.
Attr
<
T
>
(
"max"
);
auto
min
=
context
.
Attr
<
T
>
(
"min"
);
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
const
T
*
x_data
=
x
->
data
<
T
>
();
int
numel
=
x
->
numel
();
Transform
(
context
.
device_context
(),
x_data
,
x_data
+
numel
,
out_data
,
ClipFunctor
<
T
>
(
min
,
max
));
}
};
...
...
@@ -74,24 +78,18 @@ template <typename T>
class
ClipGradKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
max
=
context
.
op
().
Attr
<
float
>
(
"max"
);
auto
min
=
context
.
op
().
Attr
<
float
>
(
"min"
);
auto
*
d_out
=
context
.
Input
<
LoD
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
context
.
Output
<
LoD
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
max
=
context
.
Attr
<
T
>
(
"max"
);
auto
min
=
context
.
Attr
<
T
>
(
"min"
);
auto
*
d_out
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
if
(
d_x
!=
nullptr
)
{
auto
*
x
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
dims
=
d_x
->
dims
();
int64_t
count
=
d_out
->
numel
();
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
int64_t
numel
=
d_out
->
numel
();
auto
d_x_data
=
d_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
d_out_data
=
d_out
->
data
<
T
>
();
auto
x_data
=
x
->
data
<
T
>
();
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
if
(
x_data
[
i
]
>
min
&&
x_data
[
i
]
<
max
)
{
d_x_data
[
i
]
=
d_out_data
[
i
];
}
else
{
d_x_data
[
i
]
=
0
;
}
}
const
T
*
d_out_data
=
d_out
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
Transform
(
context
.
device_context
(),
d_out_data
,
d_out_data
+
numel
,
x_data
,
d_x_data
,
ClipGradFunctor
<
T
>
(
min
,
max
));
}
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录