TensorExpression.h 13.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
H
hedaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
H
hedaoyuan 已提交
14 15 16 17 18 19 20 21 22 23

#pragma once
#include <cstddef>
#include <stdint.h>
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/Logging.h"
#include "hl_tensor_ops.h"

namespace paddle {

H
hedaoyuan 已提交
24 25 26 27 28 29 30 31 32 33 34
template <class OP, typename ExprType, class T>
class TensorConstant;
template <class OP, typename ExprType, class T>
class TensorUnaryOp;
template <class OP, typename LhsType, typename RhsType, class T>
class TensorBinaryOp;
template <typename ExprType1, typename ExprType2, typename ExprType3, class T>
class TensorTernaryOp;

template <typename LhsType, typename RhsType, class T>
class TensorAssignOp;
H
hedaoyuan 已提交
35

H
hedaoyuan 已提交
36 37
/**
 * \brief Tensor base class.
H
hedaoyuan 已提交
38
 *
H
hedaoyuan 已提交
39 40
 * This is the base class of all Tensor and Expression class.
 */
H
hedaoyuan 已提交
41
template <typename Derived, class T>
H
hedaoyuan 已提交
42 43 44 45 46
class TensorExpression {
public:
  /**
   * Element wise unary expression.
   */
H
hedaoyuan 已提交
47 48 49
  template <typename UnaryOp>
  const TensorUnaryOp<UnaryOp, const Derived, T> unaryExpression(
      const UnaryOp& op) const {
H
hedaoyuan 已提交
50 51 52
    return TensorUnaryOp<UnaryOp, const Derived, T>(op, derived());
  }

H
hedaoyuan 已提交
53 54
  const TensorUnaryOp<hppl::unary::add_scale<T>, const Derived, T> operator+(
      T p) const {
H
hedaoyuan 已提交
55 56 57
    return unaryExpression(hppl::unary::add_scale<T>(p));
  }

H
hedaoyuan 已提交
58 59
  const TensorUnaryOp<hppl::unary::sub_scale<T>, const Derived, T> operator-(
      T p) const {
H
hedaoyuan 已提交
60 61 62
    return unaryExpression(hppl::unary::sub_scale<T>(p));
  }

H
hedaoyuan 已提交
63 64
  const TensorUnaryOp<hppl::unary::mul_scale<T>, const Derived, T> operator*(
      T p) const {
H
hedaoyuan 已提交
65 66 67
    return unaryExpression(hppl::unary::mul_scale<T>(p));
  }

H
hedaoyuan 已提交
68 69
  const TensorUnaryOp<hppl::unary::div_scale<T>, const Derived, T> operator/(
      T p) const {
H
hedaoyuan 已提交
70 71 72
    return unaryExpression(hppl::unary::div_scale<T>(p));
  }

H
hedaoyuan 已提交
73
  const TensorUnaryOp<hppl::unary::neg<T>, const Derived, T> operator-() const {
H
hedaoyuan 已提交
74 75 76
    return unaryExpression(hppl::unary::neg<T>());
  }

H
hedaoyuan 已提交
77
  const TensorUnaryOp<hppl::unary::exp_op<T>, const Derived, T> exp() const {
H
hedaoyuan 已提交
78 79 80
    return unaryExpression(hppl::unary::exp_op<T>());
  }

H
hedaoyuan 已提交
81
  const TensorUnaryOp<hppl::unary::log_op<T>, const Derived, T> log() const {
H
hedaoyuan 已提交
82 83 84
    return unaryExpression(hppl::unary::log_op<T>());
  }

H
hedaoyuan 已提交
85
  const TensorUnaryOp<hppl::unary::sqrt_op<T>, const Derived, T> sqrt() const {
H
hedaoyuan 已提交
86 87 88
    return unaryExpression(hppl::unary::sqrt_op<T>());
  }

H
hedaoyuan 已提交
89
  const TensorUnaryOp<hppl::unary::square<T>, const Derived, T> square() const {
H
hedaoyuan 已提交
90 91 92
    return unaryExpression(hppl::unary::square<T>());
  }

H
hedaoyuan 已提交
93 94
  const TensorUnaryOp<hppl::unary::reciprocal<T>, const Derived, T> reciprocal()
      const {
H
hedaoyuan 已提交
95 96 97
    return unaryExpression(hppl::unary::reciprocal<T>());
  }

H
hedaoyuan 已提交
98
  const TensorUnaryOp<hppl::unary::abs<T>, const Derived, T> abs() const {
H
hedaoyuan 已提交
99 100 101
    return unaryExpression(hppl::unary::abs<T>());
  }

H
hedaoyuan 已提交
102
  const TensorUnaryOp<hppl::unary::sign<T>, const Derived, T> sign() const {
H
hedaoyuan 已提交
103 104 105
    return unaryExpression(hppl::unary::sign<T>());
  }

H
hedaoyuan 已提交
106
  const TensorUnaryOp<hppl::unary::pow_op<T>, const Derived, T> pow(T p) const {
H
hedaoyuan 已提交
107 108 109
    return unaryExpression(hppl::unary::pow_op<T>(p));
  }

H
hedaoyuan 已提交
110
  const TensorUnaryOp<hppl::unary::min<T>, const Derived, T> min(T p) const {
H
hedaoyuan 已提交
111 112 113
    return unaryExpression(hppl::unary::min<T>(p));
  }

H
hedaoyuan 已提交
114
  const TensorUnaryOp<hppl::unary::max<T>, const Derived, T> max(T p) const {
H
hedaoyuan 已提交
115 116 117
    return unaryExpression(hppl::unary::max<T>(p));
  }

H
hedaoyuan 已提交
118 119
  const TensorUnaryOp<hppl::unary::cmp_eq<T>, const Derived, T> operator==(
      T p) const {
H
hedaoyuan 已提交
120 121 122
    return unaryExpression(hppl::unary::cmp_eq<T>(p));
  }

H
hedaoyuan 已提交
123 124
  const TensorUnaryOp<hppl::unary::cmp_ne<T>, const Derived, T> operator!=(
      T p) const {
H
hedaoyuan 已提交
125 126 127
    return unaryExpression(hppl::unary::cmp_ne<T>(p));
  }

H
hedaoyuan 已提交
128 129
  const TensorUnaryOp<hppl::unary::cmp_le<T>, const Derived, T> operator<=(
      T p) const {
H
hedaoyuan 已提交
130 131 132
    return unaryExpression(hppl::unary::cmp_le<T>(p));
  }

H
hedaoyuan 已提交
133 134
  const TensorUnaryOp<hppl::unary::cmp_lt<T>, const Derived, T> operator<(
      T p) const {
H
hedaoyuan 已提交
135 136 137
    return unaryExpression(hppl::unary::cmp_lt<T>(p));
  }

H
hedaoyuan 已提交
138 139
  const TensorUnaryOp<hppl::unary::cmp_ge<T>, const Derived, T> operator>=(
      T p) const {
H
hedaoyuan 已提交
140 141 142
    return unaryExpression(hppl::unary::cmp_ge<T>(p));
  }

H
hedaoyuan 已提交
143 144
  const TensorUnaryOp<hppl::unary::cmp_gt<T>, const Derived, T> operator>(
      T p) const {
H
hedaoyuan 已提交
145 146 147
    return unaryExpression(hppl::unary::cmp_gt<T>(p));
  }

H
hedaoyuan 已提交
148 149
  const TensorUnaryOp<hppl::unary::and_op<T>, const Derived, T> operator&&(
      T p) const {
H
hedaoyuan 已提交
150 151 152
    return unaryExpression(hppl::unary::and_op<T>(p));
  }

H
hedaoyuan 已提交
153 154
  const TensorUnaryOp<hppl::unary::or_op<T>, const Derived, T> operator||(
      T p) const {
H
hedaoyuan 已提交
155 156 157 158 159 160
    return unaryExpression(hppl::unary::or_op<T>(p));
  }

