ContextProjectionOp.cpp 15.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "ContextProjectionOp.h"
16 17 18 19
#include "paddle/math/Matrix.h"
#include "paddle/math/Vector.h"

namespace paddle {
X
xutianbing 已提交
20 21 22 23
/**
 * Context Projection Forward with CPU Matrix Device.
 *
 */
24
template <>
25 26 27
void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat,
                                               const CpuMatrix& input_mat,
                                               const CpuMatrix& weight_mat,
28
                                               const CpuIVector& seq_vec,
29 30
                                               size_t context_length,
                                               int context_start,
31
                                               size_t begin_pad) {
32 33 34 35 36 37 38 39 40 41 42
  const int* starts = seq_vec.getData();
  const size_t num_sequences = seq_vec.getSize() - 1;
  for (size_t i = 0; i < num_sequences; ++i) {
    for (size_t j = 0; j < context_length; ++j) {
      int begin = starts[i] + context_start + j;
      int end = starts[i + 1] + context_start + j;
      int dst_begin = starts[i];
      int dst_end = starts[i + 1];
      if (begin < starts[i]) {
        int64_t pad_size =
            std::min(starts[i] - begin, starts[i + 1] - starts[i]);
43 44 45 46 47
        MatrixPtr mat = out_mat.subMatrix(starts[i], pad_size);
        if (weight_mat) {
          MatrixPtr sub =
              const_cast<CpuMatrix&>(weight_mat).subMatrix(j, pad_size);
          mat->addAtOffset(*sub, j * input_mat.getWidth());
48 49 50 51 52 53 54
        }
        dst_begin = starts[i] + pad_size;
        begin = starts[i];
      }
      if (end > starts[i + 1]) {
        int64_t pad_size =
            std::min(end - starts[i + 1], starts[i + 1] - starts[i]);
55 56 57 58 59 60 61
        MatrixPtr mat = out_mat.subMatrix(starts[i + 1] - pad_size, pad_size);
        if (weight_mat) {
          MatrixPtr sub =
              const_cast<CpuMatrix&>(weight_mat)
                  .subMatrix(begin_pad + context_start + j - pad_size,
                             pad_size);
          mat->addAtOffset(*sub, j * input_mat.getWidth());
62 63 64 65 66
        }
        dst_end = starts[i + 1] - pad_size;
        end = starts[i + 1];
      }
      if (end <= begin) continue;
67 68 69 70
      MatrixPtr src =
          const_cast<CpuMatrix&>(input_mat).subMatrix(begin, end - begin);
      MatrixPtr dst = out_mat.subMatrix(dst_begin, dst_end - dst_begin);
      dst->addAtOffset(*src, j * input_mat.getWidth());
71 72 73 74 75
    }
  }
}

/**
X
xutianbing 已提交
76
 * Paddle Function for Context Projection Forward.
77
 * Calculate the output sequence after context projection.
X
xutianbing 已提交
78 79 80 81 82
 *
 * What is Context Projection?
 * For example, assumed input (x) has 4 words and the dimension of each word
 * representation is 2. If we use zero to pad instead of learned weight to pad,
 * and the context_lenth is 3, the output (y) is:
83
 *
X
xutianbing 已提交
84 85 86 87 88 89 90 91 92 93 94
 * @code
 *  x = [a1, a2;
 *       b1, b2;
 *       c1, c2;
 *       d1, d2]
 *  y = [0,  0,  a1, a2, b1, b2;
 *       a1, a2, b1, b2, c1, c2;
 *       b1, b2, c1, c2, d1, d2;
 *       c1, c2, d1, d2, 0,  0]
 * @endcode
 *
95 96 97 98 99 100
 * \param outputs[0].matrix   output value, n * (d * l)
 * \param outputs[0].vector   input sequence, n * 1
 * \param inputs[0].matrix    input value, n * d
 * \param inputs[0].vector    input sequence, n * 1
 * \param inputs[1].matrix    input weight, pad * d
 * \param inputs[1].vector    input sequence, n * 1
101 102 103 104 105 106 107 108 109 110
 */
