hierarchical_sigmoid_op.h 6.0 KB
Newer Older
Y
Yancey1989 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
W
weixing02 已提交
16 17 18 19 20 21 22
#include <iostream>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h"
Y
Yancey1989 已提交
23 24 25
namespace paddle {
namespace operators {

Y
Yancey1989 已提交
26 27 28
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
Y
Yancey1989 已提交
29
using platform::Transform;
Y
Yancey1989 已提交
30

Y
Yancey1989 已提交
31
template <typename DeviceContext, typename T>
Y
Yancey1989 已提交
32 33
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
 public:
Y
Yancey1989 已提交
34
  void Compute(const framework::ExecutionContext& ctx) const override {
Y
Yancey1989 已提交
35
    auto* in = ctx.Input<framework::Tensor>("X");
Y
Yancey1989 已提交
36
    auto* w = ctx.Input<framework::Tensor>("W");
W
weixing02 已提交
37
    auto* label = ctx.Input<framework::Tensor>("Label");
Y
Yancey1989 已提交
38
    auto* bias = ctx.Input<framework::Tensor>("Bias");
Y
Yancey1989 已提交
39
    auto* out = ctx.Output<framework::Tensor>("Out");
W
weixing02 已提交
40
    auto* pre_out = ctx.Output<framework::Tensor>("PreOut");
Y
Yancey1989 已提交
41
    size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
Y
Yancey1989 已提交
42 43 44
    int64_t code_length = math::FindLastSet(num_classes - 1);
    int64_t batch_size = in->dims()[0];
    framework::Tensor sum;
W
weixing02 已提交
45
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
G
guosheng 已提交
46
    auto* pre_out_data = pre_out->mutable_data<T>(
Y
Yancey1989 已提交
47
        framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
W
weixing02 已提交
48
    auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
G
guosheng 已提交
49 50
    // Not all class(leaf) nodes' path lengths equal code_length, thus init as
    // 0s can avoid out of path's loss.
51
    math::SetConstant<DeviceContext, T> zero;
W
weixing02 已提交
52
    zero(dev_ctx, pre_out, static_cast<T>(0.0));
Y
Yancey1989 已提交
53 54
    auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
    math::RowwiseSum<DeviceContext, T> row_sum;
W
weixing02 已提交
55
    math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
Y
Yancey1989 已提交
56

Y
Yancey1989 已提交
57 58
    std::vector<int64_t> sum_dims({batch_size, 1UL});
    sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
Y
Yancey1989 已提交
59
    auto sum_mat = EigenMatrix<T>::From(sum);
Y
Yancey1989 已提交
60
    out->mutable_data<T>(ctx.GetPlace());
Y
Yancey1989 已提交
61
    auto out_mat = framework::EigenVector<T>::Flatten(*out);
Y
Yancey1989 已提交
62
    if (bias) {
Y
Yancey1989 已提交
63
      bit_code.Add(pre_out, *bias);
Y
Yancey1989 已提交
64
    }
W
weixing02 已提交
65
    bit_code.Mul(pre_out, *w, *in);
G
guosheng 已提交
66
    // clip to [-40, 40]
Y
Yancey1989 已提交
67 68
    Transform<DeviceContext> trans;
    trans(ctx.template device_context<DeviceContext>(), pre_out_data,
W
weixing02 已提交
69
          pre_out_data + pre_out->numel(), pre_out_data,
Y
Yancey1989 已提交
70
          ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
W
weixing02 已提交
71
    bit_code.Sum(*pre_out, out, static_cast<T>(-1));
G
guosheng 已提交
72
    // use softrelu to calculate cross entropy
Y
Yancey1989 已提交
73
    pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
W
weixing02 已提交
74
    row_sum(dev_ctx, *pre_out, &sum);
75 76 77 78
    // TODO(guosheng): Subtract the out of path's loss, since not all
    // class(leaf) nodes' path lengths equal code_length. But it won't break the
    // gradient check since both have the out of path's loss and will cancel out
    // each other.
Y
Yancey1989 已提交
79
    out_mat.device(place) = sum_mat + out_mat;
Y
Yancey1989 已提交
80
  }
Y
Yancey1989 已提交
81 82
};

Y
Yancey1989 已提交
83
template <typename DeviceContext, typename T>
Y
Yancey1989 已提交
84 85
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
 public:
Y
Yancey1989 已提交
86
  void Compute(const framework::ExecutionContext& ctx) const override {
Y
Yancey1989 已提交
87
    auto* in = ctx.Input<framework::Tensor>("X");
W
weixing02 已提交
88
    auto* w = ctx.Input<framework::Tensor>("W");
Y
Yancey1989 已提交
89
    auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
W
weixing02 已提交
90 91 92
    auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W"));
    auto* bias_grad =
        ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
W
weixing02 已提交
93
    auto* label = ctx.Input<framework::Tensor>("Label");
W
weixing02 已提交
94 95 96
    auto* pre_out = ctx.Input<framework::Tensor>("PreOut");
    auto* out_grad =
        ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
97 98 99 100 101 102 103 104 105
    framework::Tensor pre_out_grad;

    pre_out_grad.mutable_data<T>(pre_out->dims(), ctx.GetPlace());
    in_grad->mutable_data<T>(ctx.GetPlace());
    w_grad->mutable_data<T>(ctx.GetPlace());
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    math::SetConstant<DeviceContext, T> zero;
    zero(dev_ctx, in_grad, static_cast<T>(0.0));
    zero(dev_ctx, w_grad, static_cast<T>(0.0));
W
weixing02 已提交
106

Y
Yancey1989 已提交
107
    size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
108 109
    math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());

Y
Yancey1989 已提交
110
    auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
W
weixing02 已提交
111 112
    auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
    auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
W
weixing02 已提交
113
    auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
114 115 116 117 118 119
    Eigen::array<int, 2> bcast({{1, static_cast<int>(pre_out_grad.dims()[1])}});

    // softrelu derivative
    pre_out_grad_mat.device(place) =
        static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp();
    bit_code.Sub(&pre_out_grad);  // the gradient of clip(w * x + b)
W
weixing02 已提交
120
    pre_out_grad_mat.device(place) =
121
        pre_out_grad_mat * out_grad_mat.broadcast(bcast);
G
guosheng 已提交
122 123
    // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
    // be consistent with the clipping in forward.
W
weixing02 已提交
124 125
    if (bias_grad) {
      bias_grad->mutable_data<T>(ctx.GetPlace());
126
      zero(dev_ctx, bias_grad, static_cast<T>(0.0));
W
weixing02 已提交
127
      bit_code.AddGrad(pre_out_grad, bias_grad);
Y
Yancey1989 已提交
128
    }
W
weixing02 已提交
129 130
    bit_code.MulGradWeight(pre_out_grad, w_grad, *in);
    bit_code.MulGradError(pre_out_grad, *w, in_grad);
Y
Yancey1989 已提交
131
  }
Y
Yancey1989 已提交
132 133 134 135
};

}  // namespace operators
}  // namespace paddle