提交 a6ad9a16 编写于 作者: X xuwei06

Fix unittest

Change-Id: Ic80845c892c96c37a0df0ddc433fe1aeaa5a9d1c
上级 bf6f690f
...@@ -605,7 +605,7 @@ public: ...@@ -605,7 +605,7 @@ public:
int batchSize = input->getHeight(); int batchSize = input->getHeight();
int size = 1; int size = 1;
resizeOutput(batchSize, size); resizeOutput(batchSize, size);
output_.value->sumRows(*input); output_.value->sumRows(*input, /* scaleSum= */1, /* scaleDest= */0);
} }
virtual void backward(const UpdateCallback& callback = nullptr) { virtual void backward(const UpdateCallback& callback = nullptr) {
......
...@@ -1473,6 +1473,21 @@ int BaseMatrixT<real>::applyRow(Agg agg, Saver sv, BaseMatrixT& b) { ...@@ -1473,6 +1473,21 @@ int BaseMatrixT<real>::applyRow(Agg agg, Saver sv, BaseMatrixT& b) {
return 0; return 0;
} }
template<>
template <class Agg>
int BaseMatrixT<real>::applyRow(
Agg agg, real scaleDest, real scaleAgg, BaseMatrixT& b) {
if (scaleDest != 0) {
applyRow(agg, base::binary::add2(scaleDest, scaleAgg), b);
} else {
applyRow(agg, base::binary::second(), b);
if (scaleAgg != 1) {
mulScalar(scaleAgg);
}
}
return 0;
}
template<> template<>
template <class Agg, class Op, class Saver> template <class Agg, class Op, class Saver>
int BaseMatrixT<real>::applyRow(Agg agg, Op op, Saver sv, int BaseMatrixT<real>::applyRow(Agg agg, Op op, Saver sv,
...@@ -1490,6 +1505,21 @@ int BaseMatrixT<real>::applyRow(Agg agg, Op op, Saver sv, ...@@ -1490,6 +1505,21 @@ int BaseMatrixT<real>::applyRow(Agg agg, Op op, Saver sv,
return 0; return 0;
} }
template<>
template <class Agg, class Op>
int BaseMatrixT<real>::applyRow(Agg agg, Op op, real scaleDest, real scaleAgg,
BaseMatrixT& b, BaseMatrixT& c) {
if (scaleDest != 0) {
applyRow(agg, op, base::binary::add2(scaleDest, scaleAgg), b, c);
} else {
applyRow(agg, op, base::binary::second(), b, c);
if (scaleAgg != 1) {
mulScalar(scaleAgg);
}
}
return 0;
}
template<> template<>
template <class Agg> template <class Agg>
int BaseMatrixT<real>::applyCol(Agg agg, BaseMatrixT& b) { int BaseMatrixT<real>::applyCol(Agg agg, BaseMatrixT& b) {
...@@ -1518,9 +1548,24 @@ int BaseMatrixT<real>::applyCol(Agg agg, Saver sv, BaseMatrixT& b) { ...@@ -1518,9 +1548,24 @@ int BaseMatrixT<real>::applyCol(Agg agg, Saver sv, BaseMatrixT& b) {
return 0; return 0;
} }
template<>
template <class Agg>
int BaseMatrixT<real>::applyCol(
Agg agg, real scaleDest, real scaleAgg, BaseMatrixT& b) {
if (scaleDest != 0) {
applyCol(agg, base::binary::add2(scaleDest, scaleAgg), b);
} else {
applyCol(agg, base::binary::second(), b);
if (scaleAgg != 1) {
mulScalar(scaleAgg);
}
}
return 0;
}
template<> template<>
void BaseMatrixT<real>::sumRows(BaseMatrixT& b, real scaleSum, real scaleDest) { void BaseMatrixT<real>::sumRows(BaseMatrixT& b, real scaleSum, real scaleDest) {
applyRow(aggregate::sum(), base::binary::add2(scaleDest, scaleSum), b); applyRow(aggregate::sum(), scaleDest, scaleSum, b);
} }
template<> template<>
...@@ -1550,21 +1595,21 @@ void BaseMatrixT<real>::minCols(BaseMatrixT& b) { ...@@ -1550,21 +1595,21 @@ void BaseMatrixT<real>::minCols(BaseMatrixT& b) {
template<> template<>
void BaseMatrixT<real>::sumCols(BaseMatrixT& b, real scaleSum, real scaleDest) { void BaseMatrixT<real>::sumCols(BaseMatrixT& b, real scaleSum, real scaleDest) {
applyCol(aggregate::sum(), base::binary::add2(scaleDest, scaleSum), b); applyCol(aggregate::sum(), scaleDest, scaleSum, b);
} }
template<> template<>
void BaseMatrixT<real>::sumOfSquaredDiffs( void BaseMatrixT<real>::sumOfSquaredDiffs(
BaseMatrixT& b, BaseMatrixT& c, real scaleSum, real scaleDest) { BaseMatrixT& b, BaseMatrixT& c, real scaleSum, real scaleDest) {
applyRow(aggregate::sum(), base::binary::squaredDiff(), applyRow(aggregate::sum(), base::binary::squaredDiff(),
base::binary::add2(scaleDest, scaleSum), b, c); scaleDest, scaleSum, b, c);
} }
template<> template<>
void BaseMatrixT<real>::sumOfProducts( void BaseMatrixT<real>::sumOfProducts(
BaseMatrixT& b, BaseMatrixT& c, real scaleSum, real scaleDest) { BaseMatrixT& b, BaseMatrixT& c, real scaleSum, real scaleDest) {
applyRow(aggregate::sum(), base::binary::mul(), applyRow(aggregate::sum(), base::binary::mul(),
base::binary::add2(scaleDest, scaleSum), b, c); scaleDest, scaleSum, b, c);
} }
template class BaseMatrixT<real>; template class BaseMatrixT<real>;
......
...@@ -317,6 +317,11 @@ public: ...@@ -317,6 +317,11 @@ public:
template <class Agg, class Op, class Saver> template <class Agg, class Op, class Saver>
int applyRow(Agg agg, Op op, Saver sv, BaseMatrixT& b, BaseMatrixT& c); int applyRow(Agg agg, Op op, Saver sv, BaseMatrixT& b, BaseMatrixT& c);
// Same as the above with the special handing of sv=add2(scaleDest, scaleAgg)
template <class Agg, class Op>
int applyRow(Agg agg, Op op, real scaleDest, real scaleAgg,
BaseMatrixT& b, BaseMatrixT& c);
/** /**
* a aggregate expression that apply each row of matrix b. * a aggregate expression that apply each row of matrix b.
* *
...@@ -329,6 +334,10 @@ public: ...@@ -329,6 +334,10 @@ public:
template <class Agg, class Saver> template <class Agg, class Saver>
int applyRow(Agg agg, Saver sv, BaseMatrixT& b); int applyRow(Agg agg, Saver sv, BaseMatrixT& b);
// Same as the above with the special handing of sv=add2(scaleDest, scaleAgg)
template <class Agg>
int applyRow(Agg agg, real scaleDest, real scaleAgg, BaseMatrixT& b);
/** /**
* a aggregate expression that apply each column of matrix b. * a aggregate expression that apply each column of matrix b.
* *
...@@ -352,6 +361,10 @@ public: ...@@ -352,6 +361,10 @@ public:
template <class Agg, class Saver> template <class Agg, class Saver>
int applyCol(Agg agg, Saver sv, BaseMatrixT& b); int applyCol(Agg agg, Saver sv, BaseMatrixT& b);
// Same as the above with the special handing of sv=add2(scaleDest, scaleAgg)
template <class Agg>
int applyCol(Agg agg, real scaleDest, real scaleAgg, BaseMatrixT& b);
bool useGpu() const { return useGpu_; } bool useGpu() const { return useGpu_; }
const T* rowBuf(size_t row) const { return data_ + width_ * row; } const T* rowBuf(size_t row) const { return data_ + width_ * row; }
......
...@@ -29,7 +29,6 @@ except ImportError: ...@@ -29,7 +29,6 @@ except ImportError:
import pickle import pickle
import copy import copy
<<<<<<< 0ba0f02c685e52b14632f6b9bfca4321494505c7
__all__ = [ __all__ = [
"full_matrix_projection", "full_matrix_projection",
"AggregateLevel", "AggregateLevel",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册