From e256bfaf28a0984a15d594110ad1e868380a3e25 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Tue, 15 Aug 2017 17:12:35 +0800 Subject: [PATCH] "update paddle enforce" --- paddle/operators/mul_op.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 9a57e6b68f..5645df6677 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -68,16 +68,16 @@ class MulOpGrad : public framework::OperatorWithKernel { "Input(Out@GRAD) should not be null"); auto *x_grad = ctx.Output(framework::GradVarName("X")); auto *y_grad = ctx.Output(framework::GradVarName("Y")); - auto dim0 = ctx.Output(framework::GradVarName("X"))->dims(); - auto dim1 = ctx.Output(framework::GradVarName("Y"))->dims(); + auto x_dims = ctx.Output(framework::GradVarName("X"))->dims(); + auto y_dims = ctx.Output(framework::GradVarName("Y"))->dims(); auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); - PADDLE_ENFORCE(dim0[0] * dim1[0] == out_dims[0], - "Out@GRAD[0] must equal to X[0] * Y[0]"); - PADDLE_ENFORCE(dim0[1] * dim1[1] == out_dims[1], - "Out@GRAD shape must equal to X[1] * Y[1]"); + PADDLE_ENFORCE(x_dims[0] == out_dims[0], + "Out@GRAD M X N must equal to X dims 0, M "); + PADDLE_ENFORCE(y_dims[1] == out_dims[1], + "Out@GRAD M X N must equal to Y dims 1, N "); - x_grad->Resize(dim1); - y_grad->Resize(dim0); + x_grad->Resize(x_dims); + y_grad->Resize(y_dims); } }; -- GitLab