  /**
   * Element wise binary expression.
   */
H
hedaoyuan 已提交
161
  template <typename BinaryOp, typename ExpressionType>
H
hedaoyuan 已提交
162 163 164
  const TensorBinaryOp<BinaryOp, const Derived, const ExpressionType, T>
  binaryExpression(const BinaryOp& op, const ExpressionType& expr) const {
    return TensorBinaryOp<BinaryOp, const Derived, const ExpressionType, T>(
H
hedaoyuan 已提交
165
        op, derived(), expr);
H
hedaoyuan 已提交
166 167
  }

H
hedaoyuan 已提交
168 169 170 171 172
  template <typename ExpressionType>
  const TensorBinaryOp<hppl::binary::cmp_eq<T>,
                       const Derived,
                       const ExpressionType,
                       T>
H
hedaoyuan 已提交
173 174 175 176
  operator==(const ExpressionType& expr) const {
    return binaryExpression(hppl::binary::cmp_eq<T>(), expr);
  }

H
hedaoyuan 已提交
177 178 179 180 181
  template <typename ExpressionType>
  const TensorBinaryOp<hppl::binary::cmp_ne<T>,
                       const Derived,
                       const ExpressionType,
                       T>
H
hedaoyuan 已提交
182 183 184 185
  operator!=(const ExpressionType& expr) const {
    return binaryExpression(hppl::binary::cmp_ne<T>(), expr);
  }

H
hedaoyuan 已提交
186 187 188 189 190
  template <typename ExpressionType>
  const TensorBinaryOp<hppl::binary::cmp_le<T>,
                       const Derived,
                       const ExpressionType,
                       T>
H
hedaoyuan 已提交
191 192 193 194
  operator<=(const ExpressionType& expr) const {
    return binaryExpression(hppl::binary::cmp_le<T>(), expr);
  }

H
hedaoyuan 已提交
195 196 197 198 199
  template <typename ExpressionType>
  const TensorBinaryOp<hppl::binary::cmp_lt<T>,
                       const Derived,
                       const ExpressionType,
                       T>
H
hedaoyuan 已提交
200 201 202 203
  operator<(const ExpressionType& expr) const {
    return binaryExpression(hppl::binary::cmp_lt<T>(), expr);
  }

H
hedaoyuan 已提交
204 205 206 207 208
  template <typename ExpressionType>
  const TensorBinaryOp<hppl::binary::cmp_ge<T>,
                       const Derived,
                       const ExpressionType,
                       T>
H
hedaoyuan 已提交
209 210 211 212
  operator>=(const ExpressionType& expr) const {
    return binaryExpression(hppl::binary::cmp_ge<T>(), expr);
  }

H
hedaoyuan 已提交
213 214 215 216 217
  template <typename ExpressionType>
  const TensorBinaryOp<hppl::binary::cmp_gt<T>,
                       const Derived,
                       const ExpressionType,
                       T>
H
hedaoyuan 已提交
218 219 220 221
  operator>(const ExpressionType& expr) const {
    return binaryExpression(hppl::binary::cmp_gt<T>(), expr);
  }

H
hedaoyuan 已提交
222 223 224 225 226
  template <typename ExpressionType>
  const TensorBinaryOp<hppl::binary::and_op<T>,
                       const Derived,
                       const ExpressionType,
                       T>
H
hedaoyuan 已提交
227 228 229 230
  operator&&(const ExpressionType& expr) const {
    return binaryExpression(hppl::binary::and_op<T>(), expr);
  }

H
hedaoyuan 已提交
231 232 233 234 235
  template <typename ExpressionType>
  const TensorBinaryOp<hppl::binary::or_op<T>,
                       const Derived,
                       const ExpressionType,
                       T>
H
hedaoyuan 已提交
236 237 238 239
  operator||(const ExpressionType& expr) const {
    return binaryExpression(hppl::binary::or_op<T>(), expr);
  }

H
hedaoyuan 已提交
240 241 242 243 244
  template <typename ExpressionType>
  const TensorBinaryOp<hppl::binary::add<T>,
                       const Derived,
                       const ExpressionType,
                       T>
H
hedaoyuan 已提交
245 246 247 248
  operator+(const ExpressionType& expr) const {
    return binaryExpression(hppl::binary::add<T>(), expr);
  }

H
hedaoyuan 已提交
249 250 251 252 253
  template <typename ExpressionType>
  const TensorBinaryOp<hppl::binary::sub<T>,
                       const Derived,
                       const ExpressionType,
                       T>
H
hedaoyuan 已提交
254 255 256 257
  operator-(const ExpressionType& expr) const {
    return binaryExpression(hppl::binary::sub<T>(), expr);
  }

H
hedaoyuan 已提交
258 259 260 261 262
  template <typename ExpressionType>
  const TensorBinaryOp<hppl::binary::mul<T>,
                       const Derived,
                       const ExpressionType,
                       T>
H
hedaoyuan 已提交
263 264 265 266
  operator*(const ExpressionType& expr) const {
    return binaryExpression(hppl::binary::mul<T>(), expr);
  }

H
hedaoyuan 已提交
267 268 269 270 271
  template <typename ExpressionType>
  const TensorBinaryOp<hppl::binary::div<T>,
                       const Derived,
                       const ExpressionType,
                       T>
H
hedaoyuan 已提交
272 273 274 275
  operator/(const ExpressionType& expr) const {
    return binaryExpression(hppl::binary::div<T>(), expr);
  }

H
hedaoyuan 已提交
276 277 278 279 280
  template <typename ExpressionType>
  const TensorBinaryOp<hppl::binary::min<T>,
                       const Derived,
                       const ExpressionType,
                       T>
H
hedaoyuan 已提交
281 282 283 284
  min(const ExpressionType& expr) const {
    return binaryExpression(hppl::binary::min<T>(), expr);
  }

H
hedaoyuan 已提交
285 286 287 288 289
  template <typename ExpressionType>
  const TensorBinaryOp<hppl::binary::max<T>,
                       const Derived,
                       const ExpressionType,
                       T>
H
hedaoyuan 已提交
290 291 292 293 294 295 296 297 298 299 300 301 302
  max(const ExpressionType& expr) const {
    return binaryExpression(hppl::binary::max<T>(), expr);
  }

