diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index e56895c63a426b782f7b46091bc86c367d49899d..21166354937c378dc3f295f9011d034eb24cfc7c 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -61,6 +61,13 @@ function(op_library TARGET) # It's enough to just adding one operator to pybind file(APPEND ${pybind_file} "USE_OP(sigmoid);\n") endif() + + # reduce_op contains several operators + if ("${TARGET}" STREQUAL "reduce_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n") + endif() # pybind USE_NO_KERNEL_OP file(READ ${TARGET}.cc TARGET_CONTENT) diff --git a/paddle/operators/reduce_op.cc b/paddle/operators/reduce_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3ef443d1c7f475cbd578078db02fb5e0d500d060 --- /dev/null +++ b/paddle/operators/reduce_op.cc @@ -0,0 +1,203 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/reduce_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class ReduceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ReduceOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ReduceOp should not be null."); + auto x_dims = ctx->GetInputDim("X"); + auto x_rank = x_dims.size(); + PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported."); + int dim = ctx->Attrs().Get("dim"); + if (dim < 0) dim = x_rank + dim; + PADDLE_ENFORCE_LT( + dim, x_rank, + "The dim should be in the range [-rank(input), rank(input))."); + bool keep_dim = ctx->Attrs().Get("keep_dim"); + auto dims_vector = vectorize(x_dims); + if (keep_dim || x_rank == 1) { + dims_vector[dim] = 1; + } else { + dims_vector.erase(dims_vector.begin() + dim); + } + auto out_dims = framework::make_ddim(dims_vector); + ctx->SetOutputDim("Out", out_dims); + if (dim != 0) { + // Only pass LoD when not reducing on the first dim. + ctx->ShareLoD("X", /*->*/ "Out"); + } + } +}; + +class ReduceGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null."); + auto x_dims = ctx->GetInputDim("X"); + auto x_rank = x_dims.size(); + PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported."); + int dim = ctx->Attrs().Get("dim"); + if (dim < 0) dim = x_rank + dim; + PADDLE_ENFORCE_LT( + dim, x_rank, + "The dim should be in the range [-rank(input), rank(input))."); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + } +}; + +class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor) The input tensor. Tensors with rank at most 6 are supported"); + AddOutput("Out", "(Tensor) The result tensor."); + AddAttr( + "dim", + "(int, default 1) The dimension to reduce. " + "Must be in the range [-rank(input), rank(input)). " + "If `dim < 0`, the dim to reduce is `rank + dim`. " + "Noting that reducing on the first dim will make the LoD info lost.") + .SetDefault(0); + AddAttr("keep_dim", + "(bool, default false) " + "If true, retain the reduced dimension with length 1.") + .SetDefault(false); + comment_ = R"DOC( +{ReduceOP} operator computes the {reduce} of input tensor along the given dimension. +The result tensor has 1 fewer dimension than the input unless `keep_dim` is true. +)DOC"; + AddComment(comment_); + } + + protected: + std::string comment_; + + void Replace(std::string &src, std::string from, std::string to) { + std::size_t len_from = std::strlen(from.c_str()); + std::size_t len_to = std::strlen(to.c_str()); + for (std::size_t pos = src.find(from); pos != std::string::npos; + pos = src.find(from, pos + len_to)) { + src.replace(pos, len_from, to); + } + } + + void SetComment(std::string name, std::string op) { + Replace(comment_, "{ReduceOP}", name); + Replace(comment_, "{reduce}", op); + } +}; + +class ReduceSumOpMaker : public ReduceOpMaker { + public: + ReduceSumOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : ReduceOpMaker(proto, op_checker) { + SetComment("ReduceSum", "sum"); + AddComment(comment_); + } +}; + +class ReduceMeanOpMaker : public ReduceOpMaker { + public: + ReduceMeanOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : ReduceOpMaker(proto, op_checker) { + SetComment("ReduceMean", "mean"); + AddComment(comment_); + } +}; + +class ReduceMaxOpMaker : public ReduceOpMaker { + public: + ReduceMaxOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : ReduceOpMaker(proto, op_checker) { + SetComment("ReduceMax", "max"); + AddComment(comment_); + } +}; + +class ReduceMinOpMaker : public ReduceOpMaker { + public: + ReduceMinOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : ReduceOpMaker(proto, op_checker) { + SetComment("ReduceMin", "min"); + AddComment(comment_); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP(reduce_sum, ops::ReduceOp, ops::ReduceSumOpMaker, reduce_sum_grad, + ops::ReduceGradOp); +REGISTER_OP_CPU_KERNEL( + reduce_sum, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL(reduce_sum_grad, + ops::ReduceGradKernel); + +REGISTER_OP(reduce_mean, ops::ReduceOp, ops::ReduceMeanOpMaker, + reduce_mean_grad, ops::ReduceGradOp); +REGISTER_OP_CPU_KERNEL( + reduce_mean, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL(reduce_mean_grad, + ops::ReduceGradKernel); + +REGISTER_OP(reduce_max, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_max_grad, + ops::ReduceGradOp); +REGISTER_OP_CPU_KERNEL( + reduce_max, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL(reduce_max_grad, + ops::ReduceGradKernel); + +REGISTER_OP(reduce_min, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_min_grad, + ops::ReduceGradOp); +REGISTER_OP_CPU_KERNEL( + reduce_min, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL(reduce_min_grad, + ops::ReduceGradKernel); diff --git a/paddle/operators/reduce_op.cu b/paddle/operators/reduce_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..595127b858ea8eb41281f92e92c6467e4d90ff1a --- /dev/null +++ b/paddle/operators/reduce_op.cu @@ -0,0 +1,46 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/reduce_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + reduce_sum, + ops::ReduceKernel); +REGISTER_OP_GPU_KERNEL(reduce_sum_grad, + ops::ReduceGradKernel); + +REGISTER_OP_GPU_KERNEL( + reduce_mean, + ops::ReduceKernel); +REGISTER_OP_GPU_KERNEL(reduce_mean_grad, + ops::ReduceGradKernel); + +REGISTER_OP_GPU_KERNEL( + reduce_max, + ops::ReduceKernel); +REGISTER_OP_GPU_KERNEL(reduce_max_grad, + ops::ReduceGradKernel); + +REGISTER_OP_GPU_KERNEL( + reduce_min, + ops::ReduceKernel); +REGISTER_OP_GPU_KERNEL(reduce_min_grad, + ops::ReduceGradKernel); diff --git a/paddle/operators/reduce_op.h b/paddle/operators/reduce_op.h new file mode 100644 index 0000000000000000000000000000000000000000..2fbf94e34f3961a9b3140fb682a7c479f3b71f4d --- /dev/null +++ b/paddle/operators/reduce_op.h @@ -0,0 +1,200 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DDim = framework::DDim; +template +using EigenTensor = framework::EigenTensor; + +struct SumFunctor { + template + void operator()(const Place& place, X& x, Y& y, const Dim& dim) { + y.device(place) = x.sum(dim); + } +}; + +struct SumGradFunctor { + template + void operator()(const Place& place, X& x, Y& y, DX& dx, DY& dy, + const Dim& dim, int size) { + dx.device(place) = dy.broadcast(dim); + } +}; + +struct MeanFunctor { + template + void operator()(const Place& place, X& x, Y& y, const Dim& dim) { + y.device(place) = x.mean(dim); + } +}; + +struct MeanGradFunctor { + template + void operator()(const Place& place, X& x, Y& y, DX& dx, DY& dy, + const Dim& dim, int size) { + dx.device(place) = dy.broadcast(dim) / dx.constant(size); + } +}; + +struct MaxFunctor { + template + void operator()(const Place& place, X& x, Y& y, const Dim& dim) { + y.device(place) = x.maximum(dim); + } +}; + +struct MinFunctor { + template + void operator()(const Place& place, X& x, Y& y, const Dim& dim) { + y.device(place) = x.minimum(dim); + } +}; + +struct MaxOrMinGradFunctor { + template + void operator()(const Place& place, X& x, Y& y, DX& dx, DY& dy, + const Dim& dim, int size) { + auto equals = x == y.broadcast(dim); + auto ones = dx.constant(1); + auto zeros = dx.constant(0); + // If there are multiple minimum or maximum elements, the subgradient of + // each is the set [0, 1], and we pass gradient to all of them here. + dx.device(place) = dy.broadcast(dim) * equals.select(ones, zeros); + } +}; + +template +class ReduceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + int rank = context.Input("X")->dims().size(); + switch (rank) { + case 1: + ReduceCompute<1>(context); + break; + case 2: + ReduceCompute<2>(context); + break; + case 3: + ReduceCompute<3>(context); + break; + case 4: + ReduceCompute<4>(context); + break; + case 5: + ReduceCompute<5>(context); + break; + case 6: + ReduceCompute<6>(context); + break; + } + } + + private: + template + void ReduceCompute(const framework::ExecutionContext& context) const { + auto* input = context.Input("X"); + auto* output = context.Output("Out"); + output->mutable_data(context.GetPlace()); + + auto x = EigenTensor::From(*input); + auto x_rank = static_cast(x.dimensions().size()); + int dim = static_cast(context.Attr("dim")); + if (dim < 0) dim = x_rank + dim; + auto reduce_dim = Eigen::array({{dim}}); + // construct the squeezed output tensor + bool keep_dim = context.Attr("keep_dim"); + DDim dims = output->dims(); + auto dims_vector = vectorize(dims); + if (keep_dim && x_rank > 1) { + dims_vector.erase(dims_vector.begin() + dim); + dims = framework::make_ddim(dims_vector); + } + auto out = EigenTensor < T, D == 1 ? 1 : (D - 1) > ::From(*output, dims); + auto& place = context.GetEigenDevice(); + Functor functor; + functor(place, x, out, reduce_dim); + } +}; + +template +class ReduceGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + int rank = context.Input("X")->dims().size(); + switch (rank) { + case 1: + ReduceGradCompute<1>(context); + break; + case 2: + ReduceGradCompute<2>(context); + break; + case 3: + ReduceGradCompute<3>(context); + break; + case 4: + ReduceGradCompute<4>(context); + break; + case 5: + ReduceGradCompute<5>(context); + break; + case 6: + ReduceGradCompute<6>(context); + break; + } + } + + private: + template + void ReduceGradCompute(const framework::ExecutionContext& context) const { + auto* input0 = context.Input("X"); + auto* input1 = context.Input("Out"); + auto* input2 = context.Input(framework::GradVarName("Out")); + auto* output = context.Output(framework::GradVarName("X")); + + output->mutable_data(context.GetPlace()); + auto x = EigenTensor::From(*input0); + auto x_grad = EigenTensor::From(*output); + auto x_rank = static_cast(x.dimensions().size()); + int dim = static_cast(context.Attr("dim")); + if (dim < 0) dim = x_rank + dim; + DDim dims = input0->dims(); + dims[dim] = 1; + auto x_reduce = EigenTensor::From(*input1, dims); + auto x_reduce_grad = EigenTensor::From(*input2, dims); + + Eigen::array braodcast_dim; + for (size_t i = 0; i < D; ++i) braodcast_dim[i] = 1; + braodcast_dim[dim] = input0->dims()[dim]; + auto& place = context.GetEigenDevice(); + Functor functor; + functor(place, x, x_reduce, x_grad, x_reduce_grad, braodcast_dim, + braodcast_dim[dim]); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_reduce_op.py b/python/paddle/v2/framework/tests/test_reduce_op.py new file mode 100644 index 0000000000000000000000000000000000000000..70359d60cbe656150877673c63e81eae92d8ab9a --- /dev/null +++ b/python/paddle/v2/framework/tests/test_reduce_op.py @@ -0,0 +1,89 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestSumOp(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.outputs = {'Out': self.inputs['X'].sum(axis=0)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestMeanOp(OpTest): + def setUp(self): + self.op_type = "reduce_mean" + self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float32")} + self.attrs = {'dim': 1} + self.outputs = {'Out': self.inputs['X'].mean(axis=self.attrs['dim'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestMaxOp(OpTest): + """Remove Max with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': -1} + self.outputs = {'Out': self.inputs['X'].max(axis=self.attrs['dim'])} + + def test_check_output(self): + self.check_output() + + +class TestMinOp(OpTest): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_min" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': 2} + self.outputs = {'Out': self.inputs['X'].min(axis=self.attrs['dim'])} + + def test_check_output(self): + self.check_output() + + +class TestKeepDimReduce(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': -2, 'keep_dim': True} + self.outputs = { + 'Out': self.inputs['X'].sum(axis=self.attrs['dim'], keepdims=True) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class Test1DReduce(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {'X': np.random.random(20).astype("float32")} + self.outputs = {'Out': self.inputs['X'].sum(axis=0)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +if __name__ == '__main__': + unittest.main()