diff --git a/paddle/function/BufferArg.cpp b/paddle/function/BufferArg.cpp index 65c6f303041d830812fb2d99503b2b2166145f4a..fde48a73b61c31d06225cc1763efbc6971c86f57 100644 --- a/paddle/function/BufferArg.cpp +++ b/paddle/function/BufferArg.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include #include "BufferArg.h" +#include "paddle/math/SparseMatrix.h" namespace paddle { @@ -28,4 +29,14 @@ const SparseMatrixArg& BufferArg::sparse() const { return dynamic_cast(*this); } +SparseMatrixArg::SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType) + : BufferArg(sparse, argType), + row_(reinterpret_cast(sparse.getRows()), VALUE_TYPE_INT32), + col_(reinterpret_cast(sparse.getCols()), VALUE_TYPE_INT32) {} + +SparseMatrixArg::SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType) + : BufferArg(sparse, argType), + row_(reinterpret_cast(sparse.getRows()), VALUE_TYPE_INT32), + col_(reinterpret_cast(sparse.getCols()), VALUE_TYPE_INT32) {} + } // namespace paddle diff --git a/paddle/function/BufferArg.h b/paddle/function/BufferArg.h index 9649913fa8d9bf82b67fc2ac97ae9f30e7029528..12352ba29e33920ba65bd66088b6f7cc53517b52 100644 --- a/paddle/function/BufferArg.h +++ b/paddle/function/BufferArg.h @@ -18,9 +18,7 @@ limitations under the License. */ #include "TensorShape.h" #include "TensorType.h" -#include "paddle/math/CpuSparseMatrix.h" #include "paddle/math/Matrix.h" -#include "paddle/math/SparseMatrix.h" namespace paddle { @@ -248,15 +246,9 @@ public: } } - SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED) - : BufferArg(sparse, argType), - row_(reinterpret_cast(sparse.getRows()), VALUE_TYPE_INT32), - col_(reinterpret_cast(sparse.getCols()), VALUE_TYPE_INT32) {} + SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED); - SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED) - : BufferArg(sparse, argType), - row_(reinterpret_cast(sparse.getRows()), VALUE_TYPE_INT32), - col_(reinterpret_cast(sparse.getCols()), VALUE_TYPE_INT32) {} + SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED); ~SparseMatrixArg() {} diff --git a/paddle/function/BufferArgTest.cpp b/paddle/function/BufferArgTest.cpp index a9ee3ab079e339b86a9db8602c41e419df9dc544..b345597435c9911ce95b596f5f7f2add47f4cd03 100644 --- a/paddle/function/BufferArgTest.cpp +++ b/paddle/function/BufferArgTest.cpp @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "Function.h" #include "paddle/math/MemoryHandle.h" +#include "paddle/math/SparseMatrix.h" namespace paddle {