/** * hl_tensor_ops.h * * Author: hedaoyuan (hedaoyuan@baidu.com) * Created on: 2016-06-06 * * Copyright (c) Baidu.com, Inc. All Rights Reserved * */ #ifndef HL_TENSOR_OPS_H_ #define HL_TENSOR_OPS_H_ #include #include "hl_matrix_type.cuh" namespace hppl { namespace unary { template class add_scale{ private: const T p; public: INLINE add_scale(const T s) : p(s) {} INLINE T operator()(const T a) const { return a + p; } }; template class sub_scale { private: const T p; public: INLINE sub_scale(const T s) : p(s) {} INLINE T operator()(const T a) const { return a - p; } }; template class mul_scale { private: const T p; public: INLINE mul_scale(const T s) : p(s) {} INLINE T operator()(const T a) const { return a * p; } }; template class div_scale { private: const T p; public: INLINE div_scale(const T s) : p(s) {} INLINE T operator()(const T a) const { return a / p; } }; template class neg { public: INLINE T operator()(const T a) const { return -a; } }; template class exp_op { public: INLINE T operator()(const T a) const { return std::exp(a); } }; template class log_op { public: INLINE T operator()(const T a) const { return std::log(a); } }; template class sqrt_op { public: INLINE T operator()(const T a) const { return std::sqrt(a); } }; template class square { public: INLINE T operator()(const T a) const { return a * a; } }; template class reciprocal { public: INLINE T operator()(const T a) const { return T(1) / a; } }; template class abs { public: INLINE T operator()(const T a) const { return a > 0 ? a : -a; } }; template class sign { public: INLINE T operator()(const T a) const { return (a > 0) - (a < 0); } }; template class min { private: const T p; public: INLINE min(const T s) : p(s) {} INLINE T operator()(const T a) const { return a > p ? p : a; } }; template class max { private: const T p; public: INLINE max(const T s) : p(s) {} INLINE T operator()(const T a) const { return a < p ? p : a; } }; template class pow_op { private: const T p; public: INLINE pow_op(const T s) : p(s) {} INLINE T operator()(const T a) const { return std::pow(a, p); } }; template class constant { private: const T p; public: INLINE constant(const T s) : p(s) {} INLINE T operator()(int i) const { return p; } INLINE T operator()(int i, int j) const { return p; } }; template class cmp_eq { private: const T p; public: INLINE cmp_eq(const T s) : p(s) {} INLINE bool operator()(const T a) const { return a == p; } }; template class cmp_ne { private: const T p; public: INLINE cmp_ne(const T s) : p(s) {} INLINE bool operator()(const T a) const { return a != p; } }; template class cmp_le { private: const T p; public: INLINE cmp_le(const T s) : p(s) {} INLINE bool operator()(const T a) const { return a <= p; } }; template class cmp_lt { private: const T p; public: INLINE cmp_lt(const T s) : p(s) {} INLINE bool operator()(const T a) const { return a < p; } }; template class cmp_ge { private: const T p; public: INLINE cmp_ge(const T s) : p(s) {} INLINE bool operator()(const T a) const { return a >= p; } }; template class cmp_gt { private: const T p; public: INLINE cmp_gt(const T s) : p(s) {} INLINE bool operator()(const T a) const { return a > p; } }; template class and_op { private: const T p; public: INLINE and_op(const T s) : p(s) {} INLINE bool operator()(const T a) const { return a && p; } }; template class or_op { private: const T p; public: INLINE or_op(const T s) : p(s) {} INLINE bool operator()(const T a) const { return a || p; } }; } // namespace unary namespace binary { template class add { public: INLINE T operator()(const T a, const T b) const { return a + b; } }; template class add_scale { private: const T p1; const T p2; public: INLINE add_scale(const T s1, const T s2) : p1(s1), p2(s2) {} INLINE T operator()(const T a, const T b) const { return p1 * a + p2 * b; } }; template class sub { public: INLINE T operator()(const T a, const T b) const { return a - b; } }; template class mul { public: INLINE T operator()(const T a, const T b) const { return a * b; } }; template class div { public: INLINE T operator()(const T a, const T b) const { return a / b; } }; template class cmp_eq { public: INLINE bool operator()(const T a, const T b) const { return a == b; } }; template class cmp_ne { public: INLINE bool operator()(const T a, const T b) const { return a != b; } }; template class cmp_le { public: INLINE bool operator()(const T a, const T b) const { return a <= b; } }; template class cmp_lt { public: INLINE bool operator()(const T a, const T b) const { return a < b; } }; template class cmp_ge { public: INLINE bool operator()(const T a, const T b) const { return a >= b; } }; template class cmp_gt { public: INLINE bool operator()(const T a, const T b) const { return a > b; } }; template class and_op { public: INLINE bool operator()(const T a, const T b) const { return a && b; } }; template class or_op { public: INLINE bool operator()(const T a, const T b) const { return a || b; } }; template class min { public: INLINE T operator()(const T a, const T b) const { return a > b ? b : a; } }; template class max { public: INLINE T operator()(const T a, const T b) const { return a < b ? b : a; } }; } // namespace binary } // namespace hppl #endif // HL_TENSOR_OPS_H_