  /**
   * Element wise ternary expression.
   *
   * ternary conditional operator(?: operator).
   * The conditional expression returns one of two values depending on
   * the result of derived expression.
   * If derived expression evaluates to true, then expression1 is evaluated.
   * If derived expression evaluates to false, then expression2 is evaluated.
   */
H
hedaoyuan 已提交
303
  template <typename ExprType1, typename ExprType2>
H
hedaoyuan 已提交
304 305
  const TensorTernaryOp<const Derived, const ExprType1, const ExprType2, T>
  condition(const ExprType1& expr1, const ExprType2& expr2) const {
H
hedaoyuan 已提交
306 307
    return TensorTernaryOp<const Derived, const ExprType1, const ExprType2, T>(
        derived(), expr1, expr2);
H
hedaoyuan 已提交
308 309
  }

H
hedaoyuan 已提交
310
  template <typename ExprType>
H
hedaoyuan 已提交
311
  const TensorTernaryOp<
H
hedaoyuan 已提交
312 313 314 315
      const Derived,
      const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
      const ExprType,
      T>
H
hedaoyuan 已提交
316 317 318 319
  condition(T p, const ExprType& expr) const {
    return condition(constant(p), expr);
  }

H
hedaoyuan 已提交
320
  template <typename ExprType>
H
hedaoyuan 已提交
321
  const TensorTernaryOp<
H
hedaoyuan 已提交
322 323 324 325
      const Derived,
      const ExprType,
      const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
      T>
H
hedaoyuan 已提交
326 327 328 329 330
  condition(const ExprType& expr, T p) const {
    return condition(expr, constant(p));
  }

