math_function.cc 14.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pten/kernels/funcs/math_function.h"

#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif

#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#endif

#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/math_function_impl.h"
#include "unsupported/Eigen/CXX11/Tensor"

namespace pten {
namespace funcs {

using float16 = paddle::platform::float16;

template struct SetConstant<paddle::platform::CPUDeviceContext,
                            paddle::platform::float16>;
template struct SetConstant<paddle::platform::CPUDeviceContext,
                            paddle::platform::bfloat16>;
template struct SetConstant<paddle::platform::CPUDeviceContext, float>;
template struct SetConstant<paddle::platform::CPUDeviceContext, double>;
template struct SetConstant<paddle::platform::CPUDeviceContext, int16_t>;
template struct SetConstant<paddle::platform::CPUDeviceContext, int>;
template struct SetConstant<paddle::platform::CPUDeviceContext, int64_t>;
template struct SetConstant<paddle::platform::CPUDeviceContext, bool>;
template struct SetConstant<paddle::platform::CPUDeviceContext, uint8_t>;
template struct SetConstant<paddle::platform::CPUDeviceContext,
                            paddle::platform::complex<float>>;
template struct SetConstant<paddle::platform::CPUDeviceContext,
                            paddle::platform::complex<double>>;

template struct SetConstant<pten::CPUContext, paddle::platform::float16>;
template struct SetConstant<pten::CPUContext, paddle::platform::bfloat16>;
template struct SetConstant<pten::CPUContext, float>;
template struct SetConstant<pten::CPUContext, double>;
template struct SetConstant<pten::CPUContext, int16_t>;
template struct SetConstant<pten::CPUContext, int>;
template struct SetConstant<pten::CPUContext, int64_t>;
template struct SetConstant<pten::CPUContext, bool>;
template struct SetConstant<pten::CPUContext, uint8_t>;
template struct SetConstant<pten::CPUContext, paddle::platform::complex<float>>;
template struct SetConstant<pten::CPUContext,
                            paddle::platform::complex<double>>;

#ifdef PADDLE_WITH_XPU
template struct SetConstant<paddle::platform::XPUDeviceContext,
                            paddle::platform::float16>;
template struct SetConstant<paddle::platform::XPUDeviceContext,
                            paddle::platform::bfloat16>;
template struct SetConstant<paddle::platform::XPUDeviceContext, float>;
template struct SetConstant<paddle::platform::XPUDeviceContext, double>;
template struct SetConstant<paddle::platform::XPUDeviceContext, uint8_t>;
template struct SetConstant<paddle::platform::XPUDeviceContext, int16_t>;
template struct SetConstant<paddle::platform::XPUDeviceContext, int>;
template struct SetConstant<paddle::platform::XPUDeviceContext, int64_t>;
template struct SetConstant<paddle::platform::XPUDeviceContext, bool>;
template struct SetConstant<paddle::platform::XPUDeviceContext,
                            paddle::platform::complex<float>>;
template struct SetConstant<paddle::platform::XPUDeviceContext,
                            paddle::platform::complex<double>>;
#endif

#define DEFINE_CPU_TRANS(RANK)                                                 \
  template struct Transpose<paddle::platform::CPUDeviceContext,                \
                            paddle::platform::float16,                         \
                            RANK>;                                             \
  template struct Transpose<paddle::platform::CPUDeviceContext,                \
                            paddle::platform::bfloat16,                        \
                            RANK>;                                             \
  template struct Transpose<paddle::platform::CPUDeviceContext, float, RANK>;  \
  template struct Transpose<paddle::platform::CPUDeviceContext, double, RANK>; \
  template struct Transpose<paddle::platform::CPUDeviceContext, int, RANK>;    \
  template struct Transpose<paddle::platform::CPUDeviceContext,                \
                            int64_t,                                           \
                            RANK>;                                             \
  template struct Transpose<paddle::platform::CPUDeviceContext, bool, RANK>;   \
  template struct Transpose<paddle::platform::CPUDeviceContext,                \
                            int16_t,                                           \
                            RANK>;                                             \
  template struct Transpose<paddle::platform::CPUDeviceContext,                \
                            uint8_t,                                           \
                            RANK>;                                             \
  template struct Transpose<paddle::platform::CPUDeviceContext, int8_t, RANK>; \
  template struct Transpose<paddle::platform::CPUDeviceContext,                \
                            paddle::platform::complex<float>,                  \
                            RANK>;                                             \
  template struct Transpose<paddle::platform::CPUDeviceContext,                \
                            paddle::platform::complex<double>,                 \
                            RANK>;

DEFINE_CPU_TRANS(1);
DEFINE_CPU_TRANS(2);
DEFINE_CPU_TRANS(3);
DEFINE_CPU_TRANS(4);
DEFINE_CPU_TRANS(5);
DEFINE_CPU_TRANS(6);

template <typename T>
struct TransposeNormal<paddle::platform::CPUDeviceContext, T> {
  void operator()(const paddle::platform::CPUDeviceContext& context,
                  const paddle::framework::Tensor& in,
                  paddle::framework::Tensor* out,
                  const std::vector<int>& axis) {
    const int rank = axis.size();
    auto in_stride = paddle::framework::stride(in.dims());
    auto out_stride = paddle::framework::stride(out->dims());
    const T* in_ptr = in.data<T>();
    T* out_ptr = out->data<T>();

    auto transpose_helper = [&](int64_t beg, int64_t end) {
      for (int64_t out_idx = beg; out_idx < end; ++out_idx) {
        int64_t in_idx = 0;
        int64_t tmp_idx = out_idx;
        // calculate the input index
        for (int i = 0; i < rank; ++i) {
          const int64_t coordinate = tmp_idx / out_stride[i];
          tmp_idx -= coordinate * out_stride[i];
          in_idx += coordinate * in_stride[axis[i]];
        }
        out_ptr[out_idx] = in_ptr[in_idx];
      }
    };
    transpose_helper(0, out->numel());
  }
};

// define transpose normal
#define DEFINE_CPU_TRANS_NORMAL(TYPE) \
  template struct TransposeNormal<paddle::platform::CPUDeviceContext, TYPE>

DEFINE_CPU_TRANS_NORMAL(paddle::platform::float16);
DEFINE_CPU_TRANS_NORMAL(paddle::platform::bfloat16);
DEFINE_CPU_TRANS_NORMAL(float);
DEFINE_CPU_TRANS_NORMAL(double);
DEFINE_CPU_TRANS_NORMAL(int);
DEFINE_CPU_TRANS_NORMAL(int64_t);
DEFINE_CPU_TRANS_NORMAL(bool);
DEFINE_CPU_TRANS_NORMAL(int16_t);
DEFINE_CPU_TRANS_NORMAL(uint8_t);
DEFINE_CPU_TRANS_NORMAL(int8_t);
DEFINE_CPU_TRANS_NORMAL(paddle::platform::complex<float>);
DEFINE_CPU_TRANS_NORMAL(paddle::platform::complex<double>);

struct TensorSetConstantCPU {
  TensorSetConstantCPU(paddle::framework::Tensor* tensor, float value)
      : tensor_(tensor), value_(value) {}
  template <typename T>
  void apply() const {
    auto cpu = paddle::platform::CPUPlace();
    auto* begin = tensor_->mutable_data<T>(cpu);
    std::fill(begin, begin + tensor_->numel(), static_cast<T>(value_));
  }
  paddle::framework::Tensor* tensor_;
  float value_;
};

template <>
void set_constant_with_place<paddle::platform::XPUPlace>(
    const paddle::platform::DeviceContext& context,
    paddle::framework::Tensor* tensor,
    float value) {
  PADDLE_THROW(
      paddle::platform::errors::Unimplemented("XPUPlace is not supported"));
}

template <>
void set_constant_with_place<paddle::platform::NPUPlace>(
    const paddle::platform::DeviceContext& context,
    paddle::framework::Tensor* tensor,
    float value) {
  PADDLE_THROW(
      paddle::platform::errors::Unimplemented("NPUPlace is not supported"));
}

template <>
void set_constant_with_place<paddle::platform::NPUPinnedPlace>(
    const paddle::platform::DeviceContext& context,
    paddle::framework::Tensor* tensor,
    float value) {
  PADDLE_THROW(paddle::platform::errors::Unimplemented(
      "NPUPinnedPlace is not supported"));
}

template <>
void set_constant_with_place<paddle::platform::IPUPlace>(
    const paddle::platform::DeviceContext& context,
    paddle::framework::Tensor* tensor,
    float value) {
  PADDLE_THROW(
      paddle::platform::errors::Unimplemented("IPUPlace is not supported"));
}

218 219 220 221 222 223 224 225 226
template <>
void set_constant_with_place<paddle::platform::CustomPlace>(
    const paddle::platform::DeviceContext& context,
    paddle::framework::Tensor* tensor,
    float value) {
  PADDLE_THROW(
      paddle::platform::errors::Unimplemented("CustomPlace is not supported"));
}

227 228 229 230 231
template <>
void set_constant_with_place<paddle::platform::CPUPlace>(
    const paddle::platform::DeviceContext& context,
    paddle::framework::Tensor* tensor,
    float value) {
232
  pten::VisitDataType(tensor->dtype(), TensorSetConstantCPU(tensor, value));
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
}

template <>
void set_constant_with_place<paddle::platform::MLUPlace>(
    const paddle::platform::DeviceContext& context,
    paddle::framework::Tensor* tensor,
    float value) {
  PADDLE_THROW(
      paddle::platform::errors::Unimplemented("MLUPlace is not supported"));
}

template <>
void set_constant_with_place<paddle::platform::CUDAPinnedPlace>(
    const paddle::platform::DeviceContext& context,
    paddle::framework::Tensor* tensor,
    float value) {
249
  pten::VisitDataType(tensor->dtype(), TensorSetConstantCPU(tensor, value));
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
}

struct TensorSetConstantWithPlace : public boost::static_visitor<void> {
  TensorSetConstantWithPlace(const paddle::platform::DeviceContext& context,
                             paddle::framework::Tensor* tensor,
                             float value)
      : context_(context), tensor_(tensor), value_(value) {}

