From a4df3f5bd8917b2cb510b23dc63bc97a20108f23 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Wed, 30 Aug 2017 22:21:53 +0800 Subject: [PATCH] Finish framework of squared_l2_distance_op. --- paddle/operators/CMakeLists.txt | 2 + paddle/operators/squared_l2_distance_op.cc | 82 ++++++++++++++++++ paddle/operators/squared_l2_distance_op.cu | 25 ++++++ paddle/operators/squared_l2_distance_op.h | 84 +++++++++++++++++++ paddle/pybind/CMakeLists.txt | 3 +- paddle/pybind/pybind.cc | 1 + .../paddle/v2/framework/tests/CMakeLists.txt | 1 + .../paddle/v2/framework/tests/op_test_util.py | 10 +-- .../tests/test_squared_l2_distance_op.py | 25 ++++++ 9 files changed, 227 insertions(+), 6 deletions(-) create mode 100644 paddle/operators/squared_l2_distance_op.cc create mode 100644 paddle/operators/squared_l2_distance_op.cu create mode 100644 paddle/operators/squared_l2_distance_op.h create mode 100644 python/paddle/v2/framework/tests/test_squared_l2_distance_op.py diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f0fd12f1b5..1c32d1df4a 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -73,3 +73,5 @@ op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu) op_library(lookup_table_op SRCS lookup_table_op.cc lookup_table_op.cu) op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op) op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op) + +op_library(squared_l2_distance_op SRCS squared_l2_distance_op.cc squared_l2_distance_op.cu) diff --git a/paddle/operators/squared_l2_distance_op.cc b/paddle/operators/squared_l2_distance_op.cc new file mode 100644 index 0000000000..9fc498d5a5 --- /dev/null +++ b/paddle/operators/squared_l2_distance_op.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/squared_l2_distance_op.h" + +namespace paddle { +namespace operators { + +class SquaredL2DistanceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input of SquaredL2DistanceOp " + "must be initialized."); + PADDLE_ENFORCE_EQ(ctx.Input("X")->dims(), + ctx.Input("Y")->dims(), + "Dimensions of SquaredL2DistanceOp's two inputs " + "must be same.") + framework::DDim dims = ctx.Input("X")->dims(); + ctx.Output("sub_result")->Resize(dims); + ctx.Output("Out")->Resize(framework::make_ddim({dims[0], 1})); + } +}; + +class SquaredL2DistanceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SquaredL2DistanceOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input value."); + AddInput("Y", "Target value."); + AddOutput("sub_result", + "Buffering substraction result which " + "will be reused in backward.") + .AsIntermediate(); + AddOutput("Out", "Squared l2 distance between input and target."); + AddComment(R"DOC( + SquaredL2DistanceOp will cacluate the squared L2 distances for + input and target. Number of distance value equals to the + first dimension of input. + )DOC"); + } +}; + +class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + ctx.Output(framework::GradVarName("X")) + ->Resize(ctx.Input("X")->dims()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(squared_l2_distance, ops::SquaredL2DistanceOp, + ops::SquaredL2DistanceOpMaker, squared_l2_distance_grad, + ops::SquaredL2DistanceGradOp); +REGISTER_OP_CPU_KERNEL( + squared_l2_distance, + ops::SquaredL2DistanceKernel); +REGISTER_OP_CPU_KERNEL( + squared_l2_distance_grad, + ops::SquaredL2DistanceGradKernel); diff --git a/paddle/operators/squared_l2_distance_op.cu b/paddle/operators/squared_l2_distance_op.cu new file mode 100644 index 0000000000..3fe62f1a9c --- /dev/null +++ b/paddle/operators/squared_l2_distance_op.cu @@ -0,0 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU + +#include "paddle/operators/squared_l2_distance_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + squared_l2_distance, + ops::SquaredL2DistanceKernel); +REGISTER_OP_GPU_KERNEL( + squared_l2_distance_grad, + ops::SquaredL2DistanceGradKernel); diff --git a/paddle/operators/squared_l2_distance_op.h b/paddle/operators/squared_l2_distance_op.h new file mode 100644 index 0000000000..b350fd0117 --- /dev/null +++ b/paddle/operators/squared_l2_distance_op.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; +template +using EigenVector = framework::EigenVector; + +template +class SquaredL2DistanceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input0 = context.Input("X"); + auto* input1 = context.Input("Y"); + auto* output0 = context.Output("sub_result"); + auto* output1 = context.Output("Out"); + + output0->mutable_data(context.GetPlace()); + output1->mutable_data(context.GetPlace()); + + auto X = EigenMatrix::From(*input0); + auto Y = EigenMatrix::From(*input1); + auto subResult = EigenMatrix::From(*output0); + auto Z = EigenMatrix::From(*output1); + + auto place = context.GetEigenDevice(); + // buffer the substraction result + subResult.device(place) = X - Y; + const auto& inDims = X.dimensions(); + const auto& subResMat = subResult.reshape(Eigen::array( + {static_cast(inDims[0]), static_cast(X.size() / inDims[0])})); + Z.device(place) = subResMat.pow(2).sum(Eigen::array({1})); + } +}; + +template +class SquaredL2DistanceGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input0 = context.Input("sub_result"); + auto* OG = context.Input(framework::GradVarName("Out")); + auto* IG = context.Output(framework::GradVarName("X")); + + IG->mutable_data(context.GetPlace()); + + auto subResult = EigenMatrix::From(*input0); + auto outGrad = EigenMatrix::From(*OG); + auto inGrad = EigenMatrix::From(*IG); + + const auto& subResDims = subResult.dimensions(); + int firstDim = static_cast(subResDims[0]); + int cols = subResult.size() / firstDim; + const auto subResMat = + subResult.reshape(Eigen::array({firstDim, cols})); + // create a matrix view for input gradient tensor + auto inGradMat = inGrad.reshape(Eigen::array({firstDim, cols})); + inGradMat.device(context.GetEigenDevice()) = + 2 * (outGrad.broadcast(Eigen::array({1, cols}))) * subResMat; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 37e186a408..df8c2b37cf 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -18,5 +18,6 @@ cc_library(paddle_pybind SHARED fill_zeros_like_op lookup_table_op scale_op - minus_op) + minus_op + squared_l2_distance_op) endif(WITH_PYTHON) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 3bc150ccb7..69a5f98a43 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -48,6 +48,7 @@ USE_OP_ITSELF(identity); USE_OP(minus); USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(scatter); +USE_OP(squared_l2_distance); namespace paddle { namespace framework { diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 661ebd8964..06ff1f4a0c 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -32,3 +32,4 @@ py_test(test_gradient_checker SRCS test_gradient_checker.py) py_test(test_lookup_table SRCS test_lookup_table.py) py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py) py_test(mnist SRCS mnist.py) +py_test(test_squared_l2_distance_op SRCS test_squared_l2_distance_op.py) diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index 3bc05a0fec..370f27eaf6 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -6,13 +6,13 @@ from paddle.v2.framework.op import Operator class OpTestMeta(type): """ Operator Test ClassMeta. - - It injects `test_all` method into user's OperatorTest class, to make Python + + It injects `test_all` method into user's OperatorTest class, to make Python unittest module run that method. - + The `test_all` read what value is stored in `self`. It use self's values to create and run a operator, and check whether that op is OK or not. - + See `test_add_two_op` for example usage. """ @@ -66,7 +66,7 @@ class OpTestMeta(type): self.assertTrue( numpy.allclose( actual, expect, atol=1e-05), - "output name: " + out_name + "has diff") + "output name: " + out_name + " has diff") obj.test_all = test_all return obj diff --git a/python/paddle/v2/framework/tests/test_squared_l2_distance_op.py b/python/paddle/v2/framework/tests/test_squared_l2_distance_op.py new file mode 100644 index 0000000000..eeddb5a3bf --- /dev/null +++ b/python/paddle/v2/framework/tests/test_squared_l2_distance_op.py @@ -0,0 +1,25 @@ +import unittest +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op +import numpy as np + + +class TestSquaredL2DistanceOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = 'squared_l2_distance' + self.inputs = { + 'X': np.random.uniform(0.1, 1., (2, 3)).astype('float32'), + 'Y': np.random.uniform(0.1, 1., (2, 3)).astype('float32') + } + subRes = self.inputs['X'] - self.inputs['Y'] + output = subRes * subRes + self.outputs = { + 'sub_result': subRes, + 'Out': np.expand_dims(output.sum(1), 1) + } + + +if __name__ == '__main__': + unittest.main() -- GitLab