提交 ed72af48 编写于 作者: X Xinghai Sun

Add cos_sim op.

上级 2d31ab5f
......@@ -56,7 +56,7 @@ list(REMOVE_ITEM GENERAL_OPS
op_library(net_op SRCS net_op.cc)
op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op)
op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor operator net_op)
op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/cos_sim_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class CosSimOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null.");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(),
ctx.Input<Tensor>("Y")->dims(),
"Dimensions of Input(X) and Input(Y) must be the same.");
auto dims = ctx.Input<Tensor>("X")->dims();
ctx.Output<Tensor>("Out")->Resize({dims[0], 1});
}
};
class CosSimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CosSimOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of cos_sim op.");
AddInput("Y", "The second input of cos_sim op.");
AddOutput("Out", "The output of cos_sim op.");
AddComment(R"DOC(
Cosine Similarity Operator.
The equation is: Out = X^T * Y / (sqrt(X^T * X) * sqrt(Y^T * Y))
)DOC");
}
};
class CosSimOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
PADDLE_ENFORCE_EQ(x_dims, y_dims,
"Dimensions of Input(X) and Input(Y) must be the same.");
PADDLE_ENFORCE_EQ(out_dims[0], x_dims[0],
"1st dimension of Out@GRAD must equal to Input(X)");
PADDLE_ENFORCE_EQ(out_dims[1], 1,
"1st dimension of Out@GRAD must equal to Input(X)");
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
x_grad->Resize(x_dims);
y_grad->Resize(y_dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(cos_sim, ops::CosSimOp, ops::CosSimOpMaker, cos_sim_grad,
ops::CosSimOpGrad);
REGISTER_OP_CPU_KERNEL(cos_sim,
ops::CosSimKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
cos_sim_grad, ops::CosSimGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/cos_sim_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(cos_sim,
ops::CosSimKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
cos_sim_grad, ops::CosSimGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class CosSimKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out");
z->mutable_data<T>(context.GetPlace());
auto dims = x->dims();
int size = static_cast<int>(framework::product(dims));
auto new_dims = framework::make_ddim({dims[0], size / dims[0]});
auto X = EigenMatrix<T>::From(*x, new_dims);
auto Y = EigenMatrix<T>::From(*y, new_dims);
auto Z = EigenMatrix<T>::From(*z, new_dims);
auto XY = (X * Y).sum(Eigen::array<int, 1>({1}));
auto XX = (X * X).sum(Eigen::array<int, 1>({1}));
auto YY = (Y * Y).sum(Eigen::array<int, 1>({1}));
auto place = context.GetEigenDevice<Place>();
Z.device(place) = XY / XX.sqrt() / YY.sqrt();
}
};
template <typename Place, typename T>
class CosSimGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* z = context.Input<Tensor>("Out");
auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* grad_y = context.Output<Tensor>(framework::GradVarName("Y"));
auto* grad_z = context.Input<Tensor>(framework::GradVarName("Out"));
grad_x->mutable_data<T>(context.GetPlace());
grad_y->mutable_data<T>(context.GetPlace());
auto dims = x->dims();
int size = static_cast<int>(framework::product(dims));
auto new_dims = framework::make_ddim({dims[0], size / dims[0]});
auto X = EigenMatrix<T>::From(*x, new_dims);
auto Y = EigenMatrix<T>::From(*y, new_dims);
auto Z = EigenMatrix<T>::From(*z);
auto dX = EigenMatrix<T>::From(*grad_x, new_dims);
auto dY = EigenMatrix<T>::From(*grad_y, new_dims);
auto dZ = EigenMatrix<T>::From(*grad_z);
auto XX = (X * X).sum(Eigen::array<int, 1>({1}));
auto YY = (Y * Y).sum(Eigen::array<int, 1>({1}));
Eigen::DSizes<int, 2> bcast(1, dims[1]);
auto denominator_bcast = (XX.sqrt() * YY.sqrt()).broadcast(bcast);
auto Z_bcast = Z.broadcast(bcast);
auto dZ_bcast = dZ.broadcast(bcast);
auto place = context.GetEigenDevice<Place>();
dX.device(place) =
dZ_bcast * (Y / denominator_bcast - Z_bcast * X / XX.broadcast(bcast));
dY.device(place) =
dZ_bcast * (X / denominator_bcast - Z_bcast * Y / YY.broadcast(bcast));
// dX.device(place) = X;
// Y.device(place) = Y;
}
};
} // namespace operators
} // namespace paddle
......@@ -46,6 +46,7 @@ USE_OP(lookup_table);
USE_OP(scale);
USE_OP_ITSELF(identity);
USE_OP(minus);
USE_OP(cos_sim);
USE_CPU_ONLY_OP(gather);
USE_CPU_ONLY_OP(scatter);
......
......@@ -4,6 +4,7 @@ py_test(test_scope SRCS test_scope.py)
py_test(test_tensor SRCS test_tensor.py)
py_test(test_mul_op SRCS test_mul_op.py)
py_test(test_cos_sim_op SRCS test_cos_sim_op.py)
py_test(test_mean_op SRCS test_mean_op.py)
......
......@@ -36,13 +36,13 @@ def get_numeric_gradient(op,
in_place=False):
"""
Get Numeric Gradient for an operator's input.
:param op: C++ operator instance, could be an network
:param input_values: The input variables. Should be an dictionary, key is
:param op: C++ operator instance, could be an network
:param input_values: The input variables. Should be an dictionary, key is
variable name. Value is numpy array.
:param output_name: The final output variable name.
:param output_name: The final output variable name.
:param input_to_check: The input variable need to get gradient.
:param delta: The perturbation value for numeric gradient method. The
:param delta: The perturbation value for numeric gradient method. The
smaller delta is, the more accurate result will get. But if that delta is
too small, it could occur numerical stability problem.
:param local_scope: The local scope used for get_numeric_gradient.
......@@ -229,9 +229,9 @@ class GradientChecker(unittest.TestCase):
"""Use relative error for the comparison.
:param numeric_grads: the numerical graidents.
:type numeric_grads: a list of numpy.array
:type numeric_grads: a list of numpy.array
:param analytic_grads: the analytical graidents.
:type analytic_grads: a list of numpy.array
:type analytic_grads: a list of numpy.array
:param name: the names of gradients, used to print for debug.
:type names: a list of string
:param msg_prefix: string info, used to print for debug.
......@@ -304,6 +304,13 @@ class GradientChecker(unittest.TestCase):
# get analytical gradients according to different device
analytic_grads = self.__get_gradient(forward_op, backward_op,
input_vars, check_names, place)
#print(numeric_grads[0], numeric_grads[0].shape)
print("dim0: ", numeric_grads[0], numeric_grads[0].shape)
print("dim0: ", analytic_grads[0], analytic_grads[0].shape)
print("---------------------")
print("dim1: ", numeric_grads[1], numeric_grads[1].shape)
print("dim1: ", analytic_grads[1], analytic_grads[1].shape)
assert False
self.__assert_is_close(numeric_grads, analytic_grads, check_names,
max_relative_error,
"Gradient Check On %s" % str(place))
......@@ -6,13 +6,13 @@ from paddle.v2.framework.op import Operator
class OpTestMeta(type):
"""
Operator Test ClassMeta.
It injects `test_all` method into user's OperatorTest class, to make Python
It injects `test_all` method into user's OperatorTest class, to make Python
unittest module run that method.
The `test_all` read what value is stored in `self`. It use self's values to
create and run a operator, and check whether that op is OK or not.
See `test_add_two_op` for example usage.
"""
......
import unittest
import numpy as np
from gradient_checker import GradientChecker, create_op
from op_test_util import OpTestMeta
class TestCosSimOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "cos_sim"
self.inputs = {
'X': np.random.random((32, 84)).astype("float32"),
'Y': np.random.random((32, 84)).astype("float32")
}
expect = (self.inputs['X'] * self.inputs['Y']).sum(axis=1) / \
np.linalg.norm(self.inputs['X'], axis=1) / \
np.linalg.norm(self.inputs['Y'], axis=1)
expect = np.expand_dims(expect, 1)
self.outputs = {'Out': expect}
class CosSimGradOpTest(GradientChecker):
def test_cos_sim(self):
op = create_op("cos_sim")
#inputs = {
#'X': np.random.random((2, 2)).astype("float32"),
#'Y': np.random.random((2, 2)).astype("float32")
#}
inputs = {
'X': np.array([[0.9, 0.6], [1.9, 1.6]]).astype("float32"),
'Y': np.array([[0.7, 0.8], [1.7, 1.8]]).astype("float32")
}
print(inputs)
self.check_grad(
op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.5)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册