Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
f1d5fb3b
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f1d5fb3b
编写于
9月 21, 2017
作者:
C
caoying03
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support soft labels.
上级
a2a0d6f8
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
74 addition
and
87 deletion
+74
-87
paddle/operators/math/CMakeLists.txt
paddle/operators/math/CMakeLists.txt
+2
-2
paddle/operators/math/softmax.cc
paddle/operators/math/softmax.cc
+1
-1
paddle/operators/math/softmax.cu
paddle/operators/math/softmax.cu
+1
-1
paddle/operators/math/softmax.h
paddle/operators/math/softmax.h
+0
-0
paddle/operators/math/utils.h
paddle/operators/math/utils.h
+0
-42
paddle/operators/softmax_op.h
paddle/operators/softmax_op.h
+1
-1
paddle/operators/softmax_with_cross_entropy_op.cc
paddle/operators/softmax_with_cross_entropy_op.cc
+51
-24
paddle/operators/softmax_with_cross_entropy_op.cu
paddle/operators/softmax_with_cross_entropy_op.cu
+12
-10
paddle/operators/softmax_with_cross_entropy_op.h
paddle/operators/softmax_with_cross_entropy_op.h
+4
-4
python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py
.../v2/framework/tests/test_softmax_with_cross_entropy_op.py
+2
-2
未找到文件。
paddle/operators/math/CMakeLists.txt
浏览文件 @
f1d5fb3b
if
(
WITH_GPU
)
if
(
WITH_GPU
)
nv_library
(
math_function SRCS math_function.cc math_function.cu im2col.cc
nv_library
(
math_function SRCS math_function.cc math_function.cu im2col.cc
im2col.cu DEPS cblas device_context operator
)
im2col.cu DEPS cblas device_context operator
)
nv_library
(
softmax_function SRCS softmax
_function.cc softmax_function
.cu
nv_library
(
softmax_function SRCS softmax
.cc softmax
.cu
DEPS operator
)
DEPS operator
)
else
()
else
()
cc_library
(
math_function SRCS math_function.cc im2col.cc
cc_library
(
math_function SRCS math_function.cc im2col.cc
DEPS cblas device_context operator
)
DEPS cblas device_context operator
)
cc_library
(
softmax_function SRCS softmax
_function
.cc DEPS operator
)
cc_library
(
softmax_function SRCS softmax.cc DEPS operator
)
endif
()
endif
()
nv_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
nv_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
...
...
paddle/operators/math/softmax
_function
.cc
→
paddle/operators/math/softmax.cc
浏览文件 @
f1d5fb3b
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/operators/math/softmax
_function
.h"
#include "paddle/operators/math/softmax.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
...
paddle/operators/math/softmax
_function
.cu
→
paddle/operators/math/softmax.cu
浏览文件 @
f1d5fb3b
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#define EIGEN_USE_GPU
#define EIGEN_USE_GPU
#include "paddle/operators/math/softmax
_function
.h"
#include "paddle/operators/math/softmax.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
...
paddle/operators/math/softmax
_function
.h
→
paddle/operators/math/softmax.h
浏览文件 @
f1d5fb3b
文件已移动
paddle/operators/math/utils.h
已删除
100644 → 0
浏览文件 @
a2a0d6f8
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/platform/assert.h"
#include "paddle/platform/hostdevice.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
T
HOSTDEVICE
tolerable_value
(
const
T
x
)
{
PADDLE_ASSERT
(
std
::
is_floating_point
<
T
>::
value
);
const
T
kApproInf
=
1e20
;
if
(
x
==
INFINITY
)
{
return
kApproInf
;
}
if
(
x
==
-
INFINITY
)
{
return
-
kApproInf
;
}
return
x
;
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/softmax_op.h
浏览文件 @
f1d5fb3b
...
@@ -15,7 +15,7 @@ limitations under the License. */
...
@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/softmax
_function
.h"
#include "paddle/operators/math/softmax.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
...
paddle/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
f1d5fb3b
...
@@ -23,16 +23,32 @@ class SoftmaxWithCrossEntropyOpMaker
...
@@ -23,16 +23,32 @@ class SoftmaxWithCrossEntropyOpMaker
SoftmaxWithCrossEntropyOpMaker
(
framework
::
OpProto
*
proto
,
SoftmaxWithCrossEntropyOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
//(TODO caoying) replace int with boolean
AddAttr
<
int
>
(
"soft_label"
,
"(int, default 0), A flag to indicate whether to interpretate "
"the given labels as soft labels."
)
.
SetDefault
(
0
);
AddInput
(
"Logits"
,
AddInput
(
"Logits"
,
"The unscaled log probabilities which is a 2-D tensor<float> with"
"(Tensor, default Tensor<float>), The unscaled log probabilities "
"shape [N x K]. N is the batch_size, and K is the class number."
)
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
"and K is the class number."
)
.
NotInGradient
();
.
NotInGradient
();
AddInput
(
"Label"
,
"The ground truth. A 1-D tensor<int> with shape N."
);
AddInput
(
AddOutput
(
"Softmax"
,
"Label"
,
"Store the outputs of softmax function, "
"(Tensor, default Tensor<int>), The ground truth which is "
"which will be used in backward calculation."
)
"a 1-D or 2-D tensor. "
"If soft_label is set to 0, Label is a Tensor<int> with shape [N x 1]. "
"If soft_label is set to 1, Label is a Tensor<float/double> "
"with shape [N x K]."
);
AddOutput
(
"Softmax"
,
"(Tensor, default Tensor<float>), A 2-D tensor with shape [N x K]. "
"The outputs value of softmax activation by given the input batch, "
"which will be used in backward calculation."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
"Out"
,
"A 1-D tensor<float> with shape N."
);
AddOutput
(
"Loss"
,
"(Tensor, default Tensor<float>), A 1-D tensor. The cross "
"entropy loss with shape [N x 1]."
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Cross entropy loss with softmax are used as the output layer extensively. This
Cross entropy loss with softmax are used as the output layer extensively. This
operator computes the softmax normalized values for each row of the input
operator computes the softmax normalized values for each row of the input
...
@@ -46,25 +62,18 @@ which will produce incorrect results.
...
@@ -46,25 +62,18 @@ which will produce incorrect results.
This operators expects mutually exclusive hard labels, each sample in a batch
This operators expects mutually exclusive hard labels, each sample in a batch
is in exactly one class with probabilities 1. Each sample in the batch with one
is in exactly one class with probabilities 1. Each sample in the batch with one
and only one label.
and only one label.
)DOC"
);
}
};
class
SoftmaxWithCrossEntropyOpGrad
:
public
framework
::
OperatorWithKernel
{
Equation:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
1) hard label (one-hot label)
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@Grad) should not be null"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Softmax"
),
"Input(Softmax) should be not null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Label"
),
"Input(Lable) should be not null."
);
ctx
.
Output
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Logits"
))
Loss_j = -\text{Logit}_{Label_j} + \log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right), j = 1, ..., K
->
Resize
(
ctx
.
Input
<
Tensor
>
(
"Softmax"
)
->
dims
());
2) soft label (a distribution over all classes)
Loss_j = -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i-\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right), j = 1,...,K
)DOC"
);
}
}
};
};
...
@@ -82,7 +91,25 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
...
@@ -82,7 +91,25 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
"The label should be a 1-d tensor."
);
"The label should be a 1-d tensor."
);
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Softmax"
)
->
Resize
(
logits
->
dims
());
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Softmax"
)
->
Resize
(
logits
->
dims
());
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
)
->
Resize
({
logits
->
dims
()[
0
],
1
});
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Loss"
)
->
Resize
({
logits
->
dims
()[
0
],
1
});
}
};
class
SoftmaxWithCrossEntropyOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Loss"
)),
"Input(Loss@Grad) should not be null"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Softmax"
),
"Input(Softmax) should be not null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Label"
),
"Input(Lable) should be not null."
);
ctx
.
Output
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Logits"
))
->
Resize
(
ctx
.
Input
<
Tensor
>
(
"Softmax"
)
->
dims
());
}
}
};
};
...
...
paddle/operators/softmax_with_cross_entropy_op.cu
浏览文件 @
f1d5fb3b
...
@@ -13,9 +13,10 @@
...
@@ -13,9 +13,10 @@
limitations under the License. */
limitations under the License. */
#define EIGEN_USE_GPU
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/
math/softmax_function
.h"
#include "paddle/operators/
cross_entropy_op
.h"
#include "paddle/operators/math/
utils
.h"
#include "paddle/operators/math/
softmax
.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -27,9 +28,10 @@ __global__ void CrossEntropyKernel(T* out, const T* softmax_out,
...
@@ -27,9 +28,10 @@ __global__ void CrossEntropyKernel(T* out, const T* softmax_out,
const
int
*
label
,
const
int
batch_size
,
const
int
*
label
,
const
int
batch_size
,
const
int
class_num
)
{
const
int
class_num
)
{
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
>=
batch_size
)
return
;
if
(
i
<
batch_size
)
{
PADDLE_ASSERT
(
label
[
i
]
>=
0
&&
label
[
i
]
<
class_num
);
PADDLE_ASSERT
(
label
[
i
]
>=
0
&&
label
[
i
]
<
class_num
);
out
[
i
]
=
-
math
::
tolerable_value
(
log
(
softmax_out
[
i
*
class_num
+
label
[
i
]]));
out
[
i
]
=
-
tolerable_value
(
std
::
log
(
softmax_out
[
i
*
class_num
+
label
[
i
]]));
}
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -38,10 +40,10 @@ __global__ void CrossEntropyWithSoftmaxGradKernel(T* softmax_out,
...
@@ -38,10 +40,10 @@ __global__ void CrossEntropyWithSoftmaxGradKernel(T* softmax_out,
const
int
batch_size
,
const
int
batch_size
,
const
int
class_num
)
{
const
int
class_num
)
{
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
>=
batch_size
)
return
;
if
(
i
<
batch_size
)
{
PADDLE_ASSERT
(
label
[
i
]
>=
0
&&
label
[
i
]
<
class_num
);
PADDLE_ASSERT
(
label
[
i
]
>=
0
&&
label
[
i
]
<
class_num
)
;
softmax_out
[
i
*
class_num
+
label
[
i
]]
-=
1.
;
softmax_out
[
i
*
class_num
+
label
[
i
]]
-=
1.
;
}
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -60,7 +62,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel {
...
@@ -60,7 +62,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel {
// Calculate the cross entropy loss based on hard labels.
// Calculate the cross entropy loss based on hard labels.
const
int
*
label_data
=
context
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
int
>
();
const
int
*
label_data
=
context
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
int
>
();
Tensor
*
loss
=
context
.
Output
<
Tensor
>
(
"
Out
"
);
Tensor
*
loss
=
context
.
Output
<
Tensor
>
(
"
Loss
"
);
loss
->
mutable_data
<
T
>
(
context
.
GetPlace
());
loss
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
loss_data
=
loss
->
data
<
T
>
();
T
*
loss_data
=
loss
->
data
<
T
>
();
...
...
paddle/operators/softmax_with_cross_entropy_op.h
浏览文件 @
f1d5fb3b
...
@@ -15,8 +15,8 @@
...
@@ -15,8 +15,8 @@
#pragma once
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/
math/softmax_function
.h"
#include "paddle/operators/
cross_entropy_op
.h"
#include "paddle/operators/math/
utils
.h"
#include "paddle/operators/math/
softmax
.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -44,7 +44,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel {
...
@@ -44,7 +44,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel {
T
*
softmax_out
=
softmax
->
data
<
T
>
();
T
*
softmax_out
=
softmax
->
data
<
T
>
();
const
int
*
label_data
=
context
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
int
>
();
const
int
*
label_data
=
context
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
int
>
();
Tensor
*
loss
=
context
.
Output
<
Tensor
>
(
"
Out
"
);
Tensor
*
loss
=
context
.
Output
<
Tensor
>
(
"
Loss
"
);
loss
->
mutable_data
<
T
>
(
context
.
GetPlace
());
loss
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
loss_data
=
loss
->
data
<
T
>
();
T
*
loss_data
=
loss
->
data
<
T
>
();
...
@@ -53,7 +53,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel {
...
@@ -53,7 +53,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel {
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
int
index
=
i
*
class_num
+
label_data
[
i
];
int
index
=
i
*
class_num
+
label_data
[
i
];
loss_data
[
i
]
=
-
math
::
tolerable_value
(
std
::
log
(
softmax_out
[
index
]));
loss_data
[
i
]
=
-
tolerable_value
(
std
::
log
(
softmax_out
[
index
]));
}
}
}
}
};
};
...
...
python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py
浏览文件 @
f1d5fb3b
...
@@ -25,13 +25,13 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
...
@@ -25,13 +25,13 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
dtype
=
"float32"
)
dtype
=
"float32"
)
self
.
inputs
=
{
"Logits"
:
logits
,
"Label"
:
labels
}
self
.
inputs
=
{
"Logits"
:
logits
,
"Label"
:
labels
}
self
.
outputs
=
{
"Softmax"
:
softmax
,
"
Out
"
:
cross_entropy
}
self
.
outputs
=
{
"Softmax"
:
softmax
,
"
Loss
"
:
cross_entropy
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
([
"Logits"
],
"
Out
"
,
max_relative_error
=
0.05
)
self
.
check_grad
([
"Logits"
],
"
Loss
"
,
max_relative_error
=
0.05
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录