Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
da3f7668
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
da3f7668
编写于
7月 12, 2018
作者:
G
Guo Sheng
提交者:
GitHub
7月 12, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12088 from guoshengCS/complete-hsigmoid
Complete hsigmoid_op
上级
3c4f04b7
4ee069fd
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
802 addition
and
1 deletion
+802
-1
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+1
-0
paddle/fluid/operators/hierarchical_sigmoid_op.cc
paddle/fluid/operators/hierarchical_sigmoid_op.cc
+167
-0
paddle/fluid/operators/hierarchical_sigmoid_op.h
paddle/fluid/operators/hierarchical_sigmoid_op.h
+135
-0
paddle/fluid/operators/math/CMakeLists.txt
paddle/fluid/operators/math/CMakeLists.txt
+1
-0
paddle/fluid/operators/math/math_function_impl.h
paddle/fluid/operators/math/math_function_impl.h
+1
-1
paddle/fluid/operators/math/matrix_bit_code.cc
paddle/fluid/operators/math/matrix_bit_code.cc
+176
-0
paddle/fluid/operators/math/matrix_bit_code.h
paddle/fluid/operators/math/matrix_bit_code.h
+143
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+69
-0
python/paddle/fluid/tests/unittests/test_hsigmoid_op.py
python/paddle/fluid/tests/unittests/test_hsigmoid_op.py
+99
-0
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+10
-0
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
da3f7668
...
...
@@ -259,6 +259,7 @@ op_library(max_sequence_len_op DEPS lod_rank_table)
op_library
(
sequence_conv_op DEPS context_project
)
op_library
(
sequence_pool_op DEPS sequence_pooling
)
op_library
(
lstm_op DEPS sequence2batch lstm_compute
)
op_library
(
hierarchical_sigmoid_op DEPS matrix_bit_code
)
op_library
(
lstmp_op DEPS sequence2batch lstm_compute
)
op_library
(
gru_op DEPS sequence2batch gru_compute
)
op_library
(
recurrent_op DEPS executor
)
...
...
paddle/fluid/operators/hierarchical_sigmoid_op.cc
0 → 100644
浏览文件 @
da3f7668
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/hierarchical_sigmoid_op.h"
#include <vector>
namespace
paddle
{
namespace
operators
{
/**
* Organize the classes into a binary tree. At each node, a sigmoid function
* is used to calculate the probability of belonging to the right branch.
* This idea is from "F. Morin, Y. Bengio (AISTATS 05):
* Hierarchical Probabilistic Neural Network Language Model."
*
* Here we uses a simple way of making the binary tree.
* Assuming the number of classes C = 6,
* The classes are organized as a binary tree in the following way:
*
* @code{.py}
* *-*-*- 2
* | | |- 3
* | |
* | |-*- 4
* | |- 5
* |
* |-*- 0
* |- 1
* @endcode
*
* where * indicates an internal node, and each leaf node represents a class.
* - Node 0 ... C-2 are internal nodes.
* - Node C-1 ... 2C-2 are leaf nodes.
* - Class c is represented by leaf node \f$c+C-1\f$.
*
* We assign an id for each node:
* - the id of root be 0.
* - the left child of a node i is 2*i+1.
* - the right child of a node i is 2*i+2.
*
* It's easy to see that:
* - the parent of node i is \f$\left\lfloor(i-1)/2\right\rfloor\f$.
* - the j-th level ancestor of node i is
* \f$\left\lfloor(i+1)/2^{j+1}\right\rfloor - 1\f$.
* - A node i is a left child of its parent if \f$(i-1)\%2==0\f$.
*
*/
class
HierarchicalSigmoidOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"W"
),
"Input(W) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"PreOut"
),
"Output(PreOut) should not be null."
);
const
int64_t
batch_size
=
ctx
->
GetInputDim
(
"X"
)[
0
];
std
::
vector
<
int64_t
>
output_shape
({
batch_size
,
1
});
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
output_shape
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
());
}
};
template
<
typename
AttrType
>
class
HierarchicalSigmoidOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor, required) The input tensor with shape [N, D], "
"where N is the size of mini-batch, and D is the feature size."
);
AddInput
(
"W"
,
"(Tensor, required), The parameters of hierarchical "
"sigmoid operator, each of them is a 2-D tensor, the shape is"
"[num_classes - 1, D]."
);
AddInput
(
"Label"
,
"(Tensor, required), The labels of training data. It's a"
"tensor with shape [N, 1]."
);
AddInput
(
"Bias"
,
"(Tensor, optional), The bias is a tensor with shape"
"[1, num_classes - 1]."
);
AddOutput
(
"Out"
,
"(Tensor, required) The output of hierarchical sigmoid operator."
"The shape is [N, 1]."
);
AddOutput
(
"PreOut"
,
"(Tensor, required) A intermedia 2-D tensor with shape "
"[batch_size, code_length], where code_length represents the "
"maximum path length from root to leaf nodes."
)
.
AsIntermediate
();
AddAttr
<
AttrType
>
(
"num_classes"
,
"(int, required), The number of classes"
)
.
SetDefault
(
2
);
AddComment
(
R"DOC(
The hierarchical sigmoid operator organize the classes into a binary tree.
At each node, a sigmoid function is used to calculate the probability of
belonging to the right branch. This idea is from
"F. Morin, Y. Bengio (AISTATS 05):
Hierarchical Probabilistic Neural Network Language Model."
)DOC"
);
}
};
class
HierarchicalSigmoidGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"W"
),
"Input(W) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"PreOut"
),
"Input(Preout) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"W"
)),
"Output(W@Grad should not be null.)"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)));
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Bias"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Bias"
),
ctx
->
GetInputDim
(
"Bias"
));
}
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"W"
),
ctx
->
GetInputDim
(
"W"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
());
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
hierarchical_sigmoid
,
ops
::
HierarchicalSigmoidOp
,
ops
::
HierarchicalSigmoidOpMaker
<
int
>
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
hierarchical_sigmoid_grad
,
ops
::
HierarchicalSigmoidGradOp
);
REGISTER_OP_CPU_KERNEL
(
hierarchical_sigmoid
,
ops
::
HierarchicalSigmoidOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
HierarchicalSigmoidOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
hierarchical_sigmoid_grad
,
ops
::
HierarchicalSigmoidGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
HierarchicalSigmoidGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/hierarchical_sigmoid_op.h
0 → 100644
浏览文件 @
da3f7668
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
using
platform
::
Transform
;
template
<
typename
DeviceContext
,
typename
T
>
class
HierarchicalSigmoidOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
w
=
ctx
.
Input
<
framework
::
Tensor
>
(
"W"
);
auto
*
label
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Label"
);
auto
*
bias
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Bias"
);
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
pre_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"PreOut"
);
size_t
num_classes
=
static_cast
<
size_t
>
(
ctx
.
Attr
<
int
>
(
"num_classes"
));
int64_t
code_length
=
math
::
FindLastSet
(
num_classes
-
1
);
int64_t
batch_size
=
in
->
dims
()[
0
];
framework
::
Tensor
sum
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
*
pre_out_data
=
pre_out
->
mutable_data
<
T
>
(
framework
::
make_ddim
({
batch_size
,
code_length
}),
ctx
.
GetPlace
());
auto
pre_out_mat
=
EigenMatrix
<
T
>::
From
(
*
pre_out
);
// Not all class(leaf) nodes' path lengths equal code_length, thus init as
// 0s can avoid out of path's loss.
math
::
SetConstant
<
DeviceContext
,
T
>
zero
;
zero
(
dev_ctx
,
pre_out
,
static_cast
<
T
>
(
0.0
));
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
math
::
RowwiseSum
<
DeviceContext
,
T
>
row_sum
;
math
::
MatrixBitCodeFunctor
<
T
>
bit_code
(
num_classes
,
label
->
data
<
int64_t
>
());
std
::
vector
<
int64_t
>
sum_dims
({
batch_size
,
1UL
});
sum
.
mutable_data
<
T
>
(
framework
::
make_ddim
(
sum_dims
),
ctx
.
GetPlace
());
auto
sum_mat
=
EigenMatrix
<
T
>::
From
(
sum
);
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
out_mat
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
if
(
bias
)
{
bit_code
.
Add
(
pre_out
,
*
bias
);
}
bit_code
.
Mul
(
pre_out
,
*
w
,
*
in
);
// clip to [-40, 40]
Transform
<
DeviceContext
>
trans
;
trans
(
ctx
.
template
device_context
<
DeviceContext
>(),
pre_out_data
,
pre_out_data
+
pre_out
->
numel
(),
pre_out_data
,
ClipFunctor
<
T
>
(
static_cast
<
T
>
(
-
40.0
),
static_cast
<
T
>
(
40.0
)));
bit_code
.
Sum
(
*
pre_out
,
out
,
static_cast
<
T
>
(
-
1
));
// use softrelu to calculate cross entropy
pre_out_mat
.
device
(
place
)
=
(
static_cast
<
T
>
(
1.0
)
+
pre_out_mat
.
exp
()).
log
();
row_sum
(
dev_ctx
,
*
pre_out
,
&
sum
);
// TODO(guosheng): Subtract the out of path's loss, since not all
// class(leaf) nodes' path lengths equal code_length. But it won't break the
// gradient check since both have the out of path's loss and will cancel out
// each other.
out_mat
.
device
(
place
)
=
sum_mat
+
out_mat
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
HierarchicalSigmoidGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
w
=
ctx
.
Input
<
framework
::
Tensor
>
(
"W"
);
auto
*
in_grad
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
w_grad
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"W"
));
auto
*
bias_grad
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Bias"
));
auto
*
label
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Label"
);
auto
*
pre_out
=
ctx
.
Input
<
framework
::
Tensor
>
(
"PreOut"
);
auto
*
out_grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
framework
::
Tensor
pre_out_grad
;
pre_out_grad
.
mutable_data
<
T
>
(
pre_out
->
dims
(),
ctx
.
GetPlace
());
in_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
w_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
math
::
SetConstant
<
DeviceContext
,
T
>
zero
;
zero
(
dev_ctx
,
in_grad
,
static_cast
<
T
>
(
0.0
));
zero
(
dev_ctx
,
w_grad
,
static_cast
<
T
>
(
0.0
));
size_t
num_classes
=
static_cast
<
size_t
>
(
ctx
.
Attr
<
int
>
(
"num_classes"
));
math
::
MatrixBitCodeFunctor
<
T
>
bit_code
(
num_classes
,
label
->
data
<
int64_t
>
());
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
pre_out_mat
=
EigenMatrix
<
T
>::
From
(
*
pre_out
);
auto
pre_out_grad_mat
=
EigenMatrix
<
T
>::
From
(
pre_out_grad
);
auto
out_grad_mat
=
EigenMatrix
<
T
>::
From
(
*
out_grad
);
Eigen
::
array
<
int
,
2
>
bcast
({{
1
,
static_cast
<
int
>
(
pre_out_grad
.
dims
()[
1
])}});
// softrelu derivative
pre_out_grad_mat
.
device
(
place
)
=
static_cast
<
T
>
(
1.0
)
-
static_cast
<
T
>
(
1.0
)
/
pre_out_mat
.
exp
();
bit_code
.
Sub
(
&
pre_out_grad
);
// the gradient of clip(w * x + b)
pre_out_grad_mat
.
device
(
place
)
=
pre_out_grad_mat
*
out_grad_mat
.
broadcast
(
bcast
);
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
// be consistent with the clipping in forward.
if
(
bias_grad
)
{
bias_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
zero
(
dev_ctx
,
bias_grad
,
static_cast
<
T
>
(
0.0
));
bit_code
.
AddGrad
(
pre_out_grad
,
bias_grad
);
}
bit_code
.
MulGradWeight
(
pre_out_grad
,
w_grad
,
*
in
);
bit_code
.
MulGradError
(
pre_out_grad
,
*
w
,
in_grad
);
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/CMakeLists.txt
浏览文件 @
da3f7668
...
...
@@ -51,6 +51,7 @@ math_library(sequence_padding)
math_library
(
sequence_pooling DEPS math_function
)
math_library
(
sequence_scale
)
math_library
(
softmax DEPS math_function
)
math_library
(
matrix_bit_code
)
math_library
(
unpooling
)
math_library
(
vol2col
)
...
...
paddle/fluid/operators/math/math_function_impl.h
浏览文件 @
da3f7668
...
...
@@ -155,7 +155,7 @@ class RowwiseSum<platform::CPUDeviceContext, T> {
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
2U
);
auto
height
=
in_dims
[
0
];
auto
size
=
in_dims
[
1
];
PADDLE_ENFORCE_EQ
(
out
->
numel
(),
size
);
PADDLE_ENFORCE_EQ
(
out
->
numel
(),
height
);
T
*
out_buf
=
out
->
mutable_data
<
T
>
(
out
->
place
());
const
T
*
in_buf
=
input
.
data
<
T
>
();
...
...
paddle/fluid/operators/math/matrix_bit_code.cc
0 → 100644
浏览文件 @
da3f7668
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include <iostream>
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
void
MatrixBitCodeFunctor
<
T
>::
Add
(
framework
::
Tensor
*
tmat
,
const
framework
::
Tensor
&
vec
)
{
SimpleCodeTable
code_table
(
num_classes_
);
size_t
batch_size
=
tmat
->
dims
()[
0
];
size_t
width
=
tmat
->
dims
()[
1
];
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
auto
code
=
code_table
(
static_cast
<
size_t
>
(
ids_
[
i
]));
int
code_length
=
code
.
get_length
();
for
(
int
j
=
0
;
j
<
code_length
;
++
j
)
{
size_t
index
=
code
.
calc_index
(
j
);
tmat
->
data
<
T
>
()[
i
*
width
+
j
]
+=
vec
.
data
<
T
>
()[
index
];
}
}
}
template
<
typename
T
>
void
MatrixBitCodeFunctor
<
T
>::
AddGrad
(
const
framework
::
Tensor
&
tmat
,
framework
::
Tensor
*
vec
)
{
SimpleCodeTable
code_table
(
num_classes_
);
size_t
batch_size
=
tmat
.
dims
()[
0
];
size_t
width
=
tmat
.
dims
()[
1
];
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
auto
code
=
code_table
(
static_cast
<
size_t
>
(
ids_
[
i
]));
int
code_length
=
code
.
get_length
();
for
(
int
j
=
0
;
j
<
code_length
;
++
j
)
{
size_t
index
=
code
.
calc_index
(
j
);
vec
->
data
<
T
>
()[
index
]
+=
tmat
.
data
<
T
>
()[
i
*
width
+
j
];
}
}
}
template
<
typename
T
>
void
MatrixBitCodeFunctor
<
T
>::
Sum
(
const
framework
::
Tensor
&
tmat
,
framework
::
Tensor
*
sum
,
T
scale_sum
)
{
SimpleCodeTable
code_table
(
num_classes_
);
size_t
num_samples
=
tmat
.
dims
()[
0
];
size_t
o_width
=
tmat
.
dims
()[
1
];
for
(
size_t
i
=
0
;
i
<
num_samples
;
++
i
)
{
T
sm
=
static_cast
<
T
>
(
0.0
);
auto
code
=
code_table
(
static_cast
<
size_t
>
(
ids_
[
i
]));
int
code_length
=
code
.
get_length
();
for
(
int
j
=
0
;
j
<
code_length
;
++
j
)
{
if
(
code
.
calc_bit
(
j
))
{
// calc_bit starts from right most bit, while data in tmat[i] is in the
// reverse order.
sm
+=
tmat
.
data
<
T
>
()[
i
*
o_width
+
j
];
}
}
sum
->
data
<
T
>
()[
i
]
=
scale_sum
*
sm
;
}
}
template
<
typename
T
>
void
MatrixBitCodeFunctor
<
T
>::
Mul
(
framework
::
Tensor
*
tmat
,
const
framework
::
Tensor
&
weight
,
const
framework
::
Tensor
&
input
)
{
SimpleCodeTable
code_table
(
num_classes_
);
size_t
num_samples
=
tmat
->
dims
()[
0
];
size_t
tmat_width
=
tmat
->
dims
()[
1
];
size_t
input_width
=
input
.
dims
()[
1
];
size_t
weight_width
=
weight
.
dims
()[
1
];
auto
tmat_value
=
tmat
->
data
<
T
>
();
auto
weight_value
=
weight
.
data
<
T
>
();
auto
input_value
=
input
.
data
<
T
>
();
for
(
size_t
i
=
0
;
i
<
num_samples
;
++
i
)
{
auto
code
=
code_table
(
static_cast
<
size_t
>
(
ids_
[
i
]));
int
code_length
=
code
.
get_length
();
for
(
int
j
=
0
;
j
<
code_length
;
++
j
)
{
size_t
index
=
code
.
calc_index
(
j
);
T
sum
=
static_cast
<
T
>
(
0.0
);
for
(
size_t
k
=
0
;
k
<
input_width
;
++
k
)
{
sum
+=
weight_value
[
weight_width
*
index
+
k
]
*
input_value
[
input_width
*
i
+
k
];
}
tmat_value
[
i
*
tmat_width
+
j
]
+=
sum
;
}
}
}
template
<
typename
T
>
void
MatrixBitCodeFunctor
<
T
>::
MulGradWeight
(
const
framework
::
Tensor
&
tmat
,
framework
::
Tensor
*
weight
,
const
framework
::
Tensor
&
input
)
{
SimpleCodeTable
code_table
(
num_classes_
);
size_t
num_samples
=
tmat
.
dims
()[
0
];
size_t
input_width
=
input
.
dims
()[
1
];
size_t
tmat_width
=
tmat
.
dims
()[
1
];
size_t
weight_width
=
weight
->
dims
()[
1
];
auto
tmat_value
=
tmat
.
data
<
T
>
();
auto
weight_value
=
weight
->
data
<
T
>
();
auto
input_value
=
input
.
data
<
T
>
();
for
(
size_t
i
=
0
;
i
<
num_samples
;
++
i
)
{
auto
code
=
code_table
(
static_cast
<
size_t
>
(
ids_
[
i
]));
int
code_length
=
code
.
get_length
();
for
(
int
j
=
0
;
j
<
code_length
;
++
j
)
{
size_t
index
=
code
.
calc_index
(
j
);
for
(
size_t
k
=
0
;
k
<
input_width
;
++
k
)
{
weight_value
[
weight_width
*
index
+
k
]
+=
tmat_value
[
i
*
tmat_width
+
j
]
*
input_value
[
input_width
*
i
+
k
];
}
}
}
}
template
<
typename
T
>
void
MatrixBitCodeFunctor
<
T
>::
MulGradError
(
const
framework
::
Tensor
&
tmat
,
const
framework
::
Tensor
&
weight
,
framework
::
Tensor
*
input
)
{
SimpleCodeTable
code_table
(
num_classes_
);
size_t
num_samples
=
tmat
.
dims
()[
0
];
size_t
tmat_width
=
tmat
.
dims
()[
1
];
size_t
input_width
=
input
->
dims
()[
1
];
size_t
weight_width
=
weight
.
dims
()[
1
];
auto
tmat_value
=
tmat
.
data
<
T
>
();
auto
weight_value
=
weight
.
data
<
T
>
();
auto
input_value
=
input
->
data
<
T
>
();
for
(
size_t
i
=
0
;
i
<
num_samples
;
++
i
)
{
auto
code
=
code_table
(
static_cast
<
size_t
>
(
ids_
[
i
]));
int
code_length
=
code
.
get_length
();
for
(
int
j
=
0
;
j
<
code_length
;
++
j
)
{
size_t
index
=
code
.
calc_index
(
j
);
for
(
size_t
k
=
0
;
k
<
input_width
;
++
k
)
{
input_value
[
input_width
*
i
+
k
]
+=
tmat_value
[
i
*
tmat_width
+
j
]
*
weight_value
[
weight_width
*
index
+
k
];
}
}
}
}
template
<
typename
T
>
void
MatrixBitCodeFunctor
<
T
>::
Sub
(
framework
::
Tensor
*
tmat
)
{
SimpleCodeTable
code_table
(
num_classes_
);
size_t
num_samples
=
tmat
->
dims
()[
0
];
size_t
o_width
=
tmat
->
dims
()[
1
];
for
(
size_t
i
=
0
;
i
<
num_samples
;
++
i
)
{
auto
code
=
code_table
(
static_cast
<
size_t
>
(
ids_
[
i
]));
int
code_length
=
code
.
get_length
();
for
(
int
j
=
0
;
j
<
code_length
;
++
j
)
{
if
(
code
.
calc_bit
(
j
))
{
tmat
->
data
<
T
>
()[
i
*
o_width
+
j
]
-=
1
;
}
}
}
}
template
class
MatrixBitCodeFunctor
<
float
>;
template
class
MatrixBitCodeFunctor
<
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/matrix_bit_code.h
0 → 100644
浏览文件 @
da3f7668
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
/**
* SimpleCodeTable class should support 3 functions:
*
* size_t size()
* return the number of ids
*
* int get_max_code_length()
* return the maximal code length
*
* SimpleCode operator()(size_t i)
* return the i-th code. Code class is descriebed below.
*
* SimpleCode class should support 3 functions:
*
* int get_length()
* return the length of the code
*
* size_t cal_index(int bit)
* bit ranges from 0 to get_length() - 1
* return the index for the (1+bit) level parent
*
* bool calc_bit(int bit)
* return true if the bit level parent is the right child of (1+bit) level
* parent
*
*/
/**
* return the 1-based index of the highest bit set
*
* for x > 0:
* \f[
* FindLastSet(x) = 1 + \floor*{\log_{2}x}
* \f]
*/
inline
constexpr
size_t
FindLastSet
(
size_t
x
)
{
return
std
::
is_same
<
size_t
,
unsigned
int
>::
value
?
(
x
?
8
*
sizeof
(
x
)
-
__builtin_clz
(
x
)
:
0
)
:
(
std
::
is_same
<
size_t
,
unsigned
long
>::
value
// NOLINT
?
(
x
?
8
*
sizeof
(
x
)
-
__builtin_clzl
(
x
)
:
0
)
:
(
x
?
8
*
sizeof
(
x
)
-
__builtin_clzll
(
x
)
:
0
));
}
struct
SimpleCode
{
SimpleCode
(
size_t
code
,
size_t
num_classes
)
:
c_
(
code
+
num_classes
)
{}
/**
* Here the id of root shoud be 1 rather than 0, thus the encoding of class c
* is `c + num_classes` and all siblings can get the same weight indice using
* prefixes.
* Weight index is the prefixes of encoding, thus leave out the right most
* bit in calc_index.
* Binary classification path is the suffixes of encoding, thus leave out the
* left most bit in calc_bit.
*/
inline
size_t
calc_index
(
int
bit
)
const
{
return
(
c_
>>
(
bit
+
1
))
-
1
;
}
inline
bool
calc_bit
(
int
bit
)
const
{
return
c_
&
(
1
<<
bit
);
}
inline
int
get_length
()
const
{
return
FindLastSet
(
c_
)
-
1
;
}
private:
size_t
c_
;
};
struct
SimpleCodeTable
{
explicit
SimpleCodeTable
(
size_t
num_classes
)
:
num_classes_
(
num_classes
)
{}
SimpleCode
operator
()(
size_t
code
)
const
{
return
SimpleCode
(
code
,
num_classes_
);
}
size_t
size
()
const
{
return
num_classes_
;
}
int
get_max_code_length
()
const
{
return
FindLastSet
(
num_classes_
-
1
);
}
private:
size_t
num_classes_
;
};
template
<
typename
T
>
class
MatrixBitCodeFunctor
{
public:
explicit
MatrixBitCodeFunctor
(
size_t
num_classes
,
const
int64_t
*
ids
)
:
num_classes_
(
num_classes
),
ids_
(
ids
)
{}
/* For j < code_length
tmat(i, j) += vec(0, index(i, j))
*/
void
Add
(
framework
::
Tensor
*
tmat
,
const
framework
::
Tensor
&
vec
);
/* For j < code_length
vec(0, index(i, j)) += tmat(i, j)
*/
void
AddGrad
(
const
framework
::
Tensor
&
tmat
,
framework
::
Tensor
*
vec
);
/* For j < code_length
sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
*/
void
Sum
(
const
framework
::
Tensor
&
tmat
,
framework
::
Tensor
*
sum
,
T
scale_sum
);
/* For j < code_length
tmat(i, j) -= bit(i, j)
*/
void
Sub
(
framework
::
Tensor
*
tmat
);
/* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j))
*/
void
Mul
(
framework
::
Tensor
*
tmat
,
const
framework
::
Tensor
&
weight
,
const
framework
::
Tensor
&
input
);
/* For index(i, j) >= 0:
weight.row(index(i, j)) += tmat(i, j) * input.row(i)
*/
void
MulGradWeight
(
const
framework
::
Tensor
&
tmat
,
framework
::
Tensor
*
weight
,
const
framework
::
Tensor
&
input
);
/* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j))
*/
void
MulGradError
(
const
framework
::
Tensor
&
tmat
,
const
framework
::
Tensor
&
weight
,
framework
::
Tensor
*
input
);
size_t
num_classes_
;
const
int64_t
*
ids_
;
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/layers/nn.py
浏览文件 @
da3f7668
...
...
@@ -85,6 +85,7 @@ __all__ = [
'transpose'
,
'im2sequence'
,
'nce'
,
'hsigmoid'
,
'beam_search'
,
'row_conv'
,
'multiplex'
,
...
...
@@ -3871,6 +3872,74 @@ def nce(input,
return
cost
/
(
num_neg_samples
+
1
)
def
hsigmoid
(
input
,
label
,
num_classes
,
param_attr
=
None
,
bias_attr
=
None
):
"""
The hierarchical sigmoid operator is used to accelerate the training
process of language model. This operator organizes the classes into a
complete binary tree, each leaf node represents a class(a word) and each
internal node acts as a binary classifier. For each word there's a unique
path from root to it's leaf node, hsigmoid calculate the cost for each
internal node on the path, and sum them to get a total cost. hsigmoid can
achive a acceleration from :math:`O(N)` to :math:`O(logN)`, where :math:`N`
represents the size of word dict.
Refer to `Hierarchical Probabilistic Neural Network Language Model
<http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>`_
Args:
input (Variable): The input tensor variable with shape
:math:`[N
\\
times D]`, where :math:`N` is the size of mini-batch,
and :math:`D` is the feature size.
label (Variable): The tensor variable contains labels of training data.
It's a tensor with shape is :math:`[N
\\
times 1]`.
num_classes: (int), The number of classes, must not be less than 2.
param_attr (ParamAttr|list of ParamAttr, default None): The parameter
attribute for learnable parameters/weights of this layer.
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter
attribute for the bias of this layer. If it is set to False, no
bias will be applied.
Returns:
Out: (Tensor) The cost of hierarchical sigmoid operator. the shape is [N, 1]
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[2], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='int64')
out = fluid.layers.hsigmoid(input=x, label=y, num_classes=6)
"""
helper
=
LayerHelper
(
'hierarchical_sigmoid'
,
**
locals
())
dtype
=
helper
.
input_dtype
()
out
=
helper
.
create_tmp_variable
(
dtype
)
pre_out
=
helper
.
create_tmp_variable
(
dtype
)
dim
=
input
.
shape
[
1
]
if
num_classes
<
2
:
raise
ValueError
(
"num_classes must not be less than 2."
)
weights
=
helper
.
create_parameter
(
attr
=
helper
.
param_attr
,
shape
=
[
num_classes
-
1
,
dim
],
is_bias
=
False
,
dtype
=
input
.
dtype
)
inputs
=
{
"X"
:
input
,
"W"
:
weights
,
"Label"
:
label
}
if
helper
.
bias_attr
:
bias
=
helper
.
create_parameter
(
attr
=
helper
.
bias_attr
,
shape
=
[
1
,
num_classes
-
1
],
is_bias
=
True
,
dtype
=
input
.
dtype
)
inputs
[
'Bias'
]
=
bias
helper
.
append_op
(
type
=
"hierarchical_sigmoid"
,
inputs
=
inputs
,
outputs
=
{
"Out"
:
out
,
"PreOut"
:
pre_out
},
attrs
=
{
"num_classes"
:
num_classes
})
return
out
def
transpose
(
x
,
perm
,
name
=
None
):
"""
Permute the dimensions of `input` according to `perm`.
...
...
python/paddle/fluid/tests/unittests/test_hsigmoid_op.py
0 → 100644
浏览文件 @
da3f7668
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
math
from
op_test
import
OpTest
def
find_latest_set
(
num
):
return
1
+
int
(
math
.
floor
(
math
.
log
(
num
,
2
)))
class
CodeTable
(
object
):
def
__init__
(
self
,
num_classes
,
code
):
self
.
c
=
num_classes
+
code
def
cal_index
(
self
,
bit
):
return
(
self
.
c
>>
(
bit
+
1
))
-
1
def
get_length
(
self
):
return
find_latest_set
(
self
.
c
)
-
1
def
cal_bit
(
self
,
bit
):
return
self
.
c
&
(
1
<<
bit
)
def
hsigmoid
(
x
,
w
,
label
,
bias
,
num_classes
):
batch_size
=
x
.
shape
[
0
]
code_length
=
find_latest_set
(
num_classes
-
1
)
code_table
=
[
0
for
_
in
range
(
code_length
)]
pre_output
=
np
.
zeros
((
batch_size
,
code_length
))
pre_sum
=
np
.
zeros
((
batch_size
,
1
))
out
=
np
.
zeros
((
batch_size
,
1
)).
astype
(
"float32"
)
for
i
in
range
(
batch_size
):
code_table
=
CodeTable
(
num_classes
,
label
[
i
])
length
=
code_table
.
get_length
()
for
j
in
range
(
length
):
idx
=
code_table
.
cal_index
(
j
)
pre_output
[
i
][
j
]
+=
bias
[
0
][
idx
]
for
i
in
range
(
batch_size
):
code_table
=
CodeTable
(
num_classes
,
label
[
i
])
length
=
code_table
.
get_length
()
for
j
in
range
(
length
):
idx
=
code_table
.
cal_index
(
j
)
pre_output
[
i
][
j
]
+=
np
.
dot
(
w
[
idx
],
x
[
i
])
# clip[-40.0, 40.0]
pre_output
=
np
.
clip
(
pre_output
,
-
40.0
,
40.0
)
# out(i, 0) = \sum_j bit(i, j) * preout(i, j)
for
i
in
range
(
batch_size
):
code_table
=
CodeTable
(
num_classes
,
label
[
i
])
length
=
code_table
.
get_length
()
sum
=
0.0
for
j
in
range
(
length
):
if
code_table
.
cal_bit
(
j
):
sum
+=
pre_output
[
i
][
j
]
out
[
i
]
=
-
1.0
*
sum
# soft relu
pre_output
=
np
.
log
(
1
+
np
.
exp
(
pre_output
))
pre_sum
=
pre_output
.
sum
(
1
).
reshape
((
batch_size
,
1
))
out
+=
pre_sum
return
pre_output
,
out
class
TestHSigmoidOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"hierarchical_sigmoid"
num_classes
=
6
feature_size
=
8
batch_size
=
4
x
=
np
.
random
.
random
((
batch_size
,
feature_size
)).
astype
(
"float32"
)
w
=
np
.
random
.
random
((
num_classes
-
1
,
feature_size
)).
astype
(
"float32"
)
label
=
np
.
random
.
randint
(
0
,
num_classes
,
(
batch_size
,
1
))
bias
=
np
.
random
.
random
((
1
,
num_classes
-
1
)).
astype
(
"float32"
)
self
.
attrs
=
{
'num_classes'
:
num_classes
}
self
.
inputs
=
{
'X'
:
x
,
'W'
:
w
,
'Label'
:
label
,
'Bias'
:
bias
}
pre_output
,
out
=
hsigmoid
(
x
,
w
,
label
,
bias
,
num_classes
)
self
.
outputs
=
{
'PreOut'
:
pre_output
,
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'Bias'
,
'X'
,
'W'
],
[
'Out'
],
no_grad_set
=
set
(
'Label'
))
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
da3f7668
...
...
@@ -174,6 +174,16 @@ class TestBook(unittest.TestCase):
x
=
dat
,
label
=
lbl
))
print
(
str
(
program
))
def
test_hsigmoid
(
self
):
program
=
Program
()
with
program_guard
(
program
):
x
=
layers
.
data
(
name
=
'x'
,
shape
=
[
2
],
dtype
=
'float32'
)
y
=
layers
.
data
(
name
=
'y'
,
shape
=
[
2
],
dtype
=
'int64'
)
self
.
assertIsNotNone
(
layers
.
hsigmoid
(
input
=
x
,
label
=
y
,
num_classes
=
2
))
print
(
str
(
program
))
def
test_sequence_expand
(
self
):
program
=
Program
()
with
program_guard
(
program
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录