提交 9e65ceed 编写于 作者: T tianbingsz 提交者: GitHub

Merge pull request #565 from tianbingsz/deepCopy

deepSwap
......@@ -1240,6 +1240,12 @@ void BaseMatrixT<T>::assignAtOffset(BaseMatrixT& b, int64_t columnOffset) {
}
}
DEFINE_MATRIX_BINARY_OP(DeepSwap, T tmp = a; a = b; b = tmp);
template<class T>
void BaseMatrixT<T>::deepSwap(BaseMatrixT& b) {
applyBinary(binary::DeepSwap<T>(), b);
}
template<>
void BaseMatrixT<real>::rowDotMul(size_t destCol,
BaseMatrixT& b,
......
......@@ -455,6 +455,17 @@ public:
*/
void assign(T p);
/**
* @code
* swap(this, b)
* example: swap two Matrices
* MatrixPtr cpuA = std::make_shared<CpuMatrix>(height, width);
* MatrixPtr cpuB = std::make_shared<CpuMatrix>(height, width);
* cpuA->deepSwap(*cpuB);
* @endcode
*/
void deepSwap(BaseMatrixT& b);
/**
* @code
* this = this + p
......
......@@ -448,6 +448,24 @@ void testMatrixZeroAtOffset(int height, int width) {
MatrixCheckEqual(*cpuA, *cpuTest);
}
void testMatrixDeepSwap(int height, int width) {
MatrixPtr cpuA = std::make_shared<CpuMatrix>(height, width);
MatrixPtr cpuB = std::make_shared<CpuMatrix>(height, width);
MatrixPtr cpuCopyA = std::make_shared<CpuMatrix>(height, width);
MatrixPtr cpuCopyB = std::make_shared<CpuMatrix>(height, width);
cpuA->randomizeUniform();
cpuB->randomizeUniform();
cpuCopyA->copyFrom(*cpuA);
cpuCopyB->copyFrom(*cpuB);
// swap matrix cpuA and cpuB
cpuA->deepSwap(*cpuB);
MatrixCheckEqual(*cpuA, *cpuCopyB);
MatrixCheckEqual(*cpuB, *cpuCopyA);
}
void testMatrixBinaryAdd(int height, int width) {
MatrixPtr cpuA = std::make_shared<CpuMatrix>(height, width);
MatrixPtr cpuB = std::make_shared<CpuMatrix>(height, width);
......@@ -480,6 +498,7 @@ void testMatrixAssign(int height, int width) {
MatrixCheckEqual(*cpuA, *outputCheck);
}
void testMatrixAdd(int height, int width) {
MatrixPtr cpuA = std::make_shared<CpuMatrix>(height, width);
MatrixPtr gpuA = std::make_shared<GpuMatrix>(height, width);
......@@ -798,6 +817,7 @@ TEST(Matrix, unary) {
testMatrixBinaryAdd(height, width);
testMatrixTanh(height, width);
testMatrixTanhDerivative(height, width);
testMatrixDeepSwap(height, width);
// applyTernary
testMatrixTernarySub(height, width);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册