template <DeviceType Device>
class ContextProjectionForwardFunc : public FunctionBase {
public:
  void init(const FuncConfig& config) override {
    context_length_ = config.get<size_t>("context_length");
    context_start_ = config.get<int>("context_start");
    begin_pad_ = config.get<size_t>("begin_pad");
  }

111
  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
112
    CHECK(1 == inputs.size() || 2 == inputs.size());
H
hedaoyuan 已提交
113
    CHECK_EQ((size_t)1, outputs.size());
114

115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
    const auto val_seqs = dynamic_cast<const SequenceArg&>(inputs[0]);
    const auto w_seqs = inputs.size() <= 1
                            ? nullptr
                            : dynamic_cast<const SequenceArg*>(&inputs[1]);
    auto out_seqs = dynamic_cast<const SequenceArg&>(outputs[0]);

    CHECK(out_seqs.data() && val_seqs.data() &&
          val_seqs.getSequenceIds().data());
    CHECK_EQ(out_seqs.shape().ndims(), (size_t)2);
    CHECK_EQ(val_seqs.shape().ndims(), (size_t)2);
    CHECK_EQ(val_seqs.getSequenceIds().shape().ndims(), (size_t)1);
    if (w_seqs) {
      CHECK_EQ(w_seqs->shape().ndims(), (size_t)2);
      CHECK_EQ(w_seqs->getSequenceIds().shape().ndims(), (size_t)1);
    }
130
    /// dim of output = dim of input * context_length
131
    CHECK_EQ(out_seqs.shape()[1], val_seqs.shape()[1] * context_length_);
132
    /// input and output has the same batch_size
133 134 135 136 137
    CHECK_EQ(val_seqs.shape()[0], out_seqs.shape()[0]);
    /// dim of input == dim of weight
    if (w_seqs) {
      CHECK_EQ(val_seqs.shape()[1], w_seqs->shape()[1]);
    }
138

139 140 141
    CHECK_EQ(out_seqs.getArgType(), ADD_TO);
    auto out_mat = out_seqs.matrix<Device>();
    const auto in_mat = val_seqs.matrix<Device>();
142
    const auto w_mat =
143 144 145
        w_seqs ? w_seqs->matrix<Device>()
               : typename Tensor<real, Device>::Matrix(nullptr, 0, 0);
    const auto seq_vec = val_seqs.getSequenceIds().vector<int, Device>();
146 147 148
    ContextProjectionForward<Device>(out_mat,
                                     in_mat,
                                     w_mat,
149
                                     seq_vec,
150 151
                                     context_length_,
                                     context_start_,
152
                                     begin_pad_);
153 154 155 156 157 158 159 160
  }

private:
  size_t context_length_;
  int context_start_;
  size_t begin_pad_;
};

X
xutianbing 已提交
161 162 163 164
/**
 * Context Projection Backward with CPU Matrix Device.
 *
 */
165
template <>
166
void ContextProjectionBackward<DEVICE_TYPE_CPU>(const CpuMatrix& out_grad_mat,
167 168
                                                CpuMatrix& in_grad_mat,
                                                CpuMatrix& w_grad_mat,
169
                                                const CpuIVector& seq_vec,
170 171 172
                                                size_t context_length,
                                                int context_start,
                                                size_t begin_pad,
173 174
                                                bool is_padding,
                                                size_t total_pad) {
175 176
  size_t input_dim = in_grad_mat ? in_grad_mat.getWidth()
                                 : w_grad_mat ? w_grad_mat.getWidth() : 0;
177 178 179 180 181 182 183 184 185 186 187 188
  const int* starts = seq_vec.getData();
  size_t num_sequences = seq_vec.getSize() - 1;
  for (size_t i = 0; i < num_sequences; ++i) {
    for (size_t j = 0; j < context_length; ++j) {
      int begin = starts[i] + context_start + j;
      int end = starts[i + 1] + context_start + j;
      int dst_begin = starts[i];
      int dst_end = starts[i + 1];
      if (begin < starts[i]) {
        int64_t pad_size =
            std::min(starts[i] - begin, starts[i + 1] - starts[i]);
        if (is_padding && w_grad_mat) {
189 190
          MatrixPtr mat = const_cast<CpuMatrix&>(out_grad_mat)
                              .subMatrix(starts[i], pad_size);
191
          MatrixPtr sub = w_grad_mat.subMatrix(j, pad_size);
192 193 194 195 196 197 198 199 200
          sub->addAtOffset(*mat, j * input_dim);
        }
        dst_begin = starts[i] + pad_size;
        begin = starts[i];
      }
      if (end > starts[i + 1]) {
        int64_t pad_size =
            std::min(end - starts[i + 1], starts[i + 1] - starts[i]);
        if (is_padding && w_grad_mat) {
201 202
          MatrixPtr mat = const_cast<CpuMatrix&>(out_grad_mat)
                              .subMatrix(starts[i + 1] - pad_size, pad_size);
203
          MatrixPtr sub = w_grad_mat.subMatrix(
204 205 206 207 208 209 210 211
              begin_pad + context_start + j - pad_size, pad_size);
          sub->addAtOffset(*mat, j * input_dim);
        }
        dst_end = starts[i + 1] - pad_size;
        end = starts[i + 1];
      }
      if (end <= begin) continue;
      if (!in_grad_mat) continue;
212
      MatrixPtr src = in_grad_mat.subMatrix(begin, end - begin);
213 214
      MatrixPtr dst = const_cast<CpuMatrix&>(out_grad_mat)
                          .subMatrix(dst_begin, dst_end - dst_begin);
215 216 217 218 219 220
      src->addAtOffset(*dst, j * input_dim);
    }
  }
}

