提交 8d736813 编写于 作者: H hedaoyuan

fix some errors due to merge

上级 671db8de
......@@ -345,7 +345,7 @@ void forward(Argument& act) {
useGpu(act.deviceId));
act.in->copyFrom(*act.value);
act.value->log(*act.value);
act.value->log2(*act.value);
}
void backward(Argument& act) { act.grad->dotDiv(*act.grad, *act.in); }
......
......@@ -37,13 +37,13 @@ TEST(BaseMatrix, void) {
};
compare(&BaseMatrix::neg);
compare(&BaseMatrix::exp);
compare(&BaseMatrix::log);
compare(&BaseMatrix::sqrt);
compare(&BaseMatrix::square);
compare(&BaseMatrix::reciprocal);
compare(&BaseMatrix::abs);
compare(&BaseMatrix::sign);
compare(&BaseMatrix::exp2);
compare(&BaseMatrix::log2);
compare(&BaseMatrix::sqrt2);
compare(&BaseMatrix::square2);
compare(&BaseMatrix::reciprocal2);
compare(&BaseMatrix::abs2);
compare(&BaseMatrix::sign2);
compare(&BaseMatrix::zero);
compare(&BaseMatrix::one);
}
......@@ -59,7 +59,7 @@ TEST(BaseMatrix, real) {
test.cmpWithoutArg<0>(f, height, width);
};
compare(&BaseMatrix::pow);
compare(&BaseMatrix::pow2);
compare(&BaseMatrix::subScalar);
compare(&BaseMatrix::mulScalar);
compare(&BaseMatrix::divScalar);
......@@ -88,21 +88,21 @@ TEST(BaseMatrix, BaseMatrix) {
compare(&BaseMatrix::softreluDerivative);
compare(&BaseMatrix::brelu);
compare(&BaseMatrix::breluDerivative);
compare(&BaseMatrix::square);
compare(&BaseMatrix::square2);
compare(&BaseMatrix::squareDerivative);
compare(&BaseMatrix::tanh);
compare(&BaseMatrix::tanhDerivative);
compare(&BaseMatrix::reciprocal);
compare(&BaseMatrix::reciprocal2);
compare(&BaseMatrix::reciprocalDerivative);
compare(&BaseMatrix::abs);
compare(&BaseMatrix::abs2);
compare(&BaseMatrix::absDerivative);
compare(&BaseMatrix::sigmoid);
compare(&BaseMatrix::sigmoidDerivative);
compare(&BaseMatrix::expDerivative);
compare(&BaseMatrix::sign);
compare(&BaseMatrix::exp);
compare(&BaseMatrix::log);
compare(&BaseMatrix::sqrt);
compare(&BaseMatrix::sign2);
compare(&BaseMatrix::exp2);
compare(&BaseMatrix::log2);
compare(&BaseMatrix::sqrt2);
compare(&BaseMatrix::dotMul);
compare(&BaseMatrix::dotMulSquare);
compare(&BaseMatrix::dotSquareMul);
......@@ -143,7 +143,7 @@ TEST(BaseMatrix, BaseMatrix_real) {
compare(&BaseMatrix::addBias);
compare(&BaseMatrix::add);
compare(&BaseMatrix::sub);
compare(&BaseMatrix::pow);
compare(&BaseMatrix::pow2);
compare(&BaseMatrix::addScalar);
compare(&BaseMatrix::subScalar);
compare(&BaseMatrix::mulScalar);
......@@ -176,7 +176,7 @@ TEST(BaseMatrix, BaseMatrix_BaseMatrix) {
compare(&BaseMatrix::logisticRegressionLoss);
compare(&BaseMatrix::logisticRegressionLossBp);
compare(&BaseMatrix::biggerThan);
compare(&BaseMatrix::max);
compare(&BaseMatrix::max2);
compare(&BaseMatrix::dotMulSquare);
compare(&BaseMatrix::dotSquareSquare);
}
......
......@@ -18,6 +18,8 @@ limitations under the License. */
using namespace paddle; // NOLINT
using namespace std; // NOLINT
using autotest::TensorCheckEqual;
using autotest::TensorCheckErr;
#define INIT_UNARY(A1, A2) \
Tensor A1(height, width); \
......
......@@ -19,6 +19,8 @@ limitations under the License. */
using namespace paddle; // NOLINT
using namespace std; // NOLINT
using autotest::TensorCheckEqual;
using autotest::TensorCheckErr;
typedef std::function<void(int height, int width)> testMatrixFunc;
void testMatrixCase(testMatrixFunc matrixFunc) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册