提交 695b1037 编写于 作者: Y Yancey1989

port hsigmoid layer

上级 9f289256
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hierarchical_sigmoid_op.h"
namespace paddle {
namespace operators {
/**
* Organize the classes into a binary tree. At each node, a sigmoid function
* is used to calculate the probability of belonging to the right branch.
* This idea is from "F. Morin, Y. Bengio (AISTATS 05):
* Hierarchical Probabilistic Neural Network Language Model."
*
* Here we uses a simple way of making the binary tree.
* Assuming the number of classes C = 6,
* The classes are organized as a binary tree in the following way:
*
* @code{.py}
* *-*-*- 2
* | | |- 3
* | |
* | |-*- 4
* | |- 5
* |
* |-*- 0
* |- 1
* @endcode
*
* where * indicates an internal node, and each leaf node represents a class.
* - Node 0 ... C-2 are internal nodes.
* - Node C-1 ... 2C-2 are leaf nodes.
* - Class c is represented by leaf node \f$c+C-1\f$.
*
* We assign an id for each node:
* - the id of root be 0.
* - the left child of a node i is 2*i+1.
* - the right child of a node i is 2*i+2.
*
* It's easy to see that:
* - the parent of node i is \f$\left\lfloor(i-1)/2\right\rfloor\f$.
* - the j-th level ancestor of node i is
* \f$\left\lfloor(i+1)/2^{j+1}\right\rfloor - 1\f$.
* - A node i is a left child of its parent if \f$(i-1)\%2==0\f$.
*
*/
class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
const int64_t batch_size = ctx->GetInputsDim("X")[0][0];
const int64_t size = ctx->GetInputsDim("X").size();
std::vector<int64_t> output_shape({batch_size, size});
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
};
class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
};
class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HierarchicalSigmoidOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(TensorArray, required) The input array. Each Tensor has the "
"same shape with [N * D]."
.AsDuplicable();
AddInput("Label",
"(Tensor, required), The labels of training data. It's a"
"1-D tensor.");
AddInput("Bias",
"(Tensor, optional), The bias is a 1-D tensor, "
"which is applied to the output");
AddOutput("Out",
"(Tensor, required) The output of hierarchical sigmoid operator.");
AddAttr<int>("num_classes",
"(int, required)",
"The number of classes");
AddComment(R"DOC(
The hierarchical sigmoid operator organize the classes into a binary tree.
At each node, a sigmoid function is used to caculate the probability of
belonging to the right branch. This idea is from
"F. Morin, Y. Bengio (AISTATS 05):
Hierarchical Probabilistic Neural Network Language Model."
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
ops::HierarchicalSigmoidOpMaker, hierarchical_sigmoid_grad,
ops::HierarchicalSigmoidGradOp);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid_grad,
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/matrix_bit_code.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {}
};
template <typename Place, typename T>
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {}
};
} // namespace operators
} // namespace paddle
...@@ -26,6 +26,7 @@ else() ...@@ -26,6 +26,7 @@ else()
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
cc_library(matrix_bit_code SRCS matrix_bit_code.cc)
endif() endif()
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
......
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "matrix_bit_code.h"
namespace paddle {
namespace operators {
namespace math {
/**
* CodeTable class should support 3 functions:
*
* size_t size()
* return the number of codes
*
* int getMaxCodeLength()
* return the maximal code length
*
* Code operator()(size_t i)
* return the i-th code. Code class is descriebed below.
*
* Code class should support 3 functions:
*
* int getLength()
* return the length of the code
*
* bool calcIndex(int bit)
* bit ranges from 0 to getLength() - 1
* return the index for the (1+bit) level parent
*
* bool calcBit(int bit)
* return true if the bit level parent is the right child of (1+bit) level
* parent
*
*/
/*
for i:
for j < codeLength:
op(a(i, j), b(0, index(i, j)))
*/
template <class CodeTable, class Op, typename T, typename Place>
static void AddByBitCodeT(Op op, CodeTable code_table,
const framework::Tensor& codes, framework::Tensor& a,
framework::Tensor& b) {
size_t num_classes = code_table.size();
size_t max_code_length = code_table.get_max_code_length();
size_t num_sample = a.dims()[0].size();
size_t width = a.dims()[1].size();
for (size_t i = 0; i < num_sample; ++i) {
auto code = code_table(codes.data<T>()[i]) int code_length =
code.get_length();
for (int j = 0; j < code_length; + j) {
size_t index = code.calc_index(j);
op(a<T>.data()[i * width + j], b<T>.data()[index]);
}
}
}
/* For j < codeLength:
a(i, j) += b(0, index(i, j))
*/
template <typename T, typename Place>
void AddByBitCode(size_t num_classes, const framework::Tensor& codes,
framework::Tensor& a, const framework::Tensor& b) {
auto op = [](T& t, T& v) { t += v; };
AddByBitCodeT<T, Place>(op, SimpleCodeTable(num_classes), codes, a, b);
}
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
namespace math {
/**
* return the 1-based index of the highest bit set
*
* for x > 0:
* \f[
* findLastSet(x) = 1 + \floor*{\log_{2}x}
* \f]
*/
inline constexpr size_t FindLastSet(size_t x) {
return std::is_same<size_t, unsigned int>::value
? (x ? 8 * sizeof(x) - __builtin_clz(x) : 0)
: (std::is_same<size_t, unsigned long>::value // NOLINT
? (x ? 8 * sizeof(x) - __builtin_clzl(x) : 0)
: (x ? 8 * sizeof(x) - __builtin_clzll(x) : 0));
}
struct SimpleCode {
SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {}
inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; }
inline bool calc_bit(int bit) const { return c_ & (1 << bit); }
inline int get_length() const { return FindLastSet(c_) - 1; }
private:
size_t c_;
};
struct SimpleCodeTable {
explicit SimpleCodeTable(size_t num_classes) : num_classes_(num_classes) {}
SimpleCode operator()(size_t code) const {
return SimpleCode(code, num_classes_);
}
size_t size() const { return num_classes_; }
int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }
private:
size_t num_classes_;
int max_code_length_;
};
} // namespace math
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册