/**
X
xutianbing 已提交
221 222 223
 * Context Projection Backward Function.
 * Update the weight gradient and input layer gradient with backprop
 *
X
xutianbing 已提交
224 225 226 227
 * \param inputs[0].seq          input sequence.
 * \param inputs[0].matrix       output layer grad.
 * \param outputs[0]             input layer grad.
 * \param outputs[1]             weight grad.
228 229 230 231 232 233 234 235 236
 */
template <DeviceType Device>
class ContextProjectionBackwardFunc : public FunctionBase {
public:
  void init(const FuncConfig& config) override {
    context_length_ = config.get<size_t>("context_length");
    context_start_ = config.get<int>("context_start");
    begin_pad_ = config.get<size_t>("begin_pad");
    is_padding_ = config.get<bool>("is_padding");
237
    total_pad_ = config.get<size_t>("total_pad");
238 239
  }

240
  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
X
xutianbing 已提交
241
    CHECK_EQ((size_t)1, inputs.size());
242
    CHECK_EQ((size_t)2, outputs.size());
243

244 245 246 247
    const auto seq_arg = dynamic_cast<const SequenceArg&>(inputs[0]);
    CHECK(seq_arg.data() && inputs[0].data());
    CHECK_EQ(seq_arg.shape().ndims(), (size_t)2);
    CHECK_EQ(seq_arg.getSequenceIds().shape().ndims(), (size_t)1);
248 249
    CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
    CHECK_EQ(outputs[1].shape().ndims(), (size_t)2);
250

251 252 253
    /// dim of input grad == dim of weight
    CHECK_EQ(outputs[0].shape()[1], outputs[1].shape()[1]);
    /// input and output grad has the same batch_size
254
    CHECK_EQ(outputs[0].shape()[0], seq_arg.shape()[0]);
255
    /// dim of output val = dim of input grad * context_length
256
    CHECK_EQ(seq_arg.shape()[1], outputs[0].shape()[1] * context_length_);
257

258
    CHECK_EQ(outputs[0].getArgType(), ADD_TO);
259
    CHECK_EQ(outputs[1].getArgType(), ADD_TO);
260

261 262
    const auto seq_vec = seq_arg.getSequenceIds().vector<int, Device>();
    const auto out_grad_mat = seq_arg.matrix<Device>();
263
    auto in_grad_mat =
264 265 266 267
        !outputs[0].data()
            ? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
            : outputs[0].matrix<Device>();
    auto w_grad_mat = !outputs[1].data()
268
                          ? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
269
                          : outputs[1].matrix<Device>();
270 271 272
    ContextProjectionBackward<Device>(out_grad_mat,
                                      in_grad_mat,
                                      w_grad_mat,
273
                                      seq_vec,
274 275 276
                                      context_length_,
                                      context_start_,
                                      begin_pad_,
277 278
                                      is_padding_,
                                      total_pad_);
279 280 281 282 283 284 285
  }

private:
  size_t context_length_;
  int context_start_;
  size_t begin_pad_;
  bool is_padding_;
286
  size_t total_pad_;
287 288
};

289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
/**
 * \param inputs[0].matrix      input grad, n*d
 * \param inputs[0].vector      input sequence, n*1
 * \param outputs[0]            output grad, n*(d*l)
 */
