/** * TensorExpression.h * * Author: hedaoyuan (hedaoyuan@baidu.com) * Created on: 2016-06-06 * * Copyright (c) Baidu.com, Inc. All Rights Reserved * */ #pragma once #include #include #include "paddle/utils/TypeDefs.h" #include "paddle/utils/Logging.h" #include "hl_tensor_ops.h" namespace paddle { template class TensorConstant; template class TensorUnaryOp; template< class OP, typename LhsType, typename RhsType, class T> class TensorBinaryOp; template< typename ExprType1, typename ExprType2, typename ExprType3, class T> class TensorTernaryOp; /** * \brief Tensor base class. * * This is the base class of all Tensor and Expression class. */ template class TensorExpression { public: /** * Element wise unary expression. */ template const TensorUnaryOp unaryExpression(const UnaryOp& op) const { return TensorUnaryOp(op, derived()); } const TensorUnaryOp, const Derived, T> operator+(T p) const { return unaryExpression(hppl::unary::add_scale(p)); } const TensorUnaryOp, const Derived, T> operator-(T p) const { return unaryExpression(hppl::unary::sub_scale(p)); } const TensorUnaryOp, const Derived, T> operator*(T p) const { return unaryExpression(hppl::unary::mul_scale(p)); } const TensorUnaryOp, const Derived, T> operator/(T p) const { return unaryExpression(hppl::unary::div_scale(p)); } const TensorUnaryOp, const Derived, T> operator-() const { return unaryExpression(hppl::unary::neg()); } const TensorUnaryOp, const Derived, T> exp() const { return unaryExpression(hppl::unary::exp_op()); } const TensorUnaryOp, const Derived, T> log() const { return unaryExpression(hppl::unary::log_op()); } const TensorUnaryOp, const Derived, T> sqrt() const { return unaryExpression(hppl::unary::sqrt_op()); } const TensorUnaryOp, const Derived, T> square() const { return unaryExpression(hppl::unary::square()); } const TensorUnaryOp, const Derived, T> reciprocal() const { return unaryExpression(hppl::unary::reciprocal()); } const TensorUnaryOp, const Derived, T> abs() const { return unaryExpression(hppl::unary::abs()); } const TensorUnaryOp, const Derived, T> sign() const { return unaryExpression(hppl::unary::sign()); } const TensorUnaryOp, const Derived, T> pow(T p) const { return unaryExpression(hppl::unary::pow_op(p)); } const TensorUnaryOp, const Derived, T> min(T p) const { return unaryExpression(hppl::unary::min(p)); } const TensorUnaryOp, const Derived, T> max(T p) const { return unaryExpression(hppl::unary::max(p)); } const TensorUnaryOp, const Derived, T> operator==(T p) const { return unaryExpression(hppl::unary::cmp_eq(p)); } const TensorUnaryOp, const Derived, T> operator!=(T p) const { return unaryExpression(hppl::unary::cmp_ne(p)); } const TensorUnaryOp, const Derived, T> operator<=(T p) const { return unaryExpression(hppl::unary::cmp_le(p)); } const TensorUnaryOp, const Derived, T> operator<(T p) const { return unaryExpression(hppl::unary::cmp_lt(p)); } const TensorUnaryOp, const Derived, T> operator>=(T p) const { return unaryExpression(hppl::unary::cmp_ge(p)); } const TensorUnaryOp, const Derived, T> operator>(T p) const { return unaryExpression(hppl::unary::cmp_gt(p)); } const TensorUnaryOp, const Derived, T> operator&&(T p) const { return unaryExpression(hppl::unary::and_op(p)); } const TensorUnaryOp, const Derived, T> operator||(T p) const { return unaryExpression(hppl::unary::or_op(p)); } /** * Element wise binary expression. */ template const TensorBinaryOp binaryExpression(const BinaryOp& op, const ExpressionType& expr) const { return TensorBinaryOp( op, derived(), expr); } template const TensorBinaryOp< hppl::binary::cmp_eq, const Derived, const ExpressionType, T> operator==(const ExpressionType& expr) const { return binaryExpression(hppl::binary::cmp_eq(), expr); } template const TensorBinaryOp< hppl::binary::cmp_ne, const Derived, const ExpressionType, T> operator!=(const ExpressionType& expr) const { return binaryExpression(hppl::binary::cmp_ne(), expr); } template const TensorBinaryOp< hppl::binary::cmp_le, const Derived, const ExpressionType, T> operator<=(const ExpressionType& expr) const { return binaryExpression(hppl::binary::cmp_le(), expr); } template const TensorBinaryOp< hppl::binary::cmp_lt, const Derived, const ExpressionType, T> operator<(const ExpressionType& expr) const { return binaryExpression(hppl::binary::cmp_lt(), expr); } template const TensorBinaryOp< hppl::binary::cmp_ge, const Derived, const ExpressionType, T> operator>=(const ExpressionType& expr) const { return binaryExpression(hppl::binary::cmp_ge(), expr); } template const TensorBinaryOp< hppl::binary::cmp_gt, const Derived, const ExpressionType, T> operator>(const ExpressionType& expr) const { return binaryExpression(hppl::binary::cmp_gt(), expr); } template const TensorBinaryOp< hppl::binary::and_op, const Derived, const ExpressionType, T> operator&&(const ExpressionType& expr) const { return binaryExpression(hppl::binary::and_op(), expr); } template const TensorBinaryOp< hppl::binary::or_op, const Derived, const ExpressionType, T> operator||(const ExpressionType& expr) const { return binaryExpression(hppl::binary::or_op(), expr); } template const TensorBinaryOp< hppl::binary::add, const Derived, const ExpressionType, T> operator+(const ExpressionType& expr) const { return binaryExpression(hppl::binary::add(), expr); } template const TensorBinaryOp< hppl::binary::sub, const Derived, const ExpressionType, T> operator-(const ExpressionType& expr) const { return binaryExpression(hppl::binary::sub(), expr); } template const TensorBinaryOp< hppl::binary::mul, const Derived, const ExpressionType, T> operator*(const ExpressionType& expr) const { return binaryExpression(hppl::binary::mul(), expr); } template const TensorBinaryOp< hppl::binary::div, const Derived, const ExpressionType, T> operator/(const ExpressionType& expr) const { return binaryExpression(hppl::binary::div(), expr); } template const TensorBinaryOp< hppl::binary::min, const Derived, const ExpressionType, T> min(const ExpressionType& expr) const { return binaryExpression(hppl::binary::min(), expr); } template const TensorBinaryOp< hppl::binary::max, const Derived, const ExpressionType, T> max(const ExpressionType& expr) const { return binaryExpression(hppl::binary::max(), expr); } /** * Element wise ternary expression. * * ternary conditional operator(?: operator). * The conditional expression returns one of two values depending on * the result of derived expression. * If derived expression evaluates to true, then expression1 is evaluated. * If derived expression evaluates to false, then expression2 is evaluated. */ template const TensorTernaryOp condition(const ExprType1& expr1, const ExprType2& expr2) const { return TensorTernaryOp (derived(), expr1, expr2); } template const TensorTernaryOp< const Derived, const TensorConstant, const Derived, T>, const ExprType, T> condition(T p, const ExprType& expr) const { return condition(constant(p), expr); } template const TensorTernaryOp< const Derived, const ExprType, const TensorConstant, const Derived, T>, T> condition(const ExprType& expr, T p) const { return condition(expr, constant(p)); } const TensorTernaryOp< const Derived, const TensorConstant, const Derived, T>, const TensorConstant, const Derived, T>, T> condition(T p1, T p2) const { return condition(constant(p1), constant(p2)); } const TensorConstant, const Derived, T> constant(T p) const { return TensorConstant, const Derived, T> (hppl::unary::constant(p), derived()); } protected: const Derived& derived() const { return *static_cast(this); } }; /** * \brief Unary Operator Expression */ template class TensorUnaryOp : public TensorExpression, T> { public: explicit TensorUnaryOp(const OP op, const ExprType& expr) : op_(op), expr_(expr) {} const OP op_; const ExprType expr_; }; /** * \brief Binary Operator Expression */ template class TensorBinaryOp : public TensorExpression, T> { public: explicit TensorBinaryOp(const OP op, const LhsType& lhs, const RhsType& rhs) : op_(op), lhs_(lhs), rhs_(rhs) {} const OP op_; const LhsType lhs_; const RhsType rhs_; }; /** * \brief Ternary Operator Expression */ template class TensorTernaryOp : public TensorExpression< TensorTernaryOp, T> { public: explicit TensorTernaryOp( const ExprType1& expr1, const ExprType2& expr2, const ExprType3& expr3) : expr1_(expr1), expr2_(expr2), expr3_(expr3) {} const ExprType1 expr1_; const ExprType2 expr2_; const ExprType3 expr3_; }; /** * \brief Constant Expression */ template class TensorConstant : public TensorExpression, T> { public: explicit TensorConstant(const OP op, const ExprType& expr) : op_(op), expr_(expr) {} const OP op_; const ExprType expr_; }; /** * \brief operator+ overload * \return a unary operator expression */ template const TensorUnaryOp, const Derived, T> operator+(T p, const TensorExpression& expr) { return expr + p; } /** * \brief operator* overload * \return a unary operator expression */ template const TensorUnaryOp, const Derived, T> operator*(T p, const TensorExpression& expr) { return expr * p; } } // namespace paddle #include "TensorApply.h" #include "TensorEvaluate.h"