  template <typename Place>
  void operator()(Place place) const {
    set_constant_with_place<Place>(context_, tensor_, value_);
  }

  const paddle::platform::DeviceContext& context_;
  paddle::framework::Tensor* tensor_;
  float value_;
};

void set_constant(const paddle::platform::DeviceContext& context,
                  paddle::framework::Tensor* tensor,
                  float value) {
  TensorSetConstantWithPlace func(context, tensor, value);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  // tensor->place().apply_visitor(func);
  paddle::platform::VisitPlace(tensor->place(), func);
#else
  func(paddle::platform::CPUPlace());
#endif
}

template <typename T>
struct RowwiseAdd<paddle::platform::CPUDeviceContext, T> {
  void operator()(const paddle::platform::CPUDeviceContext& context,
                  const paddle::framework::Tensor& input,
                  const paddle::framework::Tensor& vector,
                  paddle::framework::Tensor* output) {
    auto in_dims = input.dims();
    auto out_dims = output->dims();
    auto size = input.numel() / in_dims[0];
    PADDLE_ENFORCE_EQ(
        vector.numel(),
        size,
        paddle::platform::errors::InvalidArgument(
            "The input vector size"
            " should be equal to the size of each row of input tensor."
            " Expected vector size=%d, but received %d",
            size,
            vector.numel()));
    const char* in_dims_cstr = in_dims.to_str().c_str();
    const char* out_dims_cstr = out_dims.to_str().c_str();
    PADDLE_ENFORCE_EQ(out_dims,
                      in_dims,
                      paddle::platform::errors::InvalidArgument(
                          "The output tensor shape should be same as the input"
                          " tensor shape. Expected output tensor shape: %s,"
                          " but received %s",
                          in_dims_cstr,
                          out_dims_cstr));

    auto in = paddle::framework::EigenMatrix<T>::From(input);
    auto vec = paddle::framework::EigenVector<T>::Flatten(vector);
    auto out = paddle::framework::EigenMatrix<T>::From(*output);

    for (int64_t i = 0; i < in_dims[0]; ++i) {
      out.chip(i, 0) = in.chip(i, 0) + vec;
    }
  }
};

template struct RowwiseAdd<paddle::platform::CPUDeviceContext, float>;
template struct RowwiseAdd<paddle::platform::CPUDeviceContext, double>;

template struct ColwiseSum<paddle::platform::CPUDeviceContext, float>;
template struct ColwiseSum<paddle::platform::CPUDeviceContext, double>;
template struct ColwiseSum<paddle::platform::CPUDeviceContext, int>;
template struct ColwiseSum<paddle::platform::CPUDeviceContext, int64_t>;

template struct RowwiseSum<paddle::platform::CPUDeviceContext, float>;
template struct RowwiseSum<paddle::platform::CPUDeviceContext, double>;

template struct RowwiseMean<paddle::platform::CPUDeviceContext, float>;
template struct RowwiseMean<paddle::platform::CPUDeviceContext, double>;

template <typename T>
struct ElementwiseAddTo<paddle::platform::CPUDeviceContext, T> {
  void operator()(paddle::platform::CPUDeviceContext* ctx,
                  const paddle::framework::Tensor& src,
                  paddle::framework::Tensor* dst) {
    auto in = paddle::framework::EigenVector<T>::Flatten(src);
    auto out = paddle::framework::EigenVector<T>::Flatten(*dst);
    auto& place = *(ctx->eigen_device());
    out.device(place) = out + in;
  }
};

template struct ElementwiseAddTo<paddle::platform::CPUDeviceContext,
                                 paddle::platform::float16>;

}  // namespace funcs
}  // namespace pten