  const TensorTernaryOp<
H
hedaoyuan 已提交
331 332 333 334
      const Derived,
      const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
      const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
      T>
H
hedaoyuan 已提交
335 336 337 338
  condition(T p1, T p2) const {
    return condition(constant(p1), constant(p2));
  }

H
hedaoyuan 已提交
339 340 341
  /**
   * return a TensorConstant. A TensorConstant object hold a constant value.
   */
H
hedaoyuan 已提交
342 343 344 345
  const TensorConstant<hppl::unary::constant<T>, const Derived, T> constant(
      T p) const {
    return TensorConstant<hppl::unary::constant<T>, const Derived, T>(
        hppl::unary::constant<T>(p), derived());
H
hedaoyuan 已提交
346 347
  }

H
hedaoyuan 已提交
348 349 350 351
  /**
   * return a TensorAssignOp, and use AssignEvaluate to evaluate one or more
   * TensorAssignOp objects.
   */
H
hedaoyuan 已提交
352 353 354 355
  template <typename ExpressionType>
  TensorAssignOp<Derived, ExpressionType, T> lazyAssign(
      const ExpressionType& expr) const {
    return TensorAssignOp<Derived, ExpressionType, T>(derived(), expr);
H
hedaoyuan 已提交
356 357
  }

H
hedaoyuan 已提交
358 359 360 361 362 363 364
protected:
  const Derived& derived() const { return *static_cast<const Derived*>(this); }
};

/**
 * \brief Unary Operator Expression
 */
H
hedaoyuan 已提交
365
template <class OP, typename ExprType, class T>
H
hedaoyuan 已提交
366 367 368 369
class TensorUnaryOp
    : public TensorExpression<TensorUnaryOp<OP, ExprType, T>, T> {
public:
  explicit TensorUnaryOp(const OP op, const ExprType& expr)
H
hedaoyuan 已提交
370
      : op_(op), expr_(expr) {}
H
hedaoyuan 已提交
371 372 373 374 375 376 377 378

  const OP op_;
  const ExprType expr_;
};

/**
 * \brief Binary Operator Expression
 */
H
hedaoyuan 已提交
379
template <class OP, typename LhsType, typename RhsType, class T>
H
hedaoyuan 已提交
380 381 382 383
class TensorBinaryOp
    : public TensorExpression<TensorBinaryOp<OP, LhsType, RhsType, T>, T> {
public:
  explicit TensorBinaryOp(const OP op, const LhsType& lhs, const RhsType& rhs)
H
hedaoyuan 已提交
384
      : op_(op), lhs_(lhs), rhs_(rhs) {}
H
hedaoyuan 已提交
385 386 387 388 389 390 391 392 393

  const OP op_;
  const LhsType lhs_;
  const RhsType rhs_;
};

/**
 * \brief Ternary Operator Expression
 */
H
hedaoyuan 已提交
394 395 396 397
template <typename ExprType1, typename ExprType2, typename ExprType3, class T>
class TensorTernaryOp : public TensorExpression<
                            TensorTernaryOp<ExprType1, ExprType2, ExprType3, T>,
                            T> {
H
hedaoyuan 已提交
398
public:
H
hedaoyuan 已提交
399 400 401 402
  explicit TensorTernaryOp(const ExprType1& expr1,
                           const ExprType2& expr2,
                           const ExprType3& expr3)
      : expr1_(expr1), expr2_(expr2), expr3_(expr3) {}
H
hedaoyuan 已提交
403 404 405 406 407 408 409 410 411

  const ExprType1 expr1_;
  const ExprType2 expr2_;
  const ExprType3 expr3_;
};

/**
 * \brief Constant Expression
 */
H
hedaoyuan 已提交
412
template <class OP, typename ExprType, class T>
H
hedaoyuan 已提交
413 414 415 416
class TensorConstant
    : public TensorExpression<TensorConstant<OP, ExprType, T>, T> {
public:
  explicit TensorConstant(const OP op, const ExprType& expr)
H
hedaoyuan 已提交
417
      : op_(op), expr_(expr) {}
H
hedaoyuan 已提交
418 419 420 421 422 423 424 425 426

  const OP op_;
  const ExprType expr_;
};

/**
 * \brief operator+ overload
 * \return a unary operator expression
 */
H
hedaoyuan 已提交
427 428 429
template <typename Derived, class T>
const TensorUnaryOp<hppl::unary::add_scale<T>, const Derived, T> operator+(
    T p, const TensorExpression<Derived, T>& expr) {
H
hedaoyuan 已提交
430 431 432 433 434 435 436
  return expr + p;
}

/**
 * \brief operator* overload
 * \return a unary operator expression
 */
H
hedaoyuan 已提交
437 438 439
template <typename Derived, class T>
const TensorUnaryOp<hppl::unary::mul_scale<T>, const Derived, T> operator*(
    T p, const TensorExpression<Derived, T>& expr) {
H
hedaoyuan 已提交
440 441 442 443 444 445 446
  return expr * p;
}

}  // namespace paddle

#include "TensorApply.h"
#include "TensorEvaluate.h"