提交 52e17bf5 编写于 作者: G guosheng

Fix the unit test of weight normalization

上级 6b9f1d34
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/reduce_op.h"
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
......@@ -38,10 +37,14 @@ class ReduceOp : public framework::OperatorWithKernel {
dim, x_rank,
"The dim should be in the range [-rank(input), rank(input)).");
bool reduce_all = ctx->Attrs().Get<bool>("reduce_all");
bool keep_dim = ctx->Attrs().Get<bool>("keep_dim");
if (reduce_all) {
ctx->SetOutputDim("Out", {1});
if (keep_dim)
ctx->SetOutputDim(
"Out", framework::make_ddim(std::vector<int64_t>(x_rank, 1)));
else
ctx->SetOutputDim("Out", {1});
} else {
bool keep_dim = ctx->Attrs().Get<bool>("keep_dim");
auto dims_vector = vectorize(x_dims);
if (keep_dim || x_rank == 1) {
dims_vector[dim] = 1;
......
......@@ -112,9 +112,8 @@ class TestWeightNormalization(unittest.TestCase):
[
self.assertTrue(
numpy.allclose(
numpy.array(actual_output), expect_output, atol=0.001))
for expect_output, actual_output in zip(expect_output,
actual_output)
numpy.array(actual), expect, atol=0.001))
for expect, actual in zip(expect_output, actual_output)
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册