From 3994e91a678b8547af77b6b7f4629f122b0d9f07 Mon Sep 17 00:00:00 2001 From: guosheng Date: Fri, 8 Sep 2017 18:39:01 +0800 Subject: [PATCH] Add reduce_op --- paddle/operators/reduce_op.cc | 207 +++++++++++++++ paddle/operators/reduce_op.cu | 46 ++++ paddle/operators/reduce_op.h | 251 ++++++++++++++++++ .../v2/framework/tests/test_reduce_op.py | 92 +++++++ 4 files changed, 596 insertions(+) create mode 100644 paddle/operators/reduce_op.cc create mode 100644 paddle/operators/reduce_op.cu create mode 100644 paddle/operators/reduce_op.h create mode 100644 python/paddle/v2/framework/tests/test_reduce_op.py diff --git a/paddle/operators/reduce_op.cc b/paddle/operators/reduce_op.cc new file mode 100644 index 0000000000..ea4bfc50b2 --- /dev/null +++ b/paddle/operators/reduce_op.cc @@ -0,0 +1,207 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/reduce_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; +using framework::DDim; + +class ReduceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + auto x_dims = ctx.Input("X")->dims(); + auto x_rank = x_dims.size(); + PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported"); + int dim = static_cast(ctx.Attr("dim")); + if (dim < 0) dim = x_rank + dim; + PADDLE_ENFORCE_LT( + dim, x_rank, + "The dim should be in the range [-rank(input), rank(input)]"); + bool keep_dim = true; // TODO; + auto dims_vector = vectorize(x_dims); + if (keep_dim || x_rank == 1) { + dims_vector[dim] = 1; + } else { + dims_vector.erase(dims_vector.begin() + dim); + } + auto out_dims = framework::make_ddim(dims_vector); + ctx.Output("Out")->Resize(out_dims); + } +}; + +class ReduceGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx.Input("X")->dims(); + auto x_rank = x_dims.size(); + PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported"); + int dim = static_cast(ctx.Attr("dim")); + if (dim < 0) dim = x_rank + dim; + PADDLE_ENFORCE_LT( + dim, x_rank, + "The dim should be in the range [-rank(input), rank(input)]"); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + if (x_grad) x_grad->Resize(x_dims); + } +}; + +class ReduceSumOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ReduceSumOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor) The input tensor. Tensors with rank at most 6 are supported"); + AddOutput("Out", "(Tensor) The result tensor."); + AddComment(R"DOC( +ReduceMean operator computes the sum of input tensor along the given dimension. +The result tensor has 1 fewer dimension than the input unless `keep_dim` is true. +)DOC"); + AddAttr("dim", + "(int, default 0) The dimension to reduce. " + "Must be in the range [-rank(input), rank(input)]") + .SetDefault(0); + AddAttr("keep_dim", + "(bool, default fasle) " + "If true, retain the reduced dimension with length 1.") + .SetDefault(false); + } +}; + +class ReduceMeanOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ReduceMeanOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor) The input tensor. Tensors with rank at most 6 are supported"); + AddOutput("Out", "(Tensor) The result tensor."); + AddComment(R"DOC( +ReduceMean operator computes the mean of input tensor along the given dimension. +The result tensor has 1 fewer dimension than the input unless `keep_dim` is true. +)DOC"); + AddAttr("dim", + "(int, default 0) The dimension to reduce. " + "Must be in the range [-rank(input), rank(input)]") + .SetDefault(0); + AddAttr("keep_dim", + "(bool, default fasle) " + "If true, retain the reduced dimension with length 1.") + .SetDefault(false); + } +}; + +class ReduceMaxOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ReduceMaxOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor) The input tensor. Tensors with rank at most 6 are supported"); + AddOutput("Out", "(Tensor) The result tensor."); + AddComment(R"DOC( +ReduceMax operator computes the maximum of input tensor along the given dimension. +The result tensor has 1 fewer dimension than the input unless `keep_dim` is true. +)DOC"); + AddAttr("dim", + "(int, default 0) The dimension to reduce. " + "Must be in the range [-rank(input), rank(input)]") + .SetDefault(0); + AddAttr("keep_dim", + "(bool, default fasle) " + "If true, retain the reduced dimension with length 1.") + .SetDefault(false); + } +}; + +class ReduceMinOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ReduceMinOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor) The input tensor. Tensors with rank at most 6 are supported"); + AddOutput("Out", "(Tensor) The result tensor."); + AddComment(R"DOC( +ReduceMin operator computes the minimum of input tensor along the given dimension. +The result tensor has 1 fewer dimension than the input unless `keep_dim` is true. +)DOC"); + AddAttr("dim", + "(int, default 0) The dimension to reduce. " + "Must be in the range [-rank(input), rank(input)]") + .SetDefault(0); + AddAttr("keep_dim", + "(bool, default fasle) " + "If true, retain the reduced dimension with length 1.") + .SetDefault(false); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP(reduce_sum, ops::ReduceOp, ops::ReduceSumOpMaker, reduce_sum_grad, + ops::ReduceGradOp); +REGISTER_OP_CPU_KERNEL( + reduce_sum, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL(reduce_sum_grad, + ops::ReduceGradKernel); + +REGISTER_OP(reduce_mean, ops::ReduceOp, ops::ReduceMeanOpMaker, + reduce_mean_grad, ops::ReduceGradOp); +REGISTER_OP_CPU_KERNEL( + reduce_mean, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL(reduce_mean_grad, + ops::ReduceGradKernel); + +REGISTER_OP(reduce_max, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_max_grad, + ops::ReduceGradOp); +REGISTER_OP_CPU_KERNEL( + reduce_max, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL(reduce_max_grad, + ops::ReduceGradKernel); + +REGISTER_OP(reduce_min, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_min_grad, + ops::ReduceGradOp); +REGISTER_OP_CPU_KERNEL( + reduce_min, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL(reduce_min_grad, + ops::ReduceGradKernel); diff --git a/paddle/operators/reduce_op.cu b/paddle/operators/reduce_op.cu new file mode 100644 index 0000000000..9effc17ed3 --- /dev/null +++ b/paddle/operators/reduce_op.cu @@ -0,0 +1,46 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/reduce_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + reduce_sum, + ops::ReduceKernel); +REGISTER_OP_GPU_KERNEL(reduce_sum_grad, + ops::ReduceGradEigenKernel); + +REGISTER_OP_GPU_KERNEL( + reduce_mean, + ops::ReduceKernel); +REGISTER_OP_GPU_KERNEL(reduce_mean_grad, + ops::ReduceGradKernel); + +REGISTER_OP_GPU_KERNEL( + reduce_max, + ops::ReduceKernel); +REGISTER_OP_GPU_KERNEL(reduce_max_grad, + ops::ReduceGradKernel); + +REGISTER_OP_GPU_KERNEL( + reduce_min, + ops::ReduceKernel); +REGISTER_OP_GPU_KERNEL(reduce_min_grad, + ops::ReduceGradKernel); \ No newline at end of file diff --git a/paddle/operators/reduce_op.h b/paddle/operators/reduce_op.h new file mode 100644 index 0000000000..9fd7d335ac --- /dev/null +++ b/paddle/operators/reduce_op.h @@ -0,0 +1,251 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/operators/math/math_function.h" + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DDim = framework::DDim; +template +using EigenTensor = framework::EigenTensor; + +struct SumFunctor { + template + void operator()(const Place& place, In& in, Out& out, const Dim& dim) { + out.device(place) = in.sum(dim); + } +}; + +struct SumGradFunctor { + template + void operator()(const Place& place, In_Const& in, In& in_grad, Out& out, + Out& out_grad, const Dim& dim, int size) { + in_grad.device(place) = out_grad.broadcast(dim); + } +}; + +struct MeanFunctor { + template + void operator()(const Place& place, In& in, Out& out, const Dim& dim) { + out.device(place) = in.mean(dim); + } +}; + +struct MeanGradFunctor { + template + void operator()(const Place& place, In_Const& in, In& in_grad, Out& out, + Out& out_grad, const Dim& dim, int size) { + in_grad.device(place) = out_grad.broadcast(dim) / in_grad.constant(size); + } +}; + +struct MaxFunctor { + template + void operator()(const Place& place, In& in, Out& out, const Dim& dim) { + out.device(place) = in.maximum(dim); + } +}; + +struct MinFunctor { + template + void operator()(const Place& place, In& in, Out& out, const Dim& dim) { + out.device(place) = in.minimum(dim); + } +}; + +struct MaxOrMinGradFunctor { + template + void operator()(const Place& place, In_Const& in, In& in_grad, Out& out, + Out& out_grad, const Dim& dim, int size) { + auto equals = in == out.broadcast(dim); + auto ones = in_grad.constant(1); + auto zeros = in_grad.constant(0); + in_grad.device(place) = + out_grad.broadcast(dim) * equals.select(ones, zeros); + } +}; + +template +class ReduceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + int rank = context.Input("X")->dims().size(); + switch (rank) { + case 1: + ReduceCompute<1>(context); + break; + case 2: + ReduceCompute<2>(context); + break; + case 3: + ReduceCompute<3>(context); + break; + case 4: + ReduceCompute<4>(context); + break; + case 5: + ReduceCompute<5>(context); + break; + case 6: + ReduceCompute<6>(context); + break; + } + } + + private: + template + void ReduceCompute(const framework::ExecutionContext& context) const { + auto* input = context.Input("X"); + auto* output = context.Output("Out"); + output->mutable_data(context.GetPlace()); + + auto x = EigenTensor::From(*input); + auto x_rank = static_cast(x.dimensions().size()); + int dim = static_cast(context.Attr("dim")); + if (dim < 0) dim = x_rank + dim; + auto reduce_dim = Eigen::array({{dim}}); + // construct the squeezed output tensor + bool keep_dim = true; // static_cast(context.Attr("keep_dim")); + DDim dims = output->dims(); + auto dims_vector = vectorize(dims); + if (keep_dim && x_rank > 1) { + dims_vector.erase(dims_vector.begin() + dim); + dims = framework::make_ddim(dims_vector); + } + auto out = EigenTensor < T, D == 1 ? 1 : (D - 1) > ::From(*output, dims); + auto& place = context.GetEigenDevice(); + Functor functor; + functor(place, x, out, reduce_dim); + } +}; + +template +class ReduceGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + int rank = context.Input("X")->dims().size(); + switch (rank) { + case 1: + ReduceCompute<1>(context); + break; + case 2: + ReduceCompute<2>(context); + break; + case 3: + ReduceCompute<3>(context); + break; + case 4: + ReduceCompute<4>(context); + break; + case 5: + ReduceCompute<5>(context); + break; + case 6: + ReduceCompute<6>(context); + break; + } + } + + private: + template + void ReduceCompute(const framework::ExecutionContext& context) const { + auto* input0 = context.Input("X"); + auto* input1 = context.Input("Out"); + auto* input2 = context.Input(framework::GradVarName("Out")); + auto* output = context.Output(framework::GradVarName("X")); + + if (output != nullptr) { + output->mutable_data(context.GetPlace()); + auto x = EigenTensor::From(*input0); + auto x_grad = EigenTensor::From(*output); + auto x_rank = static_cast(x.dimensions().size()); + int dim = static_cast(context.Attr("dim")); + if (dim < 0) dim = x_rank + dim; + DDim dims = input0->dims(); + dims[dim] = 1; + auto x_reduce = EigenTensor::From(*input1, dims); + auto x_reduce_grad = EigenTensor::From(*input2, dims); + + Eigen::array braodcast_dim; + for (size_t i = 0; i < D; ++i) braodcast_dim[i] = 1; + braodcast_dim[dim] = input0->dims()[dim]; + auto& place = context.GetEigenDevice(); + Functor functor; + functor(place, x, x_grad, x_reduce, x_reduce_grad, braodcast_dim, + braodcast_dim[dim]); + } + } +}; + +// For EigenTensor unsupported reduce +template +class ReduceGradEigenFreeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out = context.Input("Out"); + auto* x_grad = context.Output(framework::GradVarName("X")); + auto* out_grad = context.Input(framework::GradVarName("Out")); + if (x_grad != nullptr) { + DDim dims = x->dims(); + int rank = dims.size(); + int dim = static_cast(context.Attr("dim")); + if (dim < 0) dim = rank + dim; + + auto* x_data = x->data(); + auto* x_grad_data = x_grad->mutable_data(context.GetPlace()); + auto* out_data = out->data(); + auto* out_grad_data = out_grad->data(); + + int outer_count = 1; + int inner_count = 1; + int mid_count = dims[dim]; + for (int i = 0; i < dim; ++i) { + outer_count *= dims[i]; + } + for (int i = dim + 1; i < rank; ++i) { + inner_count *= dims[i]; + } + + int x_offset = 0; // offset on raw data + int out_offset = 0; // offset on reduced data + Functor functor; + for (int i = 0; i < outer_count; ++i) { + for (int j = 0; j < inner_count; ++j) { + out_offset = inner_count * i + j; + for (int k = 0; k < mid_count; ++k) { + x_offset = (inner_count * mid_count) * i + inner_count * k + j; + functor(x_data + x_offset, x_grad_data + x_offset, + out_data + out_offset, out_grad_data + out_offset, + mid_count); + } + } + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_reduce_op.py b/python/paddle/v2/framework/tests/test_reduce_op.py new file mode 100644 index 0000000000..49ef8eabd2 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_reduce_op.py @@ -0,0 +1,92 @@ +import unittest +import numpy as np +from gradient_checker import GradientChecker, create_op +from op_test_util import OpTestMeta +from paddle.v2.framework.op import Operator + + +class TestSumOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "reduce_sum" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': -2} + out = self.inputs['X'].sum(axis=self.attrs['dim']) + self.outputs = {'Out': out} + + +class TestSumGradOp(GradientChecker): + def test_normal(self): + op = Operator("reduce_sum", X="X", Out="Out", dim=-2) + # use small size to decrease the error of numerical calculation + inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.check_grad(op, inputs, set(["X"]), "Out") + + def test_1d_tensor(self): + op = Operator("reduce_sum", X="X", Out="Out", dim=0) + # use small size to decrease the error of numerical calculation + inputs = {'X': np.random.random(10).astype("float32")} + self.check_grad(op, inputs, set(["X"]), "Out") + + +class TestKeepdimSumOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "reduce_sum" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': -2} + out = self.inputs['X'].sum(axis=self.attrs['dim'], keepdims=True) + self.outputs = {'Out': out} + + +class TestMeanOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "reduce_mean" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': -1} + out = self.inputs['X'].mean(axis=self.attrs['dim']) + self.outputs = {'Out': out} + + +class TestMeanGradOp(GradientChecker): + def test_normal(self): + op = Operator("reduce_mean", X="X", Out="Out", dim=-2) + # use small size to decrease the error of numerical calculation + inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.check_grad(op, inputs, set(["X"]), "Out") + + def test_1d_tensor(self): + op = Operator("reduce_mean", X="X", Out="Out", dim=0) + # use small size to decrease the error of numerical calculation + inputs = {'X': np.random.random(10).astype("float32")} + self.check_grad(op, inputs, set(["X"]), "Out") + + +class TestMaxOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "reduce_max" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': -1} + out = self.inputs['X'].max(axis=self.attrs['dim']) + self.outputs = {'Out': out} + + +class TestMinOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "reduce_max" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': -2} + out = self.inputs['X'].min(axis=self.attrs['dim']) + self.outputs = {'Out': out} + + +if __name__ == '__main__': + unittest.main() -- GitLab