diff --git a/CMakeLists.txt b/CMakeLists.txt index 8ce51bb0f4109f596840b4b4cadf8d01525f6dba..5da3aa30d1a019e988b63fcffbb9accc91293404 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -87,6 +87,12 @@ if(NOT WITH_GPU) add_definitions(-DHPPL_STUB_FUNC) list(APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu) else() + if(${CUDA_VERSION_MAJOR} GREATER 6) + if(COMPILER_SUPPORT_CXX11) + LIST(APPEND CUDA_NVCC_FLAGS -std=c++11) + endif() + endif() + # TODO(yuyang18): Change it to remove std=c++11 in cuda compile. set(CUDA_PROPAGATE_HOST_FLAGS OFF) if(NOT CUDNN_FOUND) diff --git a/paddle/math/TensorAssign.h b/paddle/math/TensorAssign.h new file mode 100644 index 0000000000000000000000000000000000000000..0dc52a7dc27c96f9052e22bb2ef0ac75297a2f45 --- /dev/null +++ b/paddle/math/TensorAssign.h @@ -0,0 +1,142 @@ +/** + * TensorAssign.h + * + * Author: hedaoyuan (hedaoyuan@baidu.com) + * Created on: 2016-10-08 + * + * Copyright (c) Baidu.com, Inc. All Rights Reserved + * + */ + +#pragma once + +#include +#include "paddle/utils/Logging.h" + +namespace paddle { + +template +class TensorAssignOp { +public: + explicit TensorAssignOp(const LhsType& lhs, const RhsType& rhs) + : lhs_(lhs), rhs_(rhs) { + #ifndef __CUDA_ARCH__ + CHECK_EQ(lhs_.getWidth(), rhs_.getWidth()); + CHECK_EQ(lhs_.getHeight(), rhs_.getHeight()); + CHECK_EQ(lhs_.useGpu(), rhs_.useGpu()); + #endif + } + + INLINE void apply(const int i, const int j) { + lhs_.applyRef(i, j) = rhs_.apply(i, j); + } + INLINE void apply(const int index) { + lhs_.applyRef(index) = rhs_.apply(index); + } + + INLINE size_t getWidth() const { return lhs_.getWidth(); } + INLINE size_t getHeight() const { return rhs_.getHeight(); } + INLINE bool isContiguous() const { + return lhs_.isContiguous() && rhs_.isContiguous(); + } + INLINE bool useGpu() const { return lhs_.useGpu(); } + +private: + TensorApply lhs_; + TensorApply rhs_; +}; + +template +void AssignCpuEvaluate(int height, int width, bool isContiguous, + Assign&& assign, AssignOp&& ... args) { + if (isContiguous) { + int size = height * width; + for (int index = 0; index < size; index++) { + assign.apply(index); + __attribute__((unused)) int dummy[] = { (((args)).apply(index), 0)... }; + } + } else { + for (int i = 0; i < height; i++) { + for (int j = 0; j < width; j++) { + assign.apply(i, j); + __attribute__((unused)) int dummy[] = { (((args)).apply(i, j), 0)... }; + } + } + } +} + +#ifdef __NVCC__ +template +__global__ +void AssignGpuEvaluate1(const int border, Assign assign, AssignOp ... args) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < border) { + assign.apply(idx); + __attribute__((unused)) int dummy[] = { (((args)).apply(idx), 0)... }; + } +} + +template +__global__ +void AssignGpuEvaluate2(const int height, const int width, + Assign assign, AssignOp ... args) { + const int colIdx = blockIdx.x * blockDim.x + threadIdx.x; + const int rowIdx = blockIdx.y * blockDim.y + threadIdx.y; + for (int i = rowIdx; i < height; i += gridDim.y * blockDim.y) { + for (int j = colIdx; j < width; j += gridDim.x * blockDim.x) { + assign.apply(i, j); + __attribute__((unused)) int dummy[] = { (((args)).apply(i, j), 0)... }; + } + } +} +#endif + +// At least one assignment expression is required +template +void AssignEvaluate(Assign&& assign, AssignOp&& ... args) { + const bool useGpu_ = assign.useGpu(); + bool isContiguous_ = assign.isContiguous(); + const size_t height = assign.getHeight(); + const size_t width = assign.getWidth(); + + const int packSize = sizeof...(args); + const bool packUseGpu[] = { ((args)).useGpu()... }; + const bool packIsContiguous[] = { ((args)).isContiguous()... }; + const size_t packHeight[] = { ((args)).getHeight()... }; + const size_t packWidth[] = { ((args)).getWidth()... }; + + for (int i = 0; i < packSize; i++) { + CHECK_EQ(useGpu_, packUseGpu[i]); + CHECK_EQ(height, packHeight[i]); + CHECK_EQ(width, packWidth[i]); + isContiguous_ = isContiguous_ && packIsContiguous[i]; + } + + if (useGpu_) { +#ifdef __NVCC__ + if (isContiguous_) { + int size = height * width; + int blockSize = size <= 1024 ? size : 1024; + int gridSize = (size + 1024 - 1) / 1024; + AssignGpuEvaluate1 + <<>>(size, assign, args...); + } else { + int blockSizeY = std::min(32, (int)height); + int blockSizeX = (32 / blockSizeY) * 32; + int gridSizeX = std::min(32, (int)(width + blockSizeX - 1) / blockSizeX); + int gridSizeY = std::min(32, (int)(height + blockSizeY - 1) / blockSizeY); + dim3 threads(blockSizeX, blockSizeY); + dim3 grid(gridSizeX, gridSizeY); + AssignGpuEvaluate2 + <<>>(height, width, assign, args...); + } + + CHECK_SYNC("AssignEvaluate failed"); +#endif + } else { + AssignCpuEvaluate(height, width, isContiguous_, assign, args...); + } +} + +} // namespace paddle + diff --git a/paddle/math/TensorEvaluate.h b/paddle/math/TensorEvaluate.h index 87f32996e7897524ac9ee1733d06e9bff616119b..1abfee602faec44f650e70fa6ded49b9ae3a7390 100644 --- a/paddle/math/TensorEvaluate.h +++ b/paddle/math/TensorEvaluate.h @@ -27,14 +27,16 @@ inline void TensorCpuApply(LeftType& lhs, const RightType& rhs) { CHECK_EQ(lhs_.getHeight(), rhs_.getHeight()); CHECK_EQ(lhs_.useGpu(), rhs_.useGpu()); + int height = lhs_.getHeight(); + int width = lhs_.getWidth(); if (lhs_.isContiguous() && rhs_.isContiguous()) { - int size = lhs_.getHeight() * lhs_.getWidth(); + int size = height * width; for (int index = 0; index < size; index++) { lhs_.applyRef(index) = rhs_.apply(index); } } else { - for (size_t i = 0; i < lhs_.getHeight(); i++) { - for (size_t j = 0; j < lhs_.getWidth(); j++) { + for (int i = 0; i < height; i++) { + for (int j = 0; j < width; j++) { lhs_.applyRef(i, j) = rhs_.apply(i, j); } } diff --git a/paddle/math/TensorExpression.h b/paddle/math/TensorExpression.h index f9ebd8b68fbe2f5fa51af5f50fb9fc0e4c3b1176..a5d4fa9d72e59eaaf977d8f4e5ba211296ef45b7 100644 --- a/paddle/math/TensorExpression.h +++ b/paddle/math/TensorExpression.h @@ -27,6 +27,8 @@ typename ExprType2, typename ExprType3, class T> class TensorTernaryOp; +template class TensorAssignOp; + /** * \brief Tensor base class. * @@ -318,6 +320,12 @@ public: (hppl::unary::constant(p), derived()); } + template + TensorAssignOp + lazyAssign(const ExpressionType& expr) const { + return TensorAssignOp (derived(), expr); + } + protected: const Derived& derived() const { return *static_cast(this); } }; diff --git a/paddle/math/TrainingAlgorithmOp.cu b/paddle/math/TrainingAlgorithmOp.cu index 7fae2c7bdd7164552e12c7b5a9fd764688712910..4698fb179bd84fdb37bc9ddacfa604d6055980b3 100644 --- a/paddle/math/TrainingAlgorithmOp.cu +++ b/paddle/math/TrainingAlgorithmOp.cu @@ -12,6 +12,175 @@ #include "BaseMatrix.h" #include "TrainingAlgorithmOp.h" +#if __cplusplus > 199711L + +#include "TensorAssign.h" + +namespace paddle { + +void sparseMomentumApply(BaseMatrix& value, + BaseMatrix& grad, + BaseMatrix& momU, + BaseMatrix& momV, + real alpha, + real beta, + real gamma, + real tau, + real learningRate) { + auto expr1 = momU.lazyAssign(momU - (alpha * gamma * learningRate) * grad); + auto expr2 = momV.lazyAssign( + momV + (tau * alpha * gamma * learningRate) * grad); + auto expr3 = value.lazyAssign( + (tau / beta + (real)1 / alpha) * momU + ((real)1 / beta) * momV); + + AssignEvaluate(expr1, expr2, expr3); +} + +void adadeltaApply(BaseMatrix& value, + BaseMatrix& grad, + BaseMatrix& mom, + BaseMatrix& accum, + BaseMatrix& accum_update, + BaseMatrix& lr, + real rou, + real epsilon, + real learningRate, + real momentum, + real decayRate) { + auto expr1 = accum.lazyAssign(rou * accum + ((real)1 - rou) * grad.square()); + auto expr2 = lr.lazyAssign( + ((accum_update + epsilon) / (accum + epsilon)).sqrt()); + auto expr3 = accum_update.lazyAssign( + rou * accum_update + ((real)1 - rou) * (grad * lr).square()); + auto expr4 = mom.lazyAssign( + mom * momentum - learningRate * lr * (grad + value * decayRate)); + auto expr5 = value.lazyAssign(value + mom); + + AssignEvaluate(expr1, expr2, expr3, expr4, expr5); +} + +void adagradApply(BaseMatrix& value, + BaseMatrix& grad, + BaseMatrix& mom, + BaseMatrix& accum_buffer, + BaseMatrix& accum, + BaseMatrix& lr, + real epsilon, + real learningRate, + real momentum, + real decayRate) { + auto expr1 = accum.lazyAssign(accum + grad.square()); + auto expr2 = lr.lazyAssign( + (accum_buffer + accum + epsilon).sqrt().reciprocal()); + auto expr3 = mom.lazyAssign( + mom * momentum - learningRate * lr * (grad + value * decayRate)); + auto expr4 = value.lazyAssign(value + mom); + + AssignEvaluate(expr1, expr2, expr3, expr4); +} + +void rmspropApply(BaseMatrix& value, + BaseMatrix& grad, + BaseMatrix& mom, + BaseMatrix& g, + BaseMatrix& f, + BaseMatrix& lr, + real accumulatedRou, + real rou, + real epsilon, + real learningRate, + real momentum, + real decayRate, + bool firstTime) { + auto expr2 = f.lazyAssign(accumulatedRou * f + ((real)1 - rou) * grad); + auto expr3 = lr.lazyAssign((g - f.square() + epsilon).sqrt().reciprocal()); + auto expr4 = mom.lazyAssign( + mom * momentum - learningRate * lr * (grad + value * decayRate)); + auto expr5 = value.lazyAssign(value + mom); + + if (firstTime) { + auto expr1 = g.lazyAssign(accumulatedRou * g + grad.square()); + + AssignEvaluate(expr1, expr2, expr3, expr4, expr5); + } else { + auto expr1 = g.lazyAssign( + accumulatedRou * g + ((real)1 - rou) * grad.square()); + + AssignEvaluate(expr1, expr2, expr3, expr4, expr5); + } +} + +void decayedAdagradApply(BaseMatrix& value, + BaseMatrix& grad, + BaseMatrix& mom, + BaseMatrix& accum, + BaseMatrix& lr, + real accumulatedRou, + real rou, + real epsilon, + real learningRate, + real momentum, + real decayRate, + bool firstTime) { + auto expr2 = lr.lazyAssign((accum + epsilon).sqrt().reciprocal()); + auto expr3 = mom.lazyAssign( + mom * momentum - learningRate * lr * (grad + value * decayRate)); + auto expr4 = value.lazyAssign(value + mom); + + if (firstTime) { + auto expr1 = accum.lazyAssign(accumulatedRou * accum + grad.square()); + + AssignEvaluate(expr1, expr2, expr3, expr4); + } else { + auto expr1 = accum.lazyAssign( + accumulatedRou * accum + ((real)1 - rou) * grad.square()); + + AssignEvaluate(expr1, expr2, expr3, expr4); + } +} + +void adamApply(BaseMatrix& value, + BaseMatrix& grad, + BaseMatrix& mom, // firse moment + BaseMatrix& v, // second moment + real beta1, + real beta2, + real beta1_power, + real beta2_power, + real epsilon, + real learningRate) { + real alpha = learningRate * + std::sqrt((real)1 - beta2_power) / ((real)1 - beta1_power); + + auto expr1 = mom.lazyAssign(beta1 * mom + ((real)1 - beta1) * grad); + auto expr2 = v.lazyAssign(beta2 * v + ((real)1 - beta2) * grad.square()); + auto expr3 = value.lazyAssign( + value - (mom * alpha) / (v.sqrt() + epsilon)); + + AssignEvaluate(expr1, expr2, expr3); +} + +void adamaxApply(BaseMatrix& value, + BaseMatrix& grad, + BaseMatrix& mom, // firse moment + BaseMatrix& u, // weighted infinity norm + real beta1, + real beta2, + int64_t step, + real alpha) { + auto expr1 = mom.lazyAssign(beta1 * mom + ((real)1 - beta1) * grad); + auto expr2 = u.lazyAssign( + (beta2 * u > grad.abs()).condition(beta2 * u, grad.abs())); + auto expr3 = value.lazyAssign( + value - (alpha / ((real)1 - (real)std::pow(beta1, step))) * (mom / u)); + + AssignEvaluate(expr1, expr2, expr3); +} + +} // namespace paddle + +#else + namespace paddle { void sparseMomentumApply(BaseMatrix& value, @@ -180,3 +349,6 @@ void adamaxApply(BaseMatrix& value, } } // namespace paddle + +#endif + diff --git a/paddle/math/tests/CMakeLists.txt b/paddle/math/tests/CMakeLists.txt index c239289ee539dd0fb50990af2e5157c2c62b3884..35cb9ee5784752543e211c90c721045d8a127a5d 100644 --- a/paddle/math/tests/CMakeLists.txt +++ b/paddle/math/tests/CMakeLists.txt @@ -15,13 +15,16 @@ add_simple_unittest(test_perturbation) add_simple_unittest(test_CpuGpuVector) add_simple_unittest(test_Allocator) if(WITH_GPU) + CUDA_ADD_EXECUTABLE(test_Tensor test_Tensor.cu) + link_paddle_test(test_Tensor) if(COMPILER_SUPPORT_CXX11) - LIST(APPEND CUDA_NVCC_FLAGS -std=c++11) - CUDA_ADD_EXECUTABLE(test_Tensor test_Tensor.cu) - link_paddle_test(test_Tensor) + CUDA_ADD_EXECUTABLE(test_lazyAssign test_lazyAssign.cu) + link_paddle_test(test_lazyAssign) endif() else() compile_cu_as_cpp(test_Tensor.cu) add_unittest(test_Tensor test_Tensor.cu) + compile_cu_as_cpp(test_lazyAssign.cu) + add_unittest(test_lazyAssign test_lazyAssign.cu) endif(WITH_GPU) diff --git a/paddle/math/tests/TensorCheck.h b/paddle/math/tests/TensorCheck.h new file mode 100644 index 0000000000000000000000000000000000000000..d11e1376e4e26892a254c365a1b0a8015b036ec3 --- /dev/null +++ b/paddle/math/tests/TensorCheck.h @@ -0,0 +1,179 @@ +/** + * test_Tensor.cpp + * + * Author: hedaoyuan (hedaoyuan@baidu.com) + * Created on: 2016-06-06 + * + * Copyright (c) Baidu.com, Inc. All Rights Reserved + */ + +#include +#include "paddle/math/Matrix.h" +using namespace paddle; // NOLINT +using namespace std; // NOLINT + +template +extern void TensorCheckEqual(const Tensor& tensor1, const Tensor& tensor2); + +void TensorCheckEqual(const CpuMatrix& matrix1, const CpuMatrix& matrix2) { + CHECK(matrix1.getHeight() == matrix2.getHeight()); + CHECK(matrix1.getWidth() == matrix2.getWidth()); + + int height = matrix1.getHeight(); + int width = matrix1.getWidth(); + const real* data1 = matrix1.getData(); + const real* data2 = matrix2.getData(); + int count = 0; + for (int i = 0; i < height; i++) { + for (int j = 0; j < width; j++) { + if (data1[i * width + j] != data2[i * width + j]) { + count++; + } + } + } + EXPECT_EQ(count, 0) << "There are " << count << " different element."; +} + +void TensorCheckEqual(const GpuMatrix& matrix1, const GpuMatrix& matrix2) { + CpuMatrix cpu1(matrix1.getHeight(), matrix1.getWidth()); + CpuMatrix cpu2(matrix2.getHeight(), matrix2.getWidth()); + cpu1.copyFrom(matrix1); + cpu2.copyFrom(matrix2); + TensorCheckEqual(cpu1, cpu2); +} + +void TensorCheckErr(const CpuMatrix& matrix1, const CpuMatrix& matrix2) { + CHECK(matrix1.getHeight() == matrix2.getHeight()); + CHECK(matrix1.getWidth() == matrix2.getWidth()); +#ifndef PADDLE_TYPE_DOUBLE + real err = 1e-5; +#else + real err = 1e-10; +#endif + + int height = matrix1.getHeight(); + int width = matrix1.getWidth(); + const real* data1 = matrix1.getData(); + const real* data2 = matrix2.getData(); + int count = 0; + for (int i = 0; i < height; i++) { + for (int j = 0; j < width; j++) { + real a = data1[i * width + j]; + real b = data2[i * width + j]; + if (fabs(a - b) > err) { + if ((fabsf(a - b) / fabsf(a)) > (err / 10.0f)) { + count++; + } + } + } + } + EXPECT_EQ(count, 0) << "There are " << count << " different element."; +} + +void TensorCheckErr(const GpuMatrix& matrix1, const GpuMatrix& matrix2) { + CpuMatrix cpu1(matrix1.getHeight(), matrix1.getWidth()); + CpuMatrix cpu2(matrix2.getHeight(), matrix2.getWidth()); + cpu1.copyFrom(matrix1); + cpu2.copyFrom(matrix2); + TensorCheckErr(cpu1, cpu2); +} + +template +void TensorCheckEqual(const CpuVectorT& vector1, + const CpuVectorT& vector2) { + CHECK(vector1.getSize() == vector2.getSize()); + + const T* data1 = vector1.getData(); + const T* data2 = vector2.getData(); + size_t size = vector1.getSize(); + int count = 0; + for (size_t i = 0; i < size; i++) { + if (data1[i] != data2[i]) { + count++; + } + } + EXPECT_EQ(count, 0) << "There are " << count << " different element."; +} + +template +void TensorCheckEqual(const GpuVectorT& vector1, + const GpuVectorT& vector2) { + CpuVectorT cpu1(vector1.getSize()); + CpuVectorT cpu2(vector2.getSize()); + cpu1.copyFrom(vector1); + cpu2.copyFrom(vector2); + TensorCheckEqual(cpu1, cpu2); +} + +int VectorCheckErr(const Vector& vector1, const Vector& vector2) { + CHECK(vector1.getSize() == vector2.getSize()); + + const real* data1 = vector1.getData(); + const real* data2 = vector2.getData(); + size_t size = vector1.getSize(); + int count = 0; + for (size_t i = 0; i < size; i++) { + real a = data1[i]; + real b = data2[i]; + if (fabs(a - b) > FLAGS_max_diff) { + if ((fabsf(a - b) / fabsf(a)) > (FLAGS_max_diff / 10.0f)) { + count++; + } + } + } + + return count; +} + +#define INIT_UNARY(A1, A2) \ + Tensor A1(height, width); \ + Tensor A2(height, width); \ + A1.randomizeUniform(); \ + A2.copyFrom(A1) +#define INIT_BINARY(A1, A2, B) \ + INIT_UNARY(A1, A2); \ + Tensor B(height, width); \ + B.randomizeUniform() +#define INIT_TERNARY(A1, A2, B, C) \ + INIT_BINARY(A1, A2, B); \ + Tensor C(height, width); \ + C.randomizeUniform() +#define INIT_QUATERNARY(A1, A2, B, C, D) \ + INIT_TERNARY(A1, A2, B, C); \ + Tensor D(height, width); \ + D.randomizeUniform() + +// Performance Check +#ifdef PADDLE_DISABLE_TIMER + +#define CHECK_VECTORPTR(vector1, vector2) \ + EXPECT_EQ(VectorCheckErr(vector1, vector2), 0) + +#define EXPRESSION_PERFORMANCE(expression) \ + expression; + +#else + +#include "paddle/utils/Stat.h" + +#define CHECK_VECTORPTR(vector1, vector2) + +#define EXPRESSION_PERFORMANCE(expression) \ + do {\ + char expr[30];\ + strncpy(expr, #expression, 30);\ + if (expr[29] != '\0') {\ + expr[27] = '.'; expr[28] = '.'; expr[29] = '\0';\ + }\ + expression;\ + for (int i = 0; i < 20; i++) {\ + REGISTER_TIMER(expr);\ + expression;\ + }\ + LOG(INFO) << std::setiosflags(std::ios::left) << std::setfill(' ')\ + << *globalStat.getStat(expr);\ + globalStat.reset();\ + } while (0) + +#endif + diff --git a/paddle/math/tests/test_TrainingAlgorithm.cpp b/paddle/math/tests/test_TrainingAlgorithm.cpp index 1759d221e51de121f304ba01dcc9ee26c78e1800..a4218118b81a1f947c358cfae36d22917245857d 100644 --- a/paddle/math/tests/test_TrainingAlgorithm.cpp +++ b/paddle/math/tests/test_TrainingAlgorithm.cpp @@ -11,6 +11,7 @@ #include "paddle/utils/Util.h" #include "paddle/math/TrainingAlgorithmOp.h" #include "OriginalOptimizerApi.h" +#include "TensorCheck.h" using namespace paddle; // NOLINT @@ -33,26 +34,6 @@ private: double max_diff_; }; -int VectorCheckErr(const Vector& vector1, const Vector& vector2) { - CHECK(vector1.getSize() == vector2.getSize()); - - const real* data1 = vector1.getData(); - const real* data2 = vector2.getData(); - size_t size = vector1.getSize(); - int count = 0; - for (size_t i = 0; i < size; i++) { - real a = data1[i]; - real b = data2[i]; - if (fabs(a - b) > FLAGS_max_diff) { - if ((fabsf(a - b) / fabsf(a)) > (FLAGS_max_diff / 10.0f)) { - count++; - } - } - } - - return count; -} - #define COPY_VECTOR_TO_CPU(cpuVec, vector) \ do {\ if (vector->useGpu()) {\ @@ -71,39 +52,6 @@ int VectorCheckErr(const VectorPtr& vector1, const VectorPtr& vector2) { return VectorCheckErr(*tmp1, *tmp2); } -#ifdef PADDLE_DISABLE_TIMER - -#define CHECK_VECTORPTR(vector1, vector2) \ - EXPECT_EQ(VectorCheckErr(vector1, vector2), 0) - -#define EXPRESSION_PERFORMANCE(expression) \ - expression; - -#else - -#include "paddle/utils/Stat.h" - -#define CHECK_VECTORPTR(vector1, vector2) - -#define EXPRESSION_PERFORMANCE(expression) \ - do {\ - char expr[30];\ - strncpy(expr, #expression, 30);\ - if (expr[29] != '\0') {\ - expr[27] = '.'; expr[28] = '.'; expr[29] = '\0';\ - }\ - expression;\ - for (int i = 0; i < 20; i++) {\ - REGISTER_TIMER(expr);\ - expression;\ - }\ - LOG(INFO) << std::setiosflags(std::ios::left) << std::setfill(' ')\ - << *globalStat.getStat(expr);\ - globalStat.reset();\ - } while (0) - -#endif - typedef std::function testMatrixFunc; void testCase(testMatrixFunc matrixFunc) { diff --git a/paddle/math/tests/test_lazyAssign.cu b/paddle/math/tests/test_lazyAssign.cu new file mode 100644 index 0000000000000000000000000000000000000000..070a9a92dd6038204a8dd5a968e782e9aa9ba59c --- /dev/null +++ b/paddle/math/tests/test_lazyAssign.cu @@ -0,0 +1,131 @@ +/** + * test_lazyAssign.cpp + * + * Author: hedaoyuan (hedaoyuan@baidu.com) + * Created on: 2016-10-15 + * + * Copyright (c) Baidu.com, Inc. All Rights Reserved + */ + +#include +#include "paddle/math/Matrix.h" +#include "paddle/math/TensorAssign.h" +#include "TensorCheck.h" + +using namespace paddle; // NOLINT +using namespace std; // NOLINT + +typedef std::function testMatrixFunc; +void testMatrixCase(testMatrixFunc matrixFunc) { + for (auto height : {1}) { + for (auto width : {1, 32, 64, 128, 512, 1024, 4096, 32768, 65536, 131072, + 262144, 524288, 1048576, 2097152, 4194304, 8388608}) { + matrixFunc(height, width); + } + } +} + +template +void testLazyAssign(int height, int width) { + INIT_QUATERNARY(A1, A2, B, C, D); + + EXPRESSION_PERFORMANCE(A1 = B + C; A1 = A1 * D;); + + EXPRESSION_PERFORMANCE( + auto expr1 = A2.lazyAssign(B + C); + auto expr2 = A2.lazyAssign(A2 * D); + AssignEvaluate(expr1, expr2);); + + TensorCheckErr(A1, A2); +} + +TEST(lazyAssign, CPU) { + testMatrixCase(testLazyAssign); +} + +#ifndef PADDLE_ONLY_CPU +TEST(lazyAssign, GPU) { + testMatrixCase(testLazyAssign); +} +#endif + +template +void sgdUpdateTensor(Tensor& A, Tensor& B, Tensor& C, Tensor& D, + real p1, real p2, real p3) { + C = C * p2 - D * (B + A * p3) * p1; + A += C; +} + +void sgdUpdateLazyAssign(BaseMatrix& A, BaseMatrix& B, + BaseMatrix& C, BaseMatrix& D, + real p1, real p2, real p3) { + auto expr1 = C.lazyAssign(C * p2 - D * (B + A * p3) * p1); + auto expr2 = A.lazyAssign(A + C); + AssignEvaluate(expr1, expr2); +} + +template +void testSgdUpdate(int height, int width) { + Tensor A1(height, width); + Tensor A2(height, width); + Tensor A3(height, width); + A1.randomizeUniform(); + A2.copyFrom(A1); + A3.copyFrom(A1); + + Tensor B(height, width); + B.randomizeUniform(); + + Tensor C1(height, width); + Tensor C2(height, width); + Tensor C3(height, width); + C1.randomizeUniform(); + C2.copyFrom(C1); + C3.copyFrom(C1); + + Tensor D(height, width); + D.randomizeUniform(); + + real p1 = 0.2; + real p2 = 0.3; + real p3 = 0.5; + + /** + * c = p2 * c - p1 * (b + p3 * a); + * a = a + c; + */ + // BaseMatrix API + EXPRESSION_PERFORMANCE( + A1.sgdUpdate(B, C1, D, p1, p2, p3);); + + // Tensor expression + EXPRESSION_PERFORMANCE( + sgdUpdateTensor(A2, B, C2, D, p1, p2, p3)); + + // lazyAssign + EXPRESSION_PERFORMANCE( + sgdUpdateLazyAssign(A3, B, C3, D, p1, p2, p3)); + + TensorCheckErr(A1, A2); + TensorCheckErr(A1, A3); + TensorCheckErr(C1, C2); + TensorCheckErr(C1, C3); +} + +TEST(sgdUpdate, CPU) { + testMatrixCase(testSgdUpdate); +} + +#ifndef PADDLE_ONLY_CPU +TEST(sgdUpdate, GPU) { + testMatrixCase(testSgdUpdate); +} +#endif + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + hl_start(); + hl_init(0); + return RUN_ALL_TESTS(); +} +