Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
b26e9bd2
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b26e9bd2
编写于
3月 12, 2019
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine code
test=develop
上级
cfd012e2
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
259 addition
and
339 deletion
+259
-339
paddle/fluid/operators/cross_entropy2_op.cc
paddle/fluid/operators/cross_entropy2_op.cc
+18
-99
paddle/fluid/operators/cross_entropy2_op.h
paddle/fluid/operators/cross_entropy2_op.h
+13
-91
paddle/fluid/operators/cross_entropy_op.cc
paddle/fluid/operators/cross_entropy_op.cc
+11
-126
paddle/fluid/operators/cross_entropy_op_base.h
paddle/fluid/operators/cross_entropy_op_base.h
+169
-0
paddle/fluid/operators/expand_op.cc
paddle/fluid/operators/expand_op.cc
+1
-0
paddle/fluid/operators/math.h
paddle/fluid/operators/math.h
+42
-0
paddle/fluid/operators/math/cross_entropy.cu
paddle/fluid/operators/math/cross_entropy.cu
+1
-12
paddle/fluid/operators/selu_op.h
paddle/fluid/operators/selu_op.h
+2
-3
paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu
paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu
+1
-3
paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu
...e/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu
+1
-5
未找到文件。
paddle/fluid/operators/cross_entropy2_op.cc
浏览文件 @
b26e9bd2
...
...
@@ -16,46 +16,22 @@ limitations under the License. */
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/operators/cross_entropy_op_base.h"
namespace
paddle
{
namespace
operators
{
class
CrossEntropyOp2
:
public
framework
::
OperatorWithKernel
{
class
CrossEntropyOp2
:
public
CrossEntropyOpBase
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
CrossEntropyOpBase
::
CrossEntropyOpBase
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
CrossEntropyOpBase
::
InferShape
(
ctx
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
"Output(Y) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"XShape"
),
"Output(XShape) should be not null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
label_dims
=
ctx
->
GetInputDim
(
"Label"
);
int
rank
=
x_dims
.
size
();
PADDLE_ENFORCE_EQ
(
rank
,
label_dims
.
size
(),
"Input(X) and Input(Label) shall have the same rank."
);
bool
check
=
true
;
if
((
!
ctx
->
IsRuntime
())
&&
(
framework
::
product
(
x_dims
)
<=
0
||
framework
::
product
(
label_dims
)
<=
0
))
{
check
=
false
;
}
if
(
check
)
{
PADDLE_ENFORCE_EQ
(
framework
::
slice_ddim
(
x_dims
,
0
,
rank
-
1
),
framework
::
slice_ddim
(
label_dims
,
0
,
rank
-
1
),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension."
);
}
PADDLE_ENFORCE_EQ
(
label_dims
[
rank
-
1
],
1UL
,
"Last dimension of Input(Label) should be 1."
);
auto
y_dims
=
x_dims
;
y_dims
[
rank
-
1
]
=
1
;
ctx
->
SetOutputDim
(
"Y"
,
y_dims
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
auto
x_dims_vec
=
framework
::
vectorize
(
x_dims
);
x_dims_vec
.
push_back
(
0
);
ctx
->
SetOutputDim
(
"XShape"
,
framework
::
make_ddim
(
x_dims_vec
));
...
...
@@ -63,73 +39,25 @@ class CrossEntropyOp2 : public framework::OperatorWithKernel {
}
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
ctx
.
device_context
());
bool
IsSoftLabel
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
return
false
;
}
};
class
CrossEntropyGradientOp2
:
public
framework
::
OperatorWithKernel
{
class
CrossEntropyGradientOp2
:
public
CrossEntropyGradientOpBase
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
CrossEntropyGradientOpBase
::
CrossEntropyGradientOpBase
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"XShape"
),
"Input(XShape) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
"Input(Y@GRAD) shoudl be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"Output(X@GRAD) should be not null."
);
auto
x_shapes
=
ctx
->
GetInputDim
(
"XShape"
);
framework
::
DDim
x_dims
(
x_shapes
.
Get
(),
x_shapes
.
size
()
-
1
);
auto
label_dims
=
ctx
->
GetInputDim
(
"Label"
);
auto
dy_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Y"
));
int
rank
=
x_dims
.
size
();
PADDLE_ENFORCE_EQ
(
dy_dims
.
size
(),
rank
,
"Input(Y@Grad) and Input(X) should have the same rank."
);
PADDLE_ENFORCE_EQ
(
label_dims
.
size
(),
rank
,
"Input(Label) and Input(X) should have the same rank."
);
bool
check
=
true
;
if
((
!
ctx
->
IsRuntime
())
&&
(
framework
::
product
(
x_dims
)
<=
0
||
framework
::
product
(
label_dims
)
<=
0
))
{
check
=
false
;
}
if
(
check
)
{
PADDLE_ENFORCE_EQ
(
framework
::
slice_ddim
(
x_dims
,
0
,
rank
-
1
),
framework
::
slice_ddim
(
label_dims
,
0
,
rank
-
1
),
"The Input(X) and Input(Label) should have the same "
"shape except the last dimension."
);
PADDLE_ENFORCE_EQ
(
framework
::
slice_ddim
(
x_dims
,
0
,
rank
-
1
),
framework
::
slice_ddim
(
dy_dims
,
0
,
rank
-
1
),
"The Input(X) and Input(Y@Grad) should have the same "
"shape except the last dimension."
);
}
PADDLE_ENFORCE_EQ
(
dy_dims
[
rank
-
1
],
1
,
"The last dimension of Input(Y@Grad) should be 1."
);
PADDLE_ENFORCE_EQ
(
label_dims
[
rank
-
1
],
1
,
"Last dimension of Input(Label) should be 1."
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
ctx
->
ShareLoD
(
"XShape"
,
framework
::
GradVarName
(
"X"
));
protected:
virtual
framework
::
DDim
GetXDim
(
framework
::
InferShapeContext
*
ctx
)
const
{
auto
x_shape
=
ctx
->
GetInputDim
(
"XShape"
);
return
framework
::
DDim
(
x_shape
.
Get
(),
x_shape
.
size
()
-
1
);
}
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
))
->
type
(),
ctx
.
device_context
());
virtual
const
char
*
VarNameWithXLoD
()
const
{
return
"XShape"
;
}
virtual
bool
IsSoftLabel
(
framework
::
InferShapeContext
*
ctx
)
const
{
return
false
;
}
};
...
...
@@ -156,7 +84,7 @@ class CrossEntropyOpMaker2 : public framework::OpProtoAndCheckerMaker {
"Only valid if soft_label is set to False"
)
.
SetDefault
(
-
100
);
AddComment
(
R"DOC(
CrossEntropy Operator.
Hard-label
CrossEntropy Operator.
The input 'X' and 'Label' will first be logically flattened to 2-D matrixs.
The matrix's second dimension(row length) is as same as the original last
...
...
@@ -173,15 +101,6 @@ or not. But the output only shares the LoD information with input X.
}
};
class
CrossEntropyOpInferVarType2
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Y"
}};
}
};
class
CrossEntropyGradOpMaker2
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
...
...
@@ -207,7 +126,7 @@ namespace ops = paddle::operators;
using
CPUCtx
=
paddle
::
platform
::
CPUDeviceContext
;
REGISTER_OPERATOR
(
cross_entropy2
,
ops
::
CrossEntropyOp2
,
ops
::
CrossEntropyOpMaker2
,
ops
::
CrossEntropyOpInferVarType
2
,
ops
::
CrossEntropyOpMaker2
,
ops
::
CrossEntropyOpInferVarType
,
ops
::
CrossEntropyGradOpMaker2
);
REGISTER_OPERATOR
(
cross_entropy_grad2
,
ops
::
CrossEntropyGradientOp2
);
REGISTER_OP_CPU_KERNEL
(
cross_entropy2
,
...
...
paddle/fluid/operators/cross_entropy2_op.h
浏览文件 @
b26e9bd2
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <cmath>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
...
...
@@ -26,81 +27,6 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
HOSTDEVICE
inline
platform
::
float16
RealLog
(
platform
::
float16
x
)
{
#ifdef __NVCC__
return
static_cast
<
platform
::
float16
>
(
logf
(
static_cast
<
float
>
(
x
)));
#else
return
static_cast
<
platform
::
float16
>
(
std
::
log
(
static_cast
<
float
>
(
x
)));
#endif
}
HOSTDEVICE
inline
float
RealLog
(
float
x
)
{
#ifdef __NVCC__
return
logf
(
x
);
#else
return
std
::
log
(
x
);
#endif
}
HOSTDEVICE
inline
double
RealLog
(
double
x
)
{
#ifdef __NVCC__
return
log
(
x
);
#else
return
std
::
log
(
x
);
#endif
}
HOSTDEVICE
inline
platform
::
float16
RealExp
(
platform
::
float16
x
)
{
#ifdef __NVCC__
return
static_cast
<
platform
::
float16
>
(
expf
(
static_cast
<
float
>
(
x
)));
#else
return
static_cast
<
platform
::
float16
>
(
std
::
exp
(
static_cast
<
float
>
(
x
)));
#endif
}
HOSTDEVICE
inline
float
RealExp
(
float
x
)
{
#ifdef __NVCC__
return
expf
(
x
);
#else
return
std
::
exp
(
x
);
#endif
}
HOSTDEVICE
inline
double
RealExp
(
double
x
)
{
#ifdef __NVCC__
return
exp
(
x
);
#else
return
std
::
exp
(
x
);
#endif
}
template
<
typename
T
>
struct
CrossEntropyForwardFunctor
{
CrossEntropyForwardFunctor
(
const
T
*
x
,
T
*
y
,
const
int64_t
*
label
,
int64_t
ignore_index
,
int64_t
feature_size
)
:
x_
(
x
),
y_
(
y
),
label_
(
label
),
ignore_index_
(
ignore_index
),
feature_size_
(
feature_size
)
{}
HOSTDEVICE
void
operator
()(
int64_t
row_idx
)
const
{
auto
col_idx
=
label_
[
row_idx
];
if
(
col_idx
!=
ignore_index_
)
{
y_
[
row_idx
]
=
-
math
::
TolerableValue
<
T
>
()(
RealLog
(
x_
[
row_idx
*
feature_size_
+
col_idx
]));
}
else
{
y_
[
row_idx
]
=
0
;
}
}
const
T
*
x_
;
T
*
y_
;
const
int64_t
*
label_
;
int64_t
ignore_index_
;
int64_t
feature_size_
;
};
template
<
typename
T
>
struct
CrossEntropyBackwardFunctor
{
CrossEntropyBackwardFunctor
(
T
*
dx
,
const
T
*
y
,
const
T
*
dy
,
...
...
@@ -118,7 +44,7 @@ struct CrossEntropyBackwardFunctor {
auto
col_idx
=
idx
%
feature_size_
;
auto
label
=
label_
[
row_idx
];
if
(
label
==
col_idx
&&
label
!=
ignore_index_
)
{
dx_
[
idx
]
=
-
dy_
[
row_idx
]
*
RealE
xp
(
y_
[
row_idx
]);
dx_
[
idx
]
=
-
dy_
[
row_idx
]
*
real_e
xp
(
y_
[
row_idx
]);
}
else
{
dx_
[
idx
]
=
0
;
}
...
...
@@ -136,24 +62,20 @@ template <typename DeviceContext, typename T>
class
CrossEntropyOpKernel2
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
*
x_original
=
ctx
.
Input
<
Tensor
>
(
"X"
);
int
rank
=
x_original
->
dims
().
size
();
auto
*
p_y
=
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
p_x
=
x
->
data
<
T
>
();
auto
*
p_label
=
label
->
data
<
int64_t
>
();
auto
x
=
framework
::
ReshapeToMatrix
(
*
x_original
,
rank
-
1
);
auto
label
=
framework
::
ReshapeToMatrix
(
*
ctx
.
Input
<
Tensor
>
(
"Label"
),
rank
-
1
);
auto
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
rank
=
x
->
dims
().
size
();
int64_t
feature_size
=
x
->
dims
()[
rank
-
1
];
int64_t
batch_size
=
framework
::
product
(
x
->
dims
())
/
feature_size
;
auto
ignore_index
=
ctx
.
Attr
<
int
>
(
"ignore_index"
);
int64_t
ignore_index
=
ctx
.
Attr
<
int
>
(
"ignore_index"
);
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
.
template
device_context
<
DeviceContext
>(),
batch_size
);
for_range
(
CrossEntropyForwardFunctor
<
T
>
(
p_x
,
p_y
,
p_label
,
ignore_index
,
feature_size
));
math
::
CrossEntropyFunctor
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
y
,
&
x
,
&
label
,
false
,
ignore_index
);
}
};
...
...
paddle/fluid/operators/cross_entropy_op.cc
浏览文件 @
b26e9bd2
...
...
@@ -14,128 +14,11 @@ limitations under the License. */
#include "paddle/fluid/operators/cross_entropy_op.h"
#include <string>
#include "paddle/fluid/operators/cross_entropy_op_base.h"
namespace
paddle
{
namespace
operators
{
class
CrossEntropyOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
"Output(Y) should be not null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
label_dims
=
ctx
->
GetInputDim
(
"Label"
);
int
rank
=
x_dims
.
size
();
PADDLE_ENFORCE_EQ
(
rank
,
label_dims
.
size
(),
"Input(X) and Input(Label) shall have the same rank."
);
bool
check
=
true
;
if
((
!
ctx
->
IsRuntime
())
&&
(
framework
::
product
(
x_dims
)
<=
0
||
framework
::
product
(
label_dims
)
<=
0
))
{
check
=
false
;
}
if
(
check
)
{
PADDLE_ENFORCE_EQ
(
framework
::
slice_ddim
(
x_dims
,
0
,
rank
-
1
),
framework
::
slice_ddim
(
label_dims
,
0
,
rank
-
1
),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension."
);
}
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"soft_label"
))
{
if
(
check
)
{
PADDLE_ENFORCE_EQ
(
x_dims
[
rank
-
1
],
label_dims
[
rank
-
1
],
"If Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal."
);
}
}
else
{
PADDLE_ENFORCE_EQ
(
label_dims
[
rank
-
1
],
1UL
,
"If Attr(softLabel) == false, the last dimension of "
"Input(Label) should be 1."
);
}
auto
y_dims
=
x_dims
;
y_dims
[
rank
-
1
]
=
1
;
ctx
->
SetOutputDim
(
"Y"
,
y_dims
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
}
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
ctx
.
device_context
());
}
};
class
CrossEntropyGradientOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
"Input(Y@GRAD) shoudl be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"Output(X@GRAD) should be not null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
label_dims
=
ctx
->
GetInputDim
(
"Label"
);
auto
dy_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Y"
));
int
rank
=
x_dims
.
size
();
PADDLE_ENFORCE_EQ
(
dy_dims
.
size
(),
rank
,
"Input(Y@Grad) and Input(X) should have the same rank."
);
PADDLE_ENFORCE_EQ
(
label_dims
.
size
(),
rank
,
"Input(Label) and Input(X) should have the same rank."
);
bool
check
=
true
;
if
((
!
ctx
->
IsRuntime
())
&&
(
framework
::
product
(
x_dims
)
<=
0
||
framework
::
product
(
label_dims
)
<=
0
))
{
check
=
false
;
}
if
(
check
)
{
PADDLE_ENFORCE_EQ
(
framework
::
slice_ddim
(
x_dims
,
0
,
rank
-
1
),
framework
::
slice_ddim
(
label_dims
,
0
,
rank
-
1
),
"The Input(X) and Input(Label) should have the same "
"shape except the last dimension."
);
PADDLE_ENFORCE_EQ
(
framework
::
slice_ddim
(
x_dims
,
0
,
rank
-
1
),
framework
::
slice_ddim
(
dy_dims
,
0
,
rank
-
1
),
"The Input(X) and Input(Y@Grad) should have the same "
"shape except the last dimension."
);
}
PADDLE_ENFORCE_EQ
(
dy_dims
[
rank
-
1
],
1
,
"The last dimension of Input(Y@Grad) should be 1."
);
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"soft_label"
))
{
if
(
check
)
{
PADDLE_ENFORCE_EQ
(
x_dims
[
rank
-
1
],
label_dims
[
rank
-
1
],
"When Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal."
);
}
}
else
{
PADDLE_ENFORCE_EQ
(
label_dims
[
rank
-
1
],
1
,
"When Attr(soft_label) == false, the last dimension of "
"Input(Label) should be 1."
);
}
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
ctx
->
ShareLoD
(
"X"
,
framework
::
GradVarName
(
"X"
));
}
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
ctx
.
device_context
());
}
};
class
CrossEntropyOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
...
...
@@ -200,22 +83,24 @@ or not. But the output only shares the LoD information with input X.
}
};
class
CrossEntropyOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Y"
}};
class
CrossEntropyGradientOp
:
public
CrossEntropyGradientOpBase
{
public:
using
CrossEntropyGradientOpBase
::
CrossEntropyGradientOpBase
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
CrossEntropyGradientOpBase
::
InferShape
(
ctx
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
using
CPUCtx
=
paddle
::
platform
::
CPUDeviceContext
;
REGISTER_OPERATOR
(
cross_entropy
,
ops
::
CrossEntropyOp
,
ops
::
CrossEntropyOpMaker
,
ops
::
CrossEntropyOpInferVarType
,
REGISTER_OPERATOR
(
cross_entropy
,
ops
::
CrossEntropyOp
Base
,
ops
::
CrossEntropyOp
Maker
,
ops
::
CrossEntropyOp
InferVarType
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
cross_entropy_grad
,
ops
::
CrossEntropyGradientOp
);
REGISTER_OP_CPU_KERNEL
(
cross_entropy
,
ops
::
CrossEntropyOpKernel
<
CPUCtx
,
float
>
,
...
...
paddle/fluid/operators/cross_entropy_op_base.h
0 → 100644
浏览文件 @
b26e9bd2
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
class
CrossEntropyOpBase
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
"Output(Y) should be not null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
label_dims
=
ctx
->
GetInputDim
(
"Label"
);
int
rank
=
x_dims
.
size
();
PADDLE_ENFORCE_EQ
(
rank
,
label_dims
.
size
(),
"Input(X) and Input(Label) shall have the same rank."
);
bool
check
=
true
;
if
((
!
ctx
->
IsRuntime
())
&&
(
framework
::
product
(
x_dims
)
<=
0
||
framework
::
product
(
label_dims
)
<=
0
))
{
check
=
false
;
}
if
(
check
)
{
PADDLE_ENFORCE_EQ
(
framework
::
slice_ddim
(
x_dims
,
0
,
rank
-
1
),
framework
::
slice_ddim
(
label_dims
,
0
,
rank
-
1
),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension."
);
}
if
(
IsSoftLabel
(
ctx
))
{
if
(
check
)
{
PADDLE_ENFORCE_EQ
(
x_dims
[
rank
-
1
],
label_dims
[
rank
-
1
],
"If Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal."
);
}
}
else
{
PADDLE_ENFORCE_EQ
(
label_dims
[
rank
-
1
],
1UL
,
"If Attr(softLabel) == false, the last dimension of "
"Input(Label) should be 1."
);
}
auto
y_dims
=
x_dims
;
y_dims
[
rank
-
1
]
=
1
;
ctx
->
SetOutputDim
(
"Y"
,
y_dims
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
}
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
ctx
.
device_context
());
}
virtual
bool
IsSoftLabel
(
framework
::
InferShapeContext
*
ctx
)
const
{
return
ctx
->
Attrs
().
Get
<
bool
>
(
"soft_label"
);
}
};
class
CrossEntropyOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetInputOutputWithSameType
()
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Y"
}};
}
};
class
CrossEntropyGradientOpBase
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
"Input(Y@GRAD) shoudl be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"Output(X@GRAD) should be not null."
);
auto
x_dims
=
GetXDim
(
ctx
);
auto
label_dims
=
ctx
->
GetInputDim
(
"Label"
);
auto
dy_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Y"
));
int
rank
=
x_dims
.
size
();
PADDLE_ENFORCE_EQ
(
dy_dims
.
size
(),
rank
,
"Input(Y@Grad) and Input(X) should have the same rank."
);
PADDLE_ENFORCE_EQ
(
label_dims
.
size
(),
rank
,
"Input(Label) and Input(X) should have the same rank."
);
bool
check
=
true
;
if
((
!
ctx
->
IsRuntime
())
&&
(
framework
::
product
(
x_dims
)
<=
0
||
framework
::
product
(
label_dims
)
<=
0
))
{
check
=
false
;
}
if
(
check
)
{
PADDLE_ENFORCE_EQ
(
framework
::
slice_ddim
(
x_dims
,
0
,
rank
-
1
),
framework
::
slice_ddim
(
label_dims
,
0
,
rank
-
1
),
"The Input(X) and Input(Label) should have the same "
"shape except the last dimension."
);
PADDLE_ENFORCE_EQ
(
framework
::
slice_ddim
(
x_dims
,
0
,
rank
-
1
),
framework
::
slice_ddim
(
dy_dims
,
0
,
rank
-
1
),
"The Input(X) and Input(Y@Grad) should have the same "
"shape except the last dimension."
);
}
if
(
IsSoftLabel
(
ctx
))
{
if
(
check
)
{
PADDLE_ENFORCE_EQ
(
x_dims
[
rank
-
1
],
label_dims
[
rank
-
1
],
"When Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal."
);
}
}
else
{
PADDLE_ENFORCE_EQ
(
label_dims
[
rank
-
1
],
1
,
"When Attr(soft_label) == false, the last dimension of "
"Input(Label) should be 1."
);
}
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
PADDLE_ENFORCE_EQ
(
dy_dims
[
rank
-
1
],
1
,
"The last dimension of Input(Y@Grad) should be 1."
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
ctx
->
ShareLoD
(
VarNameWithXLoD
(),
framework
::
GradVarName
(
"X"
));
}
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
))
->
type
(),
ctx
.
device_context
());
}
virtual
framework
::
DDim
GetXDim
(
framework
::
InferShapeContext
*
ctx
)
const
{
return
ctx
->
GetInputDim
(
"X"
);
}
virtual
const
char
*
VarNameWithXLoD
()
const
{
return
"X"
;
}
virtual
bool
IsSoftLabel
(
framework
::
InferShapeContext
*
ctx
)
const
{
return
ctx
->
Attrs
().
Get
<
bool
>
(
"soft_label"
);
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/expand_op.cc
浏览文件 @
b26e9bd2
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/expand_op.h"
#include <memory>
#include <vector>
namespace
paddle
{
...
...
paddle/fluid/operators/math.h
0 → 100644
浏览文件 @
b26e9bd2
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "math.h" // NOLINT
namespace
paddle
{
namespace
operators
{
inline
HOSTDEVICE
platform
::
float16
real_exp
(
platform
::
float16
x
)
{
return
static_cast
<
platform
::
float16
>
(
::
expf
(
static_cast
<
float
>
(
x
)));
}
inline
HOSTDEVICE
float
real_exp
(
float
x
)
{
return
::
expf
(
x
);
}
inline
HOSTDEVICE
double
real_exp
(
double
x
)
{
return
::
exp
(
x
);
}
inline
HOSTDEVICE
platform
::
float16
real_log
(
platform
::
float16
x
)
{
return
static_cast
<
platform
::
float16
>
(
::
logf
(
static_cast
<
float
>
(
x
)));
}
inline
HOSTDEVICE
float
real_log
(
float
x
)
{
return
::
logf
(
x
);
}
inline
HOSTDEVICE
double
real_log
(
double
x
)
{
return
::
log
(
x
);
}
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/cross_entropy.cu
浏览文件 @
b26e9bd2
...
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
...
...
@@ -20,17 +21,6 @@ namespace paddle {
namespace
operators
{
namespace
math
{
namespace
{
__device__
__forceinline__
float
real_log
(
float
x
)
{
return
logf
(
x
);
}
__device__
__forceinline__
double
real_log
(
double
x
)
{
return
log
(
x
);
}
__device__
__forceinline__
platform
::
float16
real_log
(
const
platform
::
float16
&
val
)
{
return
static_cast
<
platform
::
float16
>
(
logf
(
static_cast
<
float
>
(
val
)));
}
template
<
typename
T
>
__global__
void
CrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
int64_t
*
label
,
const
int
N
,
const
int
D
,
...
...
@@ -61,7 +51,6 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
Y
[
blockIdx
.
x
]
=
-
val
;
}
}
}
// namespace
template
<
typename
T
>
class
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
...
...
paddle/fluid/operators/selu_op.h
浏览文件 @
b26e9bd2
...
...
@@ -15,13 +15,12 @@ limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
static
HOSTDEVICE
float
real_exp
(
float
x
)
{
return
expf
(
x
);
}
static
HOSTDEVICE
float
real_exp
(
double
x
)
{
return
exp
(
x
);
}
template
<
typename
T
>
struct
SeluFunctor
{
SeluFunctor
(
const
T
*
x_data_ptr
,
float
alpha
,
float
scale
,
T
*
y_data_ptr
)
...
...
paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu
浏览文件 @
b26e9bd2
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include <algorithm>
#include <cub/cub.cuh> // NOLINT
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/sequence_ops/sequence_softmax_op.h"
namespace
paddle
{
...
...
@@ -21,9 +22,6 @@ namespace operators {
using
LoDTensor
=
framework
::
LoDTensor
;
__device__
__forceinline__
float
real_exp
(
float
x
)
{
return
expf
(
x
);
}
__device__
__forceinline__
double
real_exp
(
double
x
)
{
return
exp
(
x
);
}
template
<
typename
T
,
int
BlockDim
>
using
BlockReduce
=
cub
::
BlockReduce
<
T
,
BlockDim
>
;
...
...
paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu
浏览文件 @
b26e9bd2
...
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cub/cub.cuh"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/hostdevice.h"
...
...
@@ -21,11 +22,6 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
static
HOSTDEVICE
float
real_exp
(
float
x
)
{
return
expf
(
x
);
}
static
HOSTDEVICE
float
real_exp
(
double
x
)
{
return
exp
(
x
);
}
static
HOSTDEVICE
float
real_log
(
float
x
)
{
return
logf
(
x
);
}
static
HOSTDEVICE
float
real_log
(
double
x
)
{
return
log
(
x
);
}
static
constexpr
int
kNumCUDAThreads
=
512
;
static
constexpr
int
kNumMaxinumNumBlocks
=
4096
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录