diff --git a/paddle/operators/squared_l2_distance_op.cc b/paddle/operators/squared_l2_distance_op.cc index b19c274dcc0ef0164eb525160b79276904e33f00..694b00e493149f332e94d2353d3c13501a59ebd0 100644 --- a/paddle/operators/squared_l2_distance_op.cc +++ b/paddle/operators/squared_l2_distance_op.cc @@ -49,7 +49,9 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { "First dimension of target must be equal to input " "or to 1."); - ctx.Output("sub_result")->Resize(x_dims); + ctx.Output("sub_result") + ->Resize({static_cast(x_dims[0]), + static_cast(framework::product(x_dims) / x_dims[0])}); ctx.Output("Out")->Resize({x_dims[0], 1}); } }; @@ -97,8 +99,8 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { "must be 1."); auto* x_grad = ctx.Output(framework::GradVarName("X")); auto* y_grad = ctx.Output(framework::GradVarName("Y")); - if (x_grad != nullptr) x_grad->Resize(x_dims); - if (y_grad != nullptr) y_grad->Resize(y_dims); + if (x_grad) x_grad->Resize(x_dims); + if (y_grad) y_grad->Resize(y_dims); } }; diff --git a/paddle/operators/squared_l2_distance_op.h b/paddle/operators/squared_l2_distance_op.h index ec8c34ddf8d973f1dba76383135ac24879c4a6e8..97907768f7106bfd30ec5ad9fbf5dc7b635405b8 100644 --- a/paddle/operators/squared_l2_distance_op.h +++ b/paddle/operators/squared_l2_distance_op.h @@ -53,14 +53,16 @@ class SquaredL2DistanceKernel : public framework::OpKernel { auto y_dims = y.dimensions(); // buffer the substraction result if (y_dims[0] == 1 && x_dims[0] > y_dims[0]) { - auto y_broadcast_dims = y_dims; - y_broadcast_dims[0] = x_dims[0]; - sub_result.device(place) = x - y.broadcast(y_broadcast_dims); + sub_result.device(place) = + x - + y.broadcast(Eigen::array({static_cast(x_dims[0]), 1})); } else { sub_result.device(place) = x - y; } - - z.device(place) = sub_result.pow(2).sum(Eigen::array({1})); + auto sub_res_pow2 = sub_result * sub_result; + z.device(place) = + sub_res_pow2.sum(Eigen::array({1})) + .reshape(Eigen::array({static_cast(x_dims[0]), 1})); } }; @@ -86,7 +88,7 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel { // propagate back to input auto eigen_place = context.GetEigenDevice(); - if (x_g != nullptr) { + if (x_g) { x_g->mutable_data(context.GetPlace()); // eigen matrix auto x_grad = @@ -95,7 +97,7 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel { x_grad.device(eigen_place) = grad_mat; } - if (y_g != nullptr) { + if (y_g) { y_g->mutable_data(context.GetPlace()); auto y_grad = EigenMatrix::From(*y_g, framework::make_ddim({y_dims[0], cols})); @@ -107,8 +109,9 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel { if (sub_result.dimensions()[0] == y_dims[0]) { y_grad.device(eigen_place) = -1 * grad_mat; } else { + auto col_sum_res = -1 * (grad_mat.sum(Eigen::array({0}))); y_grad.device(eigen_place) = - -1 * (grad_mat.sum(Eigen::array({0}))); + col_sum_res.reshape(Eigen::array({1, cols})); } } } diff --git a/python/paddle/v2/framework/tests/test_squared_l2_distance_op.py b/python/paddle/v2/framework/tests/test_squared_l2_distance_op.py index 51c95b286a8cb8e125bd2f2baa7d0f87465c5c51..2bcdf37df434c9a089d75438d876114156261a5c 100644 --- a/python/paddle/v2/framework/tests/test_squared_l2_distance_op.py +++ b/python/paddle/v2/framework/tests/test_squared_l2_distance_op.py @@ -4,30 +4,84 @@ from gradient_checker import GradientChecker, create_op import numpy as np -class TestSquaredL2DistanceOp(unittest.TestCase): +class TestSquaredL2DistanceOp_f0(unittest.TestCase): __metaclass__ = OpTestMeta def setUp(self): self.type = 'squared_l2_distance' self.inputs = { - 'X': np.random.uniform(0.1, 1., (2, 3)).astype('float32'), - 'Y': np.random.uniform(0.1, 1., (2, 3)).astype('float32') + 'X': np.random.uniform(0.1, 1., (32, 64)).astype('float32'), + 'Y': np.random.uniform(0.1, 1., (32, 64)).astype('float32') } - subRes = self.inputs['X'] - self.inputs['Y'] - output = subRes * subRes + sub_res = self.inputs['X'] - self.inputs['Y'] + output = sub_res * sub_res self.outputs = { - 'sub_result': subRes, + 'sub_result': sub_res, + 'Out': np.expand_dims(output.sum(1), 1) + } + + +class TestSquaredL2DistanceOp_f1(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = 'squared_l2_distance' + self.inputs = { + 'X': np.random.uniform(0.1, 1., (32, 64)).astype('float32'), + 'Y': np.random.uniform(0.1, 1., (1, 64)).astype('float32') + } + sub_res = self.inputs['X'] - self.inputs['Y'] + output = sub_res * sub_res + self.outputs = { + 'sub_result': sub_res, + 'Out': np.expand_dims(output.sum(1), 1) + } + + +class TestSquaredL2DistanceOp_f2(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = 'squared_l2_distance' + self.inputs = { + 'X': np.random.uniform(0.1, 1., (32, 64, 128)).astype('float32'), + 'Y': np.random.uniform(0.1, 1., (1, 64, 128)).astype('float32') + } + sub_res = self.inputs['X'] - self.inputs['Y'] + sub_res = sub_res.reshape((32, 64 * 128)) + output = sub_res * sub_res + self.outputs = { + 'sub_result': sub_res, 'Out': np.expand_dims(output.sum(1), 1) } class TestSquaredL2DistanceGradOp(GradientChecker): - def test_squared_l2_distance(self): + def test_squared_l2_distance_b0(self): + op = create_op("squared_l2_distance") + inputs = { + 'X': np.random.uniform(0.1, .6, (2, 3)).astype('float32'), + 'Y': np.random.uniform(0.1, .6, (2, 3)).astype('float32') + } + self.compare_grad(op, inputs) + self.check_grad(op, inputs, set(["X", "Y"]), "Out") + + def test_squared_l2_distance_b1(self): + op = create_op("squared_l2_distance") + inputs = { + 'X': np.random.uniform(0.1, .6, (2, 3)).astype('float32'), + 'Y': np.random.uniform(0.1, .6, (1, 3)).astype('float32') + } + self.compare_grad(op, inputs) + self.check_grad(op, inputs, set(["X", "Y"]), "Out") + + def test_squared_l2_distance_b2(self): op = create_op("squared_l2_distance") inputs = { - 'X': np.random.uniform(0.1, 1., (2, 3)).astype('float32'), - 'Y': np.random.uniform(0.1, 1., (2, 3)).astype('float32') + 'X': np.random.uniform(0.1, .6, (2, 3, 4)).astype('float32'), + 'Y': np.random.uniform(0.1, .6, (1, 3, 4)).astype('float32') } + self.compare_grad(op, inputs) self.check_grad(op, inputs, set(["X", "Y"]), "Out")