提交 03ea7320 编写于 作者: X Xinghai Sun

Update cos_sim operator by following reviewer's comments.

上级 16fddf32
......@@ -32,17 +32,18 @@ class CosSimOp : public framework::OperatorWithKernel {
// shape check
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_EQ(framework::arity(x_dims), framework::arity(y_dims),
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
"Ranks of Input(X) and Input(Y) must be equal.");
PADDLE_ENFORCE_GE(framework::arity(x_dims), 2,
PADDLE_ENFORCE_GE(x_dims.size(), 2,
"Rank of Input(X) must not be less than 2.");
PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, 1, framework::arity(x_dims)),
framework::slice_ddim(y_dims, 1, framework::arity(y_dims)),
"All dimensions except 1st of Input(X) and Input(Y) must be equal.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 1, x_dims.size()),
framework::slice_ddim(y_dims, 1, y_dims.size()),
"All dimensions except the 1st of Input(X) and Input(Y) "
"must be equal.");
PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1,
"1st dimension of Input(Y) must be equal to Input(X) or "
"just 1 (which will be broadcasted to match Input(X)).");
"The 1st dimension of Input(Y) must be equal to Input(X) or"
" just 1 (which will be broadcasted to match Input(X)).");
// resize tensor
ctx.Output<Tensor>("Out")->Resize({x_dims[0], 1});
......@@ -58,8 +59,14 @@ class CosSimOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The 1st input of cos_sim op.");
AddInput("Y", "The 2nd input of cos_sim op.");
AddOutput("Out", "The output of cos_sim op.");
AddOutput("XNorm", "Row norm of the first input.").AsIntermediate();
AddOutput("YNorm", "Row norm of the second input.").AsIntermediate();
AddOutput("XNorm",
"Norm of the first input, reduced along the 1st "
"dimension.")
.AsIntermediate();
AddOutput("YNorm",
"Norm of the second input, reduced along the 1st "
"dimension.")
.AsIntermediate();
AddComment(R"DOC(
Cosine Similarity Operator.
......@@ -95,29 +102,32 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
// shape check
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_GE(framework::arity(x_dims), framework::arity(y_dims),
auto xnorm_dims = ctx.Input<Tensor>("XNorm")->dims();
auto ynorm_dims = ctx.Input<Tensor>("YNorm")->dims();
auto out_dims = ctx.Input<Tensor>("Out")->dims();
auto out_grad_dims =
ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Ranks of Input(X) and Input(Y) must be equal.");
PADDLE_ENFORCE_GE(framework::arity(x_dims), 2,
PADDLE_ENFORCE_GE(x_dims.size(), 2,
"Rank of Input(X) must not be less than 2.");
PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, 1, framework::arity(x_dims)),
framework::slice_ddim(y_dims, 1, framework::arity(y_dims)),
"All dimensions except 1st of Input(X) and Input(Y) must be equal.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 1, x_dims.size()),
framework::slice_ddim(y_dims, 1, y_dims.size()),
"All dimensions except the 1st of Input(X) and Input(Y) "
"must be equal.");
PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1,
"1st dimension of Input(Y) must be equal to Input(X) or "
"just 1 (which will be broadcasted to match Input(X)).");
auto xnorm_dims = ctx.Input<Tensor>("XNorm")->dims();
PADDLE_ENFORCE_EQ(xnorm_dims, framework::make_ddim({x_dims[0], 1}),
"The 1st dimension of Input(Y) must be equal to Input(X) or"
" just 1 (which will be broadcasted to match Input(X)).");
auto target_xnorm_dims = framework::make_ddim({x_dims[0], 1}),
auto target_ynorm_dims = framework::make_ddim({y_dims[0], 1}),
PADDLE_ENFORCE_EQ(xnorm_dims, target_xnorm_dims,
"Shape of Input(XNorm) must be [X.Dim(0), 1].");
auto ynorm_dims = ctx.Input<Tensor>("YNorm")->dims();
PADDLE_ENFORCE_EQ(ynorm_dims, framework::make_ddim({y_dims[0], 1}),
PADDLE_ENFORCE_EQ(ynorm_dims, target_ynorm_dims,
"Shape of Input(YNorm) must be [Y.Dim(0), 1].");
auto out_dims = ctx.Input<Tensor>("Out")->dims();
PADDLE_ENFORCE_EQ(out_dims, framework::make_ddim({x_dims[0], 1}),
PADDLE_ENFORCE_EQ(out_dims, target_xnorm_dims,
"Shape of Input(Out) must be [X.Dim(0), 1].");
auto out_grad_dims =
ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
PADDLE_ENFORCE_EQ(out_grad_dims, framework::make_ddim({x_dims[0], 1}),
PADDLE_ENFORCE_EQ(out_grad_dims, target_xnorm_dims,
"Shape of Input(Out@Grad) must be [X.Dim(0), 1].");
// resize tensor
......
......@@ -42,22 +42,23 @@ class CosSimKernel : public framework::OpKernel {
int rows_x = in_x->dims()[0];
int rows_y = in_y->dims()[0];
int cols = framework::product(in_x->dims()) / rows_x;
auto x = EigenMatrix<T>::From(*in_x, framework::make_ddim({rows_x, cols}));
auto y = EigenMatrix<T>::From(*in_y, framework::make_ddim({rows_y, cols}));
auto x = EigenMatrix<T>::Reshape(*in_x, 1);
auto y = EigenMatrix<T>::Reshape(*in_y, 1);
auto z = EigenMatrix<T>::From(*out_z);
auto x_norm = EigenMatrix<T>::From(*out_x_norm);
auto y_norm = EigenMatrix<T>::From(*out_y_norm);
// compute
auto place = context.GetEigenDevice<Place>();
x_norm.device(place) = x.square().sum(Eigen::array<int, 1>({1})).sqrt();
y_norm.device(place) = y.square().sum(Eigen::array<int, 1>({1})).sqrt();
auto row_along = Eigen::array<int, 1>({{1}});
x_norm.device(place) = x.square().sum(row_along).sqrt();
y_norm.device(place) = y.square().sum(row_along).sqrt();
if (rows_x == rows_y) {
auto xy = (x * y).sum(Eigen::array<int, 1>({1}));
z.device(place) = xy / x_norm / y_norm;
} else {
Eigen::DSizes<int, 2> bcast(rows_x, 1);
auto xy = (x * y.broadcast(bcast)).sum(Eigen::array<int, 1>({1}));
auto xy = (x * y.broadcast(bcast)).sum(row_along);
z.device(place) = xy / x_norm / y_norm.broadcast(bcast);
}
}
......@@ -78,61 +79,56 @@ class CosSimGradKernel : public framework::OpKernel {
auto* in_grad_z = context.Input<Tensor>(framework::GradVarName("Out"));
// convert Tensor to Eigen Tensor
int rows_x = in_x->dims()[0];
int rows_y = in_y->dims()[0];
int cols = framework::product(in_x->dims()) / rows_x;
auto x = EigenMatrix<T>::From(*in_x, framework::make_ddim({rows_x, cols}));
auto y = EigenMatrix<T>::From(*in_y, framework::make_ddim({rows_y, cols}));
auto x = EigenMatrix<T>::Reshape(*in_x, 1);
auto y = EigenMatrix<T>::Reshape(*in_y, 1);
auto z = EigenMatrix<T>::From(*in_z);
auto x_norm = EigenMatrix<T>::From(*in_x_norm);
auto y_norm = EigenMatrix<T>::From(*in_y_norm);
auto dz = EigenMatrix<T>::From(*in_grad_z);
// compute gradident
Eigen::DSizes<int, 2> bcast(1, cols);
auto z_bcast = z.broadcast(bcast);
auto dz_bcast = dz.broadcast(bcast);
auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast);
int rows_x = in_x->dims()[0];
int rows_y = in_y->dims()[0];
int cols = framework::product(in_x->dims()) / rows_x;
Eigen::DSizes<int, 2> bcast_cols(1, cols);
auto z_bcast = z.broadcast(bcast_cols);
auto dz_bcast = dz.broadcast(bcast_cols);
auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast_cols);
auto place = context.GetEigenDevice<Place>();
if (rows_x == rows_y) {
auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast);
auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast);
auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_cols);
auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast_cols);
// compute dx
if (out_grad_x) {
out_grad_x->mutable_data<T>(context.GetPlace());
auto dx = EigenMatrix<T>::From(*out_grad_x,
framework::make_ddim({rows_x, cols}));
auto dx = EigenMatrix<T>::Reshape(*out_grad_x, 1);
auto grad = y / norm_prod_bcast - z_bcast * x / x_snorm_bcast;
dx.device(place) = dz_bcast * grad;
}
// compute dy
if (out_grad_y) {
out_grad_y->mutable_data<T>(context.GetPlace());
auto dy = EigenMatrix<T>::From(*out_grad_y,
framework::make_ddim({rows_y, cols}));
auto grad = x / norm_prod_bcast - z_bcast * y / y_snorm_bcast;
auto dy = EigenMatrix<T>::Reshape(*out_grad_y, 1) auto grad =
x / norm_prod_bcast - z_bcast * y / y_snorm_bcast;
dy.device(place) = dz_bcast * grad;
}
} else {
Eigen::DSizes<int, 2> bcast_row(rows_x, 1);
auto y_bcast = y.broadcast(bcast_row);
auto y_snorm_bcast =
y_norm.square().eval().broadcast(bcast_row).eval().broadcast(bcast);
auto norm_prod_bcast =
(x_norm * y_norm.broadcast(bcast_row)).eval().broadcast(bcast);
Eigen::DSizes<int, 2> bcast_rows(rows_x, 1);
Eigen::DSizes<int, 2> bcast_rows_cols(rows_x, 1);
auto y_bcast = y.broadcast(bcast_rows);
auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_rows_cols);
auto norm_prod_bcast = x_norm * y_norm.broadcast(bcast_rows_cols);
// compute dx
if (out_grad_x) {
out_grad_x->mutable_data<T>(context.GetPlace());
auto dx = EigenMatrix<T>::From(
*out_grad_x, framework::make_ddim({rows_x, cols}));
auto dx = EigenMatrix<T>::Reshape(*out_grad_x, 1);
auto grad = y_bcast / norm_prod_bcast - z_bcast * x / x_snorm_bcast;
dx.device(place) = dz_bcast * grad;
}
// compute dy
if (out_grad_y) {
out_grad_y->mutable_data<T>(context.GetPlace());
auto dy = EigenMatrix<T>::From(
*out_grad_y, framework::make_ddim({rows_y, cols}));
auto dy = EigenMatrix<T>::Reshape(*out_grad_y, 1);
auto grad = x / norm_prod_bcast - z_bcast * y_bcast / y_snorm_bcast;
dy.device(place) = (dz_bcast * grad).sum(Eigen::array<int, 1>({0}));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册