/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include "paddle/utils/TypeDefs.h" #include "paddle/utils/Logging.h" #include "hl_tensor_ops.h" namespace paddle { template class TensorConstant; template class TensorUnaryOp; template class TensorBinaryOp; template class TensorTernaryOp; template class TensorAssignOp; /** * \brief Tensor base class. * * This is the base class of all Tensor and Expression class. */ template class TensorExpression { public: /** * Element wise unary expression. */ template const TensorUnaryOp unaryExpression( const UnaryOp& op) const { return TensorUnaryOp(op, derived()); } const TensorUnaryOp, const Derived, T> operator+( T p) const { return unaryExpression(hppl::unary::add_scale(p)); } const TensorUnaryOp, const Derived, T> operator-( T p) const { return unaryExpression(hppl::unary::sub_scale(p)); } const TensorUnaryOp, const Derived, T> operator*( T p) const { return unaryExpression(hppl::unary::mul_scale(p)); } const TensorUnaryOp, const Derived, T> operator/( T p) const { return unaryExpression(hppl::unary::div_scale(p)); } const TensorUnaryOp, const Derived, T> operator-() const { return unaryExpression(hppl::unary::neg()); } const TensorUnaryOp, const Derived, T> exp() const { return unaryExpression(hppl::unary::exp_op()); } const TensorUnaryOp, const Derived, T> log() const { return unaryExpression(hppl::unary::log_op()); } const TensorUnaryOp, const Derived, T> sqrt() const { return unaryExpression(hppl::unary::sqrt_op()); } const TensorUnaryOp, const Derived, T> square() const { return unaryExpression(hppl::unary::square()); } const TensorUnaryOp, const Derived, T> reciprocal() const { return unaryExpression(hppl::unary::reciprocal()); } const TensorUnaryOp, const Derived, T> abs() const { return unaryExpression(hppl::unary::abs()); } const TensorUnaryOp, const Derived, T> sign() const { return unaryExpression(hppl::unary::sign()); } const TensorUnaryOp, const Derived, T> pow(T p) const { return unaryExpression(hppl::unary::pow_op(p)); } const TensorUnaryOp, const Derived, T> min(T p) const { return unaryExpression(hppl::unary::min(p)); } const TensorUnaryOp, const Derived, T> max(T p) const { return unaryExpression(hppl::unary::max(p)); } const TensorUnaryOp, const Derived, T> operator==( T p) const { return unaryExpression(hppl::unary::cmp_eq(p)); } const TensorUnaryOp, const Derived, T> operator!=( T p) const { return unaryExpression(hppl::unary::cmp_ne(p)); } const TensorUnaryOp, const Derived, T> operator<=( T p) const { return unaryExpression(hppl::unary::cmp_le(p)); } const TensorUnaryOp, const Derived, T> operator<( T p) const { return unaryExpression(hppl::unary::cmp_lt(p)); } const TensorUnaryOp, const Derived, T> operator>=( T p) const { return unaryExpression(hppl::unary::cmp_ge(p)); } const TensorUnaryOp, const Derived, T> operator>( T p) const { return unaryExpression(hppl::unary::cmp_gt(p)); } const TensorUnaryOp, const Derived, T> operator&&( T p) const { return unaryExpression(hppl::unary::and_op(p)); } const TensorUnaryOp, const Derived, T> operator||( T p) const { return unaryExpression(hppl::unary::or_op(p)); } /** * Element wise binary expression. */ template const TensorBinaryOp binaryExpression(const BinaryOp& op, const ExpressionType& expr) const { return TensorBinaryOp( op, derived(), expr); } template const TensorBinaryOp, const Derived, const ExpressionType, T> operator==(const ExpressionType& expr) const { return binaryExpression(hppl::binary::cmp_eq(), expr); } template const TensorBinaryOp, const Derived, const ExpressionType, T> operator!=(const ExpressionType& expr) const { return binaryExpression(hppl::binary::cmp_ne(), expr); } template const TensorBinaryOp, const Derived, const ExpressionType, T> operator<=(const ExpressionType& expr) const { return binaryExpression(hppl::binary::cmp_le(), expr); } template const TensorBinaryOp, const Derived, const ExpressionType, T> operator<(const ExpressionType& expr) const { return binaryExpression(hppl::binary::cmp_lt(), expr); } template const TensorBinaryOp, const Derived, const ExpressionType, T> operator>=(const ExpressionType& expr) const { return binaryExpression(hppl::binary::cmp_ge(), expr); } template const TensorBinaryOp, const Derived, const ExpressionType, T> operator>(const ExpressionType& expr) const { return binaryExpression(hppl::binary::cmp_gt(), expr); } template const TensorBinaryOp, const Derived, const ExpressionType, T> operator&&(const ExpressionType& expr) const { return binaryExpression(hppl::binary::and_op(), expr); } template const TensorBinaryOp, const Derived, const ExpressionType, T> operator||(const ExpressionType& expr) const { return binaryExpression(hppl::binary::or_op(), expr); } template const TensorBinaryOp, const Derived, const ExpressionType, T> operator+(const ExpressionType& expr) const { return binaryExpression(hppl::binary::add(), expr); } template const TensorBinaryOp, const Derived, const ExpressionType, T> operator-(const ExpressionType& expr) const { return binaryExpression(hppl::binary::sub(), expr); } template const TensorBinaryOp, const Derived, const ExpressionType, T> operator*(const ExpressionType& expr) const { return binaryExpression(hppl::binary::mul(), expr); } template const TensorBinaryOp, const Derived, const ExpressionType, T> operator/(const ExpressionType& expr) const { return binaryExpression(hppl::binary::div(), expr); } template const TensorBinaryOp, const Derived, const ExpressionType, T> min(const ExpressionType& expr) const { return binaryExpression(hppl::binary::min(), expr); } template const TensorBinaryOp, const Derived, const ExpressionType, T> max(const ExpressionType& expr) const { return binaryExpression(hppl::binary::max(), expr); } /** * Element wise ternary expression. * * ternary conditional operator(?: operator). * The conditional expression returns one of two values depending on * the result of derived expression. * If derived expression evaluates to true, then expression1 is evaluated. * If derived expression evaluates to false, then expression2 is evaluated. */ template const TensorTernaryOp condition(const ExprType1& expr1, const ExprType2& expr2) const { return TensorTernaryOp( derived(), expr1, expr2); } template const TensorTernaryOp< const Derived, const TensorConstant, const Derived, T>, const ExprType, T> condition(T p, const ExprType& expr) const { return condition(constant(p), expr); } template const TensorTernaryOp< const Derived, const ExprType, const TensorConstant, const Derived, T>, T> condition(const ExprType& expr, T p) const { return condition(expr, constant(p)); } const TensorTernaryOp< const Derived, const TensorConstant, const Derived, T>, const TensorConstant, const Derived, T>, T> condition(T p1, T p2) const { return condition(constant(p1), constant(p2)); } /** * return a TensorConstant. A TensorConstant object hold a constant value. */ const TensorConstant, const Derived, T> constant( T p) const { return TensorConstant, const Derived, T>( hppl::unary::constant(p), derived()); } /** * return a TensorAssignOp, and use AssignEvaluate to evaluate one or more * TensorAssignOp objects. */ template TensorAssignOp lazyAssign( const ExpressionType& expr) const { return TensorAssignOp(derived(), expr); } protected: const Derived& derived() const { return *static_cast(this); } }; /** * \brief Unary Operator Expression */ template class TensorUnaryOp : public TensorExpression, T> { public: explicit TensorUnaryOp(const OP op, const ExprType& expr) : op_(op), expr_(expr) {} const OP op_; const ExprType expr_; }; /** * \brief Binary Operator Expression */ template class TensorBinaryOp : public TensorExpression, T> { public: explicit TensorBinaryOp(const OP op, const LhsType& lhs, const RhsType& rhs) : op_(op), lhs_(lhs), rhs_(rhs) {} const OP op_; const LhsType lhs_; const RhsType rhs_; }; /** * \brief Ternary Operator Expression */ template class TensorTernaryOp : public TensorExpression< TensorTernaryOp, T> { public: explicit TensorTernaryOp(const ExprType1& expr1, const ExprType2& expr2, const ExprType3& expr3) : expr1_(expr1), expr2_(expr2), expr3_(expr3) {} const ExprType1 expr1_; const ExprType2 expr2_; const ExprType3 expr3_; }; /** * \brief Constant Expression */ template class TensorConstant : public TensorExpression, T> { public: explicit TensorConstant(const OP op, const ExprType& expr) : op_(op), expr_(expr) {} const OP op_; const ExprType expr_; }; /** * \brief operator+ overload * \return a unary operator expression */ template const TensorUnaryOp, const Derived, T> operator+( T p, const TensorExpression& expr) { return expr + p; } /** * \brief operator* overload * \return a unary operator expression */ template const TensorUnaryOp, const Derived, T> operator*( T p, const TensorExpression& expr) { return expr * p; } } // namespace paddle #include "TensorApply.h" #include "TensorEvaluate.h"