template <DeviceType Device>
class ContextProjectionBackwardDataFunc : public FunctionBase {
public:
  void init(const FuncConfig& config) override {
    context_length_ = config.get<size_t>("context_length");
    context_start_ = config.get<int>("context_start");
  }

  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
    CHECK_EQ(1, static_cast<int>(inputs.size()));
    CHECK_EQ(1, static_cast<int>(outputs.size()));
    const auto in_seqs = dynamic_cast<const SequenceArg&>(inputs[0]);
    CHECK(in_seqs.data() && outputs[0].data() &&
          in_seqs.getSequenceIds().data());
    CHECK_EQ(static_cast<int>(outputs[0].shape().ndims()), 2);
    CHECK_EQ(static_cast<int>(in_seqs.shape().ndims()), 2);
    CHECK_EQ(static_cast<int>(in_seqs.getSequenceIds().shape().ndims()), 1);
    CHECK_EQ(outputs[0].shape().ndims(),
             in_seqs.shape().ndims() * context_length_);
    /// input and output has the same batch_size
    CHECK_EQ(in_seqs.shape()[0], outputs[0].shape()[0]);
    const auto out_grad_mat = outputs[0].matrix<Device>();
    auto in_grad_mat = in_seqs.matrix<Device>();
    const auto seq_vec = in_seqs.getSequenceIds().vector<int, Device>();

    ContextProjectionBackwardData<Device>(
        out_grad_mat, in_grad_mat, seq_vec, context_length_, context_start_);
  }

private:
  size_t context_length_;
  int context_start_;
};

/**
 * \param inputs[0].matrix    weight grad, pad * d
 * \param inputs[0].vecotr    input sequence, n * 1
 * \param outputs[0]          output grad, n * (d * l)
 */
template <DeviceType Device>
class ContextProjectionBackwardWeightFunc : public FunctionBase {
public:
  void init(const FuncConfig& config) override {
    context_length_ = config.get<size_t>("context_length");
    context_start_ = config.get<int>("context_start");
    begin_pad_ = config.get<size_t>("begin_pad");
    total_pad_ = config.get<size_t>("total_pad");
  }

  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
    CHECK_EQ(1, static_cast<int>(inputs.size()));
    CHECK_EQ(1, static_cast<int>(outputs.size()));

    const auto in_seqs = dynamic_cast<const SequenceArg&>(inputs[0]);
    CHECK(in_seqs.data() && in_seqs.getSequenceIds().data() &&
          outputs[0].data());
    CHECK_EQ(static_cast<int>(outputs[0].shape().ndims()), 2);
    CHECK_EQ(static_cast<int>(in_seqs.shape().ndims()), 2);
    CHECK_EQ(static_cast<int>(in_seqs.getSequenceIds().shape().ndims()), 1);
    CHECK_EQ(in_seqs.shape()[0], outputs[0].shape()[0]);
    CHECK_EQ(outputs[0].shape()[1], in_seqs.shape()[1] * context_length_);
    const auto out_grad_mat = outputs[0].matrix<Device>();
    auto w_grad_mat = inputs[0].matrix<Device>();
    const auto seq_vec = in_seqs.getSequenceIds().vector<int, Device>();
    ContextProjectionBackwardWeight<Device>(out_grad_mat,
                                            w_grad_mat,
                                            seq_vec,
                                            context_length_,
                                            context_start_,
                                            total_pad_,
                                            begin_pad_);
  }

private:
  size_t context_length_;
  int context_start_;
  size_t begin_pad_;
  size_t total_pad_;
};

374 375 376
REGISTER_TYPED_FUNC(ContextProjectionForward,
                    CPU,
                    ContextProjectionForwardFunc);
377 378 379
REGISTER_TYPED_FUNC(ContextProjectionBackward,
                    CPU,
                    ContextProjectionBackwardFunc);
380 381 382 383
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(ContextProjectionForward,
                    GPU,
                    ContextProjectionForwardFunc);
384 385 386
REGISTER_TYPED_FUNC(ContextProjectionBackward,
                    GPU,
                    ContextProjectionBackwardFunc);
387 388 389 390 391 392
REGISTER_TYPED_FUNC(ContextProjectionBackwardData,
                    GPU,
                    ContextProjectionBackwardDataFunc);
REGISTER_TYPED_FUNC(ContextProjectionBackwardWeight,
                    GPU,
                    ContextProjectionBackwardWeightFunc);
393
#endif
394
}  // namespace paddle