提交 a25c3aeb 编写于 作者: Y Yancey1989

add forward

上级 695b1037
......@@ -83,19 +83,18 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(TensorArray, required) The input array. Each Tensor has the "
"same shape with [N * D]."
.AsDuplicable();
"same shape with [N * D].")
.AsDuplicable();
AddInput("Label",
"(Tensor, required), The labels of training data. It's a"
"1-D tensor.");
AddInput("Bias",
"(Tensor, optional), The bias is a 1-D tensor, "
"which is applied to the output");
AddOutput("Out",
"(Tensor, required) The output of hierarchical sigmoid operator.");
AddAttr<int>("num_classes",
"(int, required)",
"The number of classes");
AddOutput(
"Out",
"(Tensor, required) The output of hierarchical sigmoid operator.");
AddAttr<int>("num_classes", "(int, required)", "The number of classes");
AddComment(R"DOC(
The hierarchical sigmoid operator organize the classes into a binary tree.
At each node, a sigmoid function is used to caculate the probability of
......
......@@ -22,7 +22,21 @@ template <typename Place, typename T>
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {}
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* label = ctx.Input<framework::Tensor>("Label");
auto* bias = ctx.Input<framework::Tensor>("Bias");
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
int64_t batch_size = ins[0]->dims()[0];
int64_t size = ins.size();
framework::Tensor pre_out;
std::vector<int64_t> pre_out_dims({batch_size, size});
pre_out.mutable_data<T>(framework::make_ddim(pre_out_dims), ctx.GetPlace());
if (bias != NULL) {
math::AddByBitCode<T>(num_classes, *label, pre_out, *bias);
}
}
};
template <typename Place, typename T>
......
......@@ -50,7 +50,7 @@ namespace math {
for j < codeLength:
op(a(i, j), b(0, index(i, j)))
*/
template <class CodeTable, class Op, typename T, typename Place>
template <class CodeTable, class Op, typename T>
static void AddByBitCodeT(Op op, CodeTable code_table,
const framework::Tensor& codes, framework::Tensor& a,
framework::Tensor& b) {
......@@ -72,11 +72,11 @@ static void AddByBitCodeT(Op op, CodeTable code_table,
/* For j < codeLength:
a(i, j) += b(0, index(i, j))
*/
template <typename T, typename Place>
template <typename T>
void AddByBitCode(size_t num_classes, const framework::Tensor& codes,
framework::Tensor& a, const framework::Tensor& b) {
auto op = [](T& t, T& v) { t += v; };
AddByBitCodeT<T, Place>(op, SimpleCodeTable(num_classes), codes, a, b);
AddByBitCodeT<T>(op, SimpleCodeTable(num_classes), codes, a, b);
}
} // namespace math
......
......@@ -59,6 +59,10 @@ struct SimpleCodeTable {
int max_code_length_;
};
template <typename T>
void AddByBitCode(size_t num_classes, const framework::Tensor& codes,
framework::Tensor& a, const framework::Tensor& b);
} // namespace math
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册