Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
570d89ec
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
570d89ec
编写于
12月 06, 2018
作者:
F
frankwhzhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add bpr_loss operator , test=develop
上级
400cf19f
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
357 addition
and
0 deletion
+357
-0
paddle/fluid/operators/bpr_loss_op.cc
paddle/fluid/operators/bpr_loss_op.cc
+149
-0
paddle/fluid/operators/bpr_loss_op.h
paddle/fluid/operators/bpr_loss_op.h
+142
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+13
-0
python/paddle/fluid/tests/unittests/test_bpr_loss_op.py
python/paddle/fluid/tests/unittests/test_bpr_loss_op.py
+53
-0
未找到文件。
paddle/fluid/operators/bpr_loss_op.cc
0 → 100644
浏览文件 @
570d89ec
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/bpr_loss_op.h"
namespace
paddle
{
namespace
operators
{
class
BprLossOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label_Pos"
),
"Input(Label_Pos) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
"Output(Y) should be not null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
label_Pos_dims
=
ctx
->
GetInputDim
(
"Label_Pos"
);
int
rank
=
x_dims
.
size
();
PADDLE_ENFORCE_EQ
(
rank
,
label_Pos_dims
.
size
(),
"Input(X) and Input(Label_Pos) shall have the same rank."
);
PADDLE_ENFORCE_EQ
(
framework
::
slice_ddim
(
x_dims
,
0
,
rank
-
1
),
framework
::
slice_ddim
(
label_Pos_dims
,
0
,
rank
-
1
),
"Input(X) and Input(Label_Pos) shall have the same shape "
"except the last dimension."
);
auto
y_dims
=
x_dims
;
y_dims
[
rank
-
1
]
=
1
;
ctx
->
SetOutputDim
(
"Y"
,
y_dims
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
}
protected:
// Explicitly set that the data type of computation kernel of Seq-bpr
// is determined by its input "X".
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
platform
::
CPUPlace
());
}
};
class
BprLossGradientOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label_Pos"
),
"Input(Label_Pos) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
"Input(Y@GRAD) shoudl be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"Output(X@GRAD) should be not null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
label_pos_dims
=
ctx
->
GetInputDim
(
"Label_Pos"
);
auto
dy_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Y"
));
int
rank
=
x_dims
.
size
();
PADDLE_ENFORCE_EQ
(
dy_dims
.
size
(),
rank
,
"Input(Y@Grad) and Input(X) should have the same rank."
);
PADDLE_ENFORCE_EQ
(
label_pos_dims
.
size
(),
rank
,
"Input(Label_Pos) and Input(X) should have the same rank."
);
PADDLE_ENFORCE_EQ
(
framework
::
slice_ddim
(
x_dims
,
0
,
rank
-
1
),
framework
::
slice_ddim
(
label_pos_dims
,
0
,
rank
-
1
),
"The Input(X) and Input(Label_Pos) should have the same "
"shape except the last dimension."
);
PADDLE_ENFORCE_EQ
(
framework
::
slice_ddim
(
x_dims
,
0
,
rank
-
1
),
framework
::
slice_ddim
(
dy_dims
,
0
,
rank
-
1
),
"The Input(X) and Input(Y@Grad) should have the same "
"shape except the last dimension."
);
PADDLE_ENFORCE_EQ
(
dy_dims
[
rank
-
1
],
1
,
"The last dimension of Input(Y@Grad) should be 1."
);
PADDLE_ENFORCE_EQ
(
label_pos_dims
[
rank
-
1
],
1
,
" the last dimension of Input(Label_Pos) should be 1."
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
ctx
->
ShareLoD
(
"X"
,
framework
::
GradVarName
(
"X"
));
}
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
platform
::
CPUPlace
());
}
};
class
BprLossOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor, default Tensor<float>), a tensor whose last dimension "
"size is equal to the number of classes. This input is a "
"real number."
);
AddInput
(
"Label_Pos"
,
"(Tensor), the tensor which represents the ground truth. It has the "
"same shape with 'X' except the last dimension. the last dimension "
"size is 1."
);
AddOutput
(
"Y"
,
"(Tensor, default Tensor<float>), a tensor whose shape is same "
"with 'X' except that the last dimension size is 1. It "
"represents the sequence bpr loss."
);
AddComment
(
R"DOC(
BprLoss Operator.
This operator belongs to pairwise ranking loss. Label_pos is the desired item.
The loss at a given point in one seesion is defined as:
$Y[i] = -\frac{1}{N_{i}} * \sum_{j=0}^{N_{i}}\log(\sigma(X[i, Label[i]]-X[i, j]))$
Learn more details by reading paper <session-based recommendations with recurrent
neural networks>.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
using
CPUCtx
=
paddle
::
platform
::
CPUDeviceContext
;
REGISTER_OPERATOR
(
bpr_loss
,
ops
::
BprLossOp
,
ops
::
BprLossOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
bpr_loss_grad
,
ops
::
BprLossGradientOp
);
REGISTER_OP_CPU_KERNEL
(
bpr_loss
,
ops
::
BprLossOpKernel
<
CPUCtx
,
float
>
,
ops
::
BprLossOpKernel
<
CPUCtx
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
bpr_loss_grad
,
ops
::
BprLossGradientOpKernel
<
CPUCtx
,
float
>
,
ops
::
BprLossGradientOpKernel
<
CPUCtx
,
double
>
);
paddle/fluid/operators/bpr_loss_op.h
0 → 100644
浏览文件 @
570d89ec
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
struct
TolerableValue
{
HOSTDEVICE
T
operator
()(
const
T
&
x
)
const
{
PADDLE_ASSERT
(
std
::
is_floating_point
<
T
>::
value
);
const
T
kApproInf
=
1e20
;
if
(
x
==
INFINITY
)
return
kApproInf
;
if
(
x
==
-
INFINITY
)
return
-
kApproInf
;
return
x
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
BprLossOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
labels_Pos
=
ctx
.
Input
<
Tensor
>
(
"Label_Pos"
);
auto
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
rank
=
x
->
dims
().
size
();
Tensor
x_2d
=
framework
::
ReshapeToMatrix
(
*
x
,
rank
-
1
);
Tensor
labels_Pos_2d
=
framework
::
ReshapeToMatrix
(
*
labels_Pos
,
rank
-
1
);
Tensor
y_2d
=
framework
::
ReshapeToMatrix
(
*
y
,
rank
-
1
);
const
framework
::
Tensor
*
prob
=
&
x_2d
;
const
framework
::
Tensor
*
labels_pos
=
&
labels_Pos_2d
;
framework
::
Tensor
*
out
=
&
y_2d
;
const
int
step_size
=
prob
->
dims
()[
0
];
const
int
class_num
=
prob
->
dims
()[
1
];
const
T
*
prob_data
=
prob
->
data
<
T
>
();
T
*
loss_data
=
out
->
data
<
T
>
();
const
int64_t
*
label_pos_data
=
labels_pos
->
data
<
int64_t
>
();
for
(
int
i
=
0
;
i
<
step_size
;
++
i
)
{
int
lbl_pos
=
label_pos_data
[
i
];
PADDLE_ENFORCE_GE
(
lbl_pos
,
0
);
PADDLE_ENFORCE_LT
(
lbl_pos
,
class_num
);
int
index_pos
=
i
*
class_num
+
lbl_pos
;
T
sum
=
static_cast
<
T
>
(
0
);
for
(
int
j
=
0
;
j
<
class_num
;
j
++
)
{
if
(
j
==
lbl_pos
)
continue
;
int
index_neg
=
i
*
class_num
+
j
;
sum
+=
TolerableValue
<
T
>
()(
-
std
::
log
(
1.0
f
+
TolerableValue
<
T
>
()(
std
::
exp
(
prob_data
[
index_neg
]
-
prob_data
[
index_pos
]))));
}
loss_data
[
i
]
=
-
sum
/
(
class_num
-
1
);
}
}
};
template
<
typename
T
>
class
XeGradFunctor
{
public:
XeGradFunctor
(
T
*
dx
,
const
T
*
dy
,
// NOLINT
const
T
*
x
,
// NOLINT
const
int64_t
*
label_pos
,
// NOLINT
size_t
num_classes
)
:
dx_
(
dx
),
dy_
(
dy
),
x_
(
x
),
label_pos_
(
label_pos
),
num_classes_
(
num_classes
)
{}
HOSTDEVICE
void
operator
()(
size_t
sample_id
)
{
for
(
size_t
x_offset
=
sample_id
*
num_classes_
;
x_offset
<
(
sample_id
+
1
)
*
num_classes_
;
++
x_offset
)
{
dx_
[
x_offset
]
=
static_cast
<
T
>
(
0
);
}
auto
p_index
=
sample_id
*
num_classes_
+
label_pos_
[
sample_id
];
for
(
size_t
ni
=
0
;
ni
<
num_classes_
;
ni
++
)
{
if
(
label_pos_
[
sample_id
]
==
ni
)
continue
;
auto
n_index
=
sample_id
*
num_classes_
+
ni
;
auto
grad_
=
-
dy_
[
sample_id
]
/
((
num_classes_
-
1
)
*
(
1.0
f
+
TolerableValue
<
T
>
()(
std
::
exp
(
x_
[
p_index
]
-
x_
[
n_index
]))));
dx_
[
p_index
]
+=
grad_
;
dx_
[
n_index
]
-=
grad_
;
}
}
private:
T
*
dx_
;
const
T
*
dy_
;
const
T
*
x_
;
const
int64_t
*
label_pos_
;
size_t
num_classes_
;
};
template
<
typename
DeviceContext
,
typename
T
>
class
BprLossGradientOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
*
label_pos
=
ctx
.
Input
<
Tensor
>
(
"Label_Pos"
);
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
T
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
rank
=
x
->
dims
().
size
();
int64_t
class_num
=
x
->
dims
()[
rank
-
1
];
XeGradFunctor
<
T
>
functor
(
dx_data
,
dy
->
data
<
T
>
(),
x
->
data
<
T
>
(),
label_pos
->
data
<
int64_t
>
(),
static_cast
<
size_t
>
(
class_num
));
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
.
template
device_context
<
DeviceContext
>(),
static_cast
<
size_t
>
(
dy
->
numel
()));
for_range
(
functor
);
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/layers/nn.py
浏览文件 @
570d89ec
...
...
@@ -41,6 +41,7 @@ __all__ = [
'crf_decoding'
,
'cos_sim'
,
'cross_entropy'
,
'bpr_loss'
,
'square_error_cost'
,
'chunk_eval'
,
'sequence_conv'
,
...
...
@@ -1175,6 +1176,18 @@ def cross_entropy(input, label, soft_label=False, ignore_index=-100):
return
out
def
bpr_loss
(
input
,
label_pos
):
helper
=
LayerHelper
(
'bpr_loss'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
input
.
dtype
)
helper
.
append_op
(
type
=
'bpr_loss'
,
inputs
=
{
'X'
:
[
input
],
'Label_Pos'
:
[
label_pos
]},
outputs
=
{
'Y'
:
[
out
]})
return
out
def
square_error_cost
(
input
,
label
):
"""
**Square error cost layer**
...
...
python/paddle/fluid/tests/unittests/test_bpr_loss_op.py
0 → 100644
浏览文件 @
570d89ec
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
,
randomize_probability
class
TestBprLossOp1
(
OpTest
):
"""Test BprLoss with discrete one-hot labels.
"""
def
setUp
(
self
):
self
.
op_type
=
"bpr_loss"
batch_size
=
3
class_num
=
5
X
=
randomize_probability
(
batch_size
,
class_num
,
dtype
=
'float64'
)
label_pos
=
np
.
random
.
randint
(
0
,
class_num
,
(
batch_size
,
1
),
dtype
=
"int64"
)
bpr_loss_result
=
[]
for
i
in
range
(
batch_size
):
sum
=
0.0
for
j
in
range
(
class_num
):
if
j
==
label_pos
[
i
][
0
]:
continue
sum
+=
(
-
np
.
log
(
1.0
+
np
.
exp
(
X
[
i
][
j
]
-
X
[
i
][
label_pos
[
i
][
0
]])))
bpr_loss_result
.
append
(
-
sum
/
(
class_num
-
1
))
bpr_loss
=
np
.
asmatrix
([[
x
]
for
x
in
bpr_loss_result
],
dtype
=
"float64"
)
self
.
inputs
=
{
"X"
:
X
,
"Label_Pos"
:
label_pos
}
self
.
outputs
=
{
"Y"
:
bpr_loss
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Y"
,
numeric_grad_delta
=
0.001
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录