sequence_pool_grad.cc 4.5 KB
Newer Older
C
chenjiaoAngel 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include <cmath>
#include <limits>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
C
chenjiaoAngel 已提交
20
#include "lite/backends/arm/math/sequence_pool_grad.h"
C
chenjiaoAngel 已提交
21 22 23 24 25 26 27 28 29 30
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"

namespace paddle {
namespace lite {
namespace arm {
namespace math {

template <>
C
chenjiaoAngel 已提交
31
void seq_pool_sum_grad<float>(const float* din,
C
chenjiaoAngel 已提交
32 33 34 35
                              const float* din_grad,
                              float* dout,
                              const std::vector<uint64_t> lod,
                              int64_t width) {
C
chenjiaoAngel 已提交
36
  for (int i = 0; i < static_cast<int>(lod.size()) - 1; i++) {
C
chenjiaoAngel 已提交
37 38
    int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
    const float* din_ptr = din + lod[i] * width;
C
chenjiaoAngel 已提交
39
    const float* din_grad_ptr = din_grad + i * width;
C
chenjiaoAngel 已提交
40
    float* dout_ptr = dout + lod[i] * width;
C
chenjiaoAngel 已提交
41 42 43
    if (height > 0) {
      if (width == 1) {
        for (int h = 0; h < height; ++h) {
C
chenjiaoAngel 已提交
44
          dout_ptr[h] = din_grad_ptr[h];
C
chenjiaoAngel 已提交
45 46 47 48
        }
      } else {
        for (int w = 0; w < width; w++) {
          for (int h = 0; h < height; h++) {
C
chenjiaoAngel 已提交
49 50
            dout_ptr[h] = *din_grad_ptr;
            dout_ptr += width;
C
chenjiaoAngel 已提交
51
          }
C
chenjiaoAngel 已提交
52 53
          din_grad_ptr++;
        }
C
chenjiaoAngel 已提交
54 55 56 57 58 59
      }
    }
  }
}

template <>
C
chenjiaoAngel 已提交
60
void seq_pool_average_grad<float>(const float* din,
C
chenjiaoAngel 已提交
61 62 63 64
                                  const float* din_grad,
                                  float* dout,
                                  const std::vector<uint64_t> lod,
                                  int64_t width) {
C
chenjiaoAngel 已提交
65 66 67
  for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
    int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
    const float* din_ptr = din + lod[i] * width;
C
chenjiaoAngel 已提交
68
    const float* din_grad_ptr = din_grad + i * width;
C
chenjiaoAngel 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
    float* dout_ptr = dout + lod[i] * width;
    float alpha = 1.0 / height;
    if (height > 0) {
      if (width == 1) {
        float sum = 0.f;
        for (int h = 0; h < height; ++h) {
          dout_ptr[h] = alpha * din_grad_ptr[h];
        }
      } else {
        for (int w = 0; w < width; w++) {
          for (int h = 0; h < height; h++) {
            dout_ptr[h] = alpha * din_grad_ptr[w];
            dout_ptr += width;
          }
        }
      }
    }
  }
}

template <>
C
chenjiaoAngel 已提交
90
void seq_pool_sqrt_grad<float>(const float* din,
C
chenjiaoAngel 已提交
91 92 93 94
                               const float* din_grad,
                               float* dout,
                               const std::vector<uint64_t> lod,
                               int64_t width) {
C
chenjiaoAngel 已提交
95 96 97
  for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
    int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
    const float* din_ptr = din + lod[i] * width;
C
chenjiaoAngel 已提交
98
    const float* din_grad_ptr = din_grad + i * width;
C
chenjiaoAngel 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
    float* dout_ptr = dout + lod[i] * width;
    float alpha = 1.0 / sqrtf(height);
    if (height > 0) {
      if (width == 1) {
        float sum = 0.f;
        for (int h = 0; h < height; ++h) {
          dout_ptr[h] = alpha * din_grad_ptr[h];
        }
      } else {
        for (int w = 0; w < width; w++) {
          for (int h = 0; h < height; h++) {
            dout_ptr[h] = alpha * din_grad_ptr[w];
            dout_ptr += width;
          }
        }
      }
    }
  }
}

template <>
C
chenjiaoAngel 已提交
120
void seq_pool_first_grad<float>(const float* din,
C
chenjiaoAngel 已提交
121 122 123 124
                                const float* din_grad,
                                float* dout,
                                const std::vector<uint64_t> lod,
                                int64_t width) {
C
chenjiaoAngel 已提交
125 126 127 128 129 130 131
  for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
    int64_t height = lod[i + 1] - lod[i];
    const float* din_ptr = din + width * lod[i];
    const float* din_grad_ptr = din + i * width;
    float* dout_ptr = dout + lod[i] * width;
    if (height > 0) {
      for (int w = 0; w < width; w++) {
C
chenjiaoAngel 已提交
132
        dout_ptr[w] = din_grad_ptr[w];
C
chenjiaoAngel 已提交
133 134 135 136 137 138 139 140 141
      }
    }
  }
}

}  // namespace math
}  // namespace arm
}  // namespace lite
}  // namespace paddle