提交 c33ddc74 编写于 作者: Y yangyaming

Fix some bugs, add more unittests.

上级 e9cc3282
...@@ -49,7 +49,9 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { ...@@ -49,7 +49,9 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
"First dimension of target must be equal to input " "First dimension of target must be equal to input "
"or to 1."); "or to 1.");
ctx.Output<Tensor>("sub_result")->Resize(x_dims); ctx.Output<Tensor>("sub_result")
->Resize({static_cast<int>(x_dims[0]),
static_cast<int>(framework::product(x_dims) / x_dims[0])});
ctx.Output<Tensor>("Out")->Resize({x_dims[0], 1}); ctx.Output<Tensor>("Out")->Resize({x_dims[0], 1});
} }
}; };
...@@ -97,8 +99,8 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { ...@@ -97,8 +99,8 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
"must be 1."); "must be 1.");
auto* x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto* x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* y_grad = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
if (x_grad != nullptr) x_grad->Resize(x_dims); if (x_grad) x_grad->Resize(x_dims);
if (y_grad != nullptr) y_grad->Resize(y_dims); if (y_grad) y_grad->Resize(y_dims);
} }
}; };
......
...@@ -53,14 +53,16 @@ class SquaredL2DistanceKernel : public framework::OpKernel { ...@@ -53,14 +53,16 @@ class SquaredL2DistanceKernel : public framework::OpKernel {
auto y_dims = y.dimensions(); auto y_dims = y.dimensions();
// buffer the substraction result // buffer the substraction result
if (y_dims[0] == 1 && x_dims[0] > y_dims[0]) { if (y_dims[0] == 1 && x_dims[0] > y_dims[0]) {
auto y_broadcast_dims = y_dims; sub_result.device(place) =
y_broadcast_dims[0] = x_dims[0]; x -
sub_result.device(place) = x - y.broadcast(y_broadcast_dims); y.broadcast(Eigen::array<int, 2>({static_cast<int>(x_dims[0]), 1}));
} else { } else {
sub_result.device(place) = x - y; sub_result.device(place) = x - y;
} }
auto sub_res_pow2 = sub_result * sub_result;
z.device(place) = sub_result.pow(2).sum(Eigen::array<int, 1>({1})); z.device(place) =
sub_res_pow2.sum(Eigen::array<int, 1>({1}))
.reshape(Eigen::array<int, 2>({static_cast<int>(x_dims[0]), 1}));
} }
}; };
...@@ -86,7 +88,7 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel { ...@@ -86,7 +88,7 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel {
// propagate back to input // propagate back to input
auto eigen_place = context.GetEigenDevice<Place>(); auto eigen_place = context.GetEigenDevice<Place>();
if (x_g != nullptr) { if (x_g) {
x_g->mutable_data<T>(context.GetPlace()); x_g->mutable_data<T>(context.GetPlace());
// eigen matrix // eigen matrix
auto x_grad = auto x_grad =
...@@ -95,7 +97,7 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel { ...@@ -95,7 +97,7 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel {
x_grad.device(eigen_place) = grad_mat; x_grad.device(eigen_place) = grad_mat;
} }
if (y_g != nullptr) { if (y_g) {
y_g->mutable_data<T>(context.GetPlace()); y_g->mutable_data<T>(context.GetPlace());
auto y_grad = auto y_grad =
EigenMatrix<T>::From(*y_g, framework::make_ddim({y_dims[0], cols})); EigenMatrix<T>::From(*y_g, framework::make_ddim({y_dims[0], cols}));
...@@ -107,8 +109,9 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel { ...@@ -107,8 +109,9 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel {
if (sub_result.dimensions()[0] == y_dims[0]) { if (sub_result.dimensions()[0] == y_dims[0]) {
y_grad.device(eigen_place) = -1 * grad_mat; y_grad.device(eigen_place) = -1 * grad_mat;
} else { } else {
auto col_sum_res = -1 * (grad_mat.sum(Eigen::array<int, 1>({0})));
y_grad.device(eigen_place) = y_grad.device(eigen_place) =
-1 * (grad_mat.sum(Eigen::array<int, 2>({0}))); col_sum_res.reshape(Eigen::array<int, 2>({1, cols}));
} }
} }
} }
......
...@@ -4,30 +4,84 @@ from gradient_checker import GradientChecker, create_op ...@@ -4,30 +4,84 @@ from gradient_checker import GradientChecker, create_op
import numpy as np import numpy as np
class TestSquaredL2DistanceOp(unittest.TestCase): class TestSquaredL2DistanceOp_f0(unittest.TestCase):
__metaclass__ = OpTestMeta __metaclass__ = OpTestMeta
def setUp(self): def setUp(self):
self.type = 'squared_l2_distance' self.type = 'squared_l2_distance'
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1., (2, 3)).astype('float32'), 'X': np.random.uniform(0.1, 1., (32, 64)).astype('float32'),
'Y': np.random.uniform(0.1, 1., (2, 3)).astype('float32') 'Y': np.random.uniform(0.1, 1., (32, 64)).astype('float32')
} }
subRes = self.inputs['X'] - self.inputs['Y'] sub_res = self.inputs['X'] - self.inputs['Y']
output = subRes * subRes output = sub_res * sub_res
self.outputs = { self.outputs = {
'sub_result': subRes, 'sub_result': sub_res,
'Out': np.expand_dims(output.sum(1), 1)
}
class TestSquaredL2DistanceOp_f1(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = 'squared_l2_distance'
self.inputs = {
'X': np.random.uniform(0.1, 1., (32, 64)).astype('float32'),
'Y': np.random.uniform(0.1, 1., (1, 64)).astype('float32')
}
sub_res = self.inputs['X'] - self.inputs['Y']
output = sub_res * sub_res
self.outputs = {
'sub_result': sub_res,
'Out': np.expand_dims(output.sum(1), 1)
}
class TestSquaredL2DistanceOp_f2(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = 'squared_l2_distance'
self.inputs = {
'X': np.random.uniform(0.1, 1., (32, 64, 128)).astype('float32'),
'Y': np.random.uniform(0.1, 1., (1, 64, 128)).astype('float32')
}
sub_res = self.inputs['X'] - self.inputs['Y']
sub_res = sub_res.reshape((32, 64 * 128))
output = sub_res * sub_res
self.outputs = {
'sub_result': sub_res,
'Out': np.expand_dims(output.sum(1), 1) 'Out': np.expand_dims(output.sum(1), 1)
} }
class TestSquaredL2DistanceGradOp(GradientChecker): class TestSquaredL2DistanceGradOp(GradientChecker):
def test_squared_l2_distance(self): def test_squared_l2_distance_b0(self):
op = create_op("squared_l2_distance")
inputs = {
'X': np.random.uniform(0.1, .6, (2, 3)).astype('float32'),
'Y': np.random.uniform(0.1, .6, (2, 3)).astype('float32')
}
self.compare_grad(op, inputs)
self.check_grad(op, inputs, set(["X", "Y"]), "Out")
def test_squared_l2_distance_b1(self):
op = create_op("squared_l2_distance")
inputs = {
'X': np.random.uniform(0.1, .6, (2, 3)).astype('float32'),
'Y': np.random.uniform(0.1, .6, (1, 3)).astype('float32')
}
self.compare_grad(op, inputs)
self.check_grad(op, inputs, set(["X", "Y"]), "Out")
def test_squared_l2_distance_b2(self):
op = create_op("squared_l2_distance") op = create_op("squared_l2_distance")
inputs = { inputs = {
'X': np.random.uniform(0.1, 1., (2, 3)).astype('float32'), 'X': np.random.uniform(0.1, .6, (2, 3, 4)).astype('float32'),
'Y': np.random.uniform(0.1, 1., (2, 3)).astype('float32') 'Y': np.random.uniform(0.1, .6, (1, 3, 4)).astype('float32')
} }
self.compare_grad(op, inputs)
self.check_grad(op, inputs, set(["X", "Y"]), "Out") self.check_grad(op, inputs, set(["X", "Y"]), "Out")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册