Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
695b1037
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
695b1037
编写于
11月 15, 2017
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
port hsigmoid layer
上级
9f289256
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
305 addition
and
0 deletion
+305
-0
paddle/operators/hierarchical_sigmoid_op.cc
paddle/operators/hierarchical_sigmoid_op.cc
+121
-0
paddle/operators/hierarchical_sigmoid_op.h
paddle/operators/hierarchical_sigmoid_op.h
+35
-0
paddle/operators/math/CMakeLists.txt
paddle/operators/math/CMakeLists.txt
+1
-0
paddle/operators/math/matrix_bit_code.cc
paddle/operators/math/matrix_bit_code.cc
+84
-0
paddle/operators/math/matrix_bit_code.h
paddle/operators/math/matrix_bit_code.h
+64
-0
未找到文件。
paddle/operators/hierarchical_sigmoid_op.cc
0 → 100644
浏览文件 @
695b1037
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hierarchical_sigmoid_op.h"
namespace
paddle
{
namespace
operators
{
/**
* Organize the classes into a binary tree. At each node, a sigmoid function
* is used to calculate the probability of belonging to the right branch.
* This idea is from "F. Morin, Y. Bengio (AISTATS 05):
* Hierarchical Probabilistic Neural Network Language Model."
*
* Here we uses a simple way of making the binary tree.
* Assuming the number of classes C = 6,
* The classes are organized as a binary tree in the following way:
*
* @code{.py}
* *-*-*- 2
* | | |- 3
* | |
* | |-*- 4
* | |- 5
* |
* |-*- 0
* |- 1
* @endcode
*
* where * indicates an internal node, and each leaf node represents a class.
* - Node 0 ... C-2 are internal nodes.
* - Node C-1 ... 2C-2 are leaf nodes.
* - Class c is represented by leaf node \f$c+C-1\f$.
*
* We assign an id for each node:
* - the id of root be 0.
* - the left child of a node i is 2*i+1.
* - the right child of a node i is 2*i+2.
*
* It's easy to see that:
* - the parent of node i is \f$\left\lfloor(i-1)/2\right\rfloor\f$.
* - the j-th level ancestor of node i is
* \f$\left\lfloor(i+1)/2^{j+1}\right\rfloor - 1\f$.
* - A node i is a left child of its parent if \f$(i-1)\%2==0\f$.
*
*/
class
HierarchicalSigmoidOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
),
"Inputs(X) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) should not be null."
);
const
int64_t
batch_size
=
ctx
->
GetInputsDim
(
"X"
)[
0
][
0
];
const
int64_t
size
=
ctx
->
GetInputsDim
(
"X"
).
size
();
std
::
vector
<
int64_t
>
output_shape
({
batch_size
,
size
});
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
output_shape
));
}
};
class
HierarchicalSigmoidGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
class
HierarchicalSigmoidOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
HierarchicalSigmoidOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"(TensorArray, required) The input array. Each Tensor has the "
"same shape with [N * D]."
.
AsDuplicable
();
AddInput
(
"Label"
,
"(Tensor, required), The labels of training data. It's a"
"1-D tensor."
);
AddInput
(
"Bias"
,
"(Tensor, optional), The bias is a 1-D tensor, "
"which is applied to the output"
);
AddOutput
(
"Out"
,
"(Tensor, required) The output of hierarchical sigmoid operator."
);
AddAttr
<
int
>
(
"num_classes"
,
"(int, required)"
,
"The number of classes"
);
AddComment
(
R"DOC(
The hierarchical sigmoid operator organize the classes into a binary tree.
At each node, a sigmoid function is used to caculate the probability of
belonging to the right branch. This idea is from
"F. Morin, Y. Bengio (AISTATS 05):
Hierarchical Probabilistic Neural Network Language Model."
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
hierarchical_sigmoid
,
ops
::
HierarchicalSigmoidOp
,
ops
::
HierarchicalSigmoidOpMaker
,
hierarchical_sigmoid_grad
,
ops
::
HierarchicalSigmoidGradOp
);
REGISTER_OP_CPU_KERNEL
(
hierarchical_sigmoid
,
ops
::
HierarchicalSigmoidOpKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
hierarchical_sigmoid_grad
,
ops
::
HierarchicalSigmoidGradOpKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/hierarchical_sigmoid_op.h
0 → 100644
浏览文件 @
695b1037
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/matrix_bit_code.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
Place
,
typename
T
>
class
HierarchicalSigmoidOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{}
};
template
<
typename
Place
,
typename
T
>
class
HierarchicalSigmoidGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{}
};
}
// namespace operators
}
// namespace paddle
paddle/operators/math/CMakeLists.txt
浏览文件 @
695b1037
...
@@ -26,6 +26,7 @@ else()
...
@@ -26,6 +26,7 @@ else()
cc_library
(
sequence2batch SRCS sequence2batch.cc DEPS device_context
)
cc_library
(
sequence2batch SRCS sequence2batch.cc DEPS device_context
)
cc_library
(
lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions
)
cc_library
(
lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions
)
cc_library
(
gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function
)
cc_library
(
gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function
)
cc_library
(
matrix_bit_code SRCS matrix_bit_code.cc
)
endif
()
endif
()
cc_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
cc_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
...
...
paddle/operators/math/matrix_bit_code.cc
0 → 100644
浏览文件 @
695b1037
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "matrix_bit_code.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
/**
* CodeTable class should support 3 functions:
*
* size_t size()
* return the number of codes
*
* int getMaxCodeLength()
* return the maximal code length
*
* Code operator()(size_t i)
* return the i-th code. Code class is descriebed below.
*
* Code class should support 3 functions:
*
* int getLength()
* return the length of the code
*
* bool calcIndex(int bit)
* bit ranges from 0 to getLength() - 1
* return the index for the (1+bit) level parent
*
* bool calcBit(int bit)
* return true if the bit level parent is the right child of (1+bit) level
* parent
*
*/
/*
for i:
for j < codeLength:
op(a(i, j), b(0, index(i, j)))
*/
template
<
class
CodeTable
,
class
Op
,
typename
T
,
typename
Place
>
static
void
AddByBitCodeT
(
Op
op
,
CodeTable
code_table
,
const
framework
::
Tensor
&
codes
,
framework
::
Tensor
&
a
,
framework
::
Tensor
&
b
)
{
size_t
num_classes
=
code_table
.
size
();
size_t
max_code_length
=
code_table
.
get_max_code_length
();
size_t
num_sample
=
a
.
dims
()[
0
].
size
();
size_t
width
=
a
.
dims
()[
1
].
size
();
for
(
size_t
i
=
0
;
i
<
num_sample
;
++
i
)
{
auto
code
=
code_table
(
codes
.
data
<
T
>
()[
i
])
int
code_length
=
code
.
get_length
();
for
(
int
j
=
0
;
j
<
code_length
;
+
j
)
{
size_t
index
=
code
.
calc_index
(
j
);
op
(
a
<
T
>
.
data
()[
i
*
width
+
j
],
b
<
T
>
.
data
()[
index
]);
}
}
}
/* For j < codeLength:
a(i, j) += b(0, index(i, j))
*/
template
<
typename
T
,
typename
Place
>
void
AddByBitCode
(
size_t
num_classes
,
const
framework
::
Tensor
&
codes
,
framework
::
Tensor
&
a
,
const
framework
::
Tensor
&
b
)
{
auto
op
=
[](
T
&
t
,
T
&
v
)
{
t
+=
v
;
};
AddByBitCodeT
<
T
,
Place
>
(
op
,
SimpleCodeTable
(
num_classes
),
codes
,
a
,
b
);
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/matrix_bit_code.h
0 → 100644
浏览文件 @
695b1037
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
/**
* return the 1-based index of the highest bit set
*
* for x > 0:
* \f[
* findLastSet(x) = 1 + \floor*{\log_{2}x}
* \f]
*/
inline
constexpr
size_t
FindLastSet
(
size_t
x
)
{
return
std
::
is_same
<
size_t
,
unsigned
int
>::
value
?
(
x
?
8
*
sizeof
(
x
)
-
__builtin_clz
(
x
)
:
0
)
:
(
std
::
is_same
<
size_t
,
unsigned
long
>::
value
// NOLINT
?
(
x
?
8
*
sizeof
(
x
)
-
__builtin_clzl
(
x
)
:
0
)
:
(
x
?
8
*
sizeof
(
x
)
-
__builtin_clzll
(
x
)
:
0
));
}
struct
SimpleCode
{
SimpleCode
(
size_t
code
,
size_t
num_classes
)
:
c_
(
code
+
num_classes
)
{}
inline
size_t
calc_index
(
int
bit
)
const
{
return
(
c_
>>
(
bit
+
1
))
-
1
;
}
inline
bool
calc_bit
(
int
bit
)
const
{
return
c_
&
(
1
<<
bit
);
}
inline
int
get_length
()
const
{
return
FindLastSet
(
c_
)
-
1
;
}
private:
size_t
c_
;
};
struct
SimpleCodeTable
{
explicit
SimpleCodeTable
(
size_t
num_classes
)
:
num_classes_
(
num_classes
)
{}
SimpleCode
operator
()(
size_t
code
)
const
{
return
SimpleCode
(
code
,
num_classes_
);
}
size_t
size
()
const
{
return
num_classes_
;
}
int
get_max_code_length
()
const
{
return
FindLastSet
(
num_classes_
-
1
);
}
private:
size_t
num_classes_
;
int
max_code_length_
;
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录