diff --git a/paddle/math/BaseMatrix.cu b/paddle/math/BaseMatrix.cu index 2f32b3fdd1a26c5b1bca43d0bd0ebb0896a012c4..a723ef7bc8329329fa82113f8e96a1bdbe750277 100644 --- a/paddle/math/BaseMatrix.cu +++ b/paddle/math/BaseMatrix.cu @@ -1240,6 +1240,12 @@ void BaseMatrixT::assignAtOffset(BaseMatrixT& b, int64_t columnOffset) { } } +DEFINE_MATRIX_BINARY_OP(DeepSwap, T tmp = a; a = b; b = tmp); +template +void BaseMatrixT::deepSwap(BaseMatrixT& b) { + applyBinary(binary::DeepSwap(), b); +} + template<> void BaseMatrixT::rowDotMul(size_t destCol, BaseMatrixT& b, diff --git a/paddle/math/BaseMatrix.h b/paddle/math/BaseMatrix.h index d41dcee682cce15e94d45dafeb12bb0dce19b221..dbc217c30f2f2e6d35536d99115282fdb6ed5e06 100644 --- a/paddle/math/BaseMatrix.h +++ b/paddle/math/BaseMatrix.h @@ -455,6 +455,13 @@ public: */ void assign(T p); + /** + * @code + * swap(this, b) + * @endcode + */ + void deepSwap(BaseMatrixT& b); + /** * @code * this = this + p