提交 ecef2e6b 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #4086 from guoshengCS/add-ReduceOp

Add reduce op
......@@ -62,6 +62,13 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(sigmoid);\n")
endif()
# reduce_op contains several operators
if ("${TARGET}" STREQUAL "reduce_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n")
endif()
# pybind USE_NO_KERNEL_OP
file(READ ${TARGET}.cc TARGET_CONTENT)
string(REGEX MATCH "OperatorWithKernel" regex_result "${TARGET_CONTENT}")
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/reduce_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class ReduceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ReduceOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ReduceOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
int dim = ctx->Attrs().Get<int>("dim");
if (dim < 0) dim = x_rank + dim;
PADDLE_ENFORCE_LT(
dim, x_rank,
"The dim should be in the range [-rank(input), rank(input)).");
bool keep_dim = ctx->Attrs().Get<bool>("keep_dim");
auto dims_vector = vectorize(x_dims);
if (keep_dim || x_rank == 1) {
dims_vector[dim] = 1;
} else {
dims_vector.erase(dims_vector.begin() + dim);
}
auto out_dims = framework::make_ddim(dims_vector);
ctx->SetOutputDim("Out", out_dims);
if (dim != 0) {
// Only pass LoD when not reducing on the first dim.
ctx->ShareLoD("X", /*->*/ "Out");
}
}
};
class ReduceGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
int dim = ctx->Attrs().Get<int>("dim");
if (dim < 0) dim = x_rank + dim;
PADDLE_ENFORCE_LT(
dim, x_rank,
"The dim should be in the range [-rank(input), rank(input)).");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
};
class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"(Tensor) The input tensor. Tensors with rank at most 6 are supported");
AddOutput("Out", "(Tensor) The result tensor.");
AddAttr<int>(
"dim",
"(int, default 1) The dimension to reduce. "
"Must be in the range [-rank(input), rank(input)). "
"If `dim < 0`, the dim to reduce is `rank + dim`. "
"Noting that reducing on the first dim will make the LoD info lost.")
.SetDefault(0);
AddAttr<bool>("keep_dim",
"(bool, default false) "
"If true, retain the reduced dimension with length 1.")
.SetDefault(false);
comment_ = R"DOC(
{ReduceOP} operator computes the {reduce} of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless `keep_dim` is true.
)DOC";
AddComment(comment_);
}
protected:
std::string comment_;
void Replace(std::string &src, std::string from, std::string to) {
std::size_t len_from = std::strlen(from.c_str());
std::size_t len_to = std::strlen(to.c_str());
for (std::size_t pos = src.find(from); pos != std::string::npos;
pos = src.find(from, pos + len_to)) {
src.replace(pos, len_from, to);
}
}
void SetComment(std::string name, std::string op) {
Replace(comment_, "{ReduceOP}", name);
Replace(comment_, "{reduce}", op);
}
};
class ReduceSumOpMaker : public ReduceOpMaker {
public:
ReduceSumOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceSum", "sum");
AddComment(comment_);
}
};
class ReduceMeanOpMaker : public ReduceOpMaker {
public:
ReduceMeanOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceMean", "mean");
AddComment(comment_);
}
};
class ReduceMaxOpMaker : public ReduceOpMaker {
public:
ReduceMaxOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceMax", "max");
AddComment(comment_);
}
};
class ReduceMinOpMaker : public ReduceOpMaker {
public:
ReduceMinOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceMin", "min");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(reduce_sum, ops::ReduceOp, ops::ReduceSumOpMaker, reduce_sum_grad,
ops::ReduceGradOp);
REGISTER_OP_CPU_KERNEL(
reduce_sum,
ops::ReduceKernel<paddle::platform::CPUPlace, float, ops::SumFunctor>);
REGISTER_OP_CPU_KERNEL(reduce_sum_grad,
ops::ReduceGradKernel<paddle::platform::CPUPlace, float,
ops::SumGradFunctor>);
REGISTER_OP(reduce_mean, ops::ReduceOp, ops::ReduceMeanOpMaker,
reduce_mean_grad, ops::ReduceGradOp);
REGISTER_OP_CPU_KERNEL(
reduce_mean,
ops::ReduceKernel<paddle::platform::CPUPlace, float, ops::MeanFunctor>);
REGISTER_OP_CPU_KERNEL(reduce_mean_grad,
ops::ReduceGradKernel<paddle::platform::CPUPlace, float,
ops::MeanGradFunctor>);
REGISTER_OP(reduce_max, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_max_grad,
ops::ReduceGradOp);
REGISTER_OP_CPU_KERNEL(
reduce_max,
ops::ReduceKernel<paddle::platform::CPUPlace, float, ops::MaxFunctor>);
REGISTER_OP_CPU_KERNEL(reduce_max_grad,
ops::ReduceGradKernel<paddle::platform::CPUPlace, float,
ops::MaxOrMinGradFunctor>);
REGISTER_OP(reduce_min, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_min_grad,
ops::ReduceGradOp);
REGISTER_OP_CPU_KERNEL(
reduce_min,
ops::ReduceKernel<paddle::platform::CPUPlace, float, ops::MinFunctor>);
REGISTER_OP_CPU_KERNEL(reduce_min_grad,
ops::ReduceGradKernel<paddle::platform::CPUPlace, float,
ops::MaxOrMinGradFunctor>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/reduce_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
reduce_sum,
ops::ReduceKernel<paddle::platform::GPUPlace, float, ops::SumFunctor>);
REGISTER_OP_GPU_KERNEL(reduce_sum_grad,
ops::ReduceGradKernel<paddle::platform::GPUPlace, float,
ops::SumGradFunctor>);
REGISTER_OP_GPU_KERNEL(
reduce_mean,
ops::ReduceKernel<paddle::platform::GPUPlace, float, ops::MeanFunctor>);
REGISTER_OP_GPU_KERNEL(reduce_mean_grad,
ops::ReduceGradKernel<paddle::platform::GPUPlace, float,
ops::MeanGradFunctor>);
REGISTER_OP_GPU_KERNEL(
reduce_max,
ops::ReduceKernel<paddle::platform::GPUPlace, float, ops::MaxFunctor>);
REGISTER_OP_GPU_KERNEL(reduce_max_grad,
ops::ReduceGradKernel<paddle::platform::GPUPlace, float,
ops::MaxOrMinGradFunctor>);
REGISTER_OP_GPU_KERNEL(
reduce_min,
ops::ReduceKernel<paddle::platform::GPUPlace, float, ops::MinFunctor>);
REGISTER_OP_GPU_KERNEL(reduce_min_grad,
ops::ReduceGradKernel<paddle::platform::GPUPlace, float,
ops::MaxOrMinGradFunctor>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
struct SumFunctor {
template <typename Place, typename X, typename Y, typename Dim>
void operator()(const Place& place, X& x, Y& y, const Dim& dim) {
y.device(place) = x.sum(dim);
}
};
struct SumGradFunctor {
template <typename Place, typename X, typename Y, typename DX, typename DY,
typename Dim>
void operator()(const Place& place, X& x, Y& y, DX& dx, DY& dy,
const Dim& dim, int size) {
dx.device(place) = dy.broadcast(dim);
}
};
struct MeanFunctor {
template <typename Place, typename X, typename Y, typename Dim>
void operator()(const Place& place, X& x, Y& y, const Dim& dim) {
y.device(place) = x.mean(dim);
}
};
struct MeanGradFunctor {
template <typename Place, typename X, typename Y, typename DX, typename DY,
typename Dim>
void operator()(const Place& place, X& x, Y& y, DX& dx, DY& dy,
const Dim& dim, int size) {
dx.device(place) = dy.broadcast(dim) / dx.constant(size);
}
};
struct MaxFunctor {
template <typename Place, typename X, typename Y, typename Dim>
void operator()(const Place& place, X& x, Y& y, const Dim& dim) {
y.device(place) = x.maximum(dim);
}
};
struct MinFunctor {
template <typename Place, typename X, typename Y, typename Dim>
void operator()(const Place& place, X& x, Y& y, const Dim& dim) {
y.device(place) = x.minimum(dim);
}
};
struct MaxOrMinGradFunctor {
template <typename Place, typename X, typename Y, typename DX, typename DY,
typename Dim>
void operator()(const Place& place, X& x, Y& y, DX& dx, DY& dy,
const Dim& dim, int size) {
auto equals = x == y.broadcast(dim);
auto ones = dx.constant(1);
auto zeros = dx.constant(0);
// If there are multiple minimum or maximum elements, the subgradient of
// each is the set [0, 1], and we pass gradient to all of them here.
dx.device(place) = dy.broadcast(dim) * equals.select(ones, zeros);
}
};
template <typename Place, typename T, typename Functor>
class ReduceKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size();
switch (rank) {
case 1:
ReduceCompute<1>(context);
break;
case 2:
ReduceCompute<2>(context);
break;
case 3:
ReduceCompute<3>(context);
break;
case 4:
ReduceCompute<4>(context);
break;
case 5:
ReduceCompute<5>(context);
break;
case 6:
ReduceCompute<6>(context);
break;
}
}
private:
template <size_t D>
void ReduceCompute(const framework::ExecutionContext& context) const {
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
auto x = EigenTensor<T, D>::From(*input);
auto x_rank = static_cast<int>(x.dimensions().size());
int dim = static_cast<int>(context.Attr<int>("dim"));
if (dim < 0) dim = x_rank + dim;
auto reduce_dim = Eigen::array<int, 1>({{dim}});
// construct the squeezed output tensor
bool keep_dim = context.Attr<bool>("keep_dim");
DDim dims = output->dims();
auto dims_vector = vectorize(dims);
if (keep_dim && x_rank > 1) {
dims_vector.erase(dims_vector.begin() + dim);
dims = framework::make_ddim(dims_vector);
}
auto out = EigenTensor < T, D == 1 ? 1 : (D - 1) > ::From(*output, dims);
auto& place = context.GetEigenDevice<Place>();
Functor functor;
functor(place, x, out, reduce_dim);
}
};
template <typename Place, typename T, typename Functor>
class ReduceGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size();
switch (rank) {
case 1:
ReduceGradCompute<1>(context);
break;
case 2:
ReduceGradCompute<2>(context);
break;
case 3:
ReduceGradCompute<3>(context);
break;
case 4:
ReduceGradCompute<4>(context);
break;
case 5:
ReduceGradCompute<5>(context);
break;
case 6:
ReduceGradCompute<6>(context);
break;
}
}
private:
template <size_t D>
void ReduceGradCompute(const framework::ExecutionContext& context) const {
auto* input0 = context.Input<Tensor>("X");
auto* input1 = context.Input<Tensor>("Out");
auto* input2 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* output = context.Output<Tensor>(framework::GradVarName("X"));
output->mutable_data<T>(context.GetPlace());
auto x = EigenTensor<T, D>::From(*input0);
auto x_grad = EigenTensor<T, D>::From(*output);
auto x_rank = static_cast<int>(x.dimensions().size());
int dim = static_cast<int>(context.Attr<int>("dim"));
if (dim < 0) dim = x_rank + dim;
DDim dims = input0->dims();
dims[dim] = 1;
auto x_reduce = EigenTensor<T, D>::From(*input1, dims);
auto x_reduce_grad = EigenTensor<T, D>::From(*input2, dims);
Eigen::array<int, D> braodcast_dim;
for (size_t i = 0; i < D; ++i) braodcast_dim[i] = 1;
braodcast_dim[dim] = input0->dims()[dim];
auto& place = context.GetEigenDevice<Place>();
Functor functor;
functor(place, x, x_reduce, x_grad, x_reduce_grad, braodcast_dim,
braodcast_dim[dim]);
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
class TestSumOp(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestMeanOp(OpTest):
def setUp(self):
self.op_type = "reduce_mean"
self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float32")}
self.attrs = {'dim': 1}
self.outputs = {'Out': self.inputs['X'].mean(axis=self.attrs['dim'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestMaxOp(OpTest):
"""Remove Max with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_max"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")}
self.attrs = {'dim': -1}
self.outputs = {'Out': self.inputs['X'].max(axis=self.attrs['dim'])}
def test_check_output(self):
self.check_output()
class TestMinOp(OpTest):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_min"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")}
self.attrs = {'dim': 2}
self.outputs = {'Out': self.inputs['X'].min(axis=self.attrs['dim'])}
def test_check_output(self):
self.check_output()
class TestKeepDimReduce(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")}
self.attrs = {'dim': -2, 'keep_dim': True}
self.outputs = {
'Out': self.inputs['X'].sum(axis=self.attrs['dim'], keepdims=True)
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class Test1DReduce(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {'X': np.random.random(20).astype("float32")}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册