提交 abdcb8e1 编写于 作者: H hedaoyuan

format some files

上级 d04c206f
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_MATRIX_TYPE_CUH_
#define HL_MATRIX_TYPE_CUH_
......
......@@ -21,194 +21,210 @@ limitations under the License. */
namespace hppl {
namespace unary {
template<class T>
class add_scale{
template <class T>
class add_scale {
private:
const T p;
public:
INLINE add_scale(const T s) : p(s) {}
INLINE T operator()(const T a) const { return a + p; }
};
template<class T>
template <class T>
class sub_scale {
private:
const T p;
public:
INLINE sub_scale(const T s) : p(s) {}
INLINE T operator()(const T a) const { return a - p; }
};
template<class T>
template <class T>
class mul_scale {
private:
const T p;
public:
INLINE mul_scale(const T s) : p(s) {}
INLINE T operator()(const T a) const { return a * p; }
};
template<class T>
template <class T>
class div_scale {
private:
const T p;
public:
INLINE div_scale(const T s) : p(s) {}
INLINE T operator()(const T a) const { return a / p; }
};
template<class T>
template <class T>
class neg {
public:
INLINE T operator()(const T a) const { return -a; }
};
template<class T>
template <class T>
class exp_op {
public:
INLINE T operator()(const T a) const { return std::exp(a); }
};
template<class T>
template <class T>
class log_op {
public:
INLINE T operator()(const T a) const { return std::log(a); }
};
template<class T>
template <class T>
class sqrt_op {
public:
INLINE T operator()(const T a) const { return std::sqrt(a); }
};
template<class T>
template <class T>
class square {
public:
INLINE T operator()(const T a) const { return a * a; }
};
template<class T>
template <class T>
class reciprocal {
public:
INLINE T operator()(const T a) const { return T(1) / a; }
};
template<class T>
template <class T>
class abs {
public:
INLINE T operator()(const T a) const { return a > 0 ? a : -a; }
};
template<class T>
template <class T>
class sign {
public:
INLINE T operator()(const T a) const { return (a > 0) - (a < 0); }
};
template<class T>
template <class T>
class min {
private:
const T p;
public:
INLINE min(const T s) : p(s) {}
INLINE T operator()(const T a) const { return a > p ? p : a; }
};
template<class T>
template <class T>
class max {
private:
const T p;
public:
INLINE max(const T s) : p(s) {}
INLINE T operator()(const T a) const { return a < p ? p : a; }
};
template<class T>
template <class T>
class pow_op {
private:
const T p;
public:
INLINE pow_op(const T s) : p(s) {}
INLINE T operator()(const T a) const { return std::pow(a, p); }
};
template<class T>
template <class T>
class constant {
private:
const T p;
public:
INLINE constant(const T s) : p(s) {}
INLINE T operator()(int i) const { return p; }
INLINE T operator()(int i, int j) const { return p; }
};
template<class T>
template <class T>
class cmp_eq {
private:
const T p;
public:
INLINE cmp_eq(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a == p; }
};
template<class T>
template <class T>
class cmp_ne {
private:
const T p;
public:
INLINE cmp_ne(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a != p; }
};
template<class T>
template <class T>
class cmp_le {
private:
const T p;
public:
INLINE cmp_le(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a <= p; }
};
template<class T>
template <class T>
class cmp_lt {
private:
const T p;
public:
INLINE cmp_lt(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a < p; }
};
template<class T>
template <class T>
class cmp_ge {
private:
const T p;
public:
INLINE cmp_ge(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a >= p; }
};
template<class T>
template <class T>
class cmp_gt {
private:
const T p;
public:
INLINE cmp_gt(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a > p; }
};
template<class T>
template <class T>
class and_op {
private:
const T p;
public:
INLINE and_op(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a && p; }
};
template<class T>
template <class T>
class or_op {
private:
const T p;
public:
INLINE or_op(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a || p; }
......@@ -217,97 +233,96 @@ public:
} // namespace unary
namespace binary {
template<class T>
template <class T>
class add {
public:
INLINE T operator()(const T a, const T b) const { return a + b; }
};
template<class T>
template <class T>
class add_scale {
private:
const T p1;
const T p2;
public:
INLINE add_scale(const T s1, const T s2) : p1(s1), p2(s2) {}
INLINE T operator()(const T a, const T b) const {
return p1 * a + p2 * b;
}
INLINE T operator()(const T a, const T b) const { return p1 * a + p2 * b; }
};
template<class T>
template <class T>
class sub {
public:
INLINE T operator()(const T a, const T b) const { return a - b; }
};
template<class T>
template <class T>
class mul {
public:
INLINE T operator()(const T a, const T b) const { return a * b; }
};
template<class T>
template <class T>
class div {
public:
INLINE T operator()(const T a, const T b) const { return a / b; }
INLINE T operator()(const T a, const T b) const { return a / b; }
};
template<class T>
template <class T>
class cmp_eq {
public:
INLINE bool operator()(const T a, const T b) const { return a == b; }
};
template<class T>
template <class T>
class cmp_ne {
public:
INLINE bool operator()(const T a, const T b) const { return a != b; }
};
template<class T>
template <class T>
class cmp_le {
public:
INLINE bool operator()(const T a, const T b) const { return a <= b; }
};
template<class T>
template <class T>
class cmp_lt {
public:
INLINE bool operator()(const T a, const T b) const { return a < b; }
};
template<class T>
template <class T>
class cmp_ge {
public:
INLINE bool operator()(const T a, const T b) const { return a >= b; }
};
template<class T>
template <class T>
class cmp_gt {
public:
INLINE bool operator()(const T a, const T b) const { return a > b; }
};
template<class T>
template <class T>
class and_op {
public:
INLINE bool operator()(const T a, const T b) const { return a && b; }
};
template<class T>
template <class T>
class or_op {
public:
INLINE bool operator()(const T a, const T b) const { return a || b; }
};
template<class T>
template <class T>
class min {
public:
INLINE T operator()(const T a, const T b) const { return a > b ? b : a; }
};
template<class T>
template <class T>
class max {
public:
INLINE T operator()(const T a, const T b) const { return a < b ? b : a; }
......@@ -317,4 +332,3 @@ public:
} // namespace hppl
#endif // HL_TENSOR_OPS_H_
......@@ -19,25 +19,20 @@ namespace paddle {
/**
* \brief The tensor evaluator classes.
*/
template<typename Derived, class T>
template <typename Derived, class T>
class TensorApply {
public:
explicit INLINE TensorApply(const Derived& p)
: data_(p.data_), stride_(p.stride_),
height_(p.height_), width_(p.width_), useGpu_(p.useGpu_) {}
: data_(p.data_),
stride_(p.stride_),
height_(p.height_),
width_(p.width_),
useGpu_(p.useGpu_) {}
INLINE T apply(int i, int j) const {
return data_[i * stride_ + j];
}
INLINE T apply(int index) const {
return data_[index];
}
INLINE T& applyRef(int i, int j) {
return data_[i * stride_ + j];
}
INLINE T& applyRef(int index) {
return data_[index];
}
INLINE T apply(int i, int j) const { return data_[i * stride_ + j]; }
INLINE T apply(int index) const { return data_[index]; }
INLINE T& applyRef(int i, int j) { return data_[i * stride_ + j]; }
INLINE T& applyRef(int index) { return data_[index]; }
INLINE size_t getWidth() const { return width_; }
INLINE size_t getHeight() const { return height_; }
......@@ -53,22 +48,20 @@ public:
/**
* \brief The tensor evaluator classes.
*
* evaluator for rvalues
*/
template<typename Derived, class T>
template <typename Derived, class T>
class TensorApply<const Derived, T> {
public:
explicit INLINE TensorApply(const Derived& p)
: data_(p.data_), stride_(p.stride_),
height_(p.height_), width_(p.width_), useGpu_(p.useGpu_) {}
: data_(p.data_),
stride_(p.stride_),
height_(p.height_),
width_(p.width_),
useGpu_(p.useGpu_) {}
INLINE T apply(int i, int j) const {
return data_[i * stride_ + j];
}
INLINE T apply(int index) const {
return data_[index];
}
INLINE T apply(int i, int j) const { return data_[i * stride_ + j]; }
INLINE T apply(int index) const { return data_[index]; }
INLINE size_t getWidth() const { return width_; }
INLINE size_t getHeight() const { return height_; }
......@@ -82,18 +75,14 @@ public:
bool useGpu_;
};
template<typename Derived, class T>
template <typename Derived, class T>
class TensorApply<const TensorExpression<Derived, T>, T> {
public:
explicit TensorApply(const TensorExpression<Derived, T>& expr)
: expr_(expr.derived()) {}
: expr_(expr.derived()) {}
INLINE T apply(int i, int j) const {
return expr_.apply(i, j);
}
INLINE T apply(int index) const {
return expr_.apply(index);
}
INLINE T apply(int i, int j) const { return expr_.apply(i, j); }
INLINE T apply(int index) const { return expr_.apply(index); }
INLINE size_t getWidth() const { return expr_.getWidth(); }
INLINE size_t getHeight() const { return expr_.getHeight(); }
......@@ -106,18 +95,14 @@ public:
/**
* \brief The unary expression evaluator classes.
*/
template<class OP, typename ArgType, class T>
template <class OP, typename ArgType, class T>
class TensorApply<const TensorUnaryOp<OP, ArgType, T>, T> {
public:
explicit INLINE TensorApply(const TensorUnaryOp<OP, ArgType, T>& expr)
: op_(expr.op_), expr_(expr.expr_) {}
: op_(expr.op_), expr_(expr.expr_) {}
INLINE T apply(int i, int j) const {
return op_(expr_.apply(i, j));
}
INLINE T apply(int index) const {
return op_(expr_.apply(index));
}
INLINE T apply(int i, int j) const { return op_(expr_.apply(i, j)); }
INLINE T apply(int index) const { return op_(expr_.apply(index)); }
INLINE size_t getWidth() const { return expr_.getWidth(); }
INLINE size_t getHeight() const { return expr_.getHeight(); }
......@@ -131,17 +116,17 @@ public:
/**
* \brief The binary expression evaluator classes.
*/
template<class OP, typename LhsType, typename RhsType, class T>
template <class OP, typename LhsType, typename RhsType, class T>
class TensorApply<const TensorBinaryOp<OP, LhsType, RhsType, T>, T> {
public:
explicit INLINE TensorApply(
const TensorBinaryOp<OP, LhsType, RhsType, T>& expr)
const TensorBinaryOp<OP, LhsType, RhsType, T>& expr)
: op_(expr.op_), lhs_(expr.lhs_), rhs_(expr.rhs_) {
#ifndef __CUDA_ARCH__
CHECK_EQ(lhs_.getWidth(), rhs_.getWidth());
CHECK_EQ(lhs_.getHeight(), rhs_.getHeight());
CHECK_EQ(lhs_.useGpu(), rhs_.useGpu());
#endif
#ifndef __CUDA_ARCH__
CHECK_EQ(lhs_.getWidth(), rhs_.getWidth());
CHECK_EQ(lhs_.getHeight(), rhs_.getHeight());
CHECK_EQ(lhs_.useGpu(), rhs_.useGpu());
#endif
}
INLINE T apply(int i, int j) const {
......@@ -166,20 +151,20 @@ public:
/**
* \brief The ternary expression evaluator classes.
*/
template<typename ArgType1, typename ArgType2, typename ArgType3, class T>
template <typename ArgType1, typename ArgType2, typename ArgType3, class T>
class TensorApply<const TensorTernaryOp<ArgType1, ArgType2, ArgType3, T>, T> {
public:
explicit INLINE TensorApply(
const TensorTernaryOp<ArgType1, ArgType2, ArgType3, T>& expr)
: expr1_(expr.expr1_), expr2_(expr.expr2_), expr3_(expr.expr3_) {
#ifndef __CUDA_ARCH__
CHECK_EQ(expr1_.getWidth(), expr2_.getWidth());
CHECK_EQ(expr1_.getWidth(), expr3_.getWidth());
CHECK_EQ(expr1_.getHeight(), expr2_.getHeight());
CHECK_EQ(expr1_.getHeight(), expr3_.getHeight());
CHECK_EQ(expr1_.useGpu(), expr2_.useGpu());
CHECK_EQ(expr1_.useGpu(), expr3_.useGpu());
#endif
const TensorTernaryOp<ArgType1, ArgType2, ArgType3, T>& expr)
: expr1_(expr.expr1_), expr2_(expr.expr2_), expr3_(expr.expr3_) {
#ifndef __CUDA_ARCH__
CHECK_EQ(expr1_.getWidth(), expr2_.getWidth());
CHECK_EQ(expr1_.getWidth(), expr3_.getWidth());
CHECK_EQ(expr1_.getHeight(), expr2_.getHeight());
CHECK_EQ(expr1_.getHeight(), expr3_.getHeight());
CHECK_EQ(expr1_.useGpu(), expr2_.useGpu());
CHECK_EQ(expr1_.useGpu(), expr3_.useGpu());
#endif
}
INLINE T apply(int i, int j) const {
......@@ -192,8 +177,8 @@ public:
INLINE size_t getWidth() const { return expr1_.getWidth(); }
INLINE size_t getHeight() const { return expr1_.getHeight(); }
INLINE bool isContiguous() const {
return expr1_.isContiguous() &&
expr2_.isContiguous() && expr3_.isContiguous();
return expr1_.isContiguous() && expr2_.isContiguous() &&
expr3_.isContiguous();
}
INLINE bool useGpu() const { return expr1_.useGpu(); }
......@@ -205,18 +190,14 @@ public:
/**
* \brief The const expression evaluator classes.
*/
template<class OP, typename ArgType, class T>
template <class OP, typename ArgType, class T>
class TensorApply<const TensorConstant<OP, ArgType, T>, T> {
public:
explicit INLINE TensorApply(const TensorConstant<OP, ArgType, T>& expr)
: op_(expr.op_), expr_(expr.expr_) {}
: op_(expr.op_), expr_(expr.expr_) {}
INLINE T apply(int i, int j) const {
return op_(i, j);
}
INLINE T apply(int index) const {
return op_(index);
}
INLINE T apply(int i, int j) const { return op_(i, j); }
INLINE T apply(int index) const { return op_(index); }
INLINE size_t getWidth() const { return expr_.getWidth(); }
INLINE size_t getHeight() const { return expr_.getHeight(); }
......
......@@ -21,18 +21,18 @@ namespace paddle {
/**
* \brief Tensor Assign Expression(return by lazyAssign,
* and evaluated by AssignEvaluate)
* and evaluated by AssignEvaluate)
*/
template<typename LhsType, typename RhsType, class T>
template <typename LhsType, typename RhsType, class T>
class TensorAssignOp {
public:
explicit TensorAssignOp(const LhsType& lhs, const RhsType& rhs)
: lhs_(lhs), rhs_(rhs) {
#ifndef __CUDA_ARCH__
CHECK_EQ(lhs_.getWidth(), rhs_.getWidth());
CHECK_EQ(lhs_.getHeight(), rhs_.getHeight());
CHECK_EQ(lhs_.useGpu(), rhs_.useGpu());
#endif
: lhs_(lhs), rhs_(rhs) {
#ifndef __CUDA_ARCH__
CHECK_EQ(lhs_.getWidth(), rhs_.getWidth());
CHECK_EQ(lhs_.getHeight(), rhs_.getHeight());
CHECK_EQ(lhs_.useGpu(), rhs_.useGpu());
#endif
}
INLINE void apply(const int i, const int j) {
......@@ -55,19 +55,22 @@ private:
};
template <typename Assign, typename... AssignOp>
void AssignCpuEvaluate(int height, int width, bool isContiguous,
Assign&& assign, AssignOp&& ... args) {
void AssignCpuEvaluate(int height,
int width,
bool isContiguous,
Assign&& assign,
AssignOp&&... args) {
if (isContiguous) {
int size = height * width;
for (int index = 0; index < size; index++) {
assign.apply(index);
__attribute__((unused)) int dummy[] = { (((args)).apply(index), 0)... };
__attribute__((unused)) int dummy[] = {(((args)).apply(index), 0)...};
}
} else {
for (int i = 0; i < height; i++) {
for (int j = 0; j < width; j++) {
assign.apply(i, j);
__attribute__((unused)) int dummy[] = { (((args)).apply(i, j), 0)... };
__attribute__((unused)) int dummy[] = {(((args)).apply(i, j), 0)...};
}
}
}
......@@ -75,25 +78,27 @@ void AssignCpuEvaluate(int height, int width, bool isContiguous,
#ifdef __NVCC__
template <typename Assign, typename... AssignOp>
__global__
void AssignGpuEvaluate1(const int border, Assign assign, AssignOp ... args) {
__global__ void AssignGpuEvaluate1(const int border,
Assign assign,
AssignOp... args) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < border) {
assign.apply(idx);
__attribute__((unused)) int dummy[] = { (((args)).apply(idx), 0)... };
__attribute__((unused)) int dummy[] = {(((args)).apply(idx), 0)...};
}
}
template <typename Assign, typename... AssignOp>
__global__
void AssignGpuEvaluate2(const int height, const int width,
Assign assign, AssignOp ... args) {
__global__ void AssignGpuEvaluate2(const int height,
const int width,
Assign assign,
AssignOp... args) {
const int colIdx = blockIdx.x * blockDim.x + threadIdx.x;
const int rowIdx = blockIdx.y * blockDim.y + threadIdx.y;
for (int i = rowIdx; i < height; i += gridDim.y * blockDim.y) {
for (int j = colIdx; j < width; j += gridDim.x * blockDim.x) {
assign.apply(i, j);
__attribute__((unused)) int dummy[] = { (((args)).apply(i, j), 0)... };
__attribute__((unused)) int dummy[] = {(((args)).apply(i, j), 0)...};
}
}
}
......@@ -105,23 +110,23 @@ void AssignGpuEvaluate2(const int height, const int width,
* \note At least one assignment expression is required
*/
template <typename Assign, typename... AssignOp>
void AssignEvaluate(Assign&& assign, AssignOp&& ... args) {
void AssignEvaluate(Assign&& assign, AssignOp&&... args) {
const bool useGpu_ = assign.useGpu();
bool isContiguous_ = assign.isContiguous();
const size_t height = assign.getHeight();
const size_t width = assign.getWidth();
const int packSize = sizeof...(args);
const bool packUseGpu[] = { ((args)).useGpu()... };
const bool packIsContiguous[] = { ((args)).isContiguous()... };
const size_t packHeight[] = { ((args)).getHeight()... };
const size_t packWidth[] = { ((args)).getWidth()... };
const bool packUseGpu[] = {((args)).useGpu()...};
const bool packIsContiguous[] = {((args)).isContiguous()...};
const size_t packHeight[] = {((args)).getHeight()...};
const size_t packWidth[] = {((args)).getWidth()...};
for (int i = 0; i < packSize; i++) {
CHECK_EQ(useGpu_, packUseGpu[i]);
CHECK_EQ(height, packHeight[i]);
CHECK_EQ(width, packWidth[i]);
isContiguous_ = isContiguous_ && packIsContiguous[i];
isContiguous_ = isContiguous_ && packIsContiguous[i];
}
if (useGpu_) {
......@@ -130,8 +135,8 @@ void AssignEvaluate(Assign&& assign, AssignOp&& ... args) {
int size = height * width;
int blockSize = size <= 1024 ? size : 1024;
int gridSize = (size + 1024 - 1) / 1024;
AssignGpuEvaluate1
<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(size, assign, args...);
AssignGpuEvaluate1<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(
size, assign, args...);
} else {
int blockSizeY = std::min(32, (int)height);
int blockSizeX = (32 / blockSizeY) * 32;
......@@ -139,8 +144,8 @@ void AssignEvaluate(Assign&& assign, AssignOp&& ... args) {
int gridSizeY = std::min(32, (int)(height + blockSizeY - 1) / blockSizeY);
dim3 threads(blockSizeX, blockSizeY);
dim3 grid(gridSizeX, gridSizeY);
AssignGpuEvaluate2
<<<grid, threads, 0, STREAM_DEFAULT>>>(height, width, assign, args...);
AssignGpuEvaluate2<<<grid, threads, 0, STREAM_DEFAULT>>>(
height, width, assign, args...);
}
CHECK_SYNC("AssignEvaluate failed");
......@@ -151,4 +156,3 @@ void AssignEvaluate(Assign&& assign, AssignOp&& ... args) {
}
} // namespace paddle
......@@ -23,7 +23,7 @@ namespace paddle {
/**
* \brief The tensor cpu evaluate api.
*/
template<class T, typename LeftType, typename RightType>
template <class T, typename LeftType, typename RightType>
inline void TensorCpuApply(LeftType& lhs, const RightType& rhs) {
TensorApply<LeftType, T> lhs_(lhs);
TensorApply<const RightType, T> rhs_(rhs);
......@@ -48,16 +48,17 @@ inline void TensorCpuApply(LeftType& lhs, const RightType& rhs) {
}
#ifdef __NVCC__
template<typename LeftType, typename RightType>
__global__
void TensorElementWiseOp(LeftType lhs, RightType rhs, const int border) {
template <typename LeftType, typename RightType>
__global__ void TensorElementWiseOp(LeftType lhs,
RightType rhs,
const int border) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < border) {
lhs.applyRef(idx) = rhs.apply(idx);
}
}
template<typename LeftType, typename RightType>
template <typename LeftType, typename RightType>
__global__ void TensorElementWiseOp(LeftType lhs, RightType rhs) {
const int colIdx = blockIdx.x * blockDim.x + threadIdx.x;
const int rowIdx = blockIdx.y * blockDim.y + threadIdx.y;
......@@ -71,7 +72,7 @@ __global__ void TensorElementWiseOp(LeftType lhs, RightType rhs) {
/**
* \brief The tensor gpu evaluate api.
*/
template<class T, typename LeftType, typename RightType>
template <class T, typename LeftType, typename RightType>
inline void TensorGpuApply(LeftType& lhs, const RightType& rhs) {
TensorApply<LeftType, T> lhs_(lhs);
TensorApply<const RightType, T> rhs_(rhs);
......@@ -86,8 +87,8 @@ inline void TensorGpuApply(LeftType& lhs, const RightType& rhs) {
int size = dimM * dimN;
int blockSize = size <= 1024 ? size : 1024;
int gridSize = (size + 1024 - 1) / 1024;
TensorElementWiseOp
<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(lhs_, rhs_, size);
TensorElementWiseOp<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(
lhs_, rhs_, size);
} else {
int blockSizeY = std::min(32, dimM);
int blockSizeX = (32 / blockSizeY) * 32;
......@@ -95,16 +96,14 @@ inline void TensorGpuApply(LeftType& lhs, const RightType& rhs) {
int gridSizeY = std::min(32, (dimM + blockSizeY - 1) / blockSizeY);
dim3 threads(blockSizeX, blockSizeY);
dim3 grid(gridSizeX, gridSizeY);
TensorElementWiseOp
<<<grid, threads, 0, STREAM_DEFAULT>>>(lhs_, rhs_);
TensorElementWiseOp<<<grid, threads, 0, STREAM_DEFAULT>>>(lhs_, rhs_);
}
CHECK_SYNC("TensorGpuApply failed");
}
#else
template<class T, typename LeftType, typename RightType>
inline void TensorGpuApply(LeftType& lhs, RightType& rhs) {
}
template <class T, typename LeftType, typename RightType>
inline void TensorGpuApply(LeftType& lhs, RightType& rhs) {}
#endif
} // namespace paddle
......@@ -21,254 +21,272 @@ limitations under the License. */
namespace paddle {
template<class OP, typename ExprType, class T> class TensorConstant;
template<class OP, typename ExprType, class T> class TensorUnaryOp;
template<
class OP, typename LhsType, typename RhsType, class T> class TensorBinaryOp;
template<
typename ExprType1,
typename ExprType2,
typename ExprType3,
class T> class TensorTernaryOp;
template<typename LhsType, typename RhsType, class T> class TensorAssignOp;
template <class OP, typename ExprType, class T>
class TensorConstant;
template <class OP, typename ExprType, class T>
class TensorUnaryOp;
template <class OP, typename LhsType, typename RhsType, class T>
class TensorBinaryOp;
template <typename ExprType1, typename ExprType2, typename ExprType3, class T>
class TensorTernaryOp;
template <typename LhsType, typename RhsType, class T>
class TensorAssignOp;
/**
* \brief Tensor base class.
*
*
* This is the base class of all Tensor and Expression class.
*/
template<typename Derived, class T>
template <typename Derived, class T>
class TensorExpression {
public:
/**
* Element wise unary expression.
*/
template<typename UnaryOp>
const TensorUnaryOp<UnaryOp, const Derived, T>
unaryExpression(const UnaryOp& op) const {
template <typename UnaryOp>
const TensorUnaryOp<UnaryOp, const Derived, T> unaryExpression(
const UnaryOp& op) const {
return TensorUnaryOp<UnaryOp, const Derived, T>(op, derived());
}
const TensorUnaryOp<hppl::unary::add_scale<T>, const Derived, T>
operator+(T p) const {
const TensorUnaryOp<hppl::unary::add_scale<T>, const Derived, T> operator+(
T p) const {
return unaryExpression(hppl::unary::add_scale<T>(p));
}
const TensorUnaryOp<hppl::unary::sub_scale<T>, const Derived, T>
operator-(T p) const {
const TensorUnaryOp<hppl::unary::sub_scale<T>, const Derived, T> operator-(
T p) const {
return unaryExpression(hppl::unary::sub_scale<T>(p));
}
const TensorUnaryOp<hppl::unary::mul_scale<T>, const Derived, T>
operator*(T p) const {
const TensorUnaryOp<hppl::unary::mul_scale<T>, const Derived, T> operator*(
T p) const {
return unaryExpression(hppl::unary::mul_scale<T>(p));
}
const TensorUnaryOp<hppl::unary::div_scale<T>, const Derived, T>
operator/(T p) const {
const TensorUnaryOp<hppl::unary::div_scale<T>, const Derived, T> operator/(
T p) const {
return unaryExpression(hppl::unary::div_scale<T>(p));
}
const TensorUnaryOp<hppl::unary::neg<T>, const Derived, T>
operator-() const {
const TensorUnaryOp<hppl::unary::neg<T>, const Derived, T> operator-() const {
return unaryExpression(hppl::unary::neg<T>());
}
const TensorUnaryOp<hppl::unary::exp_op<T>, const Derived, T>
exp() const {
const TensorUnaryOp<hppl::unary::exp_op<T>, const Derived, T> exp() const {
return unaryExpression(hppl::unary::exp_op<T>());
}
const TensorUnaryOp<hppl::unary::log_op<T>, const Derived, T>
log() const {
const TensorUnaryOp<hppl::unary::log_op<T>, const Derived, T> log() const {
return unaryExpression(hppl::unary::log_op<T>());
}
const TensorUnaryOp<hppl::unary::sqrt_op<T>, const Derived, T>
sqrt() const {
const TensorUnaryOp<hppl::unary::sqrt_op<T>, const Derived, T> sqrt() const {
return unaryExpression(hppl::unary::sqrt_op<T>());
}
const TensorUnaryOp<hppl::unary::square<T>, const Derived, T>
square() const {
const TensorUnaryOp<hppl::unary::square<T>, const Derived, T> square() const {
return unaryExpression(hppl::unary::square<T>());
}
const TensorUnaryOp<hppl::unary::reciprocal<T>, const Derived, T>
reciprocal() const {
const TensorUnaryOp<hppl::unary::reciprocal<T>, const Derived, T> reciprocal()
const {
return unaryExpression(hppl::unary::reciprocal<T>());
}
const TensorUnaryOp<hppl::unary::abs<T>, const Derived, T>
abs() const {
const TensorUnaryOp<hppl::unary::abs<T>, const Derived, T> abs() const {
return unaryExpression(hppl::unary::abs<T>());
}
const TensorUnaryOp<hppl::unary::sign<T>, const Derived, T>
sign() const {
const TensorUnaryOp<hppl::unary::sign<T>, const Derived, T> sign() const {
return unaryExpression(hppl::unary::sign<T>());
}
const TensorUnaryOp<hppl::unary::pow_op<T>, const Derived, T>
pow(T p) const {
const TensorUnaryOp<hppl::unary::pow_op<T>, const Derived, T> pow(T p) const {
return unaryExpression(hppl::unary::pow_op<T>(p));
}
const TensorUnaryOp<hppl::unary::min<T>, const Derived, T>
min(T p) const {
const TensorUnaryOp<hppl::unary::min<T>, const Derived, T> min(T p) const {
return unaryExpression(hppl::unary::min<T>(p));
}
const TensorUnaryOp<hppl::unary::max<T>, const Derived, T>
max(T p) const {
const TensorUnaryOp<hppl::unary::max<T>, const Derived, T> max(T p) const {
return unaryExpression(hppl::unary::max<T>(p));
}
const TensorUnaryOp<hppl::unary::cmp_eq<T>, const Derived, T>
operator==(T p) const {
const TensorUnaryOp<hppl::unary::cmp_eq<T>, const Derived, T> operator==(
T p) const {
return unaryExpression(hppl::unary::cmp_eq<T>(p));
}
const TensorUnaryOp<hppl::unary::cmp_ne<T>, const Derived, T>
operator!=(T p) const {
const TensorUnaryOp<hppl::unary::cmp_ne<T>, const Derived, T> operator!=(
T p) const {
return unaryExpression(hppl::unary::cmp_ne<T>(p));
}
const TensorUnaryOp<hppl::unary::cmp_le<T>, const Derived, T>
operator<=(T p) const {
const TensorUnaryOp<hppl::unary::cmp_le<T>, const Derived, T> operator<=(
T p) const {
return unaryExpression(hppl::unary::cmp_le<T>(p));
}
const TensorUnaryOp<hppl::unary::cmp_lt<T>, const Derived, T>
operator<(T p) const {
const TensorUnaryOp<hppl::unary::cmp_lt<T>, const Derived, T> operator<(
T p) const {
return unaryExpression(hppl::unary::cmp_lt<T>(p));
}
const TensorUnaryOp<hppl::unary::cmp_ge<T>, const Derived, T>
operator>=(T p) const {
const TensorUnaryOp<hppl::unary::cmp_ge<T>, const Derived, T> operator>=(
T p) const {
return unaryExpression(hppl::unary::cmp_ge<T>(p));
}
const TensorUnaryOp<hppl::unary::cmp_gt<T>, const Derived, T>
operator>(T p) const {
const TensorUnaryOp<hppl::unary::cmp_gt<T>, const Derived, T> operator>(
T p) const {
return unaryExpression(hppl::unary::cmp_gt<T>(p));
}
const TensorUnaryOp<hppl::unary::and_op<T>, const Derived, T>
operator&&(T p) const {
const TensorUnaryOp<hppl::unary::and_op<T>, const Derived, T> operator&&(
T p) const {
return unaryExpression(hppl::unary::and_op<T>(p));
}
const TensorUnaryOp<hppl::unary::or_op<T>, const Derived, T>
operator||(T p) const {
const TensorUnaryOp<hppl::unary::or_op<T>, const Derived, T> operator||(
T p) const {
return unaryExpression(hppl::unary::or_op<T>(p));
}
/**
* Element wise binary expression.
*/
template<typename BinaryOp, typename ExpressionType>
template <typename BinaryOp, typename ExpressionType>
const TensorBinaryOp<BinaryOp, const Derived, const ExpressionType, T>
binaryExpression(const BinaryOp& op, const ExpressionType& expr) const {
return TensorBinaryOp<BinaryOp, const Derived, const ExpressionType, T>(
op, derived(), expr);
op, derived(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::cmp_eq<T>, const Derived, const ExpressionType, T>
template <typename ExpressionType>
const TensorBinaryOp<hppl::binary::cmp_eq<T>,
const Derived,
const ExpressionType,
T>
operator==(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::cmp_eq<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::cmp_ne<T>, const Derived, const ExpressionType, T>
template <typename ExpressionType>
const TensorBinaryOp<hppl::binary::cmp_ne<T>,
const Derived,
const ExpressionType,
T>
operator!=(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::cmp_ne<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::cmp_le<T>, const Derived, const ExpressionType, T>
template <typename ExpressionType>
const TensorBinaryOp<hppl::binary::cmp_le<T>,
const Derived,
const ExpressionType,
T>
operator<=(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::cmp_le<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::cmp_lt<T>, const Derived, const ExpressionType, T>
template <typename ExpressionType>
const TensorBinaryOp<hppl::binary::cmp_lt<T>,
const Derived,
const ExpressionType,
T>
operator<(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::cmp_lt<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::cmp_ge<T>, const Derived, const ExpressionType, T>
template <typename ExpressionType>
const TensorBinaryOp<hppl::binary::cmp_ge<T>,
const Derived,
const ExpressionType,
T>
operator>=(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::cmp_ge<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::cmp_gt<T>, const Derived, const ExpressionType, T>
template <typename ExpressionType>
const TensorBinaryOp<hppl::binary::cmp_gt<T>,
const Derived,
const ExpressionType,
T>
operator>(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::cmp_gt<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::and_op<T>, const Derived, const ExpressionType, T>
template <typename ExpressionType>
const TensorBinaryOp<hppl::binary::and_op<T>,
const Derived,
const ExpressionType,
T>
operator&&(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::and_op<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::or_op<T>, const Derived, const ExpressionType, T>
template <typename ExpressionType>
const TensorBinaryOp<hppl::binary::or_op<T>,
const Derived,
const ExpressionType,
T>
operator||(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::or_op<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::add<T>, const Derived, const ExpressionType, T>
template <typename ExpressionType>
const TensorBinaryOp<hppl::binary::add<T>,
const Derived,
const ExpressionType,
T>
operator+(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::add<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::sub<T>, const Derived, const ExpressionType, T>
template <typename ExpressionType>
const TensorBinaryOp<hppl::binary::sub<T>,
const Derived,
const ExpressionType,
T>
operator-(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::sub<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::mul<T>, const Derived, const ExpressionType, T>
template <typename ExpressionType>
const TensorBinaryOp<hppl::binary::mul<T>,
const Derived,
const ExpressionType,
T>
operator*(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::mul<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::div<T>, const Derived, const ExpressionType, T>
template <typename ExpressionType>
const TensorBinaryOp<hppl::binary::div<T>,
const Derived,
const ExpressionType,
T>
operator/(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::div<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::min<T>, const Derived, const ExpressionType, T>
template <typename ExpressionType>
const TensorBinaryOp<hppl::binary::min<T>,
const Derived,
const ExpressionType,
T>
min(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::min<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::max<T>, const Derived, const ExpressionType, T>
template <typename ExpressionType>
const TensorBinaryOp<hppl::binary::max<T>,
const Derived,
const ExpressionType,
T>
max(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::max<T>(), expr);
}
......@@ -282,38 +300,38 @@ public:
* If derived expression evaluates to true, then expression1 is evaluated.
* If derived expression evaluates to false, then expression2 is evaluated.
*/
template<typename ExprType1, typename ExprType2>
template <typename ExprType1, typename ExprType2>
const TensorTernaryOp<const Derived, const ExprType1, const ExprType2, T>
condition(const ExprType1& expr1, const ExprType2& expr2) const {
return TensorTernaryOp<const Derived, const ExprType1, const ExprType2, T>
(derived(), expr1, expr2);
return TensorTernaryOp<const Derived, const ExprType1, const ExprType2, T>(
derived(), expr1, expr2);
}
template<typename ExprType>
template <typename ExprType>
const TensorTernaryOp<
const Derived,
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
const ExprType,
T>
const Derived,
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
const ExprType,
T>
condition(T p, const ExprType& expr) const {
return condition(constant(p), expr);
}
template<typename ExprType>
template <typename ExprType>
const TensorTernaryOp<
const Derived,
const ExprType,
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
T>
const Derived,
const ExprType,
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
T>
condition(const ExprType& expr, T p) const {
return condition(expr, constant(p));
}
const TensorTernaryOp<
const Derived,
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
T>
const Derived,
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
T>
condition(T p1, T p2) const {
return condition(constant(p1), constant(p2));
}
......@@ -321,20 +339,20 @@ public:
/**
* return a TensorConstant. A TensorConstant object hold a constant value.
*/
const TensorConstant<hppl::unary::constant<T>, const Derived, T>
constant(T p) const {
return TensorConstant<hppl::unary::constant<T>, const Derived, T>
(hppl::unary::constant<T>(p), derived());
const TensorConstant<hppl::unary::constant<T>, const Derived, T> constant(
T p) const {
return TensorConstant<hppl::unary::constant<T>, const Derived, T>(
hppl::unary::constant<T>(p), derived());
}
/**
* return a TensorAssignOp, and use AssignEvaluate to evaluate one or more
* TensorAssignOp objects.
*/
template<typename ExpressionType>
TensorAssignOp<Derived, ExpressionType, T>
lazyAssign(const ExpressionType& expr) const {
return TensorAssignOp<Derived, ExpressionType, T> (derived(), expr);
template <typename ExpressionType>
TensorAssignOp<Derived, ExpressionType, T> lazyAssign(
const ExpressionType& expr) const {
return TensorAssignOp<Derived, ExpressionType, T>(derived(), expr);
}
protected:
......@@ -344,12 +362,12 @@ protected:
/**
* \brief Unary Operator Expression
*/
template<class OP, typename ExprType, class T>
template <class OP, typename ExprType, class T>
class TensorUnaryOp
: public TensorExpression<TensorUnaryOp<OP, ExprType, T>, T> {
public:
explicit TensorUnaryOp(const OP op, const ExprType& expr)
: op_(op), expr_(expr) {}
: op_(op), expr_(expr) {}
const OP op_;
const ExprType expr_;
......@@ -358,12 +376,12 @@ public:
/**
* \brief Binary Operator Expression
*/
template<class OP, typename LhsType, typename RhsType, class T>
template <class OP, typename LhsType, typename RhsType, class T>
class TensorBinaryOp
: public TensorExpression<TensorBinaryOp<OP, LhsType, RhsType, T>, T> {
public:
explicit TensorBinaryOp(const OP op, const LhsType& lhs, const RhsType& rhs)
: op_(op), lhs_(lhs), rhs_(rhs) {}
: op_(op), lhs_(lhs), rhs_(rhs) {}
const OP op_;
const LhsType lhs_;
......@@ -373,14 +391,15 @@ public:
/**
* \brief Ternary Operator Expression
*/
template<typename ExprType1, typename ExprType2, typename ExprType3, class T>
class TensorTernaryOp
: public TensorExpression<
TensorTernaryOp<ExprType1, ExprType2, ExprType3, T>, T> {
template <typename ExprType1, typename ExprType2, typename ExprType3, class T>
class TensorTernaryOp : public TensorExpression<
TensorTernaryOp<ExprType1, ExprType2, ExprType3, T>,
T> {
public:
explicit TensorTernaryOp(
const ExprType1& expr1, const ExprType2& expr2, const ExprType3& expr3)
: expr1_(expr1), expr2_(expr2), expr3_(expr3) {}
explicit TensorTernaryOp(const ExprType1& expr1,
const ExprType2& expr2,
const ExprType3& expr3)
: expr1_(expr1), expr2_(expr2), expr3_(expr3) {}
const ExprType1 expr1_;
const ExprType2 expr2_;
......@@ -390,12 +409,12 @@ public:
/**
* \brief Constant Expression
*/
template<class OP, typename ExprType, class T>
template <class OP, typename ExprType, class T>
class TensorConstant
: public TensorExpression<TensorConstant<OP, ExprType, T>, T> {
public:
explicit TensorConstant(const OP op, const ExprType& expr)
: op_(op), expr_(expr) {}
: op_(op), expr_(expr) {}
const OP op_;
const ExprType expr_;
......@@ -405,9 +424,9 @@ public:
* \brief operator+ overload
* \return a unary operator expression
*/
template<typename Derived, class T>
const TensorUnaryOp<hppl::unary::add_scale<T>, const Derived, T>
operator+(T p, const TensorExpression<Derived, T>& expr) {
template <typename Derived, class T>
const TensorUnaryOp<hppl::unary::add_scale<T>, const Derived, T> operator+(
T p, const TensorExpression<Derived, T>& expr) {
return expr + p;
}
......@@ -415,9 +434,9 @@ operator+(T p, const TensorExpression<Derived, T>& expr) {
* \brief operator* overload
* \return a unary operator expression
*/
template<typename Derived, class T>
const TensorUnaryOp<hppl::unary::mul_scale<T>, const Derived, T>
operator*(T p, const TensorExpression<Derived, T>& expr) {
template <typename Derived, class T>
const TensorUnaryOp<hppl::unary::mul_scale<T>, const Derived, T> operator*(
T p, const TensorExpression<Derived, T>& expr) {
return expr * p;
}
......@@ -425,4 +444,3 @@ operator*(T p, const TensorExpression<Derived, T>& expr) {
#include "TensorApply.h"
#include "TensorEvaluate.h"
......@@ -355,4 +355,3 @@ void adamaxApply(BaseMatrix& value,
} // namespace paddle
#endif
......@@ -119,5 +119,4 @@ extern void adamaxApply(BaseMatrix& value,
real beta2,
int64_t step,
real alpha);
} // namespace paddle
......@@ -31,7 +31,8 @@ void SparseMomentumParameterOptimizer(const VectorPtr vecs[],
tau * alpha * gamma * learningRate);
vecs[PARAMETER_VALUE]->add(*vecs[PARAMETER_MOMENTUM_UT],
tau / beta + 1.0 / alpha,
*vecs[PARAMETER_MOMENTUM_VT], 1.0 / beta);
*vecs[PARAMETER_MOMENTUM_VT],
1.0 / beta);
}
void AdagradParameterOptimizer(const VectorPtr vecs[],
......@@ -46,10 +47,12 @@ void AdagradParameterOptimizer(const VectorPtr vecs[],
vecs[PARAMETER_LEARNING_RATE]->add(epsilon);
vecs[PARAMETER_LEARNING_RATE]->invSqrt(*vecs[PARAMETER_LEARNING_RATE]);
vecs[PARAMETER_VALUE]->sgdUpdate(
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM],
*vecs[PARAMETER_LEARNING_RATE], learningRate,
momentum, decayRate);
vecs[PARAMETER_VALUE]->sgdUpdate(*vecs[PARAMETER_GRADIENT],
*vecs[PARAMETER_MOMENTUM],
*vecs[PARAMETER_LEARNING_RATE],
learningRate,
momentum,
decayRate);
}
void AdaDeltaParameterOptimizer(const VectorPtr vecs[],
......@@ -59,24 +62,29 @@ void AdaDeltaParameterOptimizer(const VectorPtr vecs[],
real momentum,
real decayRate) {
// E(g_t^2) = \rou * E(g_{t-1}^2) + (1-\rou) * g^2
vecs[PARAMETER_GRADIENT_SQURESUM]->decayAddSquare(*vecs[PARAMETER_GRADIENT],
rou, 1.0f - rou);
vecs[PARAMETER_GRADIENT_SQURESUM]->decayAddSquare(
*vecs[PARAMETER_GRADIENT], rou, 1.0f - rou);
// learn_rate = sqrt( ( E(dx_{t-1}^2) + epsilon ) / ( E(g_t^2) + epsilon ) )
vecs[PARAMETER_LEARNING_RATE]->dotDiv(*vecs[PARAMETER_GRADIENT_SQURESUM1],
*vecs[PARAMETER_GRADIENT_SQURESUM],
epsilon, epsilon);
epsilon,
epsilon);
vecs[PARAMETER_LEARNING_RATE]->sqrt2();
// E(dx_t^2) = \rou * E(dx_{t-1}^2) + (1-\rou) * (-g*learn_rate)^2
vecs[PARAMETER_GRADIENT_SQURESUM1]->decayAddSquareMul(
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_LEARNING_RATE], rou,
*vecs[PARAMETER_GRADIENT],
*vecs[PARAMETER_LEARNING_RATE],
rou,
1.0f - rou);
vecs[PARAMETER_VALUE]->sgdUpdate(
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM],
*vecs[PARAMETER_LEARNING_RATE], learningRate,
momentum, decayRate);
vecs[PARAMETER_VALUE]->sgdUpdate(*vecs[PARAMETER_GRADIENT],
*vecs[PARAMETER_MOMENTUM],
*vecs[PARAMETER_LEARNING_RATE],
learningRate,
momentum,
decayRate);
}
void RMSPropParameterOptimizer(const VectorPtr vecs[],
......@@ -91,12 +99,11 @@ void RMSPropParameterOptimizer(const VectorPtr vecs[],
// For the first time update, make the sum be the current square
// so that the initial estimation of E(g_t^2) will not be too small.
vecs[PARAMETER_GRADIENT_SQURESUM]->decayAddSquare(
*vecs[PARAMETER_GRADIENT], accumulatedRou,
firstTime ? 1.0f : 1.0f - rou);
*vecs[PARAMETER_GRADIENT], accumulatedRou, firstTime ? 1.0f : 1.0f - rou);
// E(g_t) = \rou * E(g_{t-1}) + (1-\rou) * g
vecs[PARAMETER_GRADIENT_SQURESUM1]->add(*vecs[PARAMETER_GRADIENT],
accumulatedRou, 1.0f - rou);
vecs[PARAMETER_GRADIENT_SQURESUM1]->add(
*vecs[PARAMETER_GRADIENT], accumulatedRou, 1.0f - rou);
// learn_rate = 1/sqrt( ( E(g_t^2) - (E(g_t))^2 + epsilon )
// Basiclly if the sign of the gradient changes more often,
......@@ -107,10 +114,12 @@ void RMSPropParameterOptimizer(const VectorPtr vecs[],
vecs[PARAMETER_LEARNING_RATE]->add(epsilon);
vecs[PARAMETER_LEARNING_RATE]->invSqrt(*vecs[PARAMETER_LEARNING_RATE]);
vecs[PARAMETER_VALUE]->sgdUpdate(
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM],
*vecs[PARAMETER_LEARNING_RATE], learningRate,
momentum, decayRate);
vecs[PARAMETER_VALUE]->sgdUpdate(*vecs[PARAMETER_GRADIENT],
*vecs[PARAMETER_MOMENTUM],
*vecs[PARAMETER_LEARNING_RATE],
learningRate,
momentum,
decayRate);
}
void DecayedAdagradParameterOptimizer(const VectorPtr vecs[],
......@@ -125,8 +134,7 @@ void DecayedAdagradParameterOptimizer(const VectorPtr vecs[],
// For the first time update, make the sum be the current square
// so that the initial estimation of E(g_t^2) will not be too small.
vecs[PARAMETER_GRADIENT_SQURESUM]->decayAddSquare(
*vecs[PARAMETER_GRADIENT], accumulatedRou,
firstTime ? 1.0f : 1.0f - rou);
*vecs[PARAMETER_GRADIENT], accumulatedRou, firstTime ? 1.0f : 1.0f - rou);
// learn_rate = 1/sqrt( ( E(g_t^2) + epsilon )
// Basiclly if the bigger the magnitude gradient is,
......@@ -135,10 +143,12 @@ void DecayedAdagradParameterOptimizer(const VectorPtr vecs[],
vecs[PARAMETER_LEARNING_RATE]->add(*vecs[PARAMETER_GRADIENT_SQURESUM]);
vecs[PARAMETER_LEARNING_RATE]->invSqrt(*vecs[PARAMETER_LEARNING_RATE]);
vecs[PARAMETER_VALUE]->sgdUpdate(
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM],
*vecs[PARAMETER_LEARNING_RATE], learningRate,
momentum, decayRate);
vecs[PARAMETER_VALUE]->sgdUpdate(*vecs[PARAMETER_GRADIENT],
*vecs[PARAMETER_MOMENTUM],
*vecs[PARAMETER_LEARNING_RATE],
learningRate,
momentum,
decayRate);
}
void AdamParameterOptimizer(const VectorPtr vecs[],
......@@ -164,16 +174,13 @@ void AdamParameterOptimizer(const VectorPtr vecs[],
// \theta_t = \theta_{t-1} - \alpha * \sqrt(1-\beta_2^t) / (1-\beta_1^t) * tmp
g->sqrt2(*v);
g->dotDiv(*m, *g, 0., epsilon);
real alpha = learningRate *
std::sqrt((real)1 - beta2_power) / ((real)1 - beta1_power);
real alpha =
learningRate * std::sqrt((real)1 - beta2_power) / ((real)1 - beta1_power);
theta->add(*theta, 1.0, *g, -alpha);
}
void AdamaxParameterOptimizer(const VectorPtr vecs[],
real beta1,
real beta2,
int64_t step,
real alpha) {
void AdamaxParameterOptimizer(
const VectorPtr vecs[], real beta1, real beta2, int64_t step, real alpha) {
Vector* m = vecs[PARAMETER_MOMENTUM].get();
Vector* g = vecs[PARAMETER_GRADIENT].get();
Vector* u = vecs[PARAMETER_WEIGHTED_INFINITY_NORM].get();
......@@ -192,4 +199,3 @@ void AdamaxParameterOptimizer(const VectorPtr vecs[],
real learningRate = alpha / (1 - std::pow(beta1, step));
theta->add(*theta, 1.0, *g, -learningRate);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册