From 91215bcef907baffdf52bb7894a0a33d0253c16f Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Sat, 2 Sep 2017 17:25:30 +0800 Subject: [PATCH] Fix a bug causing wrong gradient results in cos_sim op. --- paddle/operators/cos_sim_op.cc | 32 ++++++++++---- paddle/operators/cos_sim_op.h | 34 +++++++++------ .../v2/framework/tests/gradient_checker.py | 7 ---- .../v2/framework/tests/test_cos_sim_op.py | 42 ++++++++++++------- 4 files changed, 71 insertions(+), 44 deletions(-) diff --git a/paddle/operators/cos_sim_op.cc b/paddle/operators/cos_sim_op.cc index 21a616522b..3760d0b161 100644 --- a/paddle/operators/cos_sim_op.cc +++ b/paddle/operators/cos_sim_op.cc @@ -25,14 +25,16 @@ class CosSimOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null."); PADDLE_ENFORCE_EQ(ctx.Input("X")->dims(), ctx.Input("Y")->dims(), "Dimensions of Input(X) and Input(Y) must be the same."); auto dims = ctx.Input("X")->dims(); ctx.Output("Out")->Resize({dims[0], 1}); + ctx.Output("XNorm")->Resize({dims[0], 1}); + ctx.Output("YNorm")->Resize({dims[0], 1}); } }; @@ -43,6 +45,9 @@ class CosSimOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The first input of cos_sim op."); AddInput("Y", "The second input of cos_sim op."); AddOutput("Out", "The output of cos_sim op."); + AddOutput("XNorm", "Row norm of the first input.").AsIntermediate(); + AddOutput("YNorm", "Row norm of the second input.").AsIntermediate(); + AddComment(R"DOC( Cosine Similarity Operator. @@ -57,20 +62,31 @@ class CosSimOpGrad : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("XNorm"), + "Input(XNorm) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("YNorm"), + "Input(YNorm) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null."); + "Input(Out@GRAD) must not be null."); auto x_dims = ctx.Input("X")->dims(); auto y_dims = ctx.Input("Y")->dims(); + auto xnorm_dims = ctx.Input("XNorm")->dims(); + auto ynorm_dims = ctx.Input("YNorm")->dims(); auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); PADDLE_ENFORCE_EQ(x_dims, y_dims, "Dimensions of Input(X) and Input(Y) must be the same."); + PADDLE_ENFORCE_EQ(xnorm_dims[0], x_dims[0], + "1st dimension of XNorm must equal that of Input(X)."); + PADDLE_ENFORCE_EQ(xnorm_dims[1], 1, "2st dimension of XNorm must be one."); + PADDLE_ENFORCE_EQ(ynorm_dims[0], y_dims[0], + "1st dimension of YNorm must equal that of Input(Y)."); + PADDLE_ENFORCE_EQ(ynorm_dims[1], 1, "2st dimension of YNorm must be one."); PADDLE_ENFORCE_EQ(out_dims[0], x_dims[0], - "1st dimension of Out@GRAD must equal to Input(X)"); - PADDLE_ENFORCE_EQ(out_dims[1], 1, - "1st dimension of Out@GRAD must equal to Input(X)"); + "1st dimension of Out@GRAD must equal that of Input(X)"); + PADDLE_ENFORCE_EQ(out_dims[1], 1, "1st dimension of Out@GRAD must be one."); auto *x_grad = ctx.Output(framework::GradVarName("X")); auto *y_grad = ctx.Output(framework::GradVarName("Y")); diff --git a/paddle/operators/cos_sim_op.h b/paddle/operators/cos_sim_op.h index 5247087cc1..69d35d8bc2 100644 --- a/paddle/operators/cos_sim_op.h +++ b/paddle/operators/cos_sim_op.h @@ -31,21 +31,27 @@ class CosSimKernel : public framework::OpKernel { auto* x = context.Input("X"); auto* y = context.Input("Y"); auto* z = context.Output("Out"); + auto* x_norm = context.Output("XNorm"); + auto* y_norm = context.Output("YNorm"); z->mutable_data(context.GetPlace()); + x_norm->mutable_data(context.GetPlace()); + y_norm->mutable_data(context.GetPlace()); auto dims = x->dims(); int size = static_cast(framework::product(dims)); auto new_dims = framework::make_ddim({dims[0], size / dims[0]}); auto X = EigenMatrix::From(*x, new_dims); auto Y = EigenMatrix::From(*y, new_dims); - auto Z = EigenMatrix::From(*z, new_dims); + auto Z = EigenMatrix::From(*z); + auto XNorm = EigenMatrix::From(*x_norm); + auto YNorm = EigenMatrix::From(*y_norm); - auto XY = (X * Y).sum(Eigen::array({1})); - auto XX = (X * X).sum(Eigen::array({1})); - auto YY = (Y * Y).sum(Eigen::array({1})); auto place = context.GetEigenDevice(); - Z.device(place) = XY / XX.sqrt() / YY.sqrt(); + auto XY = (X * Y).sum(Eigen::array({1})); + XNorm.device(place) = (X * X).sum(Eigen::array({1})).sqrt(); + YNorm.device(place) = (Y * Y).sum(Eigen::array({1})).sqrt(); + Z.device(place) = XY / XNorm / YNorm; } }; @@ -56,6 +62,8 @@ class CosSimGradKernel : public framework::OpKernel { auto* x = context.Input("X"); auto* y = context.Input("Y"); auto* z = context.Input("Out"); + auto* x_norm = context.Input("XNorm"); + auto* y_norm = context.Input("YNorm"); auto* grad_x = context.Output(framework::GradVarName("X")); auto* grad_y = context.Output(framework::GradVarName("Y")); auto* grad_z = context.Input(framework::GradVarName("Out")); @@ -69,23 +77,23 @@ class CosSimGradKernel : public framework::OpKernel { auto X = EigenMatrix::From(*x, new_dims); auto Y = EigenMatrix::From(*y, new_dims); auto Z = EigenMatrix::From(*z); + auto X_norm = EigenMatrix::From(*x_norm); + auto Y_norm = EigenMatrix::From(*y_norm); auto dX = EigenMatrix::From(*grad_x, new_dims); auto dY = EigenMatrix::From(*grad_y, new_dims); auto dZ = EigenMatrix::From(*grad_z); - auto XX = (X * X).sum(Eigen::array({1})); - auto YY = (Y * Y).sum(Eigen::array({1})); - Eigen::DSizes bcast(1, dims[1]); - auto denominator_bcast = (XX.sqrt() * YY.sqrt()).broadcast(bcast); + Eigen::DSizes bcast(1, new_dims[1]); auto Z_bcast = Z.broadcast(bcast); auto dZ_bcast = dZ.broadcast(bcast); auto place = context.GetEigenDevice(); + auto X_snorm_bcast = X_norm.square().eval().broadcast(bcast); + auto Y_snorm_bcast = Y_norm.square().eval().broadcast(bcast); + auto norm_prod_bcast = (X_norm * Y_norm).eval().broadcast(bcast); dX.device(place) = - dZ_bcast * (Y / denominator_bcast - Z_bcast * X / XX.broadcast(bcast)); + dZ_bcast * (Y / norm_prod_bcast - Z_bcast * X / X_snorm_bcast); dY.device(place) = - dZ_bcast * (X / denominator_bcast - Z_bcast * Y / YY.broadcast(bcast)); - // dX.device(place) = X; - // Y.device(place) = Y; + dZ_bcast * (X / norm_prod_bcast - Z_bcast * Y / Y_snorm_bcast); } }; diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py index bf01ea4876..409b3caf33 100644 --- a/python/paddle/v2/framework/tests/gradient_checker.py +++ b/python/paddle/v2/framework/tests/gradient_checker.py @@ -304,13 +304,6 @@ class GradientChecker(unittest.TestCase): # get analytical gradients according to different device analytic_grads = self.__get_gradient(forward_op, backward_op, input_vars, check_names, place) - #print(numeric_grads[0], numeric_grads[0].shape) - print("dim0: ", numeric_grads[0], numeric_grads[0].shape) - print("dim0: ", analytic_grads[0], analytic_grads[0].shape) - print("---------------------") - print("dim1: ", numeric_grads[1], numeric_grads[1].shape) - print("dim1: ", analytic_grads[1], analytic_grads[1].shape) - assert False self.__assert_is_close(numeric_grads, analytic_grads, check_names, max_relative_error, "Gradient Check On %s" % str(place)) diff --git a/python/paddle/v2/framework/tests/test_cos_sim_op.py b/python/paddle/v2/framework/tests/test_cos_sim_op.py index f3b04d25f2..a19be47f76 100644 --- a/python/paddle/v2/framework/tests/test_cos_sim_op.py +++ b/python/paddle/v2/framework/tests/test_cos_sim_op.py @@ -10,30 +10,40 @@ class TestCosSimOp(unittest.TestCase): def setUp(self): self.type = "cos_sim" self.inputs = { - 'X': np.random.random((32, 84)).astype("float32"), - 'Y': np.random.random((32, 84)).astype("float32") + 'X': np.random.random((32, 64)).astype("float32"), + 'Y': np.random.random((32, 64)).astype("float32") + } + expect_x_norm = np.linalg.norm(self.inputs['X'], axis=1) + expect_y_norm = np.linalg.norm(self.inputs['Y'], axis=1) + expect_out = (self.inputs['X'] * self.inputs['Y']).sum(axis=1) / \ + expect_x_norm / expect_y_norm + self.outputs = { + 'XNorm': np.expand_dims(expect_x_norm, 1), + 'YNorm': np.expand_dims(expect_y_norm, 1), + 'Out': np.expand_dims(expect_out, 1) } - expect = (self.inputs['X'] * self.inputs['Y']).sum(axis=1) / \ - np.linalg.norm(self.inputs['X'], axis=1) / \ - np.linalg.norm(self.inputs['Y'], axis=1) - expect = np.expand_dims(expect, 1) - self.outputs = {'Out': expect} class CosSimGradOpTest(GradientChecker): - def test_cos_sim(self): + def test_cos_sim_2d(self): + op = create_op("cos_sim") + inputs = { + 'X': np.random.random((10, 5)).astype("float32"), + 'Y': np.random.random((10, 5)).astype("float32") + } + self.compare_grad(op, inputs) + self.check_grad( + op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.05) + + def test_cos_sim_3d(self): op = create_op("cos_sim") - #inputs = { - #'X': np.random.random((2, 2)).astype("float32"), - #'Y': np.random.random((2, 2)).astype("float32") - #} inputs = { - 'X': np.array([[0.9, 0.6], [1.9, 1.6]]).astype("float32"), - 'Y': np.array([[0.7, 0.8], [1.7, 1.8]]).astype("float32") + 'X': np.random.random((10, 5, 2)).astype("float32"), + 'Y': np.random.random((10, 5, 2)).astype("float32") } - print(inputs) + self.compare_grad(op, inputs) self.check_grad( - op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.5) + op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.05) if __name__ == '__main__': -- GitLab