提交 16fddf32 编写于 作者: X Xinghai Sun

Add broadcasting support (e.g. matrix-vector) for cos sim operator.

上级 b59f3018
...@@ -25,16 +25,29 @@ class CosSimOp : public framework::OperatorWithKernel { ...@@ -25,16 +25,29 @@ class CosSimOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
// notnull check
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null.");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(),
ctx.Input<Tensor>("Y")->dims(), // shape check
"Dimensions of Input(X) and Input(Y) must be the same."); auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims();
auto dims = ctx.Input<Tensor>("X")->dims(); PADDLE_ENFORCE_EQ(framework::arity(x_dims), framework::arity(y_dims),
ctx.Output<Tensor>("Out")->Resize({dims[0], 1}); "Ranks of Input(X) and Input(Y) must be equal.");
ctx.Output<Tensor>("XNorm")->Resize({dims[0], 1}); PADDLE_ENFORCE_GE(framework::arity(x_dims), 2,
ctx.Output<Tensor>("YNorm")->Resize({dims[0], 1}); "Rank of Input(X) must not be less than 2.");
PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, 1, framework::arity(x_dims)),
framework::slice_ddim(y_dims, 1, framework::arity(y_dims)),
"All dimensions except 1st of Input(X) and Input(Y) must be equal.");
PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1,
"1st dimension of Input(Y) must be equal to Input(X) or "
"just 1 (which will be broadcasted to match Input(X)).");
// resize tensor
ctx.Output<Tensor>("Out")->Resize({x_dims[0], 1});
ctx.Output<Tensor>("XNorm")->Resize({x_dims[0], 1});
ctx.Output<Tensor>("YNorm")->Resize({y_dims[0], 1});
} }
}; };
...@@ -42,8 +55,8 @@ class CosSimOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -42,8 +55,8 @@ class CosSimOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CosSimOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) CosSimOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of cos_sim op."); AddInput("X", "The 1st input of cos_sim op.");
AddInput("Y", "The second input of cos_sim op."); AddInput("Y", "The 2nd input of cos_sim op.");
AddOutput("Out", "The output of cos_sim op."); AddOutput("Out", "The output of cos_sim op.");
AddOutput("XNorm", "Row norm of the first input.").AsIntermediate(); AddOutput("XNorm", "Row norm of the first input.").AsIntermediate();
AddOutput("YNorm", "Row norm of the second input.").AsIntermediate(); AddOutput("YNorm", "Row norm of the second input.").AsIntermediate();
...@@ -51,7 +64,12 @@ class CosSimOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -51,7 +64,12 @@ class CosSimOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Cosine Similarity Operator. Cosine Similarity Operator.
The equation is: Out = X^T * Y / (sqrt(X^T * X) * sqrt(Y^T * Y)) The equation is: Out = X^T * Y / (sqrt(X^T * X) * sqrt(Y^T * Y)).
Input(X) and Input(Y) must have the same shape, except that the 1st dimension
of Input(Y) could be just 1 (different from Input(X)), which will be
broadcasted to match the shape of Input(X) before computing their cosine
similarity.
)DOC"); )DOC");
} }
}; };
...@@ -62,32 +80,47 @@ class CosSimOpGrad : public framework::OperatorWithKernel { ...@@ -62,32 +80,47 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
// notnull check
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("XNorm"), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("XNorm"),
"Input(XNorm) must not be null."); "Input(XNorm) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("YNorm"), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("YNorm"),
"Input(YNorm) must not be null."); "Input(YNorm) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Out"),
"Input(Out) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) must not be null."); "Input(Out@GRAD) must not be null.");
// shape check
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims(); auto y_dims = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_GE(framework::arity(x_dims), framework::arity(y_dims),
"Ranks of Input(X) and Input(Y) must be equal.");
PADDLE_ENFORCE_GE(framework::arity(x_dims), 2,
"Rank of Input(X) must not be less than 2.");
PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, 1, framework::arity(x_dims)),
framework::slice_ddim(y_dims, 1, framework::arity(y_dims)),
"All dimensions except 1st of Input(X) and Input(Y) must be equal.");
PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1,
"1st dimension of Input(Y) must be equal to Input(X) or "
"just 1 (which will be broadcasted to match Input(X)).");
auto xnorm_dims = ctx.Input<Tensor>("XNorm")->dims(); auto xnorm_dims = ctx.Input<Tensor>("XNorm")->dims();
PADDLE_ENFORCE_EQ(xnorm_dims, framework::make_ddim({x_dims[0], 1}),
"Shape of Input(XNorm) must be [X.Dim(0), 1].");
auto ynorm_dims = ctx.Input<Tensor>("YNorm")->dims(); auto ynorm_dims = ctx.Input<Tensor>("YNorm")->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); PADDLE_ENFORCE_EQ(ynorm_dims, framework::make_ddim({y_dims[0], 1}),
PADDLE_ENFORCE_EQ(x_dims, y_dims, "Shape of Input(YNorm) must be [Y.Dim(0), 1].");
"Dimensions of Input(X) and Input(Y) must be the same."); auto out_dims = ctx.Input<Tensor>("Out")->dims();
PADDLE_ENFORCE_EQ(xnorm_dims[0], x_dims[0], PADDLE_ENFORCE_EQ(out_dims, framework::make_ddim({x_dims[0], 1}),
"1st dimension of XNorm must equal that of Input(X)."); "Shape of Input(Out) must be [X.Dim(0), 1].");
PADDLE_ENFORCE_EQ(xnorm_dims[1], 1, "2st dimension of XNorm must be one."); auto out_grad_dims =
PADDLE_ENFORCE_EQ(ynorm_dims[0], y_dims[0], ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
"1st dimension of YNorm must equal that of Input(Y)."); PADDLE_ENFORCE_EQ(out_grad_dims, framework::make_ddim({x_dims[0], 1}),
PADDLE_ENFORCE_EQ(ynorm_dims[1], 1, "2st dimension of YNorm must be one."); "Shape of Input(Out@Grad) must be [X.Dim(0), 1].");
PADDLE_ENFORCE_EQ(out_dims[0], x_dims[0],
"1st dimension of Out@GRAD must equal that of Input(X)"); // resize tensor
PADDLE_ENFORCE_EQ(out_dims[1], 1, "1st dimension of Out@GRAD must be one.");
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y")); auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
if (x_grad) x_grad->Resize(x_dims); if (x_grad) x_grad->Resize(x_dims);
......
...@@ -28,30 +28,38 @@ template <typename Place, typename T> ...@@ -28,30 +28,38 @@ template <typename Place, typename T>
class CosSimKernel : public framework::OpKernel { class CosSimKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* input_x = context.Input<Tensor>("X"); // get Tensor
auto* input_y = context.Input<Tensor>("Y"); auto* in_x = context.Input<Tensor>("X");
auto* output_z = context.Output<Tensor>("Out"); auto* in_y = context.Input<Tensor>("Y");
auto* output_x_norm = context.Output<Tensor>("XNorm"); auto* out_z = context.Output<Tensor>("Out");
auto* output_y_norm = context.Output<Tensor>("YNorm"); auto* out_x_norm = context.Output<Tensor>("XNorm");
auto* out_y_norm = context.Output<Tensor>("YNorm");
out_z->mutable_data<T>(context.GetPlace());
out_x_norm->mutable_data<T>(context.GetPlace());
out_y_norm->mutable_data<T>(context.GetPlace());
output_z->mutable_data<T>(context.GetPlace()); // convert Tensor to Eigen Tensor
output_x_norm->mutable_data<T>(context.GetPlace()); int rows_x = in_x->dims()[0];
output_y_norm->mutable_data<T>(context.GetPlace()); int rows_y = in_y->dims()[0];
int cols = framework::product(in_x->dims()) / rows_x;
auto dims = input_x->dims(); auto x = EigenMatrix<T>::From(*in_x, framework::make_ddim({rows_x, cols}));
int size = static_cast<int>(framework::product(dims)); auto y = EigenMatrix<T>::From(*in_y, framework::make_ddim({rows_y, cols}));
auto new_dims = framework::make_ddim({dims[0], size / dims[0]}); auto z = EigenMatrix<T>::From(*out_z);
auto x = EigenMatrix<T>::From(*input_x, new_dims); auto x_norm = EigenMatrix<T>::From(*out_x_norm);
auto y = EigenMatrix<T>::From(*input_y, new_dims); auto y_norm = EigenMatrix<T>::From(*out_y_norm);
auto z = EigenMatrix<T>::From(*output_z);
auto x_norm = EigenMatrix<T>::From(*output_x_norm);
auto y_norm = EigenMatrix<T>::From(*output_y_norm);
// compute
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
auto xy = (x * y).sum(Eigen::array<int, 1>({1}));
x_norm.device(place) = x.square().sum(Eigen::array<int, 1>({1})).sqrt(); x_norm.device(place) = x.square().sum(Eigen::array<int, 1>({1})).sqrt();
y_norm.device(place) = y.square().sum(Eigen::array<int, 1>({1})).sqrt(); y_norm.device(place) = y.square().sum(Eigen::array<int, 1>({1})).sqrt();
z.device(place) = xy / x_norm / y_norm; if (rows_x == rows_y) {
auto xy = (x * y).sum(Eigen::array<int, 1>({1}));
z.device(place) = xy / x_norm / y_norm;
} else {
Eigen::DSizes<int, 2> bcast(rows_x, 1);
auto xy = (x * y.broadcast(bcast)).sum(Eigen::array<int, 1>({1}));
z.device(place) = xy / x_norm / y_norm.broadcast(bcast);
}
} }
}; };
...@@ -59,43 +67,75 @@ template <typename Place, typename T> ...@@ -59,43 +67,75 @@ template <typename Place, typename T>
class CosSimGradKernel : public framework::OpKernel { class CosSimGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* input_x = context.Input<Tensor>("X"); // get Tensor
auto* input_y = context.Input<Tensor>("Y"); auto* in_x = context.Input<Tensor>("X");
auto* input_z = context.Input<Tensor>("Out"); auto* in_y = context.Input<Tensor>("Y");
auto* input_x_norm = context.Input<Tensor>("XNorm"); auto* in_z = context.Input<Tensor>("Out");
auto* input_y_norm = context.Input<Tensor>("YNorm"); auto* in_x_norm = context.Input<Tensor>("XNorm");
auto* output_grad_x = context.Output<Tensor>(framework::GradVarName("X")); auto* in_y_norm = context.Input<Tensor>("YNorm");
auto* output_grad_y = context.Output<Tensor>(framework::GradVarName("Y")); auto* out_grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* input_grad_z = context.Input<Tensor>(framework::GradVarName("Out")); auto* out_grad_y = context.Output<Tensor>(framework::GradVarName("Y"));
auto* in_grad_z = context.Input<Tensor>(framework::GradVarName("Out"));
auto dims = input_x->dims(); // convert Tensor to Eigen Tensor
int size = static_cast<int>(framework::product(dims)); int rows_x = in_x->dims()[0];
auto new_dims = framework::make_ddim({dims[0], size / dims[0]}); int rows_y = in_y->dims()[0];
auto x = EigenMatrix<T>::From(*input_x, new_dims); int cols = framework::product(in_x->dims()) / rows_x;
auto y = EigenMatrix<T>::From(*input_y, new_dims); auto x = EigenMatrix<T>::From(*in_x, framework::make_ddim({rows_x, cols}));
auto z = EigenMatrix<T>::From(*input_z); auto y = EigenMatrix<T>::From(*in_y, framework::make_ddim({rows_y, cols}));
auto x_norm = EigenMatrix<T>::From(*input_x_norm); auto z = EigenMatrix<T>::From(*in_z);
auto y_norm = EigenMatrix<T>::From(*input_y_norm); auto x_norm = EigenMatrix<T>::From(*in_x_norm);
auto dz = EigenMatrix<T>::From(*input_grad_z); auto y_norm = EigenMatrix<T>::From(*in_y_norm);
auto dz = EigenMatrix<T>::From(*in_grad_z);
Eigen::DSizes<int, 2> bcast(1, new_dims[1]); // compute gradident
Eigen::DSizes<int, 2> bcast(1, cols);
auto z_bcast = z.broadcast(bcast); auto z_bcast = z.broadcast(bcast);
auto dz_bcast = dz.broadcast(bcast); auto dz_bcast = dz.broadcast(bcast);
auto place = context.GetEigenDevice<Place>();
auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast); auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast);
auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast); auto place = context.GetEigenDevice<Place>();
auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast); if (rows_x == rows_y) {
if (output_grad_x) { auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast);
output_grad_x->mutable_data<T>(context.GetPlace()); auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast);
auto dx = EigenMatrix<T>::From(*output_grad_x, new_dims); // compute dx
dx.device(place) = if (out_grad_x) {
dz_bcast * (y / norm_prod_bcast - z_bcast * x / x_snorm_bcast); out_grad_x->mutable_data<T>(context.GetPlace());
} auto dx = EigenMatrix<T>::From(*out_grad_x,
if (output_grad_y) { framework::make_ddim({rows_x, cols}));
output_grad_y->mutable_data<T>(context.GetPlace()); auto grad = y / norm_prod_bcast - z_bcast * x / x_snorm_bcast;
auto dy = EigenMatrix<T>::From(*output_grad_y, new_dims); dx.device(place) = dz_bcast * grad;
dy.device(place) = }
dz_bcast * (x / norm_prod_bcast - z_bcast * y / y_snorm_bcast); // compute dy
if (out_grad_y) {
out_grad_y->mutable_data<T>(context.GetPlace());
auto dy = EigenMatrix<T>::From(*out_grad_y,
framework::make_ddim({rows_y, cols}));
auto grad = x / norm_prod_bcast - z_bcast * y / y_snorm_bcast;
dy.device(place) = dz_bcast * grad;
}
} else {
Eigen::DSizes<int, 2> bcast_row(rows_x, 1);
auto y_bcast = y.broadcast(bcast_row);
auto y_snorm_bcast =
y_norm.square().eval().broadcast(bcast_row).eval().broadcast(bcast);
auto norm_prod_bcast =
(x_norm * y_norm.broadcast(bcast_row)).eval().broadcast(bcast);
// compute dx
if (out_grad_x) {
out_grad_x->mutable_data<T>(context.GetPlace());
auto dx = EigenMatrix<T>::From(
*out_grad_x, framework::make_ddim({rows_x, cols}));
auto grad = y_bcast / norm_prod_bcast - z_bcast * x / x_snorm_bcast;
dx.device(place) = dz_bcast * grad;
}
// compute dy
if (out_grad_y) {
out_grad_y->mutable_data<T>(context.GetPlace());
auto dy = EigenMatrix<T>::From(
*out_grad_y, framework::make_ddim({rows_y, cols}));
auto grad = x / norm_prod_bcast - z_bcast * y_bcast / y_snorm_bcast;
dy.device(place) = (dz_bcast * grad).sum(Eigen::array<int, 1>({0}));
}
} }
} }
}; };
......
...@@ -4,7 +4,7 @@ from gradient_checker import GradientChecker, create_op ...@@ -4,7 +4,7 @@ from gradient_checker import GradientChecker, create_op
from op_test_util import OpTestMeta from op_test_util import OpTestMeta
class TestCosSimOp(unittest.TestCase): class TestCosSimOpWithRank2(unittest.TestCase):
__metaclass__ = OpTestMeta __metaclass__ = OpTestMeta
def setUp(self): def setUp(self):
...@@ -24,12 +24,72 @@ class TestCosSimOp(unittest.TestCase): ...@@ -24,12 +24,72 @@ class TestCosSimOp(unittest.TestCase):
} }
class TestCosSimOpWithRank2Bcast(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "cos_sim"
self.inputs = {
'X': np.random.random((32, 64)).astype("float32"),
'Y': np.random.random((1, 64)).astype("float32")
}
expect_x_norm = np.linalg.norm(self.inputs['X'], axis=1)
expect_y_norm = np.linalg.norm(self.inputs['Y'], axis=1)
expect_out = (self.inputs['X'] * self.inputs['Y']).sum(axis=1) / \
expect_x_norm / expect_y_norm
self.outputs = {
'XNorm': np.expand_dims(expect_x_norm, 1),
'YNorm': np.expand_dims(expect_y_norm, 1),
'Out': np.expand_dims(expect_out, 1)
}
class TestCosSimOpWithRank3(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "cos_sim"
self.inputs = {
'X': np.random.random((32, 64, 10)).astype("float32"),
'Y': np.random.random((32, 64, 10)).astype("float32")
}
expect_x_norm = np.linalg.norm(self.inputs['X'], axis=(1, 2))
expect_y_norm = np.linalg.norm(self.inputs['Y'], axis=(1, 2))
expect_out = (self.inputs['X'] * self.inputs['Y']).sum(axis=(1, 2)) / \
expect_x_norm / expect_y_norm
self.outputs = {
'XNorm': np.expand_dims(expect_x_norm, 1),
'YNorm': np.expand_dims(expect_y_norm, 1),
'Out': np.expand_dims(expect_out, 1)
}
class TestCosSimOpWithRank3Bcast(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "cos_sim"
self.inputs = {
'X': np.random.random((32, 64, 10)).astype("float32"),
'Y': np.random.random((1, 64, 10)).astype("float32")
}
expect_x_norm = np.linalg.norm(self.inputs['X'], axis=(1, 2))
expect_y_norm = np.linalg.norm(self.inputs['Y'], axis=(1, 2))
expect_out = (self.inputs['X'] * self.inputs['Y']).sum(axis=(1, 2)) / \
expect_x_norm / expect_y_norm
self.outputs = {
'XNorm': np.expand_dims(expect_x_norm, 1),
'YNorm': np.expand_dims(expect_y_norm, 1),
'Out': np.expand_dims(expect_out, 1)
}
class TestCosSimGradOp(GradientChecker): class TestCosSimGradOp(GradientChecker):
def setUp(self): def setUp(self):
self.op = create_op("cos_sim") self.op = create_op("cos_sim")
self.inputs = { self.inputs = {
'X': np.random.random((10, 5)).astype("float32"), 'X': np.random.random((6, 5)).astype("float32"),
'Y': np.random.random((10, 5)).astype("float32") 'Y': np.random.random((6, 5)).astype("float32")
} }
def test_cpu_gpu_compare(self): def test_cpu_gpu_compare(self):
...@@ -56,5 +116,32 @@ class TestCosSimGradOp(GradientChecker): ...@@ -56,5 +116,32 @@ class TestCosSimGradOp(GradientChecker):
no_grad_set={"Y"}) no_grad_set={"Y"})
class TestCosSimGradOpWithRank2Bcast(TestCosSimGradOp):
def setUp(self):
self.op = create_op("cos_sim")
self.inputs = {
'X': np.random.random((6, 5)).astype("float32"),
'Y': np.random.random((1, 5)).astype("float32")
}
class TestCosSimGradOpWithRank3(TestCosSimGradOp):
def setUp(self):
self.op = create_op("cos_sim")
self.inputs = {
'X': np.random.random((6, 5, 2)).astype("float32"),
'Y': np.random.random((6, 5, 2)).astype("float32")
}
class TestCosSimGradOpWithRank3Bcast(TestCosSimGradOp):
def setUp(self):
self.op = create_op("cos_sim")
self.inputs = {
'X': np.random.random((6, 5, 2)).astype("float32"),
'Y': np.random.random((1, 5, 2)).astype("float32")
}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册