hierarchical_sigmoid_op.h 10.3 KB
Newer Older
Y
Yancey1989 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Q
Qiao Longfei 已提交
16

W
weixing02 已提交
17
#include <iostream>
18
#include <iterator>
Q
Qiao Longfei 已提交
19
#include <memory>
J
JiabinYang 已提交
20
#include <set>
21
#include <string>
W
weixing02 已提交
22
#include <vector>
Q
Qiao Longfei 已提交
23

J
JiabinYang 已提交
24
#include "paddle/fluid/framework/mixed_vector.h"
W
weixing02 已提交
25 26
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
J
JiabinYang 已提交
27
#include "paddle/fluid/operators/detail/safe_ref.h"
W
weixing02 已提交
28 29 30
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h"
J
JiabinYang 已提交
31

32 33 34 35
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#endif

Y
Yancey1989 已提交
36 37 38
namespace paddle {
namespace operators {

Y
Yancey1989 已提交
39 40 41
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
Y
Yancey1989 已提交
42
using platform::Transform;
Y
Yancey1989 已提交
43

J
JiabinYang 已提交
44 45
static std::vector<int64_t> PathToRows(const framework::LoDTensor& path) {
  std::set<int64_t> rows;
46
  const int64_t* paths = path.data<int64_t>();
J
JiabinYang 已提交
47
  for (int64_t i = 0; i < path.numel(); ++i) {
48
    int64_t row = paths[i];
J
JiabinYang 已提交
49 50
    if (row < 0) {
      continue;
J
JiabinYang 已提交
51
    }
J
JiabinYang 已提交
52
    rows.emplace(row);
J
JiabinYang 已提交
53
  }
J
JiabinYang 已提交
54
  return std::vector<int64_t>(rows.begin(), rows.end());
J
JiabinYang 已提交
55
}
Y
Yancey1989 已提交
56
template <typename DeviceContext, typename T>
Y
Yancey1989 已提交
57 58
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
 public:
Y
Yancey1989 已提交
59
  void Compute(const framework::ExecutionContext& ctx) const override {
60 61
    auto& in = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
    auto& w = detail::Ref(ctx.Input<framework::LoDTensor>("W"));
62
    auto* path = ctx.Input<framework::LoDTensor>("PathTable");
J
JiabinYang 已提交
63
    auto* code = ctx.Input<framework::LoDTensor>("PathCode");
64
    auto& label = detail::Ref(ctx.Input<framework::LoDTensor>("Label"));
J
JiabinYang 已提交
65 66 67
    auto* bias = ctx.Input<framework::LoDTensor>("Bias");
    auto* out = ctx.Output<framework::LoDTensor>("Out");
    auto* pre_out = ctx.Output<framework::LoDTensor>("PreOut");
Y
Yancey1989 已提交
68
    size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
69 70
    // for remote prefetch

71
    auto remote_prefetch = ctx.Attr<bool>("remote_prefetch");
72
    auto epmap = ctx.Attr<std::vector<std::string>>("epmap");
73
    if (remote_prefetch && !epmap.empty()) {
74 75 76
      // if epmap is not empty, then the parameter will be fetched from remote
      // parameter
      // server
Q
Qiao Longfei 已提交
77
      auto height_sections = ctx.Attr<std::vector<int64_t>>("height_sections");
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
      auto table_names = ctx.Attr<std::vector<std::string>>("table_names");
      std::vector<int64_t> real_rows = PathToRows(*path);
      framework::Scope& local_scope = ctx.scope().NewScope();
      auto* ids = local_scope.Var("Ids@Prefetch");
      auto* x_tensor = ids->GetMutable<framework::LoDTensor>();

      x_tensor->mutable_data<int64_t>(
          framework::make_ddim({static_cast<int64_t>(real_rows.size()), 1}),
          ctx.GetPlace());
      // copy.

      std::memcpy(x_tensor->data<int64_t>(), real_rows.data(),
                  real_rows.size() * sizeof(int64_t));

      framework::DDim w_dims = ctx.Input<Tensor>("W")->dims();
      w_dims[0] = x_tensor->dims()[0];
      auto* w_tensor =
          local_scope.Var("W@Prefetch")->GetMutable<framework::LoDTensor>();
      w_tensor->Resize(w_dims);

#ifdef PADDLE_WITH_DISTRIBUTE
      // w_Out is set to used by prefetch, never change it in other cases
      auto* w_out = ctx.Output<framework::LoDTensor>("W_Out");
      operators::distributed::prefetch_with_reconstruct<T>(
          "Ids@Prefetch", "W@Prefetch", table_names, epmap, height_sections,
          ctx, local_scope, w_out);
#else
      PADDLE_THROW(
          "paddle is not compiled with distribute support, can not do "
          "parameter prefetch!");
#endif
    }

111 112 113 114 115 116
    bool is_custom = false;
    if (path) {
      is_custom = true;
    }
    int64_t code_length =
        path ? path->dims()[1] : math::FindLastSet(num_classes - 1);
J
JiabinYang 已提交
117
    int64_t batch_size = in.dims()[0];
J
JiabinYang 已提交
118
    framework::LoDTensor sum;
W
weixing02 已提交
119
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
G
guosheng 已提交
120
    auto* pre_out_data = pre_out->mutable_data<T>(
Y
Yancey1989 已提交
121
        framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
W
weixing02 已提交
122
    auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
G
guosheng 已提交
123 124
    // Not all class(leaf) nodes' path lengths equal code_length, thus init as
    // 0s can avoid out of path's loss.
125
    math::SetConstant<DeviceContext, T> zero;
W
weixing02 已提交
126
    zero(dev_ctx, pre_out, static_cast<T>(0.0));
Y
Yancey1989 已提交
127 128
    auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
    math::RowwiseSum<DeviceContext, T> row_sum;
129 130 131 132

    std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
    if (!is_custom) {
      bit_code.reset(new math::MatrixBitCodeFunctor<T>(num_classes,
J
JiabinYang 已提交
133
                                                       label.data<int64_t>()));
134
    } else {
J
JiabinYang 已提交
135 136
      bit_code.reset(new math::MatrixBitCodeFunctor<T>(*path, *code,
                                                       label.data<int64_t>()));
137
    }
Y
Yancey1989 已提交
138

Y
Yancey1989 已提交
139 140
    std::vector<int64_t> sum_dims({batch_size, 1UL});
    sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
Y
Yancey1989 已提交
141
    auto sum_mat = EigenMatrix<T>::From(sum);
Y
Yancey1989 已提交
142
    out->mutable_data<T>(ctx.GetPlace());
143
    auto out_mat = framework::EigenMatrix<T>::From(*out);
Y
Yancey1989 已提交
144
    if (bias) {
145
      bit_code->Add(*bias, pre_out);
Y
Yancey1989 已提交
146
    }
J
JiabinYang 已提交
147
    bit_code->Mul(pre_out, w, in);
G
guosheng 已提交
148
    // clip to [-40, 40]
Y
Yancey1989 已提交
149 150
    Transform<DeviceContext> trans;
    trans(ctx.template device_context<DeviceContext>(), pre_out_data,
W
weixing02 已提交
151
          pre_out_data + pre_out->numel(), pre_out_data,
Y
Yancey1989 已提交
152
          ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
153
    bit_code->Sum(*pre_out, out, static_cast<T>(-1));
G
guosheng 已提交
154
    // use softrelu to calculate cross entropy
Y
Yancey1989 已提交
155
    pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
W
weixing02 已提交
156
    row_sum(dev_ctx, *pre_out, &sum);
157 158 159 160
    // TODO(guosheng): Subtract the out of path's loss, since not all
    // class(leaf) nodes' path lengths equal code_length. But it won't break the
    // gradient check since both have the out of path's loss and will cancel out
    // each other.
Y
Yancey1989 已提交
161
    out_mat.device(place) = sum_mat + out_mat;
Y
Yancey1989 已提交
162
  }
Y
Yancey1989 已提交
163 164
};

Y
Yancey1989 已提交
165
template <typename DeviceContext, typename T>
Y
Yancey1989 已提交
166 167
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
 public:
Y
Yancey1989 已提交
168
  void Compute(const framework::ExecutionContext& ctx) const override {
169 170
    auto& in = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
    auto& w = detail::Ref(ctx.Input<framework::LoDTensor>("W"));
171
    auto* path = ctx.Input<framework::LoDTensor>("PathTable");
J
JiabinYang 已提交
172
    auto* code = ctx.Input<framework::LoDTensor>("PathCode");
J
JiabinYang 已提交
173 174 175 176 177
    auto* in_grad =
        ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
    bool is_sparse = ctx.Attr<bool>("is_sparse");
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    math::SetConstant<DeviceContext, T> zero;
178 179 180
    auto& label = detail::Ref(ctx.Input<framework::LoDTensor>("Label"));
    auto& pre_out = detail::Ref(ctx.Input<framework::LoDTensor>("PreOut"));
    auto& out_grad = detail::Ref(
J
JiabinYang 已提交
181
        ctx.Input<framework::LoDTensor>(framework::GradVarName("Out")));
J
JiabinYang 已提交
182
    framework::LoDTensor pre_out_grad;
183

J
JiabinYang 已提交
184
    pre_out_grad.mutable_data<T>(pre_out.dims(), ctx.GetPlace());
185 186
    in_grad->mutable_data<T>(ctx.GetPlace());
    zero(dev_ctx, in_grad, static_cast<T>(0.0));
W
weixing02 已提交
187

Y
Yancey1989 已提交
188
    size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
189 190 191 192 193 194 195 196 197

    bool is_custom = false;
    if (path) {
      is_custom = true;
    }

    std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
    if (!is_custom) {
      bit_code.reset(new math::MatrixBitCodeFunctor<T>(num_classes,
J
JiabinYang 已提交
198
                                                       label.data<int64_t>()));
199
    } else {
J
JiabinYang 已提交
200 201
      bit_code.reset(new math::MatrixBitCodeFunctor<T>(*path, *code,
                                                       label.data<int64_t>()));
202
    }
203

Y
Use mkl  
Yu Yang 已提交
204
    // softrelu derivative
J
JiabinYang 已提交
205

Y
Use mkl  
Yu Yang 已提交
206
    auto blas = math::GetBlas<DeviceContext, T>(ctx);
207

Y
Use mkl  
Yu Yang 已提交
208 209 210 211 212 213 214 215
    auto* pre_out_grad_data = pre_out_grad.data<T>();
    auto* pre_out_data = pre_out.data<T>();
    auto n = pre_out.numel();
    blas.VEXP(n, pre_out_data, pre_out_grad_data);
    blas.VINV(n, pre_out_grad_data, pre_out_grad_data);
    for (int64_t i = 0; i < n; ++i) {
      pre_out_grad_data[i] = 1.0 - pre_out_grad_data[i];
    }
216
    bit_code->Sub(&pre_out_grad);  // the gradient of clip(w * x + b)
Y
Use mkl  
Yu Yang 已提交
217 218 219 220 221 222 223 224
    auto* out_grad_data = out_grad.data<T>();

    int64_t dim0 = pre_out_grad.dims()[0];
    int64_t dim1 = pre_out_grad.dims()[1];
    for (int64_t i = 0; i < dim0; ++i) {
      T tmp = out_grad_data[i];
      blas.SCAL(dim1, tmp, pre_out_grad_data + i * dim1);
    }
G
guosheng 已提交
225 226
    // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
    // be consistent with the clipping in forward.
227 228 229 230 231 232 233
    auto* bias_grad =
        ctx.Output<framework::LoDTensor>(framework::GradVarName("Bias"));
    if (bias_grad) {
      bias_grad->mutable_data<T>(ctx.GetPlace());
      zero(dev_ctx, bias_grad, static_cast<T>(0.0));
      bit_code->AddGrad(pre_out_grad, bias_grad);
    }
J
JiabinYang 已提交
234 235 236 237 238
    if (!is_sparse) {
      auto* w_grad =
          ctx.Output<framework::LoDTensor>(framework::GradVarName("W"));
      w_grad->mutable_data<T>(ctx.GetPlace());
      zero(dev_ctx, w_grad, static_cast<T>(0.0));
J
JiabinYang 已提交
239
      bit_code->MulGradWeight(pre_out_grad, w_grad, in);
J
JiabinYang 已提交
240
    } else {
J
JiabinYang 已提交
241
      framework::Vector<int64_t> real_rows = PathToRows(*path);
J
JiabinYang 已提交
242 243 244
      auto* w_grad =
          ctx.Output<framework::SelectedRows>(framework::GradVarName("W"));
      w_grad->set_rows(real_rows);
245
      // Build a map of id -> row_index to speed up finding the index of one id
J
JiabinYang 已提交
246
      w_grad->set_height(w.dims()[0]);
J
JiabinYang 已提交
247
      auto* w_grad_value = w_grad->mutable_value();
J
JiabinYang 已提交
248
      framework::DDim temp_dim(w.dims());
J
JiabinYang 已提交
249 250 251 252
      set(temp_dim, 0, real_rows.size());

      w_grad_value->mutable_data<T>(temp_dim, ctx.GetPlace());
      zero(dev_ctx, w_grad_value, static_cast<T>(0.0));
J
JiabinYang 已提交
253
      bit_code->MulGradWeight(pre_out_grad, w_grad, in);
J
JiabinYang 已提交
254
    }
J
JiabinYang 已提交
255
    bit_code->MulGradError(pre_out_grad, w, in_grad);
Y
Yancey1989 已提交
256
  }
Y
Yancey1989 已提交
257 258 259 260
};

}  // namespace operators
}  // namespace paddle