提交 a6ad9a16 编写于 作者: X xuwei06

Fix unittest

Change-Id: Ic80845c892c96c37a0df0ddc433fe1aeaa5a9d1c
上级 bf6f690f
......@@ -605,7 +605,7 @@ public:
int batchSize = input->getHeight();
int size = 1;
resizeOutput(batchSize, size);
output_.value->sumRows(*input);
output_.value->sumRows(*input, /* scaleSum= */1, /* scaleDest= */0);
}
virtual void backward(const UpdateCallback& callback = nullptr) {
......
......@@ -1473,6 +1473,21 @@ int BaseMatrixT<real>::applyRow(Agg agg, Saver sv, BaseMatrixT& b) {
return 0;
}
template<>
template <class Agg>
int BaseMatrixT<real>::applyRow(
Agg agg, real scaleDest, real scaleAgg, BaseMatrixT& b) {
if (scaleDest != 0) {
applyRow(agg, base::binary::add2(scaleDest, scaleAgg), b);
} else {
applyRow(agg, base::binary::second(), b);
if (scaleAgg != 1) {
mulScalar(scaleAgg);
}
}
return 0;
}
template<>
template <class Agg, class Op, class Saver>
int BaseMatrixT<real>::applyRow(Agg agg, Op op, Saver sv,
......@@ -1490,6 +1505,21 @@ int BaseMatrixT<real>::applyRow(Agg agg, Op op, Saver sv,
return 0;
}
template<>
template <class Agg, class Op>
int BaseMatrixT<real>::applyRow(Agg agg, Op op, real scaleDest, real scaleAgg,
BaseMatrixT& b, BaseMatrixT& c) {
if (scaleDest != 0) {
applyRow(agg, op, base::binary::add2(scaleDest, scaleAgg), b, c);
} else {
applyRow(agg, op, base::binary::second(), b, c);
if (scaleAgg != 1) {
mulScalar(scaleAgg);
}
}
return 0;
}
template<>
template <class Agg>
int BaseMatrixT<real>::applyCol(Agg agg, BaseMatrixT& b) {
......@@ -1518,9 +1548,24 @@ int BaseMatrixT<real>::applyCol(Agg agg, Saver sv, BaseMatrixT& b) {
return 0;
}
template<>
template <class Agg>
int BaseMatrixT<real>::applyCol(
Agg agg, real scaleDest, real scaleAgg, BaseMatrixT& b) {
if (scaleDest != 0) {
applyCol(agg, base::binary::add2(scaleDest, scaleAgg), b);
} else {
applyCol(agg, base::binary::second(), b);
if (scaleAgg != 1) {
mulScalar(scaleAgg);
}
}
return 0;
}
template<>
void BaseMatrixT<real>::sumRows(BaseMatrixT& b, real scaleSum, real scaleDest) {
applyRow(aggregate::sum(), base::binary::add2(scaleDest, scaleSum), b);
applyRow(aggregate::sum(), scaleDest, scaleSum, b);
}
template<>
......@@ -1550,21 +1595,21 @@ void BaseMatrixT<real>::minCols(BaseMatrixT& b) {
template<>
void BaseMatrixT<real>::sumCols(BaseMatrixT& b, real scaleSum, real scaleDest) {
applyCol(aggregate::sum(), base::binary::add2(scaleDest, scaleSum), b);
applyCol(aggregate::sum(), scaleDest, scaleSum, b);
}
template<>
void BaseMatrixT<real>::sumOfSquaredDiffs(
BaseMatrixT& b, BaseMatrixT& c, real scaleSum, real scaleDest) {
applyRow(aggregate::sum(), base::binary::squaredDiff(),
base::binary::add2(scaleDest, scaleSum), b, c);
scaleDest, scaleSum, b, c);
}
template<>
void BaseMatrixT<real>::sumOfProducts(
BaseMatrixT& b, BaseMatrixT& c, real scaleSum, real scaleDest) {
applyRow(aggregate::sum(), base::binary::mul(),
base::binary::add2(scaleDest, scaleSum), b, c);
scaleDest, scaleSum, b, c);
}
template class BaseMatrixT<real>;
......
......@@ -317,6 +317,11 @@ public:
template <class Agg, class Op, class Saver>
int applyRow(Agg agg, Op op, Saver sv, BaseMatrixT& b, BaseMatrixT& c);
// Same as the above with the special handing of sv=add2(scaleDest, scaleAgg)
template <class Agg, class Op>
int applyRow(Agg agg, Op op, real scaleDest, real scaleAgg,
BaseMatrixT& b, BaseMatrixT& c);
/**
* a aggregate expression that apply each row of matrix b.
*
......@@ -329,6 +334,10 @@ public:
template <class Agg, class Saver>
int applyRow(Agg agg, Saver sv, BaseMatrixT& b);
// Same as the above with the special handing of sv=add2(scaleDest, scaleAgg)
template <class Agg>
int applyRow(Agg agg, real scaleDest, real scaleAgg, BaseMatrixT& b);
/**
* a aggregate expression that apply each column of matrix b.
*
......@@ -352,6 +361,10 @@ public:
template <class Agg, class Saver>
int applyCol(Agg agg, Saver sv, BaseMatrixT& b);
// Same as the above with the special handing of sv=add2(scaleDest, scaleAgg)
template <class Agg>
int applyCol(Agg agg, real scaleDest, real scaleAgg, BaseMatrixT& b);
bool useGpu() const { return useGpu_; }
const T* rowBuf(size_t row) const { return data_ + width_ * row; }
......
......@@ -29,7 +29,6 @@ except ImportError:
import pickle
import copy
<<<<<<< 0ba0f02c685e52b14632f6b9bfca4321494505c7
__all__ = [
"full_matrix_projection",
"AggregateLevel",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册