diff --git a/paddle/function/CrossMapNormalOp.cpp b/paddle/function/CrossMapNormalOp.cpp index 92980c503fdaaaa9ac600070197dba6ba4bfb7a4..8e7dc72524a7680a03ea6eb4770a3e25c09ad913 100644 --- a/paddle/function/CrossMapNormalOp.cpp +++ b/paddle/function/CrossMapNormalOp.cpp @@ -188,8 +188,13 @@ public: CHECK(inputs[0].shape() == inputs[3].shape()); CHECK(inputs[0].shape() == outputs[0].shape()); - // TODO(hedaoyuan): need support ASSIGN_TO mode. - CHECK_EQ(outputs[0].getArgType(), ADD_TO); + if (outputs[0].getArgType() != ADD_TO) { + // Currently, some algorithm implementations are ASSIGN_TO mode, + // if need to support the ADD_TO calculation, need to clear the output. + typename Tensor::Vector tmp( + outputs[0].shape().getElements(), outputs[0].data()); + tmp.zero(); + } size_t samples = inputs[0].shape()[0]; size_t channels = inputs[0].shape()[1]; diff --git a/paddle/function/CrossMapNormalOpTest.cpp b/paddle/function/CrossMapNormalOpTest.cpp index da196a699cc32b079f8e04c65ec6b25b7dc24700..51f5da81bfc9ae870ac9949ba74da01a9449a04d 100644 --- a/paddle/function/CrossMapNormalOpTest.cpp +++ b/paddle/function/CrossMapNormalOpTest.cpp @@ -47,9 +47,6 @@ TEST(CrossMapNormal, real) { } } -#if 0 -// TODO(hedaoyuan): Now CrossMapNormalGrad not support ASSIGN_TO mode. -// Maybe all Function need support ASSIGN_TO mode. TEST(CrossMapNormalGrad, real) { for (size_t numSamples : {5, 32}) { for (size_t channels : {1, 5, 32}) { @@ -79,6 +76,5 @@ TEST(CrossMapNormalGrad, real) { } } } -#endif } // namespace paddle