From f8c9c889c34dd3530b899fc12523579802d4f582 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 16 Jan 2017 21:30:44 +0800 Subject: [PATCH] Fix CrossMapNormalTest --- paddle/function/CrossMapNormalOp.cpp | 9 +++++++-- paddle/function/CrossMapNormalOpTest.cpp | 4 ---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/paddle/function/CrossMapNormalOp.cpp b/paddle/function/CrossMapNormalOp.cpp index 92980c503fd..8e7dc72524a 100644 --- a/paddle/function/CrossMapNormalOp.cpp +++ b/paddle/function/CrossMapNormalOp.cpp @@ -188,8 +188,13 @@ public: CHECK(inputs[0].shape() == inputs[3].shape()); CHECK(inputs[0].shape() == outputs[0].shape()); - // TODO(hedaoyuan): need support ASSIGN_TO mode. - CHECK_EQ(outputs[0].getArgType(), ADD_TO); + if (outputs[0].getArgType() != ADD_TO) { + // Currently, some algorithm implementations are ASSIGN_TO mode, + // if need to support the ADD_TO calculation, need to clear the output. + typename Tensor::Vector tmp( + outputs[0].shape().getElements(), outputs[0].data()); + tmp.zero(); + } size_t samples = inputs[0].shape()[0]; size_t channels = inputs[0].shape()[1]; diff --git a/paddle/function/CrossMapNormalOpTest.cpp b/paddle/function/CrossMapNormalOpTest.cpp index da196a699cc..51f5da81bfc 100644 --- a/paddle/function/CrossMapNormalOpTest.cpp +++ b/paddle/function/CrossMapNormalOpTest.cpp @@ -47,9 +47,6 @@ TEST(CrossMapNormal, real) { } } -#if 0 -// TODO(hedaoyuan): Now CrossMapNormalGrad not support ASSIGN_TO mode. -// Maybe all Function need support ASSIGN_TO mode. TEST(CrossMapNormalGrad, real) { for (size_t numSamples : {5, 32}) { for (size_t channels : {1, 5, 32}) { @@ -79,6 +76,5 @@ TEST(CrossMapNormalGrad, real) { } } } -#endif } // namespace paddle -- GitLab