Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
f9311406
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f9311406
编写于
8月 22, 2017
作者:
Q
qingqing01
提交者:
GitHub
8月 22, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3501 from qingqing01/cross_entropy
Implement GPU kernel for cross entropy operator.
上级
ce723af0
a8863a8d
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
140 addition
and
23 deletion
+140
-23
paddle/framework/pybind.cc
paddle/framework/pybind.cc
+1
-1
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+6
-9
paddle/operators/cross_entropy_op.cu
paddle/operators/cross_entropy_op.cu
+117
-5
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+11
-3
python/paddle/v2/framework/tests/op_test_util.py
python/paddle/v2/framework/tests/op_test_util.py
+2
-1
python/paddle/v2/framework/tests/test_cross_entropy_op.py
python/paddle/v2/framework/tests/test_cross_entropy_op.py
+3
-4
未找到文件。
paddle/framework/pybind.cc
浏览文件 @
f9311406
...
@@ -31,7 +31,7 @@ limitations under the License. */
...
@@ -31,7 +31,7 @@ limitations under the License. */
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
USE_OP
(
add_two
);
USE_OP
(
add_two
);
USE_
CPU_ONLY_
OP
(
onehot_cross_entropy
);
USE_OP
(
onehot_cross_entropy
);
USE_OP
(
sgd
);
USE_OP
(
sgd
);
USE_OP
(
mul
);
USE_OP
(
mul
);
USE_OP
(
mean
);
USE_OP
(
mean
);
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
f9311406
...
@@ -39,11 +39,10 @@ class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
...
@@ -39,11 +39,10 @@ class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
auto
X_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
// TODO(superjom) add enforce here after helper functions ready
dX
->
Resize
(
X
->
dims
());
X_grad
->
Resize
(
X
->
dims
());
}
}
};
};
...
@@ -70,9 +69,7 @@ namespace ops = paddle::operators;
...
@@ -70,9 +69,7 @@ namespace ops = paddle::operators;
REGISTER_OP
(
onehot_cross_entropy
,
ops
::
OnehotCrossEntropyOp
,
REGISTER_OP
(
onehot_cross_entropy
,
ops
::
OnehotCrossEntropyOp
,
ops
::
OnehotCrossEntropyOpMaker
,
onehot_cross_entropy_grad
,
ops
::
OnehotCrossEntropyOpMaker
,
onehot_cross_entropy_grad
,
ops
::
OnehotCrossEntropyGradientOp
);
ops
::
OnehotCrossEntropyGradientOp
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
onehot_cross_entropy
,
onehot_cross_entropy
,
ops
::
OnehotCrossEntropyOpKernel
<
float
>
);
ops
::
OnehotCrossEntropyOpKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
onehot_cross_entropy_grad
,
REGISTER_OP_CPU_KERNEL
(
ops
::
OnehotCrossEntropyGradientOpKernel
<
float
>
);
onehot_cross_entropy_grad
,
ops
::
OnehotCrossEntropyGradientOpKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/cross_entropy_op.cu
浏览文件 @
f9311406
...
@@ -12,10 +12,122 @@
...
@@ -12,10 +12,122 @@
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h"
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/platform/assert.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
__host__
__device__
T
clipping_log
(
const
T
x
)
{
PADDLE_ASSERT
(
std
::
is_floating_point
<
T
>::
value
);
const
T
kApproInf
=
1e20
;
T
v
=
log
(
x
);
if
(
v
==
INFINITY
)
{
return
kApproInf
;
}
if
(
v
==
-
INFINITY
)
{
return
-
kApproInf
;
}
return
v
;
}
template
<
typename
T
>
__global__
void
CrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
int
*
label
,
const
int
N
,
const
int
D
)
{
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
// CUDA_1D_KERNEL_LOOP(i, N) {
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
PADDLE_ASSERT
(
label
[
i
]
>=
0
&&
label
[
i
]
<
D
);
Y
[
i
]
=
-
clipping_log
(
X
[
i
*
D
+
label
[
i
]]);
}
}
// TODO(qingqing): make zero setting an common function.
template
<
typename
T
>
__global__
void
zero
(
T
*
X
,
const
int
N
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
X
[
i
]
=
0.0
;
}
}
template
<
typename
T
>
__global__
void
CrossEntropyGradientKernel
(
T
*
dX
,
const
T
*
dY
,
const
T
*
X
,
const
int
*
label
,
const
int
N
,
const
int
D
)
{
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
// CUDA_1D_KERNEL_LOOP(i, N) {
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
idx
=
i
*
D
+
label
[
i
];
dX
[
idx
]
=
-
dY
[
i
]
/
X
[
idx
];
}
}
template
<
typename
T
>
class
OnehotCrossEntropyOpCUDAKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use GPUPlace."
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
T
*
Xdata
=
X
->
data
<
T
>
();
const
int
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"label"
)
->
data
<
int
>
();
auto
Y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
Y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
Ydata
=
Y
->
data
<
T
>
();
int
N
=
X
->
dims
()[
0
];
int
D
=
X
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
N
+
block
-
1
)
/
block
;
// TODO(qingqing) launch kernel on specified stream
// base on ExecutionContext.
CrossEntropyKernel
<
T
><<<
grid
,
block
>>>
(
Ydata
,
Xdata
,
label_data
,
N
,
D
);
}
};
template
<
typename
T
>
class
OnehotCrossEntropyGradientOpCUDAKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use GPUPlace."
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
dY
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
label
=
ctx
.
Input
<
Tensor
>
(
"label"
);
auto
*
dXdata
=
dX
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
auto
*
dYdata
=
dY
->
template
data
<
T
>();
auto
*
Xdata
=
X
->
template
data
<
T
>();
auto
*
label_data
=
label
->
data
<
int
>
();
int
N
=
X
->
dims
()[
0
];
int
D
=
X
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
N
*
D
+
block
-
1
)
/
block
;
zero
<
T
><<<
grid
,
block
>>>
(
dXdata
,
N
*
D
);
grid
=
(
N
+
block
-
1
)
/
block
;
// TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext.
CrossEntropyGradientKernel
<
T
><<<
grid
,
block
>>>
(
dXdata
,
dYdata
,
Xdata
,
label_data
,
N
,
D
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
REGISTER_OP_GPU_KERNEL
(
onehot_cross_entropy
,
onehot_cross_entropy
,
ops
::
OnehotCrossEntropyOpCUDAKernel
<
float
>
);
ops
::
OnehotCrossEntropyOpKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
onehot_cross_entropy_grad
,
ops
::
OnehotCrossEntropyGradientOpCUDAKernel
<
float
>
);
paddle/operators/cross_entropy_op.h
浏览文件 @
f9311406
...
@@ -21,7 +21,7 @@ namespace operators {
...
@@ -21,7 +21,7 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
template
<
typename
T
>
T
tolerable_value
(
T
x
)
{
inline
T
tolerable_value
(
const
T
x
)
{
static_assert
(
std
::
is_floating_point
<
T
>::
value
,
static_assert
(
std
::
is_floating_point
<
T
>::
value
,
"tolerable_value works only on float, "
"tolerable_value works only on float, "
"double and double double."
);
"double and double double."
);
...
@@ -39,10 +39,13 @@ T tolerable_value(T x) {
...
@@ -39,10 +39,13 @@ T tolerable_value(T x) {
return
x
;
return
x
;
}
}
template
<
typename
Place
,
typename
T
>
template
<
typename
T
>
class
OnehotCrossEntropyOpKernel
:
public
framework
::
OpKernel
{
class
OnehotCrossEntropyOpKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"It must use CPUPlace."
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
T
*
Xdata
=
X
->
data
<
T
>
();
const
T
*
Xdata
=
X
->
data
<
T
>
();
const
int
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"label"
)
->
data
<
int
>
();
const
int
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"label"
)
->
data
<
int
>
();
...
@@ -62,10 +65,13 @@ class OnehotCrossEntropyOpKernel : public framework::OpKernel {
...
@@ -62,10 +65,13 @@ class OnehotCrossEntropyOpKernel : public framework::OpKernel {
}
}
};
};
template
<
typename
Place
,
typename
T
>
template
<
typename
T
>
class
OnehotCrossEntropyGradientOpKernel
:
public
framework
::
OpKernel
{
class
OnehotCrossEntropyGradientOpKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"It must use CPUPlace."
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
dY
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
dY
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
...
@@ -79,6 +85,8 @@ class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel {
...
@@ -79,6 +85,8 @@ class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel {
const
int
batch_size
=
X
->
dims
()[
0
];
const
int
batch_size
=
X
->
dims
()[
0
];
const
int
class_num
=
X
->
dims
()[
1
];
const
int
class_num
=
X
->
dims
()[
1
];
// TODO(qingqing): make zero setting an common function.
memset
(
dXdata
,
0
,
sizeof
(
T
)
*
batch_size
*
class_num
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
int
index
=
i
*
class_num
+
label_data
[
i
];
int
index
=
i
*
class_num
+
label_data
[
i
];
dXdata
[
index
]
=
-
tolerable_value
(
dYdata
[
i
]
/
Xdata
[
index
]);
dXdata
[
index
]
=
-
tolerable_value
(
dYdata
[
i
]
/
Xdata
[
index
]);
...
...
python/paddle/v2/framework/tests/op_test_util.py
浏览文件 @
f9311406
...
@@ -64,7 +64,8 @@ class OpTestMeta(type):
...
@@ -64,7 +64,8 @@ class OpTestMeta(type):
actual
=
numpy
.
array
(
scope
.
find_var
(
out_name
).
get_tensor
())
actual
=
numpy
.
array
(
scope
.
find_var
(
out_name
).
get_tensor
())
expect
=
self
.
outputs
[
out_name
]
expect
=
self
.
outputs
[
out_name
]
self
.
assertTrue
(
self
.
assertTrue
(
numpy
.
allclose
(
actual
,
expect
),
numpy
.
allclose
(
actual
,
expect
,
atol
=
1e-05
),
"output name: "
+
out_name
+
"has diff"
)
"output name: "
+
out_name
+
"has diff"
)
obj
.
test_all
=
test_all
obj
.
test_all
=
test_all
...
...
python/paddle/v2/framework/tests/test_cross_entropy_op.py
浏览文件 @
f9311406
...
@@ -8,9 +8,8 @@ class TestCrossEntropy(unittest.TestCase):
...
@@ -8,9 +8,8 @@ class TestCrossEntropy(unittest.TestCase):
__metaclass__
=
OpTestMeta
__metaclass__
=
OpTestMeta
def
setUp
(
self
):
def
setUp
(
self
):
# TODO this unit test is not passed
self
.
type
=
"onehot_cross_entropy"
self
.
type
=
"onehot_cross_entropy"
batch_size
=
10
0
batch_size
=
3
0
class_num
=
10
class_num
=
10
X
=
numpy
.
random
.
random
((
batch_size
,
class_num
)).
astype
(
"float32"
)
X
=
numpy
.
random
.
random
((
batch_size
,
class_num
)).
astype
(
"float32"
)
label
=
5
*
numpy
.
ones
(
batch_size
).
astype
(
"int32"
)
label
=
5
*
numpy
.
ones
(
batch_size
).
astype
(
"int32"
)
...
@@ -22,9 +21,9 @@ class TestCrossEntropy(unittest.TestCase):
...
@@ -22,9 +21,9 @@ class TestCrossEntropy(unittest.TestCase):
class
CrossEntropyGradOpTest
(
GradientChecker
):
class
CrossEntropyGradOpTest
(
GradientChecker
):
def
test_
softmax
_grad
(
self
):
def
test_
check
_grad
(
self
):
op
=
create_op
(
"onehot_cross_entropy"
)
op
=
create_op
(
"onehot_cross_entropy"
)
batch_size
=
10
0
batch_size
=
3
0
class_num
=
10
class_num
=
10
inputs
=
{
inputs
=
{
"X"
:
numpy
.
random
.
uniform
(
"X"
:
numpy
.
random
.
uniform
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录