hierarchical_sigmoid_op.h 10.2 KB
Newer Older
Y
Yancey1989 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Q
Qiao Longfei 已提交
16

W
weixing02 已提交
17
#include <iostream>
18
#include <iterator>
Q
Qiao Longfei 已提交
19
#include <memory>
J
JiabinYang 已提交
20
#include <set>
21
#include <string>
W
weixing02 已提交
22
#include <vector>
Q
Qiao Longfei 已提交
23

J
JiabinYang 已提交
24
#include "paddle/fluid/framework/mixed_vector.h"
W
weixing02 已提交
25 26
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
J
JiabinYang 已提交
27
#include "paddle/fluid/operators/detail/safe_ref.h"
W
weixing02 已提交
28 29 30
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h"
J
JiabinYang 已提交
31

32 33 34 35
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#endif

Y
Yancey1989 已提交
36 37 38
namespace paddle {
namespace operators {

Y
Yancey1989 已提交
39 40 41
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
Y
Yancey1989 已提交
42
using platform::Transform;
Y
Yancey1989 已提交
43

J
JiabinYang 已提交
44 45
static std::vector<int64_t> PathToRows(const framework::LoDTensor& path) {
  std::set<int64_t> rows;
46
  const int64_t* paths = path.data<int64_t>();
J
JiabinYang 已提交
47
  for (int64_t i = 0; i < path.numel(); ++i) {
48
    int64_t row = paths[i];
J
JiabinYang 已提交
49 50
    if (row < 0) {
      continue;
J
JiabinYang 已提交
51
    }
J
JiabinYang 已提交
52
    rows.emplace(row);
J
JiabinYang 已提交
53
  }
J
JiabinYang 已提交
54
  return std::vector<int64_t>(rows.begin(), rows.end());
J
JiabinYang 已提交
55
}
Y
Yancey1989 已提交
56
template <typename DeviceContext, typename T>
Y
Yancey1989 已提交
57 58
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
 public:
Y
Yancey1989 已提交
59
  void Compute(const framework::ExecutionContext& ctx) const override {
60 61
    auto& in = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
    auto& w = detail::Ref(ctx.Input<framework::LoDTensor>("W"));
62
    auto* path = ctx.Input<framework::LoDTensor>("PathTable");
J
JiabinYang 已提交
63
    auto* code = ctx.Input<framework::LoDTensor>("PathCode");
64
    auto& label = detail::Ref(ctx.Input<framework::LoDTensor>("Label"));
J
JiabinYang 已提交
65 66 67
    auto* bias = ctx.Input<framework::LoDTensor>("Bias");
    auto* out = ctx.Output<framework::LoDTensor>("Out");
    auto* pre_out = ctx.Output<framework::LoDTensor>("PreOut");
Y
Yancey1989 已提交
68
    size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
69 70 71 72 73 74 75
    // for remote prefetch

    auto epmap = ctx.Attr<std::vector<std::string>>("epmap");
    if (!epmap.empty()) {
      // if epmap is not empty, then the parameter will be fetched from remote
      // parameter
      // server
Q
Qiao Longfei 已提交
76
      auto height_sections = ctx.Attr<std::vector<int64_t>>("height_sections");
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
      auto table_names = ctx.Attr<std::vector<std::string>>("table_names");
      std::vector<int64_t> real_rows = PathToRows(*path);
      framework::Scope& local_scope = ctx.scope().NewScope();
      auto* ids = local_scope.Var("Ids@Prefetch");
      auto* x_tensor = ids->GetMutable<framework::LoDTensor>();

      x_tensor->mutable_data<int64_t>(
          framework::make_ddim({static_cast<int64_t>(real_rows.size()), 1}),
          ctx.GetPlace());
      // copy.

      std::memcpy(x_tensor->data<int64_t>(), real_rows.data(),
                  real_rows.size() * sizeof(int64_t));

      framework::DDim w_dims = ctx.Input<Tensor>("W")->dims();
      w_dims[0] = x_tensor->dims()[0];
      auto* w_tensor =
          local_scope.Var("W@Prefetch")->GetMutable<framework::LoDTensor>();
      w_tensor->Resize(w_dims);

#ifdef PADDLE_WITH_DISTRIBUTE
      // w_Out is set to used by prefetch, never change it in other cases
      auto* w_out = ctx.Output<framework::LoDTensor>("W_Out");
      operators::distributed::prefetch_with_reconstruct<T>(
          "Ids@Prefetch", "W@Prefetch", table_names, epmap, height_sections,
          ctx, local_scope, w_out);
#else
      PADDLE_THROW(
          "paddle is not compiled with distribute support, can not do "
          "parameter prefetch!");
#endif
    }

110 111 112 113 114 115
    bool is_custom = false;
    if (path) {
      is_custom = true;
    }
    int64_t code_length =
        path ? path->dims()[1] : math::FindLastSet(num_classes - 1);
J
JiabinYang 已提交
116
    int64_t batch_size = in.dims()[0];
J
JiabinYang 已提交
117
    framework::LoDTensor sum;
W
weixing02 已提交
118
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
G
guosheng 已提交
119
    auto* pre_out_data = pre_out->mutable_data<T>(
Y
Yancey1989 已提交
120
        framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
W
weixing02 已提交
121
    auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
G
guosheng 已提交
122 123
    // Not all class(leaf) nodes' path lengths equal code_length, thus init as
    // 0s can avoid out of path's loss.
124
    math::SetConstant<DeviceContext, T> zero;
W
weixing02 已提交
125
    zero(dev_ctx, pre_out, static_cast<T>(0.0));
Y
Yancey1989 已提交
126 127
    auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
    math::RowwiseSum<DeviceContext, T> row_sum;
128 129 130 131

    std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
    if (!is_custom) {
      bit_code.reset(new math::MatrixBitCodeFunctor<T>(num_classes,
J
JiabinYang 已提交
132
                                                       label.data<int64_t>()));
133
    } else {
J
JiabinYang 已提交
134 135
      bit_code.reset(new math::MatrixBitCodeFunctor<T>(*path, *code,
                                                       label.data<int64_t>()));
136
    }
Y
Yancey1989 已提交
137

Y
Yancey1989 已提交
138 139
    std::vector<int64_t> sum_dims({batch_size, 1UL});
    sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
Y
Yancey1989 已提交
140
    auto sum_mat = EigenMatrix<T>::From(sum);
Y
Yancey1989 已提交
141
    out->mutable_data<T>(ctx.GetPlace());
142
    auto out_mat = framework::EigenMatrix<T>::From(*out);
Y
Yancey1989 已提交
143
    if (bias) {
144
      bit_code->Add(*bias, pre_out);
Y
Yancey1989 已提交
145
    }
J
JiabinYang 已提交
146
    bit_code->Mul(pre_out, w, in);
G
guosheng 已提交
147
    // clip to [-40, 40]
Y
Yancey1989 已提交
148 149
    Transform<DeviceContext> trans;
    trans(ctx.template device_context<DeviceContext>(), pre_out_data,
W
weixing02 已提交
150
          pre_out_data + pre_out->numel(), pre_out_data,
Y
Yancey1989 已提交
151
          ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
152
    bit_code->Sum(*pre_out, out, static_cast<T>(-1));
G
guosheng 已提交
153
    // use softrelu to calculate cross entropy
Y
Yancey1989 已提交
154
    pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
W
weixing02 已提交
155
    row_sum(dev_ctx, *pre_out, &sum);
156 157 158 159
    // TODO(guosheng): Subtract the out of path's loss, since not all
    // class(leaf) nodes' path lengths equal code_length. But it won't break the
    // gradient check since both have the out of path's loss and will cancel out
    // each other.
Y
Yancey1989 已提交
160
    out_mat.device(place) = sum_mat + out_mat;
Y
Yancey1989 已提交
161
  }
Y
Yancey1989 已提交
162 163
};

Y
Yancey1989 已提交
164
template <typename DeviceContext, typename T>
Y
Yancey1989 已提交
165 166
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
 public:
Y
Yancey1989 已提交
167
  void Compute(const framework::ExecutionContext& ctx) const override {
168 169
    auto& in = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
    auto& w = detail::Ref(ctx.Input<framework::LoDTensor>("W"));
170
    auto* path = ctx.Input<framework::LoDTensor>("PathTable");
J
JiabinYang 已提交
171
    auto* code = ctx.Input<framework::LoDTensor>("PathCode");
J
JiabinYang 已提交
172 173 174 175 176
    auto* in_grad =
        ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
    bool is_sparse = ctx.Attr<bool>("is_sparse");
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    math::SetConstant<DeviceContext, T> zero;
177 178 179
    auto& label = detail::Ref(ctx.Input<framework::LoDTensor>("Label"));
    auto& pre_out = detail::Ref(ctx.Input<framework::LoDTensor>("PreOut"));
    auto& out_grad = detail::Ref(
J
JiabinYang 已提交
180
        ctx.Input<framework::LoDTensor>(framework::GradVarName("Out")));
J
JiabinYang 已提交
181
    framework::LoDTensor pre_out_grad;
182

J
JiabinYang 已提交
183
    pre_out_grad.mutable_data<T>(pre_out.dims(), ctx.GetPlace());
184 185
    in_grad->mutable_data<T>(ctx.GetPlace());
    zero(dev_ctx, in_grad, static_cast<T>(0.0));
W
weixing02 已提交
186

Y
Yancey1989 已提交
187
    size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
188 189 190 191 192 193 194 195 196

    bool is_custom = false;
    if (path) {
      is_custom = true;
    }

    std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
    if (!is_custom) {
      bit_code.reset(new math::MatrixBitCodeFunctor<T>(num_classes,
J
JiabinYang 已提交
197
                                                       label.data<int64_t>()));
198
    } else {
J
JiabinYang 已提交
199 200
      bit_code.reset(new math::MatrixBitCodeFunctor<T>(*path, *code,
                                                       label.data<int64_t>()));
201
    }
202

Y
Use mkl  
Yu Yang 已提交
203
    // softrelu derivative
J
JiabinYang 已提交
204

Y
Use mkl  
Yu Yang 已提交
205
    auto blas = math::GetBlas<DeviceContext, T>(ctx);
206

Y
Use mkl  
Yu Yang 已提交
207 208 209 210 211 212 213 214
    auto* pre_out_grad_data = pre_out_grad.data<T>();
    auto* pre_out_data = pre_out.data<T>();
    auto n = pre_out.numel();
    blas.VEXP(n, pre_out_data, pre_out_grad_data);
    blas.VINV(n, pre_out_grad_data, pre_out_grad_data);
    for (int64_t i = 0; i < n; ++i) {
      pre_out_grad_data[i] = 1.0 - pre_out_grad_data[i];
    }
215
    bit_code->Sub(&pre_out_grad);  // the gradient of clip(w * x + b)
Y
Use mkl  
Yu Yang 已提交
216 217 218 219 220 221 222 223
    auto* out_grad_data = out_grad.data<T>();

    int64_t dim0 = pre_out_grad.dims()[0];
    int64_t dim1 = pre_out_grad.dims()[1];
    for (int64_t i = 0; i < dim0; ++i) {
      T tmp = out_grad_data[i];
      blas.SCAL(dim1, tmp, pre_out_grad_data + i * dim1);
    }
G
guosheng 已提交
224 225
    // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
    // be consistent with the clipping in forward.
226 227 228 229 230 231 232
    auto* bias_grad =
        ctx.Output<framework::LoDTensor>(framework::GradVarName("Bias"));
    if (bias_grad) {
      bias_grad->mutable_data<T>(ctx.GetPlace());
      zero(dev_ctx, bias_grad, static_cast<T>(0.0));
      bit_code->AddGrad(pre_out_grad, bias_grad);
    }
J
JiabinYang 已提交
233 234 235 236 237
    if (!is_sparse) {
      auto* w_grad =
          ctx.Output<framework::LoDTensor>(framework::GradVarName("W"));
      w_grad->mutable_data<T>(ctx.GetPlace());
      zero(dev_ctx, w_grad, static_cast<T>(0.0));
J
JiabinYang 已提交
238
      bit_code->MulGradWeight(pre_out_grad, w_grad, in);
J
JiabinYang 已提交
239
    } else {
J
JiabinYang 已提交
240
      framework::Vector<int64_t> real_rows = PathToRows(*path);
J
JiabinYang 已提交
241 242 243
      auto* w_grad =
          ctx.Output<framework::SelectedRows>(framework::GradVarName("W"));
      w_grad->set_rows(real_rows);
244
      // Build a map of id -> row_index to speed up finding the index of one id
J
JiabinYang 已提交
245
      w_grad->set_height(w.dims()[0]);
J
JiabinYang 已提交
246
      auto* w_grad_value = w_grad->mutable_value();
J
JiabinYang 已提交
247
      framework::DDim temp_dim(w.dims());
J
JiabinYang 已提交
248 249 250 251
      set(temp_dim, 0, real_rows.size());

      w_grad_value->mutable_data<T>(temp_dim, ctx.GetPlace());
      zero(dev_ctx, w_grad_value, static_cast<T>(0.0));
J
JiabinYang 已提交
252
      bit_code->MulGradWeight(pre_out_grad, w_grad, in);
J
JiabinYang 已提交
253
    }
J
JiabinYang 已提交
254
    bit_code->MulGradError(pre_out_grad, w, in_grad);
Y
Yancey1989 已提交
255
  }
Y
Yancey1989 已提交
256 257 258 259
};

}  // namespace operators
}  // namespace paddle