hierarchical_sigmoid_op.h 5.6 KB
Newer Older
Y
Yancey1989 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
W
weixing02 已提交
16 17 18 19 20 21 22
#include <iostream>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h"
Y
Yancey1989 已提交
23 24 25
namespace paddle {
namespace operators {

Y
Yancey1989 已提交
26 27 28
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
Y
Yancey1989 已提交
29
using platform::Transform;
Y
Yancey1989 已提交
30

Y
Yancey1989 已提交
31
template <typename DeviceContext, typename T>
Y
Yancey1989 已提交
32 33
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
 public:
Y
Yancey1989 已提交
34
  void Compute(const framework::ExecutionContext& ctx) const override {
Y
Yancey1989 已提交
35
    auto* in = ctx.Input<framework::Tensor>("X");
Y
Yancey1989 已提交
36
    auto* w = ctx.Input<framework::Tensor>("W");
W
weixing02 已提交
37
    auto* label = ctx.Input<framework::Tensor>("Label");
Y
Yancey1989 已提交
38
    auto* bias = ctx.Input<framework::Tensor>("Bias");
Y
Yancey1989 已提交
39
    auto* out = ctx.Output<framework::Tensor>("Out");
W
weixing02 已提交
40
    auto* pre_out = ctx.Output<framework::Tensor>("PreOut");
Y
Yancey1989 已提交
41
    size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
Y
Yancey1989 已提交
42 43 44
    int64_t code_length = math::FindLastSet(num_classes - 1);
    int64_t batch_size = in->dims()[0];
    framework::Tensor sum;
W
weixing02 已提交
45 46
    math::SetConstant<DeviceContext, T> zero;
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
G
guosheng 已提交
47
    auto* pre_out_data = pre_out->mutable_data<T>(
Y
Yancey1989 已提交
48
        framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
W
weixing02 已提交
49
    auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
G
guosheng 已提交
50 51
    // Not all class(leaf) nodes' path lengths equal code_length, thus init as
    // 0s can avoid out of path's loss.
W
weixing02 已提交
52
    zero(dev_ctx, pre_out, static_cast<T>(0.0));
Y
Yancey1989 已提交
53 54
    auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
    math::RowwiseSum<DeviceContext, T> row_sum;
W
weixing02 已提交
55
    math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
Y
Yancey1989 已提交
56

Y
Yancey1989 已提交
57 58
    std::vector<int64_t> sum_dims({batch_size, 1UL});
    sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
Y
Yancey1989 已提交
59
    auto sum_mat = EigenMatrix<T>::From(sum);
Y
Yancey1989 已提交
60
    out->mutable_data<T>(ctx.GetPlace());
Y
Yancey1989 已提交
61
    auto out_mat = framework::EigenVector<T>::Flatten(*out);
Y
Yancey1989 已提交
62
    if (bias) {
Y
Yancey1989 已提交
63
      bit_code.Add(pre_out, *bias);
Y
Yancey1989 已提交
64
    }
W
weixing02 已提交
65
    bit_code.Mul(pre_out, *w, *in);
G
guosheng 已提交
66
    // clip to [-40, 40]
Y
Yancey1989 已提交
67 68
    Transform<DeviceContext> trans;
    trans(ctx.template device_context<DeviceContext>(), pre_out_data,
W
weixing02 已提交
69
          pre_out_data + pre_out->numel(), pre_out_data,
Y
Yancey1989 已提交
70
          ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
W
weixing02 已提交
71
    bit_code.Sum(*pre_out, out, static_cast<T>(-1));
G
guosheng 已提交
72
    // use softrelu to calculate cross entropy
Y
Yancey1989 已提交
73
    pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
W
weixing02 已提交
74
    row_sum(dev_ctx, *pre_out, &sum);
Y
Yancey1989 已提交
75
    out_mat.device(place) = sum_mat + out_mat;
Y
Yancey1989 已提交
76
  }
Y
Yancey1989 已提交
77 78
};

Y
Yancey1989 已提交
79
template <typename DeviceContext, typename T>
Y
Yancey1989 已提交
80 81
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
 public:
Y
Yancey1989 已提交
82
  void Compute(const framework::ExecutionContext& ctx) const override {
Y
Yancey1989 已提交
83
    auto* in = ctx.Input<framework::Tensor>("X");
W
weixing02 已提交
84
    auto* w = ctx.Input<framework::Tensor>("W");
Y
Yancey1989 已提交
85
    auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
W
weixing02 已提交
86 87 88
    auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W"));
    auto* bias_grad =
        ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
W
weixing02 已提交
89
    auto* label = ctx.Input<framework::Tensor>("Label");
W
weixing02 已提交
90 91 92 93
    auto* pre_out = ctx.Input<framework::Tensor>("PreOut");
    auto* out_grad =
        ctx.Input<framework::Tensor>(framework::GradVarName("Out"));

Y
Yancey1989 已提交
94
    size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
Y
Yancey1989 已提交
95 96
    int64_t code_length = math::FindLastSet(num_classes - 1);
    int64_t batch_size = in->dims()[0];
W
weixing02 已提交
97 98 99
    framework::Tensor pre_out_grad;
    pre_out_grad.mutable_data<T>(
        framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
Y
Yancey1989 已提交
100
    auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
W
weixing02 已提交
101 102
    auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
    auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
W
weixing02 已提交
103
    math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
G
guosheng 已提交
104
    Eigen::array<int, 2> bcast({{1, static_cast<int>(pre_out_grad.dims()[1])}});
W
weixing02 已提交
105 106
    auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
    pre_out_grad_mat = out_grad_mat.broadcast(bcast);
W
weixing02 已提交
107 108
    pre_out_grad_mat.device(place) =
        pre_out_grad_mat *
G
guosheng 已提交
109 110
        (static_cast<T>(1.0) -
         static_cast<T>(1.0) / pre_out_mat.exp());  // softrelu derivative
W
weixing02 已提交
111
    bit_code.Sub(&pre_out_grad);
G
guosheng 已提交
112 113
    // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
    // be consistent with the clipping in forward.
W
weixing02 已提交
114 115 116
    if (bias_grad) {
      bias_grad->mutable_data<T>(ctx.GetPlace());
      bit_code.AddGrad(pre_out_grad, bias_grad);
Y
Yancey1989 已提交
117
    }
Y
Yancey1989 已提交
118
    in_grad->mutable_data<T>(ctx.GetPlace());
W
weixing02 已提交
119 120 121
    w_grad->mutable_data<T>(ctx.GetPlace());
    bit_code.MulGradWeight(pre_out_grad, w_grad, *in);
    bit_code.MulGradError(pre_out_grad, *w, in_grad);
Y
Yancey1989 已提交
122
  }
Y
Yancey1989 已提交
123 124 125 126
};

}  // namespace operators
}  // namespace paddle