提交 e63f1e69 编写于 作者: H hedaoyuan

merge from cooder

上级 e83950b0
......@@ -19,7 +19,6 @@ limitations under the License. */
#include "hl_base.h"
#ifdef __CUDA_ARCH__
// typedef void* vecType;
#include <vector_types.h>
#ifndef PADDLE_TYPE_DOUBLE
typedef float4 vecType;
......@@ -37,4 +36,10 @@ typedef __m128d vecType;
#endif
#endif
#endif /* HL_MATRIX_TYPE_CUH_ */
#ifdef __CUDA_ARCH__
#define INLINE __device__ inline
#else
#define INLINE inline
#endif
#endif // HL_MATRIX_TYPE_CUH_
/**
* hl_tensor_ops.h
*
* Author: hedaoyuan (hedaoyuan@baidu.com)
* Created on: 2016-06-06
*
* Copyright (c) Baidu.com, Inc. All Rights Reserved
*
*/
#ifndef HL_TENSOR_OPS_H_
#define HL_TENSOR_OPS_H_
#include <cmath>
#include "hl_matrix_type.cuh"
namespace hppl {
namespace unary {
template<class T>
class add_scale{
private:
const T p;
public:
INLINE add_scale(const T s) : p(s) {}
INLINE T operator()(const T a) const { return a + p; }
};
template<class T>
class sub_scale {
private:
const T p;
public:
INLINE sub_scale(const T s) : p(s) {}
INLINE T operator()(const T a) const { return a - p; }
};
template<class T>
class mul_scale {
private:
const T p;
public:
INLINE mul_scale(const T s) : p(s) {}
INLINE T operator()(const T a) const { return a * p; }
};
template<class T>
class div_scale {
private:
const T p;
public:
INLINE div_scale(const T s) : p(s) {}
INLINE T operator()(const T a) const { return a / p; }
};
template<class T>
class neg {
public:
INLINE T operator()(const T a) const { return -a; }
};
template<class T>
class exp_op {
public:
INLINE T operator()(const T a) const { return std::exp(a); }
};
template<class T>
class log_op {
public:
INLINE T operator()(const T a) const { return std::log(a); }
};
template<class T>
class sqrt_op {
public:
INLINE T operator()(const T a) const { return std::sqrt(a); }
};
template<class T>
class square {
public:
INLINE T operator()(const T a) const { return a * a; }
};
template<class T>
class reciprocal {
public:
INLINE T operator()(const T a) const { return T(1) / a; }
};
template<class T>
class abs {
public:
INLINE T operator()(const T a) const { return a > 0 ? a : -a; }
};
template<class T>
class sign {
public:
INLINE T operator()(const T a) const { return (a > 0) - (a < 0); }
};
template<class T>
class min {
private:
const T p;
public:
INLINE min(const T s) : p(s) {}
INLINE T operator()(const T a) const { return a > p ? p : a; }
};
template<class T>
class max {
private:
const T p;
public:
INLINE max(const T s) : p(s) {}
INLINE T operator()(const T a) const { return a < p ? p : a; }
};
template<class T>
class pow_op {
private:
const T p;
public:
INLINE pow_op(const T s) : p(s) {}
INLINE T operator()(const T a) const { return std::pow(a, p); }
};
template<class T>
class constant {
private:
const T p;
public:
INLINE constant(const T s) : p(s) {}
INLINE T operator()(int i) const { return p; }
INLINE T operator()(int i, int j) const { return p; }
};
template<class T>
class cmp_eq {
private:
const T p;
public:
INLINE cmp_eq(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a == p; }
};
template<class T>
class cmp_ne {
private:
const T p;
public:
INLINE cmp_ne(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a != p; }
};
template<class T>
class cmp_le {
private:
const T p;
public:
INLINE cmp_le(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a <= p; }
};
template<class T>
class cmp_lt {
private:
const T p;
public:
INLINE cmp_lt(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a < p; }
};
template<class T>
class cmp_ge {
private:
const T p;
public:
INLINE cmp_ge(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a >= p; }
};
template<class T>
class cmp_gt {
private:
const T p;
public:
INLINE cmp_gt(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a > p; }
};
template<class T>
class and_op {
private:
const T p;
public:
INLINE and_op(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a && p; }
};
template<class T>
class or_op {
private:
const T p;
public:
INLINE or_op(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a || p; }
};
} // namespace unary
namespace binary {
template<class T>
class add {
public:
INLINE T operator()(const T a, const T b) const { return a + b; }
};
template<class T>
class add_scale {
private:
const T p1;
const T p2;
public:
INLINE add_scale(const T s1, const T s2) : p1(s1), p2(s2) {}
INLINE T operator()(const T a, const T b) const {
return p1 * a + p2 * b;
}
};
template<class T>
class sub {
public:
INLINE T operator()(const T a, const T b) const { return a - b; }
};
template<class T>
class mul {
public:
INLINE T operator()(const T a, const T b) const { return a * b; }
};
template<class T>
class div {
public:
INLINE T operator()(const T a, const T b) const { return a / b; }
};
template<class T>
class cmp_eq {
public:
INLINE bool operator()(const T a, const T b) const { return a == b; }
};
template<class T>
class cmp_ne {
public:
INLINE bool operator()(const T a, const T b) const { return a != b; }
};
template<class T>
class cmp_le {
public:
INLINE bool operator()(const T a, const T b) const { return a <= b; }
};
template<class T>
class cmp_lt {
public:
INLINE bool operator()(const T a, const T b) const { return a < b; }
};
template<class T>
class cmp_ge {
public:
INLINE bool operator()(const T a, const T b) const { return a >= b; }
};
template<class T>
class cmp_gt {
public:
INLINE bool operator()(const T a, const T b) const { return a > b; }
};
template<class T>
class and_op {
public:
INLINE bool operator()(const T a, const T b) const { return a && b; }
};
template<class T>
class or_op {
public:
INLINE bool operator()(const T a, const T b) const { return a || b; }
};
template<class T>
class min {
public:
INLINE T operator()(const T a, const T b) const { return a > b ? b : a; }
};
template<class T>
class max {
public:
INLINE T operator()(const T a, const T b) const { return a < b ? b : a; }
};
} // namespace binary
} // namespace hppl
#endif // HL_TENSOR_OPS_H_
......@@ -271,7 +271,7 @@ void forward(Argument& act) {
/* trans */ false, useGpu(act.deviceId));
act.in->copyFrom(*act.value);
act.value->abs(*act.value);
act.value->abs2(*act.value);
}
void backward(Argument& act) { act.grad->absDerivative(*act.in); }
......@@ -290,7 +290,7 @@ void forward(Argument& act) {
/* trans */ false, useGpu(act.deviceId));
act.in->copyFrom(*act.value);
act.value->square(*act.value);
act.value->square2(*act.value);
}
void backward(Argument& act) { act.grad->squareDerivative(*act.in); }
......@@ -302,7 +302,7 @@ END_DEFINE_ACTIVATION(square)
* \f]
*/
BEGIN_DEFINE_ACTIVATION(exponential)
void forward(Argument& act) { act.value->exp(*act.value); }
void forward(Argument& act) { act.value->exp2(*act.value); }
void backward(Argument& act) { act.grad->expDerivative(*act.value); }
END_DEFINE_ACTIVATION(exponential)
......
......@@ -41,7 +41,7 @@ void BatchNormalizationLayer::calMeanAndStd(const MatrixPtr& mat) {
savedMean_->mulScalar(1.0 / numSamples); // E[x]
tmpMat_->assign(*mat);
tmpMat_->square();
tmpMat_->square2();
savedInvVar_->zeroMem();
savedInvVar_->accumulateColSum(*tmpMat_);
savedInvVar_->mulScalar(1.0 / numSamples); // E[x^2]
......@@ -55,7 +55,7 @@ void BatchNormalizationLayer::calMeanAndStd(const MatrixPtr& mat) {
calMovingMeanAndVar();
savedInvVar_->subScalar(-EPS);
savedInvVar_->sqrt(*savedInvVar_);
savedInvVar_->sqrt2(*savedInvVar_);
}
void BatchNormalizationLayer::calMovingMeanAndVar() {
......@@ -86,7 +86,7 @@ void BatchNormalizationLayer::setMeanAndStd() {
savedInvVar_->downClip(real(0.0));
savedInvVar_->subScalar(-EPS);
savedInvVar_->sqrt(*savedInvVar_);
savedInvVar_->sqrt2(*savedInvVar_);
}
void BatchNormalizationLayer::expandMat(const MatrixPtr& in, MatrixPtr& out) {
......
......@@ -114,12 +114,12 @@ void MultiClassCrossEntropyWithSelfNorm::forwardImp(Matrix& output,
Matrix& target) {
Matrix::resizeOrCreate(sftMaxSum_, output.getHeight(), 1, false, useGpu_);
output.rowSum(*sftMaxSum_);
sftMaxSum_->log();
sftMaxSum_->log2();
target.oneHotCrossEntropy(output, *label.ids);
target.add(*sftMaxSum_);
sftMaxSum_->square();
sftMaxSum_->square2();
target.add(*sftMaxSum_, config_.softmax_selfnorm_alpha());
}
......@@ -130,12 +130,12 @@ void MultiClassCrossEntropyWithSelfNorm::backwardImp(Matrix& output,
output.rowSum(*sftMaxSum_);
Matrix::resizeOrCreate(sumInv_, output.getHeight(), 1, false, useGpu_);
sftMaxSum_->reciprocal(*sumInv_);
sftMaxSum_->reciprocal2(*sumInv_);
outputG.oneHotCrossEntropyBp(output, *label.ids);
outputG.addColumnVector(*sumInv_);
sftMaxSum_->log();
sftMaxSum_->log2();
sumInv_->dotMul(*sumInv_, *sftMaxSum_);
sumInv_->mulScalar(2 * config_.softmax_selfnorm_alpha());
......
......@@ -310,12 +310,12 @@ void Layer::showOutputStats() {
auto tmpMat = dynamic_cast<CpuSparseMatrix*>(outSquare.get());
min = tmpMat->getMin();
max = tmpMat->getMax();
tmpMat->square();
tmpMat->square2();
LOG(INFO) << "show statistics of [none zero values] in sparse matrix";
} else {
min = outSquare->getMin();
max = outSquare->getMax();
outSquare->square();
outSquare->square2();
}
real std = (outSquare->getSum() / outSquare->getElementCnt()) - mean * mean;
std = std > 0 ? std : 0;
......
......@@ -61,7 +61,7 @@ real LinearChainCRF::forward(real* x, int* s, int length) {
expX_->assign(*matX);
// subtract max to avoid overflow or underflow
expX_->mul(maxX_, ones_, (real)-1, (real)1);
expX_->exp();
expX_->exp2();
real* a = a_->getData();
real* b = b_->getData();
......@@ -70,7 +70,7 @@ real LinearChainCRF::forward(real* x, int* s, int length) {
real* expX = expX_->getData();
real* maxX = maxX_->getData();
expW_->exp(*w_);
expW_->exp2(*w_);
real* expW = expW_->getData();
for (int i = 0; i < numClasses_; ++i) {
......
......@@ -100,7 +100,7 @@ void PowerLayer::backward(const UpdateCallback& callback) {
Matrix::resizeOrCreate(tmpMtx, batchSize, dataDim, false, useGpu_);
if (inG0) {
tmpMtx->log(*inV1);
tmpMtx->log2(*inV1);
tmpMtx->dotMul(*tmpMtx, *outV);
// inG0 += outG .* (log(inV1) * outV)
......
......@@ -355,11 +355,11 @@ void BaseMatrixT<T>::neg() { applyUnary(unary::Neg<T>()); }
DEFINE_MATRIX_UNARY_OP(Exp, a = exp(a));
template<>
void BaseMatrixT<real>::exp() { applyUnary(unary::Exp<real>()); }
void BaseMatrixT<real>::exp2() { applyUnary(unary::Exp<real>()); }
DEFINE_MATRIX_UNARY_OP(Log, a = log(a));
template<>
void BaseMatrixT<real>::log() {
void BaseMatrixT<real>::log2() {
if (useGpu_) {
applyUnary(unary::Log<real>());
} else {
......@@ -369,23 +369,23 @@ void BaseMatrixT<real>::log() {
DEFINE_MATRIX_UNARY_OP(Sqrt, a = sqrt(a));
template<>
void BaseMatrixT<real>::sqrt() { applyUnary(unary::Sqrt<real>()); }
void BaseMatrixT<real>::sqrt2() { applyUnary(unary::Sqrt<real>()); }
DEFINE_MATRIX_UNARY_OP(Square, a = a * a);
template<class T>
void BaseMatrixT<T>::square() { applyUnary(unary::Square<T>()); }
void BaseMatrixT<T>::square2() { applyUnary(unary::Square<T>()); }
DEFINE_MATRIX_UNARY_OP(Reciprocal, a = 1.0f / a);
template<class T>
void BaseMatrixT<T>::reciprocal() { applyUnary(unary::Reciprocal<T>()); }
void BaseMatrixT<T>::reciprocal2() { applyUnary(unary::Reciprocal<T>()); }
DEFINE_MATRIX_UNARY_OP(Abs, a = a > 0 ? a : -a);
template<class T>
void BaseMatrixT<T>::abs() { applyUnary(unary::Abs<T>()); }
void BaseMatrixT<T>::abs2() { applyUnary(unary::Abs<T>()); }
DEFINE_MATRIX_UNARY_OP(Sign, a = (a > 0) - (a < 0));
template<class T>
void BaseMatrixT<T>::sign() { applyUnary(unary::Sign<T>()); }
void BaseMatrixT<T>::sign2() { applyUnary(unary::Sign<T>()); }
DEFINE_MATRIX_UNARY_OP(Zero, a = 0);
template<class T>
......@@ -405,7 +405,7 @@ void BaseMatrixT<T>::one() { applyUnary(unary::One<T>()); }
DEFINE_MATRIX_UNARY_PARAMETER_OP(Pow, ONE_PARAMETER, a = pow(a, p));
template<>
void BaseMatrixT<real>::pow(real p) {
void BaseMatrixT<real>::pow2(real p) {
if (useGpu_) {
applyUnary(unary::Pow<real>(p));
} else {
......@@ -534,7 +534,7 @@ void BaseMatrixT<T>::add(BaseMatrixT& b, T p) {
DEFINE_MATRIX_BINARY_PARAMETER_OP(Pow, ONE_PARAMETER, a = pow(b, p));
template<>
void BaseMatrixT<real>::pow(BaseMatrixT& b, real p) {
void BaseMatrixT<real>::pow2(BaseMatrixT& b, real p) {
if (useGpu_) {
applyBinary(binary::Pow<real>(p), b);
} else {
......@@ -615,7 +615,7 @@ void BaseMatrixT<T>::breluDerivative(BaseMatrixT& b) {
DEFINE_MATRIX_BINARY_OP(Square, b = a * a);
template<class T>
void BaseMatrixT<T>::square(BaseMatrixT& b) {
void BaseMatrixT<T>::square2(BaseMatrixT& b) {
applyBinary(binary::Square<T>(), b);
}
......@@ -654,7 +654,7 @@ void BaseMatrixT<T>::scaledTanhDerivative(BaseMatrixT& b, T p1, T p2) {
DEFINE_MATRIX_BINARY_OP(Reciprocal, b = 1.0f / a);
template<class T>
void BaseMatrixT<T>::reciprocal(BaseMatrixT& b) {
void BaseMatrixT<T>::reciprocal2(BaseMatrixT& b) {
applyBinary(binary::Reciprocal<T>(), b);
}
......@@ -666,7 +666,7 @@ void BaseMatrixT<T>::reciprocalDerivative(BaseMatrixT& b) {
DEFINE_MATRIX_BINARY_OP(Abs, b = a > 0.0f ? a : -a);
template<class T>
void BaseMatrixT<T>::abs(BaseMatrixT& b) { applyBinary(binary::Abs<T>(), b); }
void BaseMatrixT<T>::abs2(BaseMatrixT& b) { applyBinary(binary::Abs<T>(), b); }
DEFINE_MATRIX_BINARY_OP(AbsDerivative, a = (b > 0) ? a : (b < 0) ? -a : 0);
template<class T>
......@@ -726,17 +726,19 @@ void BaseMatrixT<T>::expDerivative(BaseMatrixT& b) {
DEFINE_MATRIX_BINARY_OP(Sign, b = a > 0.0f ? 1.0f : -1.0f);
template<class T>
void BaseMatrixT<T>::sign(BaseMatrixT& b) { applyBinary(binary::Sign<T>(), b); }
void BaseMatrixT<T>::sign2(BaseMatrixT& b) {
applyBinary(binary::Sign<T>(), b);
}
DEFINE_MATRIX_BINARY_OP(Exp, a = exp(b));
template<>
void BaseMatrixT<real>::exp(BaseMatrixT& b) {
void BaseMatrixT<real>::exp2(BaseMatrixT& b) {
applyBinary(binary::Exp<real>(), b);
}
DEFINE_MATRIX_BINARY_OP(Log, a = log(b));
template<>
void BaseMatrixT<real>::log(BaseMatrixT& b) {
void BaseMatrixT<real>::log2(BaseMatrixT& b) {
if (useGpu_) {
applyBinary(binary::Log<real>(), b);
} else {
......@@ -746,7 +748,7 @@ void BaseMatrixT<real>::log(BaseMatrixT& b) {
DEFINE_MATRIX_BINARY_OP(Sqrt, a = sqrt(b));
template<>
void BaseMatrixT<real>::sqrt(BaseMatrixT& b) {
void BaseMatrixT<real>::sqrt2(BaseMatrixT& b) {
applyBinary(binary::Sqrt<real>(), b);
}
......@@ -1062,7 +1064,7 @@ void BaseMatrixT<T>::biggerThan(BaseMatrixT& b,
DEFINE_MATRIX_TERNARY_OP(Max, a = (b > c) ? b : c);
template<class T>
void BaseMatrixT<T>::max(BaseMatrixT& b, BaseMatrixT& c) { // NOLINT
void BaseMatrixT<T>::max2(BaseMatrixT& b, BaseMatrixT& c) {
applyTernary(ternary::Max<T>(), b, c);
}
......@@ -1165,7 +1167,7 @@ void BaseMatrixT<T>::reciprocalSum(BaseMatrixT& b, BaseMatrixT& c, T p1, T p2,
DEFINE_MATRIX_BINARY_PARAMETER_OP(Reciprocal2, TWO_PARAMETER,
a = 1 / (p1 * b + p2));
template<class T>
void BaseMatrixT<T>::reciprocal(BaseMatrixT& b, T p1, T p2) {
void BaseMatrixT<T>::reciprocal2(BaseMatrixT& b, T p1, T p2) {
applyBinary(binary::Reciprocal2<T>(p1, p2), b);
}
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <cstddef>
#include <stdint.h>
#include "paddle/utils/TypeDefs.h"
#include "TensorExpression.h"
namespace paddle {
......@@ -66,7 +67,7 @@ public:
};
template<class T>
class BaseMatrixT {
class BaseMatrixT : public TensorExpression<BaseMatrixT<T>, T> {
public:
size_t height_, width_;
size_t stride_;
......@@ -351,14 +352,14 @@ public:
*
*/
void neg();
void exp();
void pow(T p);
void log();
void sqrt();
void square();
void reciprocal();
void abs();
void sign();
void exp2();
void pow2(T p);
void log2();
void sqrt2();
void square2();
void reciprocal2();
void abs2();
void sign2();
void zero();
/**
......@@ -516,7 +517,7 @@ public:
* b = this * this
* @endcode
*/
void square(BaseMatrixT& b);
void square2(BaseMatrixT& b);
void squareDerivative(BaseMatrixT& b);
/**
......@@ -540,7 +541,7 @@ public:
* b = 1.0f / this
* @endcode
*/
void reciprocal(BaseMatrixT& b);
void reciprocal2(BaseMatrixT& b);
void reciprocalDerivative(BaseMatrixT& b);
/**
......@@ -548,7 +549,7 @@ public:
* b = this > 0.0f ? this : -this
* @endcode
*/
void abs(BaseMatrixT& b);
void abs2(BaseMatrixT& b);
void absDerivative(BaseMatrixT& b);
/**
......@@ -566,12 +567,12 @@ public:
*/
void expDerivative(BaseMatrixT& b);
void sign(BaseMatrixT& b);
void sign2(BaseMatrixT& b);
void exp(BaseMatrixT& b);
void pow(BaseMatrixT& b, T p);
void log(BaseMatrixT& b);
void sqrt(BaseMatrixT& b);
void exp2(BaseMatrixT& b);
void pow2(BaseMatrixT& b, T p);
void log2(BaseMatrixT& b);
void sqrt2(BaseMatrixT& b);
void addScalar(BaseMatrixT& b, T p);
void subScalar(BaseMatrixT& b, T p);
void mulScalar(BaseMatrixT& b, T p);
......@@ -742,7 +743,7 @@ public:
* this = b>c ? b : c
* @endcode
*/
void max(BaseMatrixT& b, BaseMatrixT& c); // NOLINT
void max2(BaseMatrixT& b, BaseMatrixT& c);
/**
* @code
......@@ -837,7 +838,7 @@ public:
* this = 1 / (p1 * b + p2)
* @endcode
*/
void reciprocal(BaseMatrixT& b, T p1, T p2);
void reciprocal2(BaseMatrixT& b, T p1, T p2);
/**
* @code
......@@ -953,6 +954,32 @@ public:
virtual bool isSparse() const {
return false;
}
template<typename ExpressionType>
void operator=(const ExpressionType& expr) {
if (useGpu_) {
TensorGpuApply<T>(*this, expr);
} else {
TensorCpuApply<T>(*this, expr);
}
}
template<typename ExpressionType>
void operator+=(const ExpressionType& expr) {
(*this) = (*this) + expr;
}
template<typename ExpressionType>
void operator-=(const ExpressionType& expr) {
(*this) = (*this) - expr;
}
template<typename ExpressionType>
void operator*=(const ExpressionType& expr) {
(*this) = (*this) * expr;
}
template<typename ExpressionType>
void operator/=(const ExpressionType& expr) {
(*this) = (*this) / expr;
}
};
typedef BaseMatrixT<real> BaseMatrix;
......
......@@ -16,10 +16,12 @@ file(GLOB MATH_HEADERS . *.h)
file(GLOB MATH_SOURCES . *.cpp)
set(MATH_SOURCES
"${PROJ_ROOT}/paddle/math/BaseMatrix.cu"
"${PROJ_ROOT}/paddle/math/TrainingAlgorithmOp.cu"
${MATH_SOURCES})
if(NOT WITH_GPU)
# then compile BaseMatrix.cu as c++ file
compile_cu_as_cpp("${PROJ_ROOT}/paddle/math/BaseMatrix.cu")
compile_cu_as_cpp("${PROJ_ROOT}/paddle/math/TrainingAlgorithmOp.cu")
add_library(paddle_math STATIC
${MATH_SOURCES})
else()
......
......@@ -125,7 +125,7 @@ public:
return sum;
}
virtual void square() {
virtual void square2() {
CHECK(isContiguous());
if (valueType_ == NO_VALUE) {
return;
......
......@@ -930,6 +930,15 @@ public:
virtual void paramReluBackwardDiff(Matrix& oGrad, Matrix& data, Matrix& W) {
LOG(FATAL) << "Not implemented";
}
template<typename ExpressionType>
void operator=(const ExpressionType& expr) {
if (useGpu_) {
TensorGpuApply<real>(*this, expr);
} else {
TensorCpuApply<real>(*this, expr);
}
}
};
inline std::ostream& operator<<(std::ostream& os, const Matrix& mat) {
......@@ -1191,6 +1200,11 @@ public:
int contextLength,
int contextStart, int totalPad,
size_t beginPad);
template<typename ExpressionType>
void operator=(const ExpressionType& expr) {
TensorGpuApply<real>(*this, expr);
}
};
class CpuMatrix : public Matrix {
......@@ -1469,6 +1483,11 @@ public:
void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label);
void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label);
void classificationErrorMulti(Matrix& output, Matrix& label, real threshold);
template<typename ExpressionType>
void operator=(const ExpressionType& expr) {
TensorCpuApply<real>(*this, expr);
}
};
class SharedCpuMatrix : public CpuMatrix {
......@@ -1504,6 +1523,7 @@ public:
void add(real p1, real p2);
private:
using Matrix::mul;
void initShared(int blockNum);
void initBlock(int blockNum);
......
/**
* TensorApply.h
*
* Author: hedaoyuan (hedaoyuan@baidu.com)
* Created on: 2016-06-06
*
* Copyright (c) Baidu.com, Inc. All Rights Reserved
*
*/
#pragma once
namespace paddle {
/**
* \brief The tensor evaluator classes.
*/
template<typename Derived, class T>
class TensorApply {
public:
explicit INLINE TensorApply(const Derived& p)
: data_(p.data_), stride_(p.stride_),
height_(p.height_), width_(p.width_), useGpu_(p.useGpu_) {}
INLINE T apply(int i, int j) const {
return data_[i * stride_ + j];
}
INLINE T apply(int index) const {
return data_[index];
}
INLINE T& applyRef(int i, int j) {
return data_[i * stride_ + j];
}
INLINE T& applyRef(int index) {
return data_[index];
}
INLINE size_t getWidth() const { return width_; }
INLINE size_t getHeight() const { return height_; }
INLINE bool isContiguous() const { return stride_ == width_ || height_ == 1; }
INLINE bool useGpu() const { return useGpu_; }
T* data_;
size_t stride_;
size_t height_;
size_t width_;
bool useGpu_;
};
/**
* \brief The tensor evaluator classes.
*
* evaluator for rvalues
*/
template<typename Derived, class T>
class TensorApply<const Derived, T> {
public:
explicit INLINE TensorApply(const Derived& p)
: data_(p.data_), stride_(p.stride_),
height_(p.height_), width_(p.width_), useGpu_(p.useGpu_) {}
INLINE T apply(int i, int j) const {
return data_[i * stride_ + j];
}
INLINE T apply(int index) const {
return data_[index];
}
INLINE size_t getWidth() const { return width_; }
INLINE size_t getHeight() const { return height_; }
INLINE bool isContiguous() const { return stride_ == width_ || height_ == 1; }
INLINE bool useGpu() const { return useGpu_; }
const T* data_;
size_t stride_;
size_t height_;
size_t width_;
bool useGpu_;
};
template<typename Derived, class T>
class TensorApply<const TensorExpression<Derived, T>, T> {
public:
explicit TensorApply(const TensorExpression<Derived, T>& expr)
: expr_(expr.derived()) {}
INLINE T apply(int i, int j) const {
return expr_.apply(i, j);
}
INLINE T apply(int index) const {
return expr_.apply(index);
}
INLINE size_t getWidth() const { return expr_.getWidth(); }
INLINE size_t getHeight() const { return expr_.getHeight(); }
INLINE bool isContiguous() const { return expr_.isContiguous(); }
INLINE bool useGpu() const { return expr_.useGpu(); }
TensorApply<const Derived, T> expr_;
};
/**
* \brief The unary expression evaluator classes.
*/
template<class OP, typename ArgType, class T>
class TensorApply<const TensorUnaryOp<OP, ArgType, T>, T> {
public:
explicit INLINE TensorApply(const TensorUnaryOp<OP, ArgType, T>& expr)
: op_(expr.op_), expr_(expr.expr_) {}
INLINE T apply(int i, int j) const {
return op_(expr_.apply(i, j));
}
INLINE T apply(int index) const {
return op_(expr_.apply(index));
}
INLINE size_t getWidth() const { return expr_.getWidth(); }
INLINE size_t getHeight() const { return expr_.getHeight(); }
INLINE bool isContiguous() const { return expr_.isContiguous(); }
INLINE bool useGpu() const { return expr_.useGpu(); }
const OP op_;
TensorApply<ArgType, T> expr_;
};
/**
* \brief The binary expression evaluator classes.
*/
template<class OP, typename LhsType, typename RhsType, class T>
class TensorApply<const TensorBinaryOp<OP, LhsType, RhsType, T>, T> {
public:
explicit INLINE TensorApply(
const TensorBinaryOp<OP, LhsType, RhsType, T>& expr)
: op_(expr.op_), lhs_(expr.lhs_), rhs_(expr.rhs_) {
#ifndef __CUDA_ARCH__
CHECK_EQ(lhs_.getWidth(), rhs_.getWidth());
CHECK_EQ(lhs_.getHeight(), rhs_.getHeight());
CHECK_EQ(lhs_.useGpu(), rhs_.useGpu());
#endif
}
INLINE T apply(int i, int j) const {
return op_(lhs_.apply(i, j), rhs_.apply(i, j));
}
INLINE T apply(int index) const {
return op_(lhs_.apply(index), rhs_.apply(index));
}
INLINE size_t getWidth() const { return lhs_.getWidth(); }
INLINE size_t getHeight() const { return rhs_.getHeight(); }
INLINE bool isContiguous() const {
return lhs_.isContiguous() && rhs_.isContiguous();
}
INLINE bool useGpu() const { return lhs_.useGpu(); }
const OP op_;
TensorApply<LhsType, T> lhs_;
TensorApply<RhsType, T> rhs_;
};
/**
* \brief The ternary expression evaluator classes.
*/
template<typename ArgType1, typename ArgType2, typename ArgType3, class T>
class TensorApply<const TensorTernaryOp<ArgType1, ArgType2, ArgType3, T>, T> {
public:
explicit INLINE TensorApply(
const TensorTernaryOp<ArgType1, ArgType2, ArgType3, T>& expr)
: expr1_(expr.expr1_), expr2_(expr.expr2_), expr3_(expr.expr3_) {
#ifndef __CUDA_ARCH__
CHECK_EQ(expr1_.getWidth(), expr2_.getWidth());
CHECK_EQ(expr1_.getWidth(), expr3_.getWidth());
CHECK_EQ(expr1_.getHeight(), expr2_.getHeight());
CHECK_EQ(expr1_.getHeight(), expr3_.getHeight());
CHECK_EQ(expr1_.useGpu(), expr2_.useGpu());
CHECK_EQ(expr1_.useGpu(), expr3_.useGpu());
#endif
}
INLINE T apply(int i, int j) const {
return expr1_.apply(i, j) ? expr2_.apply(i, j) : expr3_.apply(i, j);
}
INLINE T apply(int index) const {
return expr1_.apply(index) ? expr2_.apply(index) : expr3_.apply(index);
}
INLINE size_t getWidth() const { return expr1_.getWidth(); }
INLINE size_t getHeight() const { return expr1_.getHeight(); }
INLINE bool isContiguous() const {
return expr1_.isContiguous() &&
expr2_.isContiguous() && expr3_.isContiguous();
}
INLINE bool useGpu() const { return expr1_.useGpu(); }
TensorApply<ArgType1, T> expr1_;
TensorApply<ArgType2, T> expr2_;
TensorApply<ArgType3, T> expr3_;
};
/**
* \brief The const expression evaluator classes.
*/
template<class OP, typename ArgType, class T>
class TensorApply<const TensorConstant<OP, ArgType, T>, T> {
public:
explicit INLINE TensorApply(const TensorConstant<OP, ArgType, T>& expr)
: op_(expr.op_), expr_(expr.expr_) {}
INLINE T apply(int i, int j) const {
return op_(i, j);
}
INLINE T apply(int index) const {
return op_(index);
}
INLINE size_t getWidth() const { return expr_.getWidth(); }
INLINE size_t getHeight() const { return expr_.getHeight(); }
INLINE bool isContiguous() const { return true; }
INLINE bool useGpu() const { return expr_.useGpu(); }
const OP op_;
TensorApply<ArgType, T> expr_;
};
} // namespace paddle
/**
* TensorEvaluate.h
*
* Author: hedaoyuan (hedaoyuan@baidu.com)
* Created on: 2016-06-06
*
* Copyright (c) Baidu.com, Inc. All Rights Reserved
*
*/
#pragma once
#include <algorithm>
#include "paddle/utils/Logging.h"
#include "hl_base.h"
namespace paddle {
/**
* \brief The tensor cpu evaluate api.
*/
template<class T, typename LeftType, typename RightType>
inline void TensorCpuApply(LeftType& lhs, const RightType& rhs) {
TensorApply<LeftType, T> lhs_(lhs);
TensorApply<const RightType, T> rhs_(rhs);
CHECK_EQ(lhs_.getWidth(), rhs_.getWidth());
CHECK_EQ(lhs_.getHeight(), rhs_.getHeight());
CHECK_EQ(lhs_.useGpu(), rhs_.useGpu());
if (lhs_.isContiguous() && rhs_.isContiguous()) {
int size = lhs_.getHeight() * lhs_.getWidth();
for (int index = 0; index < size; index++) {
lhs_.applyRef(index) = rhs_.apply(index);
}
} else {
for (size_t i = 0; i < lhs_.getHeight(); i++) {
for (size_t j = 0; j < lhs_.getWidth(); j++) {
lhs_.applyRef(i, j) = rhs_.apply(i, j);
}
}
}
}
#ifdef __NVCC__
template<typename LeftType, typename RightType>
__global__
void TensorElementWiseOp(LeftType lhs, RightType rhs, const int border) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < border) {
lhs.applyRef(idx) = rhs.apply(idx);
}
}
template<typename LeftType, typename RightType>
__global__ void TensorElementWiseOp(LeftType lhs, RightType rhs) {
const int colIdx = blockIdx.x * blockDim.x + threadIdx.x;
const int rowIdx = blockIdx.y * blockDim.y + threadIdx.y;
for (int i = rowIdx; i < lhs.getHeight(); i += gridDim.y * blockDim.y) {
for (int j = colIdx; j < lhs.getWidth(); j += gridDim.x * blockDim.x) {
lhs.applyRef(i, j) = rhs.apply(i, j);
}
}
}
/**
* \brief The tensor gpu evaluate api.
*/
template<class T, typename LeftType, typename RightType>
inline void TensorGpuApply(LeftType& lhs, const RightType& rhs) {
TensorApply<LeftType, T> lhs_(lhs);
TensorApply<const RightType, T> rhs_(rhs);
CHECK_EQ(lhs_.getWidth(), rhs_.getWidth());
CHECK_EQ(lhs_.getHeight(), rhs_.getHeight());
CHECK_EQ(lhs_.useGpu(), rhs_.useGpu());
int dimM = lhs_.getHeight();
int dimN = lhs_.getWidth();
if (lhs_.isContiguous() && rhs_.isContiguous()) {
int size = dimM * dimN;
int blockSize = size <= 1024 ? size : 1024;
int gridSize = (size + 1024 - 1) / 1024;
TensorElementWiseOp
<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(lhs_, rhs_, size);
} else {
int blockSizeY = std::min(32, dimM);
int blockSizeX = (32 / blockSizeY) * 32;
int gridSizeX = std::min(32, (dimN + blockSizeX - 1) / blockSizeX);
int gridSizeY = std::min(32, (dimM + blockSizeY - 1) / blockSizeY);
dim3 threads(blockSizeX, blockSizeY);
dim3 grid(gridSizeX, gridSizeY);
TensorElementWiseOp
<<<grid, threads, 0, STREAM_DEFAULT>>>(lhs_, rhs_);
}
CHECK_SYNC("TensorGpuApply failed");
}
#else
template<class T, typename LeftType, typename RightType>
inline void TensorGpuApply(LeftType& lhs, RightType& rhs) {
}
#endif
} // namespace paddle
/**
* TensorExpression.h
*
* Author: hedaoyuan (hedaoyuan@baidu.com)
* Created on: 2016-06-06
*
* Copyright (c) Baidu.com, Inc. All Rights Reserved
*
*/
#pragma once
#include <cstddef>
#include <stdint.h>
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/Logging.h"
#include "hl_tensor_ops.h"
namespace paddle {
template<class OP, typename ExprType, class T> class TensorConstant;
template<class OP, typename ExprType, class T> class TensorUnaryOp;
template<
class OP, typename LhsType, typename RhsType, class T> class TensorBinaryOp;
template<
typename ExprType1,
typename ExprType2,
typename ExprType3,
class T> class TensorTernaryOp;
/**
* \brief Tensor base class.
*
* This is the base class of all Tensor and Expression class.
*/
template<typename Derived, class T>
class TensorExpression {
public:
/**
* Element wise unary expression.
*/
template<typename UnaryOp>
const TensorUnaryOp<UnaryOp, const Derived, T>
unaryExpression(const UnaryOp& op) const {
return TensorUnaryOp<UnaryOp, const Derived, T>(op, derived());
}
const TensorUnaryOp<hppl::unary::add_scale<T>, const Derived, T>
operator+(T p) const {
return unaryExpression(hppl::unary::add_scale<T>(p));
}
const TensorUnaryOp<hppl::unary::sub_scale<T>, const Derived, T>
operator-(T p) const {
return unaryExpression(hppl::unary::sub_scale<T>(p));
}
const TensorUnaryOp<hppl::unary::mul_scale<T>, const Derived, T>
operator*(T p) const {
return unaryExpression(hppl::unary::mul_scale<T>(p));
}
const TensorUnaryOp<hppl::unary::div_scale<T>, const Derived, T>
operator/(T p) const {
return unaryExpression(hppl::unary::div_scale<T>(p));
}
const TensorUnaryOp<hppl::unary::neg<T>, const Derived, T>
operator-() const {
return unaryExpression(hppl::unary::neg<T>());
}
const TensorUnaryOp<hppl::unary::exp_op<T>, const Derived, T>
exp() const {
return unaryExpression(hppl::unary::exp_op<T>());
}
const TensorUnaryOp<hppl::unary::log_op<T>, const Derived, T>
log() const {
return unaryExpression(hppl::unary::log_op<T>());
}
const TensorUnaryOp<hppl::unary::sqrt_op<T>, const Derived, T>
sqrt() const {
return unaryExpression(hppl::unary::sqrt_op<T>());
}
const TensorUnaryOp<hppl::unary::square<T>, const Derived, T>
square() const {
return unaryExpression(hppl::unary::square<T>());
}
const TensorUnaryOp<hppl::unary::reciprocal<T>, const Derived, T>
reciprocal() const {
return unaryExpression(hppl::unary::reciprocal<T>());
}
const TensorUnaryOp<hppl::unary::abs<T>, const Derived, T>
abs() const {
return unaryExpression(hppl::unary::abs<T>());
}
const TensorUnaryOp<hppl::unary::sign<T>, const Derived, T>
sign() const {
return unaryExpression(hppl::unary::sign<T>());
}
const TensorUnaryOp<hppl::unary::pow_op<T>, const Derived, T>
pow(T p) const {
return unaryExpression(hppl::unary::pow_op<T>(p));
}
const TensorUnaryOp<hppl::unary::min<T>, const Derived, T>
min(T p) const {
return unaryExpression(hppl::unary::min<T>(p));
}
const TensorUnaryOp<hppl::unary::max<T>, const Derived, T>
max(T p) const {
return unaryExpression(hppl::unary::max<T>(p));
}
const TensorUnaryOp<hppl::unary::cmp_eq<T>, const Derived, T>
operator==(T p) const {
return unaryExpression(hppl::unary::cmp_eq<T>(p));
}
const TensorUnaryOp<hppl::unary::cmp_ne<T>, const Derived, T>
operator!=(T p) const {
return unaryExpression(hppl::unary::cmp_ne<T>(p));
}
const TensorUnaryOp<hppl::unary::cmp_le<T>, const Derived, T>
operator<=(T p) const {
return unaryExpression(hppl::unary::cmp_le<T>(p));
}
const TensorUnaryOp<hppl::unary::cmp_lt<T>, const Derived, T>
operator<(T p) const {
return unaryExpression(hppl::unary::cmp_lt<T>(p));
}
const TensorUnaryOp<hppl::unary::cmp_ge<T>, const Derived, T>
operator>=(T p) const {
return unaryExpression(hppl::unary::cmp_ge<T>(p));
}
const TensorUnaryOp<hppl::unary::cmp_gt<T>, const Derived, T>
operator>(T p) const {
return unaryExpression(hppl::unary::cmp_gt<T>(p));
}
const TensorUnaryOp<hppl::unary::and_op<T>, const Derived, T>
operator&&(T p) const {
return unaryExpression(hppl::unary::and_op<T>(p));
}
const TensorUnaryOp<hppl::unary::or_op<T>, const Derived, T>
operator||(T p) const {
return unaryExpression(hppl::unary::or_op<T>(p));
}
/**
* Element wise binary expression.
*/
template<typename BinaryOp, typename ExpressionType>
const TensorBinaryOp<BinaryOp, const Derived, const ExpressionType, T>
binaryExpression(const BinaryOp& op, const ExpressionType& expr) const {
return TensorBinaryOp<BinaryOp, const Derived, const ExpressionType, T>(
op, derived(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::cmp_eq<T>, const Derived, const ExpressionType, T>
operator==(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::cmp_eq<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::cmp_ne<T>, const Derived, const ExpressionType, T>
operator!=(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::cmp_ne<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::cmp_le<T>, const Derived, const ExpressionType, T>
operator<=(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::cmp_le<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::cmp_lt<T>, const Derived, const ExpressionType, T>
operator<(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::cmp_lt<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::cmp_ge<T>, const Derived, const ExpressionType, T>
operator>=(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::cmp_ge<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::cmp_gt<T>, const Derived, const ExpressionType, T>
operator>(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::cmp_gt<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::and_op<T>, const Derived, const ExpressionType, T>
operator&&(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::and_op<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::or_op<T>, const Derived, const ExpressionType, T>
operator||(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::or_op<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::add<T>, const Derived, const ExpressionType, T>
operator+(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::add<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::sub<T>, const Derived, const ExpressionType, T>
operator-(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::sub<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::mul<T>, const Derived, const ExpressionType, T>
operator*(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::mul<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::div<T>, const Derived, const ExpressionType, T>
operator/(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::div<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::min<T>, const Derived, const ExpressionType, T>
min(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::min<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::max<T>, const Derived, const ExpressionType, T>
max(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::max<T>(), expr);
}
/**
* Element wise ternary expression.
*
* ternary conditional operator(?: operator).
* The conditional expression returns one of two values depending on
* the result of derived expression.
* If derived expression evaluates to true, then expression1 is evaluated.
* If derived expression evaluates to false, then expression2 is evaluated.
*/
template<typename ExprType1, typename ExprType2>
const TensorTernaryOp<const Derived, const ExprType1, const ExprType2, T>
condition(const ExprType1& expr1, const ExprType2& expr2) const {
return TensorTernaryOp<const Derived, const ExprType1, const ExprType2, T>
(derived(), expr1, expr2);
}
template<typename ExprType>
const TensorTernaryOp<
const Derived,
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
const ExprType,
T>
condition(T p, const ExprType& expr) const {
return condition(constant(p), expr);
}
template<typename ExprType>
const TensorTernaryOp<
const Derived,
const ExprType,
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
T>
condition(const ExprType& expr, T p) const {
return condition(expr, constant(p));
}
const TensorTernaryOp<
const Derived,
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
T>
condition(T p1, T p2) const {
return condition(constant(p1), constant(p2));
}
const TensorConstant<hppl::unary::constant<T>, const Derived, T>
constant(T p) const {
return TensorConstant<hppl::unary::constant<T>, const Derived, T>
(hppl::unary::constant<T>(p), derived());
}
protected:
const Derived& derived() const { return *static_cast<const Derived*>(this); }
};
/**
* \brief Unary Operator Expression
*/
template<class OP, typename ExprType, class T>
class TensorUnaryOp
: public TensorExpression<TensorUnaryOp<OP, ExprType, T>, T> {
public:
explicit TensorUnaryOp(const OP op, const ExprType& expr)
: op_(op), expr_(expr) {}
const OP op_;
const ExprType expr_;
};
/**
* \brief Binary Operator Expression
*/
template<class OP, typename LhsType, typename RhsType, class T>
class TensorBinaryOp
: public TensorExpression<TensorBinaryOp<OP, LhsType, RhsType, T>, T> {
public:
explicit TensorBinaryOp(const OP op, const LhsType& lhs, const RhsType& rhs)
: op_(op), lhs_(lhs), rhs_(rhs) {}
const OP op_;
const LhsType lhs_;
const RhsType rhs_;
};
/**
* \brief Ternary Operator Expression
*/
template<typename ExprType1, typename ExprType2, typename ExprType3, class T>
class TensorTernaryOp
: public TensorExpression<
TensorTernaryOp<ExprType1, ExprType2, ExprType3, T>, T> {
public:
explicit TensorTernaryOp(
const ExprType1& expr1, const ExprType2& expr2, const ExprType3& expr3)
: expr1_(expr1), expr2_(expr2), expr3_(expr3) {}
const ExprType1 expr1_;
const ExprType2 expr2_;
const ExprType3 expr3_;
};
/**
* \brief Constant Expression
*/
template<class OP, typename ExprType, class T>
class TensorConstant
: public TensorExpression<TensorConstant<OP, ExprType, T>, T> {
public:
explicit TensorConstant(const OP op, const ExprType& expr)
: op_(op), expr_(expr) {}
const OP op_;
const ExprType expr_;
};
/**
* \brief operator+ overload
* \return a unary operator expression
*/
template<typename Derived, class T>
const TensorUnaryOp<hppl::unary::add_scale<T>, const Derived, T>
operator+(T p, const TensorExpression<Derived, T>& expr) {
return expr + p;
}
/**
* \brief operator* overload
* \return a unary operator expression
*/
template<typename Derived, class T>
const TensorUnaryOp<hppl::unary::mul_scale<T>, const Derived, T>
operator*(T p, const TensorExpression<Derived, T>& expr) {
return expr * p;
}
} // namespace paddle
#include "TensorApply.h"
#include "TensorEvaluate.h"
/**
* TrainingAlgorithmOp.cu
*
* Author: hedaoyuan (hedaoyuan@baidu.com)
* Created on: 2016-06-29
*
* Copyright (c) Baidu.com, Inc. All Rights Reserved
*
*/
#include "paddle/utils/Logging.h"
#include "BaseMatrix.h"
#include "TrainingAlgorithmOp.h"
namespace paddle {
void sparseMomentumApply(BaseMatrix& value,
BaseMatrix& grad,
BaseMatrix& momU,
BaseMatrix& momV,
real alpha,
real beta,
real gamma,
real tau,
real learningRate) {
/**
* \alpha_t = \alpha_{t-1} / k
* \beta_t = \beta_{t-1} / (1 + \lambda\gamma_t)
* u_t = u_{t-1} - \alpha_t \gamma_t g_t
* v_t = v_{t-1} + \tau_{t-1} \alpha_t \gamma_t g_t
* \tau_t = \tau_{t-1} + \beta_t / \alpha_t
*/
momU -= (alpha * gamma * learningRate) * grad;
momV += (tau * alpha * gamma * learningRate) * grad;
value = (tau / beta + (real)1 / alpha) * momU + ((real)1 / beta) * momV;
}
void adadeltaApply(BaseMatrix& value,
BaseMatrix& grad,
BaseMatrix& mom,
BaseMatrix& accum,
BaseMatrix& accum_update,
BaseMatrix& lr,
real rou,
real epsilon,
real learningRate,
real momentum,
real decayRate) {
// E(g_t^2) = \rou * E(g_{t-1}^2) + (1-\rou) * g^2
accum = rou * accum + ((real)1 - rou) * grad.square();
// learn_rate: sqrt(( E(dx_{t-1}^2) + epsilon ) / ( E(g_t^2) + epsilon ))
lr = ((accum_update + epsilon) / (accum + epsilon)).sqrt();
// E(dx_t^2) = \rou * E(dx_{t-1}^2) + (1-\rou) * (-g*learn_rate)^2
accum_update = rou * accum_update + ((real)1 - rou) * (grad * lr).square();
mom = mom * momentum - learningRate * lr * (grad + value * decayRate);
value += mom;
}
void adagradApply(BaseMatrix& value,
BaseMatrix& grad,
BaseMatrix& mom,
BaseMatrix& accum_buffer,
BaseMatrix& accum,
BaseMatrix& lr,
real epsilon,
real learningRate,
real momentum,
real decayRate) {
accum += grad.square();
lr = (accum_buffer + accum + epsilon).sqrt().reciprocal();
mom = mom * momentum - learningRate * lr * (grad + value * decayRate);
value += mom;
}
void rmspropApply(BaseMatrix& value,
BaseMatrix& grad,
BaseMatrix& mom,
BaseMatrix& g,
BaseMatrix& f,
BaseMatrix& lr,
real accumulatedRou,
real rou,
real epsilon,
real learningRate,
real momentum,
real decayRate,
bool firstTime) {
// E(g_t^2) = \rou * E(g_{t-1}^2) + (1-\rou) * g^2
// For the first time update, make the sum be the current square
// so that the initial estimation of E(g_t^2) will not be too small.
if (firstTime) {
g = accumulatedRou * g + grad.square();
} else {
g = accumulatedRou * g + ((real)1 - rou) * grad.square();
}
// E(f_t) = \rou * E(f_{t-1}) + (1-\rou) * g
f = accumulatedRou * f + ((real)1 - rou) * grad;
// learn_rate = 1/sqrt( ( E(g_t^2) - (E(f_t))^2 + epsilon )
// Basiclly if the sign of the gradient changes more often,
// the learning rate will be decreased.
lr = (g - f.square() + epsilon).sqrt().reciprocal();
mom = mom * momentum - learningRate * lr * (grad + value * decayRate);
value += mom;
}
void decayedAdagradApply(BaseMatrix& value,
BaseMatrix& grad,
BaseMatrix& mom,
BaseMatrix& accum,
BaseMatrix& lr,
real accumulatedRou,
real rou,
real epsilon,
real learningRate,
real momentum,
real decayRate,
bool firstTime) {
// E(g_t^2) = \rou * E(g_{t-1}^2) + (1-\rou) * g^2
// For the first time update, make the sum be the current square
// so that the initial estimation of E(g_t^2) will not be too small.
if (firstTime) {
accum = accumulatedRou * accum + grad.square();
} else {
accum = accumulatedRou * accum + ((real)1 - rou) * grad.square();
}
// learn_rate = 1/sqrt( ( E(g_t^2) + epsilon )
// Basiclly if the bigger the magnitude gradient is,
// the smaller the learning rate will be.
lr = (accum + epsilon).sqrt().reciprocal();
mom = mom * momentum - learningRate * lr * (grad + value * decayRate);
value += mom;
}
void adamApply(BaseMatrix& value,
BaseMatrix& grad,
BaseMatrix& mom, // firse moment
BaseMatrix& v, // second moment
real beta1,
real beta2,
real beta1_power,
real beta2_power,
real epsilon,
real learningRate) {
real alpha = learningRate *
std::sqrt((real)1 - beta2_power) / ((real)1 - beta1_power);
// m_t = \beta_1 * m_{t-1} + (1-\beta_1)* g_t;
mom = beta1 * mom + ((real)1 - beta1) * grad;
// v_t = \beta_2 * v_{t-1} + (1-\beta_2)* g_{t-1}^2
v = beta2 * v + ((real)1 - beta2) * grad.square();
value -= (mom * alpha) / (v.sqrt() + epsilon);
}
void adamaxApply(BaseMatrix& value,
BaseMatrix& grad,
BaseMatrix& mom, // firse moment
BaseMatrix& u, // weighted infinity norm
real beta1,
real beta2,
int64_t step,
real alpha) {
// m_t = \beta_1 * m_{t-1} + (1-\beta_1)* g_t;
mom = beta1 * mom + ((real)1 - beta1) * grad;
// u_t = max(\beta_2*u_{t-1}, abs(g_t))
u = (beta2 * u > grad.abs()).condition(beta2 * u, grad.abs());
// \theta_t = \theta_{t-1} - (\alpha/(1-\beta_1^t))*m_t/u_t
value -= (alpha / ((real)1 - (real)std::pow(beta1, step))) * (mom / u);
}
} // namespace paddle
/**
* TrainingAlgorithmOp.h
*
* Author: hedaoyuan (hedaoyuan@baidu.com)
* Created on: 2016-06-29
*
* Copyright (c) Baidu.com, Inc. All Rights Reserved
*
*/
#pragma once
#include "paddle/utils/Logging.h"
#include "BaseMatrix.h"
namespace paddle {
/**
* \brief Sparse Momentum optimizer.
*/
extern void sparseMomentumApply(BaseMatrix& value,
BaseMatrix& grad,
BaseMatrix& momU,
BaseMatrix& momV,
real alpha,
real beta,
real gamma,
real tau,
real learningRate);
/**
* \brief AdaDelta optimizer.
*/
extern void adadeltaApply(BaseMatrix& value,
BaseMatrix& grad,
BaseMatrix& sum,
BaseMatrix& sum1,
BaseMatrix& mom,
BaseMatrix& lr,
real rou,
real epsilon,
real learningRate,
real momentum,
real decayRate);
/**
* \brief AdaGrad optimizer.
*/
extern void adagradApply(BaseMatrix& value,
BaseMatrix& grad,
BaseMatrix& sum,
BaseMatrix& sum1,
BaseMatrix& mom,
BaseMatrix& lr,
real epsilon,
real learningRate,
real momentum,
real decayRate);
/**
* \brief RMSProp optimizer.
*/
extern void rmspropApply(BaseMatrix& value,
BaseMatrix& grad,
BaseMatrix& g,
BaseMatrix& f,
BaseMatrix& mom,
BaseMatrix& lr,
real accumulatedRou,
real rou,
real epsilon,
real learningRate,
real momentum,
real decayRate,
bool firstTime);
/**
* \brief Decayed AdaGrad optimizer.
*/
extern void decayedAdagradApply(BaseMatrix& value,
BaseMatrix& grad,
BaseMatrix& mom,
BaseMatrix& accum,
BaseMatrix& lr,
real accumulatedRou,
real rou,
real epsilon,
real learningRate,
real momentum,
real decayRate,
bool firstTime);
/**
* \brief Adam optimizer.
*/
extern void adamApply(BaseMatrix& value,
BaseMatrix& grad,
BaseMatrix& mom,
BaseMatrix& v,
real beta1,
real beta2,
real beta1_power,
real beta2_power,
real epsilon,
real learningRate);
/**
* \brief AdaMax optimizer.
*/
extern void adamaxApply(BaseMatrix& value,
BaseMatrix& grad,
BaseMatrix& mom, // firse moment
BaseMatrix& u, // weighted infinity norm
real beta1,
real beta2,
int64_t step,
real alpha);
} // namespace paddle
......@@ -258,6 +258,15 @@ public:
/// print the "idx" element of the Vector
virtual void printOneElement(std::ostream& os, size_t idx) const = 0;
template<typename ExpressionType>
void operator=(const ExpressionType& expr) {
if (BaseVector<T>::useGpu_) {
TensorGpuApply<T>(*this, expr);
} else {
TensorCpuApply<T>(*this, expr);
}
}
protected:
friend class GpuVectorT<T>;
friend class CpuVectorT<T>;
......@@ -315,6 +324,11 @@ public:
virtual void print(std::ostream& os, size_t num) const;
virtual void printOneElement(std::ostream& os, size_t idx) const;
template<typename ExpressionType>
void operator=(const ExpressionType& expr) {
TensorGpuApply<T>(*this, expr);
}
protected:
virtual void copyTo(CpuVectorT<T>* dest) const;
virtual void copyTo(GpuVectorT<T>* dest) const;
......@@ -378,6 +392,11 @@ public:
virtual T get(size_t pos);
virtual void print(std::ostream& os, size_t num) const;
virtual void printOneElement(std::ostream& os, size_t idx) const;
template<typename ExpressionType>
void operator=(const ExpressionType& expr) {
TensorCpuApply<T>(*this, expr);
}
};
template <class T>
......
......@@ -3,6 +3,7 @@
add_simple_unittest(test_ExecViaCpu)
add_simple_unittest(test_SIMDFunctions)
add_simple_unittest(test_matrix)
add_simple_unittest(test_TrainingAlgorithm)
# TODO(yuyang18): Refactor TestUtil.cpp. Remove this cross module reference.
add_unittest(test_matrixCompare
......@@ -13,3 +14,8 @@ add_simple_unittest(test_sparseMatrixCompare)
add_simple_unittest(test_perturbation)
add_simple_unittest(test_CpuGpuVector)
add_simple_unittest(test_Allocator)
if(COMPILER_SUPPORT_CXX11)
LIST(APPEND CUDA_NVCC_FLAGS -std=c++11 -Xcompiler -fPIC --use_fast_math)
CUDA_ADD_EXECUTABLE(test_Tensor test_Tensor.cu)
link_paddle_test(test_Tensor)
endif()
/**
* OriginalOptimizerApi.h
*
* Author: hedaoyuan (hedaoyuan@baidu.com)
* Created on: 2016-06-29
*
* Copyright (c) Baidu.com, Inc. All Rights Reserved
*/
#pragma once
#include "paddle/utils/GlobalConstants.h"
#include "paddle/math/Vector.h"
using namespace paddle; // NOLINT
void SparseMomentumParameterOptimizer(const VectorPtr vecs[],
real alpha,
real beta,
real gamma,
real tau,
real learningRate) {
vecs[PARAMETER_MOMENTUM_UT]->add(*vecs[PARAMETER_GRADIENT],
-alpha * gamma * learningRate);
vecs[PARAMETER_MOMENTUM_VT]->add(*vecs[PARAMETER_GRADIENT],
tau * alpha * gamma * learningRate);
vecs[PARAMETER_VALUE]->add(*vecs[PARAMETER_MOMENTUM_UT],
tau / beta + 1.0 / alpha,
*vecs[PARAMETER_MOMENTUM_VT], 1.0 / beta);
}
void AdagradParameterOptimizer(const VectorPtr vecs[],
real epsilon,
real learningRate,
real momentum,
real decayRate) {
vecs[PARAMETER_GRADIENT_SQURESUM1]->addSquare(*vecs[PARAMETER_GRADIENT],
1.0f);
vecs[PARAMETER_LEARNING_RATE]->add(*vecs[PARAMETER_GRADIENT_SQURESUM],
*vecs[PARAMETER_GRADIENT_SQURESUM1]);
vecs[PARAMETER_LEARNING_RATE]->add(epsilon);
vecs[PARAMETER_LEARNING_RATE]->invSqrt(*vecs[PARAMETER_LEARNING_RATE]);
vecs[PARAMETER_VALUE]->sgdUpdate(
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM],
*vecs[PARAMETER_LEARNING_RATE], learningRate,
momentum, decayRate);
}
void AdaDeltaParameterOptimizer(const VectorPtr vecs[],
real rou,
real epsilon,
real learningRate,
real momentum,
real decayRate) {
// E(g_t^2) = \rou * E(g_{t-1}^2) + (1-\rou) * g^2
vecs[PARAMETER_GRADIENT_SQURESUM]->decayAddSquare(*vecs[PARAMETER_GRADIENT],
rou, 1.0f - rou);
// learn_rate = sqrt( ( E(dx_{t-1}^2) + epsilon ) / ( E(g_t^2) + epsilon ) )
vecs[PARAMETER_LEARNING_RATE]->dotDiv(*vecs[PARAMETER_GRADIENT_SQURESUM1],
*vecs[PARAMETER_GRADIENT_SQURESUM],
epsilon, epsilon);
vecs[PARAMETER_LEARNING_RATE]->sqrt2();
// E(dx_t^2) = \rou * E(dx_{t-1}^2) + (1-\rou) * (-g*learn_rate)^2
vecs[PARAMETER_GRADIENT_SQURESUM1]->decayAddSquareMul(
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_LEARNING_RATE], rou,
1.0f - rou);
vecs[PARAMETER_VALUE]->sgdUpdate(
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM],
*vecs[PARAMETER_LEARNING_RATE], learningRate,
momentum, decayRate);
}
void RMSPropParameterOptimizer(const VectorPtr vecs[],
real accumulatedRou,
real rou,
real epsilon,
real learningRate,
real momentum,
real decayRate,
bool firstTime) {
// E(g_t^2) = \rou * E(g_{t-1}^2) + (1-\rou) * g^2
// For the first time update, make the sum be the current square
// so that the initial estimation of E(g_t^2) will not be too small.
vecs[PARAMETER_GRADIENT_SQURESUM]->decayAddSquare(
*vecs[PARAMETER_GRADIENT], accumulatedRou,
firstTime ? 1.0f : 1.0f - rou);
// E(g_t) = \rou * E(g_{t-1}) + (1-\rou) * g
vecs[PARAMETER_GRADIENT_SQURESUM1]->add(*vecs[PARAMETER_GRADIENT],
accumulatedRou, 1.0f - rou);
// learn_rate = 1/sqrt( ( E(g_t^2) - (E(g_t))^2 + epsilon )
// Basiclly if the sign of the gradient changes more often,
// the learning rate will be decreased.
vecs[PARAMETER_LEARNING_RATE]->assign(*vecs[PARAMETER_GRADIENT_SQURESUM]);
vecs[PARAMETER_LEARNING_RATE]->addSquare(*vecs[PARAMETER_GRADIENT_SQURESUM1],
-1.0f);
vecs[PARAMETER_LEARNING_RATE]->add(epsilon);
vecs[PARAMETER_LEARNING_RATE]->invSqrt(*vecs[PARAMETER_LEARNING_RATE]);
vecs[PARAMETER_VALUE]->sgdUpdate(
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM],
*vecs[PARAMETER_LEARNING_RATE], learningRate,
momentum, decayRate);
}
void DecayedAdagradParameterOptimizer(const VectorPtr vecs[],
real accumulatedRou,
real rou,
real epsilon,
real learningRate,
real momentum,
real decayRate,
bool firstTime) {
// E(g_t^2) = \rou * E(g_{t-1}^2) + (1-\rou) * g^2
// For the first time update, make the sum be the current square
// so that the initial estimation of E(g_t^2) will not be too small.
vecs[PARAMETER_GRADIENT_SQURESUM]->decayAddSquare(
*vecs[PARAMETER_GRADIENT], accumulatedRou,
firstTime ? 1.0f : 1.0f - rou);
// learn_rate = 1/sqrt( ( E(g_t^2) + epsilon )
// Basiclly if the bigger the magnitude gradient is,
// the smaller the learning rate will be.
vecs[PARAMETER_LEARNING_RATE]->assign(epsilon);
vecs[PARAMETER_LEARNING_RATE]->add(*vecs[PARAMETER_GRADIENT_SQURESUM]);
vecs[PARAMETER_LEARNING_RATE]->invSqrt(*vecs[PARAMETER_LEARNING_RATE]);
vecs[PARAMETER_VALUE]->sgdUpdate(
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM],
*vecs[PARAMETER_LEARNING_RATE], learningRate,
momentum, decayRate);
}
void AdamParameterOptimizer(const VectorPtr vecs[],
real beta1,
real beta2,
real beta1_power,
real beta2_power,
real epsilon,
real learningRate) {
Vector* m = vecs[PARAMETER_MOMENTUM].get();
Vector* g = vecs[PARAMETER_GRADIENT].get();
Vector* v = vecs[PARAMETER_SECOND_MOMENTUM].get();
Vector* theta = vecs[PARAMETER_VALUE].get();
// m_t = \beta_1 * m_{t-1} + (1-\beta_1)* g_t;
m->add(*g, beta1, 1 - beta1);
// v_t = \beta_2 * v_{t-1} + (1-\beta_2)* g_{t-1}^2
g->square2();
v->add(*g, beta2, 1 - beta2);
// tmp = m_t / ( \sqrt{v_t} + \epsilon )
// \theta_t = \theta_{t-1} - \alpha * \sqrt(1-\beta_2^t) / (1-\beta_1^t) * tmp
g->sqrt2(*v);
g->dotDiv(*m, *g, 0., epsilon);
real alpha = learningRate *
std::sqrt((real)1 - beta2_power) / ((real)1 - beta1_power);
theta->add(*theta, 1.0, *g, -alpha);
}
void AdamaxParameterOptimizer(const VectorPtr vecs[],
real beta1,
real beta2,
int64_t step,
real alpha) {
Vector* m = vecs[PARAMETER_MOMENTUM].get();
Vector* g = vecs[PARAMETER_GRADIENT].get();
Vector* u = vecs[PARAMETER_WEIGHTED_INFINITY_NORM].get();
Vector* theta = vecs[PARAMETER_VALUE].get();
// m_t = \beta_1 * m_{t-1} + (1-\beta_1)* g_t;
m->add(*g, beta1, 1 - beta1);
// u_t = max(\beta_2*u_{t-1}, abs(g_t))
u->mulScalar(beta2);
g->abs2();
u->max2(*u, *g);
// \theta_t = \theta_{t-1} - (\alpha/(1-\beta_1^t))*m_t/u_t
g->dotDiv(*m, *u);
real learningRate = alpha / (1 - std::pow(beta1, step));
theta->add(*theta, 1.0, *g, -learningRate);
}
/**
* test_Tensor.cpp
*
* Author: hedaoyuan (hedaoyuan@baidu.com)
* Created on: 2016-06-06
*
* Copyright (c) Baidu.com, Inc. All Rights Reserved
*/
#include <gtest/gtest.h>
#include "paddle/math/Matrix.h"
using namespace paddle; // NOLINT
using namespace std; // NOLINT
template<typename Tensor>
extern void TensorCheckEqual(const Tensor& tensor1, const Tensor& tensor2);
void TensorCheckEqual(const CpuMatrix& matrix1, const CpuMatrix& matrix2) {
CHECK(matrix1.getHeight() == matrix2.getHeight());
CHECK(matrix1.getWidth() == matrix2.getWidth());
int height = matrix1.getHeight();
int width = matrix1.getWidth();
const real* data1 = matrix1.getData();
const real* data2 = matrix2.getData();
int count = 0;
for (int i = 0; i < height; i++) {
for (int j = 0; j < width; j++) {
if (data1[i * width + j] != data2[i * width + j]) {
count++;
}
}
}
EXPECT_EQ(count, 0) << "There are " << count << " different element.";
}
void TensorCheckEqual(const GpuMatrix& matrix1, const GpuMatrix& matrix2) {
CpuMatrix cpu1(matrix1.getHeight(), matrix1.getWidth());
CpuMatrix cpu2(matrix2.getHeight(), matrix2.getWidth());
cpu1.copyFrom(matrix1);
cpu2.copyFrom(matrix2);
TensorCheckEqual(cpu1, cpu2);
}
void TensorCheckErr(const CpuMatrix& matrix1, const CpuMatrix& matrix2) {
CHECK(matrix1.getHeight() == matrix2.getHeight());
CHECK(matrix1.getWidth() == matrix2.getWidth());
#ifndef PADDLE_TYPE_DOUBLE
real err = 1e-5;
#else
real err = 1e-10;
#endif
int height = matrix1.getHeight();
int width = matrix1.getWidth();
const real* data1 = matrix1.getData();
const real* data2 = matrix2.getData();
int count = 0;
for (int i = 0; i < height; i++) {
for (int j = 0; j < width; j++) {
real a = data1[i * width + j];
real b = data2[i * width + j];
if (fabs(a - b) > err) {
if ((fabsf(a - b) / fabsf(a)) > (err / 10.0f)) {
count++;
}
}
}
}
EXPECT_EQ(count, 0) << "There are " << count << " different element.";
}
void TensorCheckErr(const GpuMatrix& matrix1, const GpuMatrix& matrix2) {
CpuMatrix cpu1(matrix1.getHeight(), matrix1.getWidth());
CpuMatrix cpu2(matrix2.getHeight(), matrix2.getWidth());
cpu1.copyFrom(matrix1);
cpu2.copyFrom(matrix2);
TensorCheckErr(cpu1, cpu2);
}
template<class T>
void TensorCheckEqual(const CpuVectorT<T>& vector1,
const CpuVectorT<T>& vector2) {
CHECK(vector1.getSize() == vector2.getSize());
const T* data1 = vector1.getData();
const T* data2 = vector2.getData();
size_t size = vector1.getSize();
int count = 0;
for (size_t i = 0; i < size; i++) {
if (data1[i] != data2[i]) {
count++;
}
}
EXPECT_EQ(count, 0) << "There are " << count << " different element.";
}
template<class T>
void TensorCheckEqual(const GpuVectorT<T>& vector1,
const GpuVectorT<T>& vector2) {
CpuVectorT<T> cpu1(vector1.getSize());
CpuVectorT<T> cpu2(vector2.getSize());
cpu1.copyFrom(vector1);
cpu2.copyFrom(vector2);
TensorCheckEqual(cpu1, cpu2);
}
#define INIT_UNARY(A1, A2) \
Tensor A1(height, width); \
Tensor A2(height, width); \
A1.randomizeUniform(); \
A2.copyFrom(A1)
#define INIT_BINARY(A1, A2, B) \
INIT_UNARY(A1, A2); \
Tensor B(height, width); \
B.randomizeUniform()
#define INIT_TERNARY(A1, A2, B, C) \
INIT_BINARY(A1, A2, B); \
Tensor C(height, width); \
C.randomizeUniform()
#define INIT_QUATERNARY(A1, A2, B, C, D) \
INIT_TERNARY(A1, A2, B, C); \
Tensor D(height, width); \
D.randomizeUniform()
template<typename Tensor>
struct TestUnaryMatrix {
typedef std::function<void(Tensor& A1, Tensor& A2)> UnaryFunc;
explicit TestUnaryMatrix(UnaryFunc testUnaryFunc) {
for (auto height : {1, 11, 73, 128, 200, 330}) {
for (auto width : {1, 32, 100, 512, 1000, 3210}) {
LOG(INFO) << " height=" << height << " width=" << width;
INIT_UNARY(A1, A2);
testUnaryFunc(A1, A2);
}
}
}
};
template<typename Tensor>
struct TestBinaryMatrix {
typedef std::function<void(Tensor& A1, Tensor& A2, Tensor& B)> BinaryFunc;
explicit TestBinaryMatrix(BinaryFunc testBinaryFunc) {
for (auto height : {1, 11, 73, 128, 200, 330}) {
for (auto width : {1, 32, 100, 512, 1000, 3210}) {
LOG(INFO) << " height=" << height << " width=" << width;
INIT_BINARY(A1, A2, B);
testBinaryFunc(A1, A2, B);
}
}
}
};
template<typename Tensor>
struct TestTernaryMatrix {
typedef std::function<void(
Tensor& A1, Tensor& A2, Tensor& B, Tensor& C)> TernaryFunc;
explicit TestTernaryMatrix(TernaryFunc testTernaryFunc) {
for (auto height : {1, 11, 73, 128, 200, 330}) {
for (auto width : {1, 32, 100, 512, 1000, 3210}) {
LOG(INFO) << " height=" << height << " width=" << width;
INIT_TERNARY(A1, A2, B, C);
testTernaryFunc(A1, A2, B, C);
}
}
}
};
template<typename Tensor>
struct TestQuaternaryMatrix {
typedef std::function<void(
Tensor& A1, Tensor& A2, Tensor& B, Tensor& C, Tensor& D)> QuaternaryFunc;
explicit TestQuaternaryMatrix(QuaternaryFunc testQuaternaryFunc) {
for (auto height : {1, 11, 73, 128, 200, 330}) {
for (auto width : {1, 32, 100, 512, 1000, 3210}) {
LOG(INFO) << " height=" << height << " width=" << width;
INIT_QUATERNARY(A1, A2, B, C, D);
testQuaternaryFunc(A1, A2, B, C, D);
}
}
}
};
template<typename Tensor, class T>
struct TestUnaryVectorT {
typedef std::function<void(Tensor& A1, Tensor& A2)> UnaryFunc;
explicit TestUnaryVectorT(UnaryFunc testUnaryFunc) {
for (auto size : {1, 11, 73, 128, 200, 330, 512, 1000, 4210}) {
LOG(INFO) << " size=" << size;
Tensor A1(size);
Tensor A2(size);
if (typeid(T) == typeid(real)) {
A1.rand();
} else {
A1.rand(1000);
}
A2.copyFrom(A1);
testUnaryFunc(A1, A2);
}
}
};
void SetTensorValue(Matrix& matrix, real value) {
int height = matrix.getHeight();
int width = matrix.getWidth();
int stride = matrix.getStride();
real* data = matrix.getData();
for (int i = 0; i < height; i++) {
int j = rand() % width; // NOLINT
if (typeid(matrix) == typeid(CpuMatrix)) {
data[i * stride + j] = value;
} else if (typeid(matrix) == typeid(GpuMatrix)) {
hl_memcpy(&data[i * stride + j], &value, sizeof(real));
} else {
}
}
}
template<typename Tensor>
void testTensorAddScalar(Tensor& A1, Tensor& A2) {
real p1 = 2.5;
real p2 = 3.0;
A1.add(p1); // a += p
A2 += p1;
TensorCheckEqual(A1, A2);
A1.add(p1, p2); // a = a * p1 + p2
A2 = A2 * p1 + p2;
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorSubScalar(Tensor& A1, Tensor& A2) {
real p = 2.5;
A1.subScalar(p); // a -= p
A2 -= p;
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorMulScalar(Tensor& A1, Tensor& A2) {
real p = 2.5;
A1.mulScalar(p); // a *= p
A2 *= p;
TensorCheckEqual(A1, A2);
real learningRate = 0.7f;
real decayRate = 1.2f;
A1.applyL2(learningRate, decayRate);
A2 = A2 * (1.0f / (1.0f + learningRate * decayRate));
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorDivScalar(Tensor& A1, Tensor& A2) {
real p = 2.5;
A1.divScalar(p); // a /= p
A2 /= p;
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorNeg(Tensor& A1, Tensor& A2) {
A1.neg(); // a = -a
A2 = -A2;
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorAbs(Tensor& A1, Tensor& A2) {
A1.abs2(); // a = a > 0 ? a : -a
A2 = A2.abs();
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorSquare(Tensor& A1, Tensor& A2) {
A1.square2(); // a = a * a
A2 = A2.square();
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorReciprocal(Tensor& A1, Tensor& A2) {
A1.reciprocal2(); // a = 1.0f / a
A2 = A2.reciprocal();
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorSign(Tensor& A1, Tensor& A2) {
A1.sign2(); // a = (a > 0) - (a < 0)
A2 = A2.sign();
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorAssign(Tensor& A1, Tensor& A2) {
A1.assign(1.5); // a = p
A2 = A2.constant(1.5);
TensorCheckEqual(A1, A2);
A1.one(); // a = 1
A2 = A2.constant(1.0);
TensorCheckEqual(A1, A2);
A1.zero(); // a = 0
A2 = A2.constant(0.0);
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testUnaryBaseOp(Tensor& A1, Tensor& A2) {
testTensorAddScalar(A1, A2);
testTensorSubScalar(A1, A2);
testTensorMulScalar(A1, A2);
testTensorDivScalar(A1, A2);
testTensorNeg(A1, A2);
testTensorAbs(A1, A2);
testTensorSquare(A1, A2);
testTensorReciprocal(A1, A2);
testTensorSign(A1, A2);
testTensorAssign(A1, A2);
}
template<typename Tensor>
void testUnaryBaseOpInt(Tensor& A1, Tensor& A2) {
A1.add(2); // a += p
A2 += 2;
TensorCheckEqual(A1, A2);
A1.add(3, 2); // a = a * p1 + p2
A2 = A2 * 3 + 2;
TensorCheckEqual(A1, A2);
testTensorNeg(A1, A2);
testTensorAbs(A1, A2);
}
TEST(Unary, BaseOp) {
TestUnaryMatrix<CpuMatrix> testCpuMatrix(testUnaryBaseOp<CpuMatrix>);
TestUnaryVectorT<CpuVector, real> testCpuVector(testUnaryBaseOp<CpuVector>);
TestUnaryVectorT<CpuIVector, int>
testCpuIVector(testUnaryBaseOpInt<CpuIVector>);
#ifndef PADDLE_ONLY_CPU
TestUnaryMatrix<GpuMatrix> testGpuMatrix(testUnaryBaseOp<GpuMatrix>);
TestUnaryVectorT<GpuVector, real> testGpuVector(testUnaryBaseOp<GpuVector>);
TestUnaryVectorT<GpuIVector, int>
testGpuIVector(testUnaryBaseOpInt<GpuIVector>);
#endif
}
template<typename Tensor>
void testTensorExp(Tensor& A1, Tensor& A2) {
A1.exp2(); // a = exp(a)
A2 = A2.exp();
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorLog(Tensor& A1, Tensor& A2) {
A1.log2(); // a = log(a)
A2 = A2.log();
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorSqrt(Tensor& A1, Tensor& A2) {
A1.sqrt2(); // a = sqrt(a)
A2 = A2.sqrt();
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorPow(Tensor& A1, Tensor& A2) {
A1.pow2(3.2); // a = pow(a, p)
A2 = A2.pow(3.2);
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testUnayrMathOp(Tensor& A1, Tensor& A2) {
testTensorExp(A1, A2);
testTensorLog(A1, A2);
testTensorSqrt(A1, A2);
testTensorPow(A1, A2);
}
TEST(Unary, MathOp) {
TestUnaryMatrix<CpuMatrix> testCpu(testUnayrMathOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU
TestUnaryMatrix<GpuMatrix> testGpu(testUnayrMathOp<GpuMatrix>);
#endif
}
template<typename Tensor>
void testTensorClip(Tensor& A1, Tensor& A2) {
real p1 = 0.003f;
real p2 = 0.877f;
A1.clip(p1, p2); // a = a < p1 ? p1 : (a > p2 ? p2 : a)
// A2 = A2.min(0.877f).max(0.003f);
A2 = (A2 < p1).condition(p1, (A2 > p2).condition(p2, A2));
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorBiggerThanScalar(Tensor& A1, Tensor& A2) {
real p = 0.5f;
A1.biggerThanScalar(p); // a = a > p ? 1.0f : 0.0f
A2 = (A2 > p).condition((real)1.0, (real)0.0);
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorapplyL1(Tensor& A1, Tensor& A2) {
/**
* T lambda = p;
* a = (a > lambda) ? (a - lambda)
* : (a < -lambda) ? (a + lambda) : 0
*
* p = learningRate * decayRate;
*/
real learningRate = 0.7f;
real decayRate = 0.6f;
A1.applyL1(learningRate, decayRate);
A2 = (A2 > (learningRate * decayRate)).condition(
(A2 - (learningRate * decayRate)),
(A2 < -(learningRate * decayRate)).condition(
(A2 + (learningRate * decayRate)), (real)0.0));
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testUnayrCompareOp(Tensor& A1, Tensor& A2) {
testTensorClip(A1, A2);
testTensorBiggerThanScalar(A1, A2);
A1.randomizeUniform();
A1.subScalar(0.5f);
A2.copyFrom(A1);
testTensorapplyL1(A1, A2);
}
TEST(Unary, CompareOp) {
TestUnaryMatrix<CpuMatrix> testCpu(testUnayrCompareOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU
TestUnaryMatrix<GpuMatrix> testGpu(testUnayrCompareOp<GpuMatrix>);
#endif
}
template<typename Tensor>
void testTensorAdd(Tensor& A1, Tensor& A2, Tensor& B) {
real p1 = 2.5;
real p2 = 3.2;
A1.add(B); // a += b
A2 += B;
TensorCheckEqual(A1, A2);
A1.add(B, p1); // a += b * p
A2 += B * p1;
TensorCheckEqual(A1, A2);
A1.add(B, p1, p2); // a = p1 * a + p2 * b
A2 = A2 * p1 + B * p2;
TensorCheckEqual(A1, A2);
A1.addScalar(B, p1); // a = b + p
A2 = B + p1;
TensorCheckEqual(A1, A2);
A1.addSquare(B, p1); // a += p * b * b
A2 += B.constant(p1) * B * B;
TensorCheckEqual(A1, A2);
A1.decayAddSquare(B, p1, p2); // a = p1 * a + p2 * b * b
A2 = A2 * p1 + B.constant(p2) * B * B;
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorSub(Tensor& A1, Tensor& A2, Tensor& B) {
real p = 2.5;
A1.sub(B); // a -= b
A2 -= B;
TensorCheckEqual(A1, A2);
A1.sub(B, p); // a -= b * p
A2 -= B * p;
TensorCheckEqual(A1, A2);
A1.subScalar(B, p); // a = b - p
A2 = B - p;
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorMul(Tensor& A1, Tensor& A2, Tensor& B) {
real p = 2.5;
A1.mulScalar(B, p); // a = b * p
A2 = B * p;
TensorCheckEqual(A1, A2);
A1.dotMulSquare(B); // a *= b * b
A2 *= B * B;
TensorCheckEqual(A1, A2);
A1.dotSquareMul(B); // a = a * a * b
A2 = A2 * A2 * B;
TensorCheckEqual(A1, A2);
A1.dotMul(B); // a *= b
A2 *= B;
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorDiv(Tensor& A1, Tensor& A2, Tensor& B) {
real p = 2.5;
A1.divScalar(B, p); // a = b / p
A2 = B / p;
TensorCheckEqual(A1, A2);
A1.scalarDiv(B, p); // a = p / b
A2 = B.constant(p) / B;
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorAssign(Tensor& A1, Tensor& A2, Tensor& B) {
A1.assign(B); // a = b
A2 = B;
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorSquare(Tensor& A1, Tensor& A2, Tensor& B) {
B.square2(A1); // b = a * a
A2 = B.square();
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorSquareDerivative(Tensor& A1, Tensor& A2, Tensor& B) {
A1.squareDerivative(B); // a *= 2.0 * b
A2 = A2 * (real)2.0 * B;
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorReciprocal(Tensor& A1, Tensor& A2, Tensor& B) {
B.reciprocal2(A1); // b = 1.0f / a
A2 = B.reciprocal();
TensorCheckEqual(A1, A2);
real p1 = 0.58;
real p2 = 0.32;
A1.reciprocal2(B, p1, p2); // a = 1 / (p1 * b + p2)
A2 = (B * p1 + p2).reciprocal();
TensorCheckEqual(A1, A2);
real learningRate = 0.7f;
real decayRate = 1.2f;
A1.applyL2(B, learningRate, decayRate); // a *= (1.0f / (1.0f + p * b))
A2 *= (B.constant(1.0f) +
B.constant(learningRate * decayRate) * B).reciprocal();
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorReciprocalDerivative(Tensor& A1, Tensor& A2, Tensor& B) {
A1.reciprocalDerivative(B); // a *= -b * b
A2 *= (-B) * B;
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorSign(Tensor& A1, Tensor& A2, Tensor& B) {
B.sign2(A1); // b = a > 0.0f ? 1.0f : -1.0f
A2 = B.sign();
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorAbs(Tensor& A1, Tensor& A2, Tensor& B) {
B.abs2(A1); // b = a > 0.0f ? a : -a
A2 = B.abs();
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testBinaryBaseOp(Tensor& A1, Tensor& A2, Tensor& B) {
testTensorAdd(A1, A2, B);
testTensorSub(A1, A2, B);
testTensorMul(A1, A2, B);
testTensorDiv(A1, A2, B);
testTensorSquare(A1, A2, B);
testTensorSquareDerivative(A1, A2, B);
testTensorReciprocal(A1, A2, B);
testTensorReciprocalDerivative(A1, A2, B);
testTensorAbs(A1, A2, B);
testTensorSign(A1, A2, B);
testTensorAssign(A1, A2, B);
}
TEST(Binary, BaseOp) {
TestBinaryMatrix<CpuMatrix> testCpu(testBinaryBaseOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU
TestBinaryMatrix<GpuMatrix> testGpu(testBinaryBaseOp<GpuMatrix>);
#endif
}
template<typename Tensor>
void testTensorExp(Tensor& A1, Tensor& A2, Tensor& B) {
// a = exp(b)
A1.exp2(B);
A2 = B.exp();
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorExpDerivative(Tensor& A1, Tensor& A2, Tensor& B) {
A1.expDerivative(B); // a *= b
A2 *= B;
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorLog(Tensor& A1, Tensor& A2, Tensor& B) {
// a = log(b)
A1.log2(B);
A2 = B.log();
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorSqrt(Tensor& A1, Tensor& A2, Tensor& B) {
// a = sqrt(b)
A1.sqrt2(B);
A2 = B.sqrt();
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorInvSqrt(Tensor& A1, Tensor& A2, Tensor& B) {
// a = 1.0f / sqrt(b)
A1.invSqrt(B);
A2 = B.sqrt().reciprocal();
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorPow(Tensor& A1, Tensor& A2, Tensor& B) {
A1.pow2(B, 2.5f); // a = pow(b, p)
A2 = B.pow(2.5f);
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorSoftrelu(Tensor& A1, Tensor& A2, Tensor& B) {
/*
* const T THRESHOLD = 40.0;
* b = log(1.0 +
* exp((a > THRESHOLD) ? THRESHOLD
* : ((a < -THRESHOLD) ? (-THRESHOLD) : a)))
*/
B.softrelu(A1);
real THRESHOLD = 40.0;
A2 = (B.constant(1.0f) +
(B > THRESHOLD).condition(
THRESHOLD, (B < -THRESHOLD).condition(-THRESHOLD, B)).exp()).log();
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorSoftreluDerivative(Tensor& A1, Tensor& A2, Tensor& B) {
/*
* const T THRESHOLD = 40.0;
* a *= (1.0 - exp(-1.0 * ((b > THRESHOLD)
* ? THRESHOLD
* : ((b < -THRESHOLD) ? (-THRESHOLD) : b)))));
*/
A1.softreluDerivative(B);
real THRESHOLD = 40.0;
A2 = A2 * (B.constant(1.0f) -
(B.constant(-1.0f) *
(B > THRESHOLD).condition(
THRESHOLD, (B < -THRESHOLD).condition(-THRESHOLD, B))).exp());
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorSigmoid(Tensor& A1, Tensor& A2, Tensor& B) {
/*
const T THRESHOLD_MIN = -40.0;
const T THRESHOLD_MAX = 13.0;
T tmp = (a < THRESHOLD_MIN) ? THRESHOLD_MIN
: ((a > THRESHOLD_MAX) ? THRESHOLD_MAX : a);
b = 1.0f / (1.0f + exp(-tmp)))
*/
B.sigmoid(A1);
const real THRESHOLD_MIN = -40.0;
const real THRESHOLD_MAX = 13.0;
auto tmp = (B < THRESHOLD_MIN).condition(
THRESHOLD_MIN, (B > THRESHOLD_MAX).condition(THRESHOLD_MAX, B));
A2 = (B.constant(1.0f) + (-tmp).exp()).reciprocal();
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorSigmoidDerivative(Tensor& A1, Tensor& A2, Tensor& B) {
A1.sigmoidDerivative(B); // a *= b * (1 - b)
A2 *= B * (B.constant(1.0f) - B);
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorTanh(Tensor& A1, Tensor& A2, Tensor& B) {
B.tanh(A1); // b = 2.0 / (1.0 + exp(-2 * a)) - 1.0
A2 = B.constant(2.0f) / ((B * ((real)-2.0f)).exp() + (real)1.0f) - (real)1.0f;
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorTanhDerivative(Tensor& A1, Tensor& A2, Tensor& B) {
A1.tanhDerivative(B); // a *= 1 - b * b
A2 *= B.constant(1.0f) - B * B;
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorScaledTanh(Tensor& A1, Tensor& A2, Tensor& B) {
real p1 = 2.5;
real p2 = 3.1;
// b = p1 * (2.0 / (1.0 + exp(-2 * p2 * a)) - 1.0)
B.scaledTanh(A1, p1, p2);
A2 = B.constant(p1) *
(B.constant(2.0f) / ((B.constant(-2.0f) * p2 * B).exp() + (real)1.0)
- (real)1.0);
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorScaledTanhDerivative(Tensor& A1, Tensor& A2, Tensor& B) {
real p1 = 2.5;
real p2 = 3.1;
// a *= (p2 / p1) * (p1 * p1 - b * b));
A1.scaledTanhDerivative(B, p1, p2);
A2 = A2 * (B.constant(p2 / p1) * (B.constant(p1 * p1) - B * B));
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testBinaryMathOp(Tensor& A1, Tensor& A2, Tensor& B) {
testTensorTanhDerivative(A1, A2, B);
testTensorScaledTanhDerivative(A1, A2, B);
testTensorSigmoidDerivative(A1, A2, B);
testTensorExpDerivative(A1, A2, B);
testTensorScaledTanh(A1, A2, B);
testTensorTanh(A1, A2, B);
testTensorExp(A1, A2, B);
testTensorLog(A1, A2, B);
testTensorSqrt(A1, A2, B);
testTensorInvSqrt(A1, A2, B);
testTensorPow(A1, A2, B);
testTensorSoftrelu(A1, A2, B);
testTensorSoftreluDerivative(A1, A2, B);
testTensorSigmoid(A1, A2, B);
}
TEST(Binary, MathOp) {
TestBinaryMatrix<CpuMatrix> testCpu(testBinaryMathOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU
TestBinaryMatrix<GpuMatrix> testGpu(testBinaryMathOp<GpuMatrix>);
#endif
}
template<typename Tensor>
void testTensorRelu(Tensor& A1, Tensor& A2, Tensor& B) {
B.relu(A1); // b = a > 0.0f ? a : 0.0f
A2 = (B > (real)0.0f).condition(B, (real)0.0f);
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorReluDerivative(Tensor& A1, Tensor& A2, Tensor& B) {
A1.reluDerivative(B); // a *= (b > 0.0f ? 1.0f : 0.0f)
A2 *= (B > (real)0.0).condition((real)1.0, (real)0.0);
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorBrelu(Tensor& A1, Tensor& A2, Tensor& B) {
/*
* b = a > p1 ? a : p1
* b = b < p2 ? b : p2
* int p1 = 0, p2 = 24;
*/
SetTensorValue(B, 32.0f);
B.brelu(A1);
auto tmp = (B > (real)0.0f).condition(B, (real)0.0f);
A2 = (tmp < (real)24.0f).condition(tmp, (real)24.0f);
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorBreluDerivative(Tensor& A1, Tensor& A2, Tensor& B) {
SetTensorValue(B, 32.0f);
/*
* a *= (b > p1 && b < p2) ? 1.0 : 0.0
* int p1 = 0, p2 = 24;
*/
A1.breluDerivative(B);
A2 *= (B > (real)0.0f && B < (real)24.0f).condition((real)1.0f, (real)0.0f);
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorAbsDerivative(Tensor& A1, Tensor& A2, Tensor& B) {
A1.absDerivative(B); // a = (b > 0) ? a : (b < 0) ? -a : 0
A2 = (B > (real)0.0f).condition(A2,
(B < (real)0.0f).condition(-A2, (real)0.0f));
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorIsEqualTo(Tensor& A1, Tensor& A2, Tensor& B) {
real p = 0.613;
SetTensorValue(B, p);
A1.isEqualTo(B, p); // a = (b == p)
A2 = (B == p);
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorapplyL1(Tensor& A1, Tensor& A2, Tensor& B) {
/**
* T lambda = p * b;
* a = (a > lambda) ? (a - lambda)
* : (a < -lambda) ? (a + lambda) : 0
*
* p = learningRate * decayRate;
*/
real learningRate = 0.7f;
real decayRate = 0.6f;
A1.applyL1(B, learningRate, decayRate);
auto lambda = B.constant(learningRate * decayRate) * B;
A2 = (A2 > lambda).condition(
(A2 - lambda), (A2 < -lambda).condition((A2 + lambda), (real)0.0f));
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testBinaryCompareOp(Tensor& A1, Tensor& A2, Tensor& B) {
B.subScalar(0.5f);
SetTensorValue(B, 0.0f);
testTensorReluDerivative(A1, A2, B);
A1.randomizeUniform();
A2.copyFrom(A1);
testTensorBreluDerivative(A1, A2, B);
testTensorAbsDerivative(A1, A2, B);
testTensorRelu(A1, A2, B);
testTensorBrelu(A1, A2, B);
testTensorIsEqualTo(A1, A2, B);
}
TEST(Binary, CompareOp) {
TestBinaryMatrix<CpuMatrix> testCpu(testBinaryCompareOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU
TestBinaryMatrix<GpuMatrix> testGpu(testBinaryCompareOp<GpuMatrix>);
#endif
}
template<typename Tensor>
void testTensorAdd(Tensor& A1, Tensor& A2, Tensor& B, Tensor& C) {
A1.add(B, C); // a = b + c
A2 = B + C;
TensorCheckEqual(A1, A2);
real p1 = 1.5;
real p2 = 2.5;
real p3 = 3.8;
A1.add(B, p1, C, p2); // a = p1 * b + p2 * c
A2 = B * p1 + C * p2;
TensorCheckEqual(A1, A2);
A1.add2(B, C); // a = a + b + c
A2 = A2 + B + C;
TensorCheckEqual(A1, A2);
A1.add2(B, C, p1, p2, p3); // a = p1 * a + p2 * b + p3 * c
A2 = A2 * p1 + B * p2 + C * p3;
TensorCheckEqual(A1, A2);
A1.decayAddSquareMul(B, C, p1, p2); // a = p1 * a + p2 * b * b * c * c
A2 = A2 * p1 + B.constant(p2) * B * B * C * C;
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorSub(Tensor& A1, Tensor& A2, Tensor& B, Tensor& C) {
A1.sub(B, C); // a = b - c
A2 = B - C;
TensorCheckEqual(A1, A2);
real p1 = 1.5;
real p2 = 2.5;
A1.sub(B, p1, C, p2); // a = p1 * b - p2 * c
A2 = B * p1 - C * p2;
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorMul(Tensor& A1, Tensor& A2, Tensor& B, Tensor& C) {
A1.dotMul(B, C); // a = b * c
A2 = B * C;
TensorCheckEqual(A1, A2);
A1.dotMulSquare(B, C); // a = b * c * c
A2 = B * C * C;
TensorCheckEqual(A1, A2);
A1.dotSquareSquare(B, C); // a = b * b * c * c
A2 = B * B * C * C;
TensorCheckEqual(A1, A2);
real p1 = 1.5;
real p2 = 2.5;
/*
* T tmp = p1 * b + p2 * c;
* a *= tmp * tmp
*/
A1.dotMulSquareSum(B, C, p1, p2);
auto tmp = B * p1 + C * p2;
A2 *= tmp * tmp;
TensorCheckEqual(A1, A2);
/*
* T tmp = p1 * b + p2 * c;
* a = tmp * tmp
*/
A1.dotSquareSum(B, C, p1, p2);
auto tmp2 = B * p1 + C * p2;
A2 = tmp2 * tmp2;
TensorCheckEqual(A1, A2);
// a *= p1 * b + p2 * c
A1.dotMulSum(B, C, p1, p2);
A2 *= B * p1 + C * p2;
TensorCheckEqual(A1, A2);
// a = p1 * a + p2 * b * c
A1.addDotMul(B, C, p1, p2);
A2 = A2 * p1 + B.constant(p2) * B * C;
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorDiv(Tensor& A1, Tensor& A2, Tensor& B, Tensor& C) {
A1.dotDiv(B, C); // a = (b == 0.0) ? 0.0 : b / c
A2 = (B == (real)0.0).condition((real)0.0, B / C);
TensorCheckEqual(A1, A2);
real p1 = 1.5;
real p2 = 2.5;
A1.dotDiv(B, C, p1, p2); // a = (b + p1) / (c + p2)
A2 = (B + p1) / (C + p2);
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorReciprocal(Tensor& A1, Tensor& A2, Tensor& B, Tensor& C) {
real p1 = 1.5;
real p2 = 2.5;
real p3 = 3.5;
A1.reciprocalSum(B, C, p1, p2, p3); // a = 1 / (p1 * b + p2 * c + p3)
A2 = (B * p1 + C * p2 + p3).reciprocal();
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorSoftCrossEntropy(Tensor& A1, Tensor& A2, Tensor& B, Tensor& C) {
A1.softCrossEntropy(B, C); // a = -c * log(b) - (1 - c) * log(1 - b)
A2 = -C * B.log() - (C.constant(1.0f) - C) * (B.constant(1.0f) - B).log();
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorSoftCrossEntropyBp(Tensor& A1,
Tensor& A2,
Tensor& B,
Tensor& C) {
A1.softCrossEntropyBp(B, C); // a += (b - c) / (b * (1 - b))
A2 += (B - C) / (B * (B.constant(1.0f) - B));
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTernaryBaseOp(Tensor& A1, Tensor& A2, Tensor& B, Tensor& C) {
testTensorAdd(A1, A2, B, C);
testTensorSub(A1, A2, B, C);
testTensorMul(A1, A2, B, C);
testTensorDiv(A1, A2, B, C);
testTensorReciprocal(A1, A2, B, C);
testTensorSoftCrossEntropyBp(A1, A2, B, C);
testTensorSoftCrossEntropy(A1, A2, B, C);
}
TEST(Ternary, BaseOp) {
TestTernaryMatrix<CpuMatrix> testCpu(testTernaryBaseOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU
TestTernaryMatrix<GpuMatrix> testGpu(testTernaryBaseOp<GpuMatrix>);
#endif
}
template<typename Tensor>
void testTensorBinaryLabelCrossEntropy(Tensor& A1,
Tensor& A2,
Tensor& B,
Tensor& C) {
A1.binaryLabelCrossEntropy(B, C); // a = c > 0.5 ? -log(b) : -log(1.0 - b)
A2 = (C > (real)0.5).condition(
-(B.log()), -((B.constant(1.0f) - B).log()));
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorBinaryLabelCrossEntropyBp(Tensor& A1,
Tensor& A2,
Tensor& B,
Tensor& C) {
// a += c > 0.5 ? -1.0 / b : 1.0 / (1.0 - b)
A1.binaryLabelCrossEntropyBp(B, C);
A2 += (C > (real)0.5).condition(
(B.constant(-1.0f) / B), (B.constant(1.0f) - B).reciprocal());
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorLogisticRegressionLoss(Tensor& A1,
Tensor& A2,
Tensor& B,
Tensor& C) {
SetTensorValue(B, 50.0f);
SetTensorValue(B, -50.0f);
/**
* const T THRESHOLD = 40.0;
* T x = (b > THRESHOLD) ? THRESHOLD : (b < -THRESHOLD)
* ? -THRESHOLD
* : b;
* a = log(1 + exp(x)) - c * x
*/
A1.logisticRegressionLoss(B, C);
real THRESHOLD = 40.0;
auto tmp = (B > THRESHOLD).condition(
THRESHOLD, (B < -THRESHOLD).condition(-THRESHOLD, B));
A2 = (C.constant(1.0f) + tmp.exp()).log() - C * tmp;
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorLogisticRegressionLossBp(Tensor& A1,
Tensor& A2,
Tensor& B,
Tensor& C) {
SetTensorValue(B, 50.0f);
SetTensorValue(B, -50.0f);
/**
* const T THRESHOLD = 40.0;
* T x = (b > THRESHOLD) ? THRESHOLD : (b < -THRESHOLD)
* ? -THRESHOLD
* : b;
* x = exp(x); a = x / (1 + x) - c
*/
A1.logisticRegressionLossBp(B, C);
real THRESHOLD = 40.0;
auto tmp = (B > THRESHOLD).condition(
THRESHOLD, (B < -THRESHOLD).condition(-THRESHOLD, B));
auto tmp2 = tmp.exp();
A2 = tmp2 / (C.constant(1.0) + tmp2) - C;
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorBiggerThan(Tensor& A1, Tensor& A2, Tensor& B, Tensor& C) {
A1.biggerThan(B, C); // a = (b > c) ? 1.0f : 0.0f
A2 = (B > C).condition((real)1.0f, (real)0.0f);
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorMax(Tensor& A1, Tensor& A2, Tensor& B, Tensor& C) {
A1.max2(B, C); // a = (b > c) ? b : c
A2 = (B > C).condition(B, C);
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTernaryCompareOp(Tensor& A1, Tensor& A2, Tensor& B, Tensor& C) {
testTensorBinaryLabelCrossEntropyBp(A1, A2, B, C);
testTensorBinaryLabelCrossEntropy(A1, A2, B, C);
testTensorBiggerThan(A1, A2, B, C);
testTensorMax(A1, A2, B, C);
testTensorLogisticRegressionLoss(A1, A2, B, C);
testTensorLogisticRegressionLossBp(A1, A2, B, C);
}
TEST(Ternary, CompareOp) {
TestTernaryMatrix<CpuMatrix> testCpu(testTernaryCompareOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU
TestTernaryMatrix<GpuMatrix> testGpu(testTernaryCompareOp<GpuMatrix>);
#endif
}
template<typename Tensor>
void testQuaternaryAdd(Tensor& A1,
Tensor& A2,
Tensor& B,
Tensor& C,
Tensor& D) {
// A1.add3(B, C, D, 1.5f, 2.5f, 3.5f); // a = p1 * b + p2 * c + p3 * d
// A2 = B * 1.5f + C * 2.5f + D * 3.5f;
// TensorCheckEqual(A1, A2);
/*
* T tmp = p1 * b + p2 * c + p3 * d;
* a += tmp * tmp
*/
real p1 = 1.5f;
real p2 = 2.5f;
real p3 = 3.5f;
A1.addSquareSum(B, C, D, p1, p2, p3);
auto tmp = B * p1 + C * p2 + D * p3;
A2 += tmp * tmp;
TensorCheckEqual(A1, A2);
}
TEST(Quaternary, BaseOp) {
TestQuaternaryMatrix<CpuMatrix> testCpu(testQuaternaryAdd<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU
TestQuaternaryMatrix<GpuMatrix> testGpu(testQuaternaryAdd<GpuMatrix>);
#endif
}
template<typename Tensor>
void testTensorBiggerThan(Tensor& A1,
Tensor& A2,
Tensor& B,
Tensor& C,
Tensor& D) {
// a = ((b > c && d > 0.5f) || (b < c && d < 0.5f)) ? 1.0f : 0.0f);
A1.biggerThan(B, C, D);
A2 = ((B > C && D > (real)0.5)
|| (B < C && D < (real)0.5)).condition((real)1.0, (real)0.0);
TensorCheckEqual(A1, A2);
}
template<typename Tensor>
void testTensorRankLoss(Tensor& A1,
Tensor& A2,
Tensor& B,
Tensor& C,
Tensor& D) {
/**
* const T THRESHOLD = 40.0; a = b - c;
* a = (a > THRESHOLD)
* ? THRESHOLD
* : ((a < -THRESHOLD) ? (-THRESHOLD) : a);
* a = log(1 + exp(a)) - a * d
*/
A1.rankLoss(B, C, D);
real THRESHOLD = 40.0;
auto tmp = B - C;
auto tmp2 = (tmp > THRESHOLD).condition(
THRESHOLD, (tmp < -THRESHOLD).condition(-THRESHOLD, tmp));
A2 = (D.constant(1.0f) + tmp2.exp()).log() - tmp2 * D;
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testTensorRankLossBp(Tensor& A1,
Tensor& A2,
Tensor& B,
Tensor& C,
Tensor& D) {
/**
* const T THRESHOLD = 40.0; a = b - c;
* a = (a > THRESHOLD)
* ? THRESHOLD
* : ((a < -THRESHOLD) ? (-THRESHOLD) : a);
* a = exp(a); a = (a / (1 + a) - d)
*/
A1.rankLossBp(B, C, D);
real THRESHOLD = 40.0;
auto tmp = B - C;
auto tmp2 = (tmp > THRESHOLD).condition(
THRESHOLD, (tmp < -THRESHOLD).condition(-THRESHOLD, tmp));
auto tmp3 = tmp2.exp();
A2 = tmp3 / (D.constant(1.0f) + tmp3) - D;
TensorCheckErr(A1, A2);
}
template<typename Tensor>
void testQuaternaryCompareOp(Tensor& A1,
Tensor& A2,
Tensor& B,
Tensor& C,
Tensor& D) {
testTensorBiggerThan(A1, A2, B, C, D);
testTensorRankLoss(A1, A2, B, C, D);
testTensorRankLossBp(A1, A2, B, C, D);
}
TEST(Quaternary, CompareOp) {
TestQuaternaryMatrix<CpuMatrix> testCpu(testQuaternaryCompareOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU
TestQuaternaryMatrix<GpuMatrix> testGpu(testQuaternaryCompareOp<GpuMatrix>);
#endif
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
hl_start();
hl_init(0);
return RUN_ALL_TESTS();
}
/**
* test_TrainingAlgorithm.cpp
*
* Author: hedaoyuan (hedaoyuan@baidu.com)
* Created on: 2016-06-29
*
* Copyright (c) Baidu.com, Inc. All Rights Reserved
*/
#include <gtest/gtest.h>
#include "paddle/utils/Util.h"
#include "paddle/math/TrainingAlgorithmOp.h"
#include "OriginalOptimizerApi.h"
using namespace paddle; // NOLINT
#ifndef PADDLE_TYPE_DOUBLE
P_DEFINE_double(max_diff, 1e-5, "max diff allowed");
#else
P_DEFINE_double(max_diff, 1e-13, "max diff allowed");
#endif
class SetMaxDiff {
public:
explicit SetMaxDiff(double max_diff) {
max_diff_ = FLAGS_max_diff;
FLAGS_max_diff = max_diff;
}
~SetMaxDiff() {
FLAGS_max_diff = max_diff_;
}
private:
double max_diff_;
};
int VectorCheckErr(const Vector& vector1, const Vector& vector2) {
CHECK(vector1.getSize() == vector2.getSize());
const real* data1 = vector1.getData();
const real* data2 = vector2.getData();
size_t size = vector1.getSize();
int count = 0;
for (size_t i = 0; i < size; i++) {
real a = data1[i];
real b = data2[i];
if (fabs(a - b) > FLAGS_max_diff) {
if ((fabsf(a - b) / fabsf(a)) > (FLAGS_max_diff / 10.0f)) {
count++;
}
}
}
return count;
}
#define COPY_VECTOR_TO_CPU(cpuVec, vector) \
do {\
if (vector->useGpu()) {\
cpuVec = Vector::create(vector->getSize(), false);\
cpuVec->copyFrom(*vector);\
} else {\
cpuVec = vector;\
}\
} while (0)
int VectorCheckErr(const VectorPtr& vector1, const VectorPtr& vector2) {
VectorPtr tmp1;
VectorPtr tmp2;
COPY_VECTOR_TO_CPU(tmp1, vector1);
COPY_VECTOR_TO_CPU(tmp2, vector2);
return VectorCheckErr(*tmp1, *tmp2);
}
#ifdef PADDLE_DISABLE_TIMER
#define CHECK_VECTORPTR(vector1, vector2) \
EXPECT_EQ(VectorCheckErr(vector1, vector2), 0)
#define EXPRESSION_PERFORMANCE(expression) \
expression;
#else
#include "paddle/common/Stat.h"
#define CHECK_VECTORPTR(vector1, vector2)
#define EXPRESSION_PERFORMANCE(expression) \
do {\
char expr[30];\
strncpy(expr, #expression, 30);\
if (expr[29] != '\0') {\
expr[27] = '.'; expr[28] = '.'; expr[29] = '\0';\
}\
expression;\
for (int i = 0; i < 20; i++) {\
REGISTER_TIMER(expr);\
expression;\
}\
LOG(INFO) << std::setiosflags(std::ios::left) << std::setfill(' ')\
<< *globalStat.getStat(expr);\
globalStat.reset();\
} while (0)
#endif
typedef std::function<void(size_t size, bool useGpu)> testMatrixFunc;
void testCase(testMatrixFunc matrixFunc) {
for (auto useGpu : {false, true}) {
for (auto size : {1, 32, 64, 128, 512, 1024, 4096, 32768, 65536, 131072,
262144, 524288, 1048576, 2097152}) {
LOG(INFO) << " size=" << size << " useGpu=" << useGpu;
matrixFunc(size, useGpu);
}
}
}
#define INIT_VECTOR(vec1, vec2, type, size, useGpu) \
vec1[type] = Vector::create(size, useGpu); \
vec2[type] = Vector::create(size, useGpu); \
vec1[type]->rand(); \
vec2[type]->copyFrom(*vec1[type]);
void testAdagrad(size_t size, bool useGpu) {
VectorPtr bufs1[NUM_PARAMETER_TYPES];
VectorPtr bufs2[NUM_PARAMETER_TYPES];
INIT_VECTOR(bufs1, bufs2, PARAMETER_VALUE, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_GRADIENT, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_MOMENTUM, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_GRADIENT_SQURESUM, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_GRADIENT_SQURESUM1, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_LEARNING_RATE, size, useGpu);
real epsilon = (real)rand() / (real)RAND_MAX; // NOLINT
real learningRate = (real)rand() / (real)RAND_MAX; // NOLINT
real momentum = (real)rand() / (real)RAND_MAX; // NOLINT
real decayRate = (real)rand() / (real)RAND_MAX; // NOLINT
EXPRESSION_PERFORMANCE(AdagradParameterOptimizer(bufs1,
epsilon, learningRate, momentum, decayRate));
BaseMatrix& value = *bufs2[PARAMETER_VALUE];
BaseMatrix& grad = *bufs2[PARAMETER_GRADIENT];
BaseMatrix& mom = *bufs2[PARAMETER_MOMENTUM];
BaseMatrix& accum_buffer = *bufs2[PARAMETER_GRADIENT_SQURESUM];
BaseMatrix& accum = *bufs2[PARAMETER_GRADIENT_SQURESUM1];
BaseMatrix& lr = *bufs2[PARAMETER_LEARNING_RATE];
EXPRESSION_PERFORMANCE(adagradApply(value, grad, mom, accum_buffer, accum, lr,
epsilon, learningRate, momentum, decayRate));
CHECK_VECTORPTR(bufs1[PARAMETER_VALUE], bufs2[PARAMETER_VALUE]);
CHECK_VECTORPTR(bufs1[PARAMETER_MOMENTUM], bufs2[PARAMETER_MOMENTUM]);
CHECK_VECTORPTR(bufs1[PARAMETER_GRADIENT_SQURESUM1],
bufs2[PARAMETER_GRADIENT_SQURESUM1]);
CHECK_VECTORPTR(bufs1[PARAMETER_LEARNING_RATE],
bufs2[PARAMETER_LEARNING_RATE]);
}
TEST(Training, Adagrad) {
testCase(testAdagrad);
}
void testAdaDelta(size_t size, bool useGpu) {
VectorPtr bufs1[NUM_PARAMETER_TYPES];
VectorPtr bufs2[NUM_PARAMETER_TYPES];
INIT_VECTOR(bufs1, bufs2, PARAMETER_VALUE, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_GRADIENT, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_MOMENTUM, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_GRADIENT_SQURESUM, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_GRADIENT_SQURESUM1, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_LEARNING_RATE, size, useGpu);
real rou = (real)rand() / (real)RAND_MAX; // NOLINT
real epsilon = (real)rand() / (real)RAND_MAX; // NOLINT
real learningRate = (real)rand() / (real)RAND_MAX; // NOLINT
real momentum = (real)rand() / (real)RAND_MAX; // NOLINT
real decayRate = (real)rand() / (real)RAND_MAX; // NOLINT
EXPRESSION_PERFORMANCE(AdaDeltaParameterOptimizer(bufs1,
rou, epsilon, learningRate, momentum, decayRate));
BaseMatrix& value = *bufs2[PARAMETER_VALUE];
BaseMatrix& grad = *bufs2[PARAMETER_GRADIENT];
BaseMatrix& mom = *bufs2[PARAMETER_MOMENTUM];
BaseMatrix& accum = *bufs2[PARAMETER_GRADIENT_SQURESUM];
BaseMatrix& accum_update = *bufs2[PARAMETER_GRADIENT_SQURESUM1];
BaseMatrix& lr = *bufs2[PARAMETER_LEARNING_RATE];
EXPRESSION_PERFORMANCE(adadeltaApply(value, grad, mom, accum, accum_update,
lr, rou, epsilon, learningRate, momentum, decayRate));
CHECK_VECTORPTR(bufs1[PARAMETER_VALUE], bufs2[PARAMETER_VALUE]);
CHECK_VECTORPTR(bufs1[PARAMETER_MOMENTUM], bufs2[PARAMETER_MOMENTUM]);
CHECK_VECTORPTR(bufs1[PARAMETER_GRADIENT_SQURESUM],
bufs2[PARAMETER_GRADIENT_SQURESUM]);
CHECK_VECTORPTR(bufs1[PARAMETER_GRADIENT_SQURESUM1],
bufs2[PARAMETER_GRADIENT_SQURESUM1]);
CHECK_VECTORPTR(bufs1[PARAMETER_LEARNING_RATE],
bufs2[PARAMETER_LEARNING_RATE]);
}
TEST(Training, AdaDelta) {
testCase(testAdaDelta);
}
template<bool isFirstTime>
void testRMSProp(size_t size, bool useGpu) {
VectorPtr bufs1[NUM_PARAMETER_TYPES];
VectorPtr bufs2[NUM_PARAMETER_TYPES];
INIT_VECTOR(bufs1, bufs2, PARAMETER_VALUE, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_GRADIENT, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_MOMENTUM, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_GRADIENT_SQURESUM, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_GRADIENT_SQURESUM1, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_LEARNING_RATE, size, useGpu);
/* make sure 'g - f.square()' greater than 0 */
bufs1[PARAMETER_GRADIENT_SQURESUM]->add(1.0);
bufs2[PARAMETER_GRADIENT_SQURESUM]->copyFrom(
*bufs1[PARAMETER_GRADIENT_SQURESUM]);
real rou = (real)rand() / (real)RAND_MAX; // NOLINT
real epsilon = (real)rand() / (real)RAND_MAX; // NOLINT
real learningRate = (real)rand() / (real)RAND_MAX; // NOLINT
real momentum = (real)rand() / (real)RAND_MAX; // NOLINT
real decayRate = (real)rand() / (real)RAND_MAX; // NOLINT
real accumulatedRou = rou;
EXPRESSION_PERFORMANCE(RMSPropParameterOptimizer(bufs1,
accumulatedRou, rou, epsilon, learningRate, momentum, decayRate,
isFirstTime));
BaseMatrix& value = *bufs2[PARAMETER_VALUE];
BaseMatrix& grad = *bufs2[PARAMETER_GRADIENT];
BaseMatrix& mom = *bufs2[PARAMETER_MOMENTUM];
BaseMatrix& sum = *bufs2[PARAMETER_GRADIENT_SQURESUM];
BaseMatrix& sum1 = *bufs2[PARAMETER_GRADIENT_SQURESUM1];
BaseMatrix& lr = *bufs2[PARAMETER_LEARNING_RATE];
EXPRESSION_PERFORMANCE(rmspropApply(value, grad, mom, sum, sum1, lr,
accumulatedRou, rou, epsilon, learningRate, momentum, decayRate,
isFirstTime));
CHECK_VECTORPTR(bufs1[PARAMETER_VALUE], bufs2[PARAMETER_VALUE]);
CHECK_VECTORPTR(bufs1[PARAMETER_MOMENTUM], bufs2[PARAMETER_MOMENTUM]);
CHECK_VECTORPTR(bufs1[PARAMETER_GRADIENT_SQURESUM],
bufs2[PARAMETER_GRADIENT_SQURESUM]);
CHECK_VECTORPTR(bufs1[PARAMETER_GRADIENT_SQURESUM1],
bufs2[PARAMETER_GRADIENT_SQURESUM1]);
CHECK_VECTORPTR(bufs1[PARAMETER_LEARNING_RATE],
bufs2[PARAMETER_LEARNING_RATE]);
}
TEST(Training, RMSProp) {
testCase(testRMSProp<true>);
testCase(testRMSProp<false>);
}
template<bool isFirstTime>
void testDecayedAdagrad(size_t size, bool useGpu) {
VectorPtr bufs1[NUM_PARAMETER_TYPES];
VectorPtr bufs2[NUM_PARAMETER_TYPES];
INIT_VECTOR(bufs1, bufs2, PARAMETER_VALUE, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_GRADIENT, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_MOMENTUM, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_GRADIENT_SQURESUM, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_LEARNING_RATE, size, useGpu);
real rou = (real)rand() / (real)RAND_MAX; // NOLINT
real epsilon = (real)rand() / (real)RAND_MAX; // NOLINT
real learningRate = (real)rand() / (real)RAND_MAX; // NOLINT
real momentum = (real)rand() / (real)RAND_MAX; // NOLINT
real decayRate = (real)rand() / (real)RAND_MAX; // NOLINT
real accumulatedRou = rou;
if (isFirstTime) {
bufs1[PARAMETER_GRADIENT_SQURESUM]->zeroMem();
bufs2[PARAMETER_GRADIENT_SQURESUM]->zeroMem();
}
EXPRESSION_PERFORMANCE(DecayedAdagradParameterOptimizer(bufs1,
accumulatedRou, rou, epsilon, learningRate, momentum, decayRate,
isFirstTime));
BaseMatrix& value = *bufs2[PARAMETER_VALUE];
BaseMatrix& grad = *bufs2[PARAMETER_GRADIENT];
BaseMatrix& mom = *bufs2[PARAMETER_MOMENTUM];
BaseMatrix& sum = *bufs2[PARAMETER_GRADIENT_SQURESUM];
BaseMatrix& lr = *bufs2[PARAMETER_LEARNING_RATE];
EXPRESSION_PERFORMANCE(decayedAdagradApply(value, grad, mom, sum, lr,
accumulatedRou, rou, epsilon, learningRate, momentum, decayRate,
isFirstTime));
CHECK_VECTORPTR(bufs1[PARAMETER_VALUE], bufs2[PARAMETER_VALUE]);
CHECK_VECTORPTR(bufs1[PARAMETER_MOMENTUM], bufs2[PARAMETER_MOMENTUM]);
CHECK_VECTORPTR(bufs1[PARAMETER_GRADIENT_SQURESUM],
bufs2[PARAMETER_GRADIENT_SQURESUM]);
CHECK_VECTORPTR(bufs1[PARAMETER_LEARNING_RATE],
bufs2[PARAMETER_LEARNING_RATE]);
}
TEST(Training, DecayedAdagrad) {
testCase(testDecayedAdagrad<false>);
testCase(testDecayedAdagrad<true>);
}
void testAdam(size_t size, bool useGpu) {
VectorPtr bufs1[NUM_PARAMETER_TYPES];
VectorPtr bufs2[NUM_PARAMETER_TYPES];
INIT_VECTOR(bufs1, bufs2, PARAMETER_VALUE, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_GRADIENT, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_MOMENTUM, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_SECOND_MOMENTUM, size, useGpu);
real beta1 = (real)rand() / (real)RAND_MAX; // NOLINT
real beta2 = (real)rand() / (real)RAND_MAX; // NOLINT
real beta1_power = (real)rand() / (real)RAND_MAX; // NOLINT
real beta2_power = (real)rand() / (real)RAND_MAX; // NOLINT
real epsilon = (real)rand() / (real)RAND_MAX; // NOLINT
real learningRate = (real)rand() / (real)RAND_MAX; // NOLINT
EXPRESSION_PERFORMANCE(AdamParameterOptimizer(bufs1,
beta1, beta2, beta1_power, beta2_power, epsilon, learningRate));
BaseMatrix& value = *bufs2[PARAMETER_VALUE];
BaseMatrix& grad = *bufs2[PARAMETER_GRADIENT];
BaseMatrix& mom = *bufs2[PARAMETER_MOMENTUM];
BaseMatrix& v = *bufs2[PARAMETER_SECOND_MOMENTUM];
EXPRESSION_PERFORMANCE(adamApply(value, grad, mom, v,
beta1, beta2, beta1_power, beta2_power, epsilon, learningRate));
CHECK_VECTORPTR(bufs1[PARAMETER_VALUE], bufs2[PARAMETER_VALUE]);
CHECK_VECTORPTR(bufs1[PARAMETER_MOMENTUM], bufs2[PARAMETER_MOMENTUM]);
CHECK_VECTORPTR(bufs1[PARAMETER_SECOND_MOMENTUM],
bufs2[PARAMETER_SECOND_MOMENTUM]);
}
TEST(Training, Adam) {
testCase(testAdam);
}
void testAdamax(size_t size, bool useGpu) {
VectorPtr bufs1[NUM_PARAMETER_TYPES];
VectorPtr bufs2[NUM_PARAMETER_TYPES];
INIT_VECTOR(bufs1, bufs2, PARAMETER_VALUE, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_GRADIENT, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_MOMENTUM, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_WEIGHTED_INFINITY_NORM, size, useGpu);
real beta1 = (real)rand() / (real)RAND_MAX; // NOLINT
real beta2 = (real)rand() / (real)RAND_MAX; // NOLINT
real alpha = (real)rand() / (real)RAND_MAX; // NOLINT
int64_t step = 2;
EXPRESSION_PERFORMANCE(AdamaxParameterOptimizer(bufs1,
beta1, beta2, step, alpha));
BaseMatrix& value = *bufs2[PARAMETER_VALUE];
BaseMatrix& grad = *bufs2[PARAMETER_GRADIENT];
BaseMatrix& mom = *bufs2[PARAMETER_MOMENTUM];
BaseMatrix& u = *bufs2[PARAMETER_WEIGHTED_INFINITY_NORM];
EXPRESSION_PERFORMANCE(adamaxApply(value, grad, mom, u,
beta1, beta2, step, alpha));
CHECK_VECTORPTR(bufs1[PARAMETER_VALUE], bufs2[PARAMETER_VALUE]);
CHECK_VECTORPTR(bufs1[PARAMETER_MOMENTUM], bufs2[PARAMETER_MOMENTUM]);
CHECK_VECTORPTR(bufs1[PARAMETER_WEIGHTED_INFINITY_NORM],
bufs2[PARAMETER_WEIGHTED_INFINITY_NORM]);
}
TEST(Training, Adamax) {
#ifndef PADDLE_TYPE_DOUBLE
SetMaxDiff diff(1e-4);
#endif
testCase(testAdamax);
}
void testSparseMomentum(size_t size, bool useGpu) {
VectorPtr bufs1[NUM_PARAMETER_TYPES];
VectorPtr bufs2[NUM_PARAMETER_TYPES];
INIT_VECTOR(bufs1, bufs2, PARAMETER_VALUE, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_GRADIENT, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_MOMENTUM_UT, size, useGpu);
INIT_VECTOR(bufs1, bufs2, PARAMETER_MOMENTUM_VT, size, useGpu);
real alpha = (real)rand() / (real)RAND_MAX; // NOLINT
real beta = (real)rand() / (real)RAND_MAX; // NOLINT
real gamma = (real)rand() / (real)RAND_MAX; // NOLINT
real tau = (real)rand() / (real)RAND_MAX; // NOLINT
real learningRate = (real)rand() / (real)RAND_MAX; // NOLINT
EXPRESSION_PERFORMANCE(SparseMomentumParameterOptimizer(bufs1,
alpha, beta, gamma, tau, learningRate));
BaseMatrix& value = *bufs2[PARAMETER_VALUE];
BaseMatrix& grad = *bufs2[PARAMETER_GRADIENT];
BaseMatrix& momU = *bufs2[PARAMETER_MOMENTUM_UT];
BaseMatrix& momV = *bufs2[PARAMETER_MOMENTUM_VT];
EXPRESSION_PERFORMANCE(sparseMomentumApply(value, grad, momU, momV,
alpha, beta, gamma, tau, learningRate));
CHECK_VECTORPTR(bufs1[PARAMETER_VALUE], bufs2[PARAMETER_VALUE]);
CHECK_VECTORPTR(bufs1[PARAMETER_MOMENTUM_UT],
bufs2[PARAMETER_MOMENTUM_UT]);
CHECK_VECTORPTR(bufs1[PARAMETER_MOMENTUM_VT],
bufs2[PARAMETER_MOMENTUM_VT]);
}
TEST(Training, SparseMomentum) {
testCase(testSparseMomentum);
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
initMain(argc, argv);
hl_start();
hl_init(FLAGS_gpu_id);
return RUN_ALL_TESTS();
}
......@@ -381,8 +381,8 @@ void testMatrixSqrt(int height, int width) {
cpuA->randomizeUniform();
gpuA->copyFrom(*cpuA);
cpuA->sqrt();
gpuA->sqrt();
cpuA->sqrt2();
gpuA->sqrt2();
MatrixPtr outputCheck = std::make_shared<CpuMatrix>(height, width);
outputCheck->copyFrom(*gpuA);
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include "paddle/utils/Util.h"
#include "paddle/utils/Flags.h"
#include "paddle/math/TrainingAlgorithmOp.h"
#include "FirstOrderOptimizer.h"
#include <cmath>
......@@ -113,17 +113,20 @@ void SparseMomentumParameterOptimizer::finishBatch() {
void AdagradParameterOptimizer::update(const VectorPtr vecs[],
const ParameterConfig& config,
size_t sparseId) const {
vecs[PARAMETER_GRADIENT_SQURESUM1]->addSquare(*vecs[PARAMETER_GRADIENT],
1.0f);
vecs[PARAMETER_LEARNING_RATE]->add(*vecs[PARAMETER_GRADIENT_SQURESUM],
*vecs[PARAMETER_GRADIENT_SQURESUM1]);
vecs[PARAMETER_LEARNING_RATE]->add(optConfig_.ada_epsilon());
vecs[PARAMETER_LEARNING_RATE]->invSqrt(*vecs[PARAMETER_LEARNING_RATE]);
vecs[PARAMETER_VALUE]->sgdUpdate(
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM],
*vecs[PARAMETER_LEARNING_RATE], learningRate_ * config.learning_rate(),
config.momentum(), applyDecay_ ? config.decay_rate() : 0);
BaseMatrix& value = *vecs[PARAMETER_VALUE];
BaseMatrix& grad = *vecs[PARAMETER_GRADIENT];
BaseMatrix& mom = *vecs[PARAMETER_MOMENTUM];
BaseMatrix& accum_buffer = *vecs[PARAMETER_GRADIENT_SQURESUM];
BaseMatrix& accum = *vecs[PARAMETER_GRADIENT_SQURESUM1];
BaseMatrix& lr = *vecs[PARAMETER_LEARNING_RATE];
real epsilon = optConfig_.ada_epsilon();
real learningRate = learningRate_ * config.learning_rate();
real momentum = config.momentum();
real decayRate = applyDecay_ ? config.decay_rate() : 0;
adagradApply(value, grad, mom, accum_buffer, accum, lr,
epsilon, learningRate, momentum, decayRate);
}
ParameterOptimizer::TraverseCallback
......@@ -147,32 +150,32 @@ void AdaDeltaParameterOptimizer::update(const VectorPtr vecs[],
const ParameterConfig& config,
size_t sparseId) const {
CHECK(sparseId == -1LU) << "Sparse update is not supported";
// E(g_t^2) = \rou * E(g_{t-1}^2) + (1-\rou) * g^2
vecs[PARAMETER_GRADIENT_SQURESUM]->decayAddSquare(*vecs[PARAMETER_GRADIENT],
rou_, 1.0f - rou_);
// learn_rate = sqrt( ( E(dx_{t-1}^2) + epsilon ) / ( E(g_t^2) + epsilon ) )
vecs[PARAMETER_LEARNING_RATE]->dotDiv(*vecs[PARAMETER_GRADIENT_SQURESUM1],
*vecs[PARAMETER_GRADIENT_SQURESUM],
epsilon_, epsilon_);
vecs[PARAMETER_LEARNING_RATE]->sqrt();
// E(dx_t^2) = \rou * E(dx_{t-1}^2) + (1-\rou) * (-g*learn_rate)^2
vecs[PARAMETER_GRADIENT_SQURESUM1]->decayAddSquareMul(
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_LEARNING_RATE], rou_,
1.0f - rou_);
vecs[PARAMETER_VALUE]->sgdUpdate(
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM],
*vecs[PARAMETER_LEARNING_RATE], learningRate_ * config.learning_rate(),
config.momentum(), applyDecay_ ? config.decay_rate() : 0);
BaseMatrix& value = *vecs[PARAMETER_VALUE];
BaseMatrix& grad = *vecs[PARAMETER_GRADIENT];
BaseMatrix& mom = *vecs[PARAMETER_MOMENTUM];
BaseMatrix& accum = *vecs[PARAMETER_GRADIENT_SQURESUM];
BaseMatrix& accum_update = *vecs[PARAMETER_GRADIENT_SQURESUM1];
BaseMatrix& lr = *vecs[PARAMETER_LEARNING_RATE];
real learningRate = learningRate_ * config.learning_rate();
real momentum = config.momentum();
real decayRate = applyDecay_ ? config.decay_rate() : 0;
adadeltaApply(value, grad, mom, accum, accum_update, lr,
rou_, epsilon_, learningRate, momentum, decayRate);
}
void RMSPropParameterOptimizer::update(const VectorPtr vecs[],
const ParameterConfig& config,
size_t sparseId) const {
real accumulatedRou = rou_;
BaseMatrix& value = *vecs[PARAMETER_VALUE];
BaseMatrix& grad = *vecs[PARAMETER_GRADIENT];
BaseMatrix& mom = *vecs[PARAMETER_MOMENTUM];
BaseMatrix& sum = *vecs[PARAMETER_GRADIENT_SQURESUM];
BaseMatrix& sum1 = *vecs[PARAMETER_GRADIENT_SQURESUM1];
BaseMatrix& lr = *vecs[PARAMETER_LEARNING_RATE];
real accumulatedRou = rou_;
bool firstTime = timer_ == 0;
if (sparseId != -1LU) {
CHECK_LT(sparseId, t0Vec_.size());
......@@ -181,37 +184,26 @@ void RMSPropParameterOptimizer::update(const VectorPtr vecs[],
t0Vec_[sparseId] = timer_ + 1;
}
// E(g_t^2) = \rou * E(g_{t-1}^2) + (1-\rou) * g^2
// For the first time update, make the sum be the current square
// so that the initial estimation of E(g_t^2) will not be too small.
vecs[PARAMETER_GRADIENT_SQURESUM]->decayAddSquare(
*vecs[PARAMETER_GRADIENT], accumulatedRou,
firstTime ? 1.0f : 1.0f - rou_);
// E(g_t) = \rou * E(g_{t-1}) + (1-\rou) * g
vecs[PARAMETER_GRADIENT_SQURESUM1]->add(*vecs[PARAMETER_GRADIENT],
accumulatedRou, 1.0f - rou_);
// learn_rate = 1/sqrt( ( E(g_t^2) - (E(g_t))^2 + epsilon )
// Basiclly if the sign of the gradient changes more often,
// the learning rate will be decreased.
vecs[PARAMETER_LEARNING_RATE]->assign(*vecs[PARAMETER_GRADIENT_SQURESUM]);
vecs[PARAMETER_LEARNING_RATE]->addSquare(*vecs[PARAMETER_GRADIENT_SQURESUM1],
-1.0f);
vecs[PARAMETER_LEARNING_RATE]->add(optConfig_.ada_epsilon());
vecs[PARAMETER_LEARNING_RATE]->invSqrt(*vecs[PARAMETER_LEARNING_RATE]);
vecs[PARAMETER_VALUE]->sgdUpdate(
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM],
*vecs[PARAMETER_LEARNING_RATE], learningRate_ * config.learning_rate(),
config.momentum(), applyDecay_ ? config.decay_rate() : 0);
real epsilon = optConfig_.ada_epsilon();
real learningRate = learningRate_ * config.learning_rate();
real momentum = config.momentum();
real decayRate = applyDecay_ ? config.decay_rate() : 0;
rmspropApply(value, grad, mom, sum, sum1, lr,
accumulatedRou, rou_, epsilon, learningRate, momentum, decayRate,
firstTime);
}
void DecayedAdagradParameterOptimizer::update(const VectorPtr vecs[],
const ParameterConfig& config,
size_t sparseId) const {
real accumulatedRou = rou_;
BaseMatrix& value = *vecs[PARAMETER_VALUE];
BaseMatrix& grad = *vecs[PARAMETER_GRADIENT];
BaseMatrix& mom = *vecs[PARAMETER_MOMENTUM];
BaseMatrix& sum = *vecs[PARAMETER_GRADIENT_SQURESUM];
BaseMatrix& lr = *vecs[PARAMETER_LEARNING_RATE];
real accumulatedRou = rou_;
bool firstTime = timer_ == 0;
if (sparseId != -1LU) {
CHECK_LT(sparseId, t0Vec_.size());
......@@ -220,77 +212,48 @@ void DecayedAdagradParameterOptimizer::update(const VectorPtr vecs[],
t0Vec_[sparseId] = timer_ + 1;
}
// E(g_t^2) = \rou * E(g_{t-1}^2) + (1-\rou) * g^2
// For the first time update, make the sum be the current square
// so that the initial estimation of E(g_t^2) will not be too small.
vecs[PARAMETER_GRADIENT_SQURESUM]->decayAddSquare(
*vecs[PARAMETER_GRADIENT], accumulatedRou,
firstTime ? 1.0f : 1.0f - rou_);
// learn_rate = 1/sqrt( ( E(g_t^2) + epsilon )
// Basiclly if the bigger the magnitude gradient is,
// the smaller the learning rate will be.
vecs[PARAMETER_LEARNING_RATE]->assign(optConfig_.ada_epsilon());
vecs[PARAMETER_LEARNING_RATE]->add(*vecs[PARAMETER_GRADIENT_SQURESUM]);
vecs[PARAMETER_LEARNING_RATE]->invSqrt(*vecs[PARAMETER_LEARNING_RATE]);
vecs[PARAMETER_VALUE]->sgdUpdate(
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM],
*vecs[PARAMETER_LEARNING_RATE], learningRate_ * config.learning_rate(),
config.momentum(), applyDecay_ ? config.decay_rate() : 0);
real epsilon = optConfig_.ada_epsilon();
real learningRate = learningRate_ * config.learning_rate();
real momentum = config.momentum();
real decayRate = applyDecay_ ? config.decay_rate() : 0;
decayedAdagradApply(value, grad, mom, sum, lr,
accumulatedRou, rou_, epsilon, learningRate, momentum, decayRate,
firstTime);
}
void AdamParameterOptimizer::update(const VectorPtr vecs[],
const ParameterConfig& config,
size_t sparseId) const {
CHECK(sparseId == -1UL) << "Sparse update is not supported";
Vector* m = vecs[PARAMETER_MOMENTUM].get();
Vector* g = vecs[PARAMETER_GRADIENT].get();
Vector* v = vecs[PARAMETER_SECOND_MOMENTUM].get();
Vector* theta = vecs[PARAMETER_VALUE].get();
// m_t = \beta_1 * m_{t-1} + (1-\beta_1)* g_t;
m->add(*g, beta1_, 1 - beta1_);
// v_t = \beta_2 * v_{t-1} + (1-\beta_2)* g_{t-1}^2
g->square();
v->add(*g, beta2_, 1 - beta2_);
// tmp = m_t / ( \sqrt{v_t} + \epsilon )
// \theta_t = \theta_{t-1} - \alpha * \sqrt(1-\beta_2^t) / (1-\beta_1^t) * tmp
g->sqrt(*v);
g->dotDiv(*m, *g, 0., epsilon_);
real alpha = config.learning_rate() * learningRate_;
alpha = alpha * std::sqrt(1 - std::pow(beta2_, step_)) /
(1 - std::pow(beta1_, step_));
theta->add(*theta, 1.0, *g, -alpha);
real beta1_power = std::pow(beta1_, step_);
real beta2_power = std::pow(beta2_, step_);
real learningRate = config.learning_rate() * learningRate_;
BaseMatrix& value = *vecs[PARAMETER_VALUE];
BaseMatrix& grad = *vecs[PARAMETER_GRADIENT];
BaseMatrix& mom = *vecs[PARAMETER_MOMENTUM];
BaseMatrix& v = *vecs[PARAMETER_SECOND_MOMENTUM];
adamApply(value, grad, mom, v,
beta1_, beta2_, beta1_power, beta2_power, epsilon_, learningRate);
}
void AdamaxParameterOptimizer::update(const VectorPtr vecs[],
const ParameterConfig& config,
size_t sparseId) const {
CHECK(sparseId == -1UL) << "Sparse update is not supported";
Vector* m = vecs[PARAMETER_MOMENTUM].get();
Vector* g = vecs[PARAMETER_GRADIENT].get();
Vector* u = vecs[PARAMETER_WEIGHTED_INFINITY_NORM].get();
Vector* theta = vecs[PARAMETER_VALUE].get();
// m_t = \beta_1 * m_{t-1} + (1-\beta_1)* g_t;
m->add(*g, beta1_, 1 - beta1_);
real learningRate = config.learning_rate() * learningRate_;
// u_t = max(\beta_2*u_{t-1}, abs(g_t))
u->mulScalar(beta2_);
g->abs();
u->max(*u, *g);
BaseMatrix& value = *vecs[PARAMETER_VALUE];
BaseMatrix& grad = *vecs[PARAMETER_GRADIENT];
BaseMatrix& mom = *vecs[PARAMETER_MOMENTUM];
BaseMatrix& u = *vecs[PARAMETER_WEIGHTED_INFINITY_NORM];
// \theta_t = \theta_{t-1} - (\alpha/(1-\beta_1^t))*m_t/u_t
g->dotDiv(*m, *u);
real learningRate = config.learning_rate() * learningRate_;
learningRate /= (1 - std::pow(beta1_, step_));
theta->add(*theta, 1.0, *g, -learningRate);
adamaxApply(value, grad, mom, u,
beta1_, beta2_, step_, learningRate);
}
void OptimizerWithGradientClipping::update(const VectorPtr vecs[],
const ParameterConfig& config,
size_t sparseId) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册