diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 8a0ff1eb535a542e106ceafca6713aefff2526d5..f9ea25ab045a02be5ab9ed81ef9c679126d3a188 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -59,7 +59,7 @@ set(DEPS_OPS op_library(identity_op DEPS scale_op) op_library(minus_op DEPS scale_op) op_library(mul_op DEPS math_function) -op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc +op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor operator net_op) op_library(scale_op DEPS net_op) diff --git a/paddle/operators/squared_l2_distance_op.cc b/paddle/operators/squared_l2_distance_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..dc30644a5e7e33d4289e48cac093aa5fde7e75e7 --- /dev/null +++ b/paddle/operators/squared_l2_distance_op.cc @@ -0,0 +1,118 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/squared_l2_distance_op.h" + +namespace paddle { +namespace operators { + +class SquaredL2DistanceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext& ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input of SquaredL2DistanceOp " + "must be initialized."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), + "Target of SquaredL2DistanceOp " + "must be initialized."); + + auto* x = ctx.Input("X"); + auto x_dims = x->dims(); + auto* y = ctx.Input("Y"); + auto y_dims = y->dims(); + + PADDLE_ENFORCE_EQ(framework::arity(x_dims), framework::arity(y_dims), + "Tensor rank of both SquaredL2DistanceOp's " + "inputs must be same."); + + int rank = framework::arity(x_dims); + PADDLE_ENFORCE_GE(rank, 2, "Tensor rank should be at least equal to 2."); + PADDLE_ENFORCE_EQ(framework::product(x_dims) / x_dims[0], + framework::product(y_dims) / y_dims[0], + "Product of dimensions expcet the first dimension of " + "input and target must be equal."); + PADDLE_ENFORCE(y_dims[0] == 1 || y_dims[0] == x_dims[0], + "First dimension of target must be equal to input " + "or to 1."); + + ctx.Output("sub_result") + ->Resize({static_cast(x_dims[0]), + static_cast(framework::product(x_dims) / x_dims[0])}); + ctx.Output("Out")->Resize({x_dims[0], 1}); + } +}; + +class SquaredL2DistanceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SquaredL2DistanceOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of SquaredL2DistanceOp."); + AddInput("Y", "Target of SquaredL2DistanceOp."); + AddOutput("sub_result", + "Buffering substraction result which " + "will be reused in backward.") + .AsIntermediate(); + AddOutput("Out", "Squared l2 distance between input and target."); + AddComment(R"DOC( + SquaredL2DistanceOp will cacluate the squared L2 distance for + input and target. Number of distance value equals to the + first dimension of input. First dimension of target could be equal to + input or to 1. If the first dimension of target is 1, SquaredL2DistanceOp + will broadcast target's first dimension to input's first dimension. + You can decide whether calculate the gradient of input and target. + )DOC"); + } +}; + +class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext& ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Gradient of Out should not be null"); + auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); + auto x_dims = ctx.Input("X")->dims(); + auto y_dims = ctx.Input("Y")->dims(); + PADDLE_ENFORCE_EQ(out_dims[0], x_dims[0], + "First dimension of output gradient and " + "input value must be equal."); + PADDLE_ENFORCE_EQ(out_dims[1], 1, + "Second dimension of output gradient " + "must be 1."); + auto* x_grad = ctx.Output(framework::GradVarName("X")); + auto* y_grad = ctx.Output(framework::GradVarName("Y")); + if (x_grad) x_grad->Resize(x_dims); + if (y_grad) y_grad->Resize(y_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(squared_l2_distance, ops::SquaredL2DistanceOp, + ops::SquaredL2DistanceOpMaker, squared_l2_distance_grad, + ops::SquaredL2DistanceGradOp); +REGISTER_OP_CPU_KERNEL( + squared_l2_distance, + ops::SquaredL2DistanceKernel); +REGISTER_OP_CPU_KERNEL( + squared_l2_distance_grad, + ops::SquaredL2DistanceGradKernel); diff --git a/paddle/operators/squared_l2_distance_op.cu b/paddle/operators/squared_l2_distance_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..3fe62f1a9cb56722ea544b0fed052ac384e799aa --- /dev/null +++ b/paddle/operators/squared_l2_distance_op.cu @@ -0,0 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU + +#include "paddle/operators/squared_l2_distance_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + squared_l2_distance, + ops::SquaredL2DistanceKernel); +REGISTER_OP_GPU_KERNEL( + squared_l2_distance_grad, + ops::SquaredL2DistanceGradKernel); diff --git a/paddle/operators/squared_l2_distance_op.h b/paddle/operators/squared_l2_distance_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ad3347a0b35f3385c5adbcd7ceaa94fe134105e3 --- /dev/null +++ b/paddle/operators/squared_l2_distance_op.h @@ -0,0 +1,123 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; +template +using EigenMatrix = framework::EigenMatrix; + +template +class SquaredL2DistanceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in0 = context.Input("X"); + auto* in1 = context.Input("Y"); + auto* out0 = context.Output("sub_result"); + auto* out1 = context.Output("Out"); + + auto in0_dims = in0->dims(); + auto in1_dims = in1->dims(); + + int cols = framework::product(in0_dims) / in0_dims[0]; + // reduce dimensions except the first + auto x = + EigenMatrix::From(*in0, framework::make_ddim({in0_dims[0], cols})); + auto y = + EigenMatrix::From(*in1, framework::make_ddim({in1_dims[0], cols})); + + out0->mutable_data(context.GetPlace()); + out1->mutable_data(context.GetPlace()); + auto sub_result = EigenMatrix::From(*out0); + auto z = EigenVector::Flatten(*out1); + + auto place = context.GetEigenDevice(); + auto x_dims = x.dimensions(); + auto y_dims = y.dimensions(); + // buffer the substraction result + if (y_dims[0] == 1 && x_dims[0] > y_dims[0]) { + sub_result.device(place) = + x - + y.broadcast(Eigen::array({{static_cast(x_dims[0]), 1}})); + } else { + sub_result.device(place) = x - y; + } + auto sub_res_pow2 = sub_result * sub_result; + z.device(place) = sub_res_pow2.sum(Eigen::array({{1}})); + } +}; + +template +class SquaredL2DistanceGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in0 = context.Input("sub_result"); + auto* in1 = context.Input(framework::GradVarName("Out")); + auto* x_g = context.Output(framework::GradVarName("X")); + auto* y_g = context.Output(framework::GradVarName("Y")); + + auto sub_result = EigenMatrix::From(*in0); + auto out_grad = EigenMatrix::From(*in1); + + auto x_dims = x_g->dims(); + auto y_dims = y_g->dims(); + + int cols = framework::product(x_dims) / x_dims[0]; + // calculate gradient + auto grad_mat = 2 * + (out_grad.broadcast(Eigen::array({{1, cols}}))) * + sub_result; + + // propagate back to input + auto eigen_place = context.GetEigenDevice(); + if (x_g) { + x_g->mutable_data(context.GetPlace()); + // eigen matrix + auto x_grad = + EigenMatrix::From(*x_g, framework::make_ddim({x_dims[0], cols})); + // dimensions are same with subResult + x_grad.device(eigen_place) = grad_mat; + } + + if (y_g) { + y_g->mutable_data(context.GetPlace()); + + PADDLE_ENFORCE_GE(sub_result.dimensions()[0], y_dims[0], + "First dimension of gradient must be greater or " + "equal than first dimension of target."); + + if (sub_result.dimensions()[0] == y_dims[0]) { + auto y_grad = + EigenMatrix::From(*y_g, framework::make_ddim({y_dims[0], cols})); + y_grad.device(eigen_place) = -1 * grad_mat; + } else { + auto col_sum_res = -1 * (grad_mat.sum(Eigen::array({{0}}))); + auto y_grad = EigenVector::Flatten(*y_g); + y_grad.device(eigen_place) = col_sum_res; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 6e637fa40fbd80fdfa0323a645c57c42d7ca502e..c21ad3470ba97f533b6c42bc2966be04bc6f7976 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -49,6 +49,7 @@ USE_OP(minus); USE_OP(cos_sim); USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(scatter); +USE_OP(squared_l2_distance); namespace paddle { namespace framework { diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index e0f77d797390be0461f466726f63a70dd485a290..a9c33ea1631e8358c41a8566de9db4bd00fc9b74 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -33,3 +33,4 @@ py_test(test_gradient_checker SRCS test_gradient_checker.py) py_test(test_lookup_table SRCS test_lookup_table.py) py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py) py_test(mnist SRCS mnist.py) +py_test(test_squared_l2_distance_op SRCS test_squared_l2_distance_op.py) diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index a4899355b53d62903b97999ebf9c2c7ecfc6c4cd..370f27eaf658dadbf7e82262c118140a10d15c41 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -66,7 +66,7 @@ class OpTestMeta(type): self.assertTrue( numpy.allclose( actual, expect, atol=1e-05), - "output name: " + out_name + "has diff") + "output name: " + out_name + " has diff") obj.test_all = test_all return obj diff --git a/python/paddle/v2/framework/tests/test_squared_l2_distance_op.py b/python/paddle/v2/framework/tests/test_squared_l2_distance_op.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcdf37df434c9a089d75438d876114156261a5c --- /dev/null +++ b/python/paddle/v2/framework/tests/test_squared_l2_distance_op.py @@ -0,0 +1,89 @@ +import unittest +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op +import numpy as np + + +class TestSquaredL2DistanceOp_f0(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = 'squared_l2_distance' + self.inputs = { + 'X': np.random.uniform(0.1, 1., (32, 64)).astype('float32'), + 'Y': np.random.uniform(0.1, 1., (32, 64)).astype('float32') + } + sub_res = self.inputs['X'] - self.inputs['Y'] + output = sub_res * sub_res + self.outputs = { + 'sub_result': sub_res, + 'Out': np.expand_dims(output.sum(1), 1) + } + + +class TestSquaredL2DistanceOp_f1(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = 'squared_l2_distance' + self.inputs = { + 'X': np.random.uniform(0.1, 1., (32, 64)).astype('float32'), + 'Y': np.random.uniform(0.1, 1., (1, 64)).astype('float32') + } + sub_res = self.inputs['X'] - self.inputs['Y'] + output = sub_res * sub_res + self.outputs = { + 'sub_result': sub_res, + 'Out': np.expand_dims(output.sum(1), 1) + } + + +class TestSquaredL2DistanceOp_f2(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = 'squared_l2_distance' + self.inputs = { + 'X': np.random.uniform(0.1, 1., (32, 64, 128)).astype('float32'), + 'Y': np.random.uniform(0.1, 1., (1, 64, 128)).astype('float32') + } + sub_res = self.inputs['X'] - self.inputs['Y'] + sub_res = sub_res.reshape((32, 64 * 128)) + output = sub_res * sub_res + self.outputs = { + 'sub_result': sub_res, + 'Out': np.expand_dims(output.sum(1), 1) + } + + +class TestSquaredL2DistanceGradOp(GradientChecker): + def test_squared_l2_distance_b0(self): + op = create_op("squared_l2_distance") + inputs = { + 'X': np.random.uniform(0.1, .6, (2, 3)).astype('float32'), + 'Y': np.random.uniform(0.1, .6, (2, 3)).astype('float32') + } + self.compare_grad(op, inputs) + self.check_grad(op, inputs, set(["X", "Y"]), "Out") + + def test_squared_l2_distance_b1(self): + op = create_op("squared_l2_distance") + inputs = { + 'X': np.random.uniform(0.1, .6, (2, 3)).astype('float32'), + 'Y': np.random.uniform(0.1, .6, (1, 3)).astype('float32') + } + self.compare_grad(op, inputs) + self.check_grad(op, inputs, set(["X", "Y"]), "Out") + + def test_squared_l2_distance_b2(self): + op = create_op("squared_l2_distance") + inputs = { + 'X': np.random.uniform(0.1, .6, (2, 3, 4)).astype('float32'), + 'Y': np.random.uniform(0.1, .6, (1, 3, 4)).astype('float32') + } + self.compare_grad(op, inputs) + self.check_grad(op, inputs, set(["X", "Y"]), "Out") + + +if __name__ == '__main__': + unittest.main()