提交 f8c9c889 编写于 作者: H hedaoyuan

Fix CrossMapNormalTest

上级 1e233171
...@@ -188,8 +188,13 @@ public: ...@@ -188,8 +188,13 @@ public:
CHECK(inputs[0].shape() == inputs[3].shape()); CHECK(inputs[0].shape() == inputs[3].shape());
CHECK(inputs[0].shape() == outputs[0].shape()); CHECK(inputs[0].shape() == outputs[0].shape());
// TODO(hedaoyuan): need support ASSIGN_TO mode. if (outputs[0].getArgType() != ADD_TO) {
CHECK_EQ(outputs[0].getArgType(), ADD_TO); // Currently, some algorithm implementations are ASSIGN_TO mode,
// if need to support the ADD_TO calculation, need to clear the output.
typename Tensor<real, Device>::Vector tmp(
outputs[0].shape().getElements(), outputs[0].data<real>());
tmp.zero();
}
size_t samples = inputs[0].shape()[0]; size_t samples = inputs[0].shape()[0];
size_t channels = inputs[0].shape()[1]; size_t channels = inputs[0].shape()[1];
......
...@@ -47,9 +47,6 @@ TEST(CrossMapNormal, real) { ...@@ -47,9 +47,6 @@ TEST(CrossMapNormal, real) {
} }
} }
#if 0
// TODO(hedaoyuan): Now CrossMapNormalGrad not support ASSIGN_TO mode.
// Maybe all Function need support ASSIGN_TO mode.
TEST(CrossMapNormalGrad, real) { TEST(CrossMapNormalGrad, real) {
for (size_t numSamples : {5, 32}) { for (size_t numSamples : {5, 32}) {
for (size_t channels : {1, 5, 32}) { for (size_t channels : {1, 5, 32}) {
...@@ -79,6 +76,5 @@ TEST(CrossMapNormalGrad, real) { ...@@ -79,6 +76,5 @@ TEST(CrossMapNormalGrad, real) {
} }
} }
} }
#endif
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册