提交 ee13b396 编写于 作者: W weixing02

fix some errors

上级 8bd148dc
......@@ -62,7 +62,7 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("PreOut"),
......@@ -87,19 +87,18 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X",
"(Tensor, required) The input Tensor, which the shape is"
"[N * D], which N is the size of mini-batch,"
"[N, D], which N is the size of mini-batch,"
"D is the embded size");
AddInput("W",
"(Tensor, required), The parameters of hierarchical "
"sigmoid operator, each of them is s a 3-D tensor, the shape is"
"sigmoid operator, each of them is s a 2-D tensor, the shape is"
"[num_classes - 1, D]");
AddInput("Ids",
AddInput("Label",
"(Tensor, required), The labels of training data. It's a"
"1-D tensor, which the shape is [1, N]");
AddInput("Bias",
"(Tensor, optional), The bias is a 1-D tensor, "
"which is applied to the output, the shape is"
"[1, num_classes -1]");
"(Tensor, optional), The bias is a tensor with shape"
"[1, num_classes - 1]");
AddOutput("Out",
"(Tensor, required) The output of hierarchical sigmoid operator."
"the shape is [N, 1]");
......@@ -111,7 +110,7 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(2);
AddComment(R"DOC(
The hierarchical sigmoid operator organize the classes into a binary tree.
At each node, a sigmoid function is used to caculate the probability of
At each node, a sigmoid function is used to calculate the probability of
belonging to the right branch. This idea is from
"F. Morin, Y. Bengio (AISTATS 05):
Hierarchical Probabilistic Neural Network Language Model."
......@@ -124,7 +123,7 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("PreOut"),
"Input(Preout) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
......@@ -155,9 +154,14 @@ REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
ops::HierarchicalSigmoidOpMaker<int>,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp);
REGISTER_OP_CPU_KERNEL(hierarchical_sigmoid,
ops::HierarchicalSigmoidOpKernel<
paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(hierarchical_sigmoid_grad,
ops::HierarchicalSigmoidGradOpKernel<
paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid_grad,
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
float>,
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
double>);
......@@ -34,7 +34,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* w = ctx.Input<framework::Tensor>("W");
auto* ids = ctx.Input<framework::Tensor>("Ids");
auto* label = ctx.Input<framework::Tensor>("Label");
auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* out = ctx.Output<framework::Tensor>("Out");
auto* pre_out = ctx.Output<framework::Tensor>("PreOut");
......@@ -50,7 +50,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
zero(dev_ctx, pre_out, static_cast<T>(0.0));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
math::RowwiseSum<DeviceContext, T> row_sum;
math::MatrixBitCodeFunctor<T> bit_code(num_classes, ids->data<int64_t>());
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
std::vector<int64_t> sum_dims({batch_size, 1UL});
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
......@@ -87,7 +87,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W"));
auto* bias_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
auto* ids = ctx.Input<framework::Tensor>("Ids");
auto* label = ctx.Input<framework::Tensor>("Label");
auto* pre_out = ctx.Input<framework::Tensor>("PreOut");
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
......@@ -101,9 +101,11 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
math::MatrixBitCodeFunctor<T> bit_code(num_classes, ids->data<int64_t>());
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
// softrelu derivative
bit_code.OutGrad(&pre_out_grad, *out_grad);
Eigen::array<int, 2> bcast({1, static_cast<int>(pre_out_grad.dims()[1])});
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
pre_out_grad_mat = out_grad_mat.broadcast(bcast);
pre_out_grad_mat.device(place) =
pre_out_grad_mat *
(static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp());
......
......@@ -18,32 +18,6 @@ namespace paddle {
namespace operators {
namespace math {
/**
* CodeTable class should support 3 functions:
*
* size_t size()
* return the number of ids
*
* int getMaxCodeLength()
* return the maximal code length
*
* Code operator()(size_t i)
* return the i-th code. Code class is descriebed below.
*
* Code class should support 3 functions:
*
* int getLength()
* return the length of the code
*
* bool calcIndex(int bit)
* bit ranges from 0 to getLength() - 1
* return the index for the (1+bit) level parent
*
* bool calcBit(int bit)
* return true if the bit level parent is the right child of (1+bit) level
* parent
*
*/
template <typename T>
void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
const framework::Tensor& vec) {
......@@ -192,17 +166,6 @@ void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) {
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::OutGrad(framework::Tensor* tmat,
const framework::Tensor& input) {
size_t num_samples = tmat->dims()[0];
size_t code_length = tmat->dims()[1];
for (size_t i = 0; i < num_samples; ++i)
for (size_t j = 0; j < code_length; ++j) {
tmat->data<T>()[i * code_length + j] = input.data<T>()[i];
}
}
template class MatrixBitCodeFunctor<float>;
template class MatrixBitCodeFunctor<double>;
......
......@@ -20,13 +20,39 @@ limitations under the License. */
namespace paddle {
namespace operators {
namespace math {
/**
* SimpleCodeTable class should support 3 functions:
*
* size_t size()
* return the number of ids
*
* int get_max_code_length()
* return the maximal code length
*
* SimpleCode operator()(size_t i)
* return the i-th code. Code class is descriebed below.
*
* SimpleCode class should support 3 functions:
*
* int get_length()
* return the length of the code
*
* size_t cal_index(int bit)
* bit ranges from 0 to get_length() - 1
* return the index for the (1+bit) level parent
*
* bool calc_bit(int bit)
* return true if the bit level parent is the right child of (1+bit) level
* parent
*
*/
/**
* return the 1-based index of the highest bit set
*
* for x > 0:
* \f[
* findLastSet(x) = 1 + \floor*{\log_{2}x}
* FindLastSet(x) = 1 + \floor*{\log_{2}x}
* \f]
*/
inline constexpr size_t FindLastSet(size_t x) {
......@@ -100,10 +126,6 @@ class MatrixBitCodeFunctor {
*/
void MulGradError(const framework::Tensor& tmat,
const framework::Tensor& weight, framework::Tensor* input);
/* For j < code_length
tmat(i, j) == input(i)
*/
void OutGrad(framework::Tensor* tmat, const framework::Tensor& input);
size_t num_classes_;
const int64_t* ids_;
......
......@@ -3571,18 +3571,17 @@ def hsigmoid(input, label, num_classes=2, param_attr=None, bias_attr=None):
shape=[num_classes - 1, dim],
is_bias=False,
dtype=input.dtype)
bias = helper.create_parameter(
attr=helper.bias_attr,
shape=[1, num_classes - 1],
is_bias=True,
dtype=input.dtype)
inputs = {"X": input, "W": weights, "Label": label}
if helper.bias_attr:
bias = helper.create_parameter(
attr=helper.bias_attr,
shape=[1, num_classes - 1],
is_bias=True,
dtype=input.dtype)
inputs['Bias'] = bias
helper.append_op(
type="hierarchical_sigmoid",
inputs={"X": input,
"W": weights,
"Ids": label,
"Bias": bias},
inputs=inputs,
outputs={"Out": out,
"PreOut": pre_out},
attrs={"num_classes": num_classes})
......
......@@ -36,7 +36,7 @@ class CodeTable(object):
return self.c & (1 << bit)
def hsigmoid(x, w, ids, bias, num_classes):
def hsigmoid(x, w, label, bias, num_classes):
global pre_output
batch_size = x.shape[0]
code_length = find_latest_set(num_classes - 1)
......@@ -45,13 +45,13 @@ def hsigmoid(x, w, ids, bias, num_classes):
pre_sum = np.zeros((batch_size, 1))
out = np.zeros((batch_size, 1)).astype("float32")
for i in range(batch_size):
code_table = CodeTable(num_classes, ids[i])
code_table = CodeTable(num_classes, label[i])
length = code_table.get_length()
for j in range(length):
idx = code_table.cal_index(j)
pre_output[i][j] += bias[0][idx]
for j in range(batch_size):
code_table = CodeTable(num_classes, ids[j])
code_table = CodeTable(num_classes, label[j])
length = code_table.get_length()
for k in range(length):
idx = code_table.cal_index(k)
......@@ -60,10 +60,10 @@ def hsigmoid(x, w, ids, bias, num_classes):
sum += w[idx][l] * x[j][l]
pre_output[j][k] += sum
# clip[-40.0, 40.0]
np.clip(pre_output, -40.0, 40.0)
pre_output = np.clip(pre_output, -40.0, 40.0)
# out(i, 0) = \sum_j bit(i, j) * preout(i, j)
for i in range(batch_size):
code_table = CodeTable(num_classes, ids[i])
code_table = CodeTable(num_classes, label[i])
length = code_table.get_length()
sum = 0.0
for j in range(length):
......@@ -86,18 +86,18 @@ class TestHSigmoidOp(OpTest):
batch_size = 1
x = np.random.random((batch_size, embded_size)).astype("float32")
w = np.random.random((num_classes - 1, embded_size)).astype("float32")
ids = np.random.randint(0, num_classes, batch_size)
label = np.random.randint(0, num_classes, batch_size)
bias = np.random.random((1, num_classes - 1)).astype("float32")
self.attrs = {'num_classes': num_classes}
self.inputs = {'X': x, 'W': w, 'Ids': ids, 'Bias': bias}
out = hsigmoid(x, w, ids, bias, num_classes)
self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias}
out = hsigmoid(x, w, label, bias, num_classes)
self.outputs = {'PreOut': pre_output, 'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Bias', 'X', 'W'], 'Out', no_grad_set=set('Ids'))
self.check_grad(['Bias', 'X', 'W'], 'Out', no_grad_set=set('Label'))
if __name__ == '__main__':
......
......@@ -176,8 +176,8 @@ class TestBook(unittest.TestCase):
def test_hsigmoid(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[2, 2], dtype='float32')
y = layers.data(name='y', shape=[1, 2], dtype='int64')
x = layers.data(name='x', shape=[2], dtype='float32')
y = layers.data(name='y', shape=[2], dtype='int64')
self.assertIsNotNone(
layers.hsigmoid(
input=x, label=y, num_classes=2))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册