未验证 提交 d48172f2 编写于 作者: D dzhwinter 提交者: GitHub

split reduce op into multiple libraries, accelerate the compiling (#11029)

* "split into multiple .ccl"

* "refine file structure"

* "refine files"

* "remove the cmakelist"

* "fix typo"

* "fix typo"

* fix ci
上级 58031157
......@@ -156,15 +156,15 @@ class OpKernelRegistrar : public Registrar {
/**
* Macro to register OperatorKernel.
*/
#define REGISTER_OP_KERNEL(op_type, LIBRARY_TYPE, place_class, ...) \
#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##op_type##_##LIBRARY_TYPE##__, \
__reg_op_kernel_##op_type##_##library_type##__, \
"REGISTER_OP_KERNEL must be called in global namespace"); \
static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \
__op_kernel_registrar_##op_type##_##LIBRARY_TYPE##__(#op_type, \
#LIBRARY_TYPE); \
int TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE() { \
__op_kernel_registrar_##op_type##_##LIBRARY_TYPE##__.Touch(); \
__op_kernel_registrar_##op_type##_##library_type##__(#op_type, \
#library_type); \
int TouchOpKernelRegistrar_##op_type##_##library_type() { \
__op_kernel_registrar_##op_type##_##library_type##__.Touch(); \
return 0; \
}
......
......@@ -166,8 +166,6 @@ function(op_library TARGET)
# NOTE(*): activation use macro to regist the kernels, set use_op manually.
if(${TARGET} STREQUAL "activation")
file(APPEND ${pybind_file} "USE_OP(relu);\n")
elseif(${TARGET} STREQUAL "reduce")
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n")
elseif(${TARGET} STREQUAL "fake_dequantize")
file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
else()
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_min_max_op.h"
REGISTER_REDUCE_OP(reduce_max);
REGISTER_OP_CPU_KERNEL(
reduce_max, ops::ReduceKernel<paddle::platform::CPUDeviceContext, float,
ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, double,
ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int, ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::MaxFunctor>);
REGISTER_OP_CPU_KERNEL(
reduce_max_grad, ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
float, ops::MaxOrMinGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, double,
ops::MaxOrMinGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, int,
ops::MaxOrMinGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::MaxOrMinGradFunctor>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_min_max_op.h"
REGISTER_OP_CUDA_KERNEL(reduce_max,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
float, ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
double, ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int, ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int64_t, ops::MaxFunctor>);
REGISTER_OP_CUDA_KERNEL(
reduce_max_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
float, ops::MaxOrMinGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, double,
ops::MaxOrMinGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int,
ops::MaxOrMinGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int64_t,
ops::MaxOrMinGradFunctor>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_mean_op.h"
REGISTER_REDUCE_OP(reduce_mean);
REGISTER_OP_CPU_KERNEL(reduce_mean,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
float, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
double, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
int, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MeanFunctor>);
REGISTER_OP_CPU_KERNEL(reduce_mean_grad,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
float, ops::MeanGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
double, ops::MeanGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
int, ops::MeanGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MeanGradFunctor>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_mean_op.h"
REGISTER_OP_CUDA_KERNEL(reduce_mean,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
float, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
double, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int64_t, ops::MeanFunctor>);
REGISTER_OP_CUDA_KERNEL(
reduce_mean_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
float, ops::MeanGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, double,
ops::MeanGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int,
ops::MeanGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int64_t,
ops::MeanGradFunctor>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/reduce_op.h"
namespace paddle {
namespace operators {
struct MeanFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->mean(dim);
}
};
struct MeanGradFunctor {
template <typename DeviceContext, typename X, typename Y, typename DX,
typename DY, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
const Dim& dim, int size) {
dx->device(place) = dy->broadcast(dim) / dx->constant(size);
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/reduce_op.h"
namespace paddle {
namespace operators {
struct MaxFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->maximum(dim);
}
};
struct MinFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->minimum(dim);
}
};
struct MaxOrMinGradFunctor {
template <typename DeviceContext, typename X, typename Y, typename DX,
typename DY, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
const Dim& dim, int size) {
auto equals = (*x) == y->broadcast(dim);
auto ones = dx->constant(1);
auto zeros = dx->constant(0);
// If there are multiple minimum or maximum elements, the subgradient of
// each is the set [0, 1], and we pass gradient to all of them here.
dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros);
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_min_max_op.h"
REGISTER_REDUCE_OP(reduce_min);
REGISTER_OP_CPU_KERNEL(
reduce_min, ops::ReduceKernel<paddle::platform::CPUDeviceContext, float,
ops::MinFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, double,
ops::MinFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int, ops::MinFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::MinFunctor>);
REGISTER_OP_CPU_KERNEL(
reduce_min_grad, ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
float, ops::MaxOrMinGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, double,
ops::MaxOrMinGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, int,
ops::MaxOrMinGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::MaxOrMinGradFunctor>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_min_max_op.h"
REGISTER_OP_CUDA_KERNEL(reduce_min,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
float, ops::MinFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
double, ops::MinFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int, ops::MinFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int64_t, ops::MinFunctor>);
REGISTER_OP_CUDA_KERNEL(
reduce_min_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
float, ops::MaxOrMinGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, double,
ops::MaxOrMinGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int,
ops::MaxOrMinGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int64_t,
ops::MaxOrMinGradFunctor>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reduce_op.h"
#include <algorithm>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
using framework::Tensor;
class ReduceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ReduceOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ReduceOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) dims[i] = x_rank + dims[i];
PADDLE_ENFORCE_LT(
dims[i], x_rank,
"The dim should be in the range [-rank(input), rank(input)).");
}
sort(dims.begin(), dims.end());
bool reduce_all = ctx->Attrs().Get<bool>("reduce_all");
bool keep_dim = ctx->Attrs().Get<bool>("keep_dim");
if (reduce_all) {
if (keep_dim)
ctx->SetOutputDim(
"Out", framework::make_ddim(std::vector<int64_t>(x_rank, 1)));
else
ctx->SetOutputDim("Out", {1});
} else {
auto dims_vector = vectorize(x_dims);
if (keep_dim) {
for (size_t i = 0; i < dims.size(); ++i) {
dims_vector[dims[i]] = 1;
}
} else {
const int kDelFlag = -2;
for (size_t i = 0; i < dims.size(); ++i) {
dims_vector[dims[i]] = kDelFlag;
}
dims_vector.erase(
remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
}
auto out_dims = framework::make_ddim(dims_vector);
ctx->SetOutputDim("Out", out_dims);
if (dims[0] != 0) {
// Only pass LoD when not reducing on the first dim.
ctx->ShareLoD("X", /*->*/ "Out");
}
}
}
};
class ReduceGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) dims[i] = x_rank + dims[i];
PADDLE_ENFORCE_LT(
dims[i], x_rank,
"The dim should be in the range [-rank(input), rank(input)).");
}
sort(dims.begin(), dims.end());
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
}
};
class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() final {
AddInput("X",
"(Tensor) The input tensor. Tensors with rank at most 6 are "
"supported.");
AddOutput("Out", "(Tensor) The result tensor.");
AddAttr<std::vector<int>>(
"dim",
"(list<int>, default {0}) The dimensions to reduce. "
"Must be in the range [-rank(input), rank(input)). "
"If `dim[i] < 0`, the dims[i] to reduce is `rank + dims[i]`. "
"Note that reducing on the first dim will make the LoD info lost.")
.SetDefault({0});
AddAttr<bool>("keep_dim",
"(bool, default false) "
"If true, retain the reduced dimension with length 1.")
.SetDefault(false);
AddAttr<bool>("reduce_all",
"(bool, default false) "
"If true, output a scalar reduced along all dimensions.")
.SetDefault(false);
AddComment(string::Sprintf(R"DOC(
%s Operator.
This operator computes the %s of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless keep_dim is true.
If reduce_all is true, just reduce along all dimensions and output a scalar.
)DOC",
GetOpType(), GetName()));
}
protected:
virtual std::string GetName() const = 0;
virtual std::string GetOpType() const = 0;
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
#define REGISTER_REDUCE_OP(op_name) \
class __##op_name##Maker__ : public ops::ReduceOpMaker { \
protected: \
virtual std::string GetName() const { return #op_name; } \
virtual std::string GetOpType() const { return "Reduce " #op_name; } \
}; \
REGISTER_OPERATOR(reduce_##op_name, ops::ReduceOp, __##op_name##Maker__, \
paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(reduce_##op_name##_grad, ops::ReduceGradOp)
REGISTER_REDUCE_OP(sum);
REGISTER_REDUCE_OP(mean);
REGISTER_REDUCE_OP(max);
REGISTER_REDUCE_OP(min);
REGISTER_REDUCE_OP(prod);
#define REGISTER_REDUCE_CPU_KERNEL(reduce_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL(reduce_type, \
ops::ReduceKernel<paddle::platform::CPUDeviceContext, \
float, ops::functor>, \
ops::ReduceKernel<paddle::platform::CPUDeviceContext, \
double, ops::functor>, \
ops::ReduceKernel<paddle::platform::CPUDeviceContext, \
int, ops::functor>, \
ops::ReduceKernel<paddle::platform::CPUDeviceContext, \
int64_t, ops::functor>); \
REGISTER_OP_CPU_KERNEL( \
reduce_type##_grad, \
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, float, \
ops::grad_functor>, \
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, double, \
ops::grad_functor>, \
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, int, \
ops::grad_functor>, \
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, int64_t, \
ops::grad_functor>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_REDUCE_CPU_KERNEL);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/reduce_op.h"
namespace ops = paddle::operators;
#define REGISTER_REDUCE_GPU_KERNEL(reduce_type, functor, grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
reduce_type, ops::ReduceKernel<paddle::platform::CUDADeviceContext, \
float, ops::functor>, \
ops::ReduceKernel<paddle::platform::CUDADeviceContext, double, \
ops::functor>, \
ops::ReduceKernel<paddle::platform::CUDADeviceContext, int, \
ops::functor>, \
ops::ReduceKernel<paddle::platform::CUDADeviceContext, int64_t, \
ops::functor>); \
REGISTER_OP_CUDA_KERNEL( \
reduce_type##_grad, \
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, float, \
ops::grad_functor>, \
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, double, \
ops::grad_functor>, \
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int, \
ops::grad_functor>, \
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int64_t, \
ops::grad_functor>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_REDUCE_GPU_KERNEL);
......@@ -14,105 +14,20 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/reduce_op_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
struct SumFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->sum(dim);
}
};
struct SumGradFunctor {
template <typename DeviceContext, typename X, typename Y, typename DX,
typename DY, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
const Dim& dim, int size) {
dx->device(place) = dy->broadcast(dim);
}
};
struct MeanFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->mean(dim);
}
};
struct MeanGradFunctor {
template <typename DeviceContext, typename X, typename Y, typename DX,
typename DY, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
const Dim& dim, int size) {
dx->device(place) = dy->broadcast(dim) / dx->constant(size);
}
};
struct MaxFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->maximum(dim);
}
};
struct MinFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->minimum(dim);
}
};
struct MaxOrMinGradFunctor {
template <typename DeviceContext, typename X, typename Y, typename DX,
typename DY, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
const Dim& dim, int size) {
auto equals = (*x) == y->broadcast(dim);
auto ones = dx->constant(1);
auto zeros = dx->constant(0);
// If there are multiple minimum or maximum elements, the subgradient of
// each is the set [0, 1], and we pass gradient to all of them here.
dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros);
}
};
struct ProdFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->prod(dim);
}
};
struct ProdGradFunctor {
template <typename DeviceContext, typename X, typename Y, typename DX,
typename DY, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
const Dim& dim, int size) {
dx->device(place) = dy->broadcast(dim) * y->broadcast(dim) * x->inverse();
}
};
#define HANDLE_DIM(NDIM, RDIM) \
if (ndim == NDIM && rdim == RDIM) { \
ReduceCompute<NDIM, RDIM>(context); \
#define HANDLE_DIM(NDIM, RDIM) \
if (ndim == NDIM && rdim == RDIM) { \
ReduceFunctor<DeviceContext, T, NDIM, RDIM, Functor>( \
context.template device_context<DeviceContext>(), *input, output, \
dims, keep_dim); \
}
template <typename DeviceContext, typename T, typename Functor>
......@@ -120,11 +35,15 @@ class ReduceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
bool reduce_all = context.Attr<bool>("reduce_all");
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
auto dims = context.Attr<std::vector<int>>("dim");
bool keep_dim = context.Attr<bool>("keep_dim");
if (reduce_all) {
// Flatten and reduce 1-D tensor
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
auto x = EigenVector<T>::Flatten(*input);
auto out = EigenScalar<T>::From(*output);
auto& place =
......@@ -133,8 +52,8 @@ class ReduceKernel : public framework::OpKernel<T> {
Functor functor;
functor(place, &x, &out, reduce_dim);
} else {
int ndim = context.Input<Tensor>("X")->dims().size();
int rdim = context.Attr<std::vector<int>>("dim").size();
int ndim = input->dims().size();
int rdim = dims.size();
// comments for accelerating compiling temporarily.
// HANDLE_DIM(6, 5);
// HANDLE_DIM(6, 4);
......@@ -154,48 +73,6 @@ class ReduceKernel : public framework::OpKernel<T> {
HANDLE_DIM(1, 1);
}
}
private:
template <size_t D, size_t R_D>
void ReduceCompute(const framework::ExecutionContext& context) const {
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
auto x = EigenTensor<T, D>::From(*input);
auto x_rank = static_cast<int>(x.dimensions().size());
auto dims = context.Attr<std::vector<int>>("dim");
auto reduce_dim = Eigen::array<int, R_D>();
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) dims[i] = x_rank + dims[i];
reduce_dim[i] = dims[i];
}
// construct the squeezed output tensor
bool keep_dim = context.Attr<bool>("keep_dim");
DDim out_dims = output->dims();
if (keep_dim && x_rank > 1) {
const int kDelFlag = -2;
auto dims_vector = vectorize(out_dims);
for (size_t i = 0; i < dims.size(); ++i) {
dims_vector[dims[i]] = kDelFlag;
}
dims_vector.erase(
remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
out_dims = framework::make_ddim(dims_vector);
}
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
Functor functor;
if (D == 1) {
auto out = EigenScalar<T>::From(*output);
functor(place, &x, &out, reduce_dim);
} else {
auto out = EigenTensor<T, (D - R_D)>::From(*output, out_dims);
functor(place, &x, &out, reduce_dim);
}
}
};
template <typename DeviceContext, typename T, typename Functor>
......@@ -203,12 +80,15 @@ class ReduceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
bool reduce_all = context.Attr<bool>("reduce_all");
auto dims = context.Attr<std::vector<int>>("dim");
auto* input0 = context.Input<Tensor>("X");
auto* input1 = context.Input<Tensor>("Out");
auto* input2 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* output = context.Output<Tensor>(framework::GradVarName("X"));
output->mutable_data<T>(context.GetPlace());
if (reduce_all) {
auto* input0 = context.Input<Tensor>("X");
auto* input1 = context.Input<Tensor>("Out");
auto* input2 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* output = context.Output<Tensor>(framework::GradVarName("X"));
output->mutable_data<T>(context.GetPlace());
auto x = EigenVector<T>::Flatten(*input0);
auto x_reduce = EigenVector<T>::From(*input1);
auto x_reduce_grad = EigenVector<T>::From(*input2);
......@@ -221,74 +101,172 @@ class ReduceGradKernel : public framework::OpKernel<T> {
functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim,
broadcast_dim[0]);
} else {
int rank = context.Input<Tensor>("X")->dims().size();
int rank = input0->dims().size();
switch (rank) {
case 1:
ReduceGradCompute<1>(context);
ReduceGradFunctor<DeviceContext, T, 1, Functor>(
context.template device_context<DeviceContext>(), *input0,
*input1, *input2, output, dims);
break;
case 2:
ReduceGradCompute<2>(context);
ReduceGradFunctor<DeviceContext, T, 2, Functor>(
context.template device_context<DeviceContext>(), *input0,
*input1, *input2, output, dims);
break;
case 3:
ReduceGradCompute<3>(context);
ReduceGradFunctor<DeviceContext, T, 3, Functor>(
context.template device_context<DeviceContext>(), *input0,
*input1, *input2, output, dims);
break;
case 4:
ReduceGradCompute<4>(context);
ReduceGradFunctor<DeviceContext, T, 4, Functor>(
context.template device_context<DeviceContext>(), *input0,
*input1, *input2, output, dims);
break;
case 5:
ReduceGradCompute<5>(context);
ReduceGradFunctor<DeviceContext, T, 5, Functor>(
context.template device_context<DeviceContext>(), *input0,
*input1, *input2, output, dims);
break;
case 6:
ReduceGradCompute<6>(context);
ReduceGradFunctor<DeviceContext, T, 6, Functor>(
context.template device_context<DeviceContext>(), *input0,
*input1, *input2, output, dims);
break;
}
}
}
};
private:
template <size_t D>
void ReduceGradCompute(const framework::ExecutionContext& context) const {
auto* input0 = context.Input<Tensor>("X");
auto* input1 = context.Input<Tensor>("Out");
auto* input2 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* output = context.Output<Tensor>(framework::GradVarName("X"));
class ReduceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
output->mutable_data<T>(context.GetPlace());
auto x = EigenTensor<T, D>::From(*input0);
auto x_grad = EigenTensor<T, D>::From(*output);
auto x_rank = static_cast<int>(x.dimensions().size());
auto dims = context.Attr<std::vector<int>>("dim");
auto x_dims = input0->dims();
auto reduced_dims_v = vectorize(x_dims);
Eigen::array<int, D> broadcast_dim;
for (size_t i = 0; i < D; ++i) broadcast_dim[i] = 1;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ReduceOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ReduceOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) dims[i] = x_rank + dims[i];
PADDLE_ENFORCE_LT(
dims[i], x_rank,
"The dim should be in the range [-rank(input), rank(input)).");
}
sort(dims.begin(), dims.end());
bool reduce_all = ctx->Attrs().Get<bool>("reduce_all");
bool keep_dim = ctx->Attrs().Get<bool>("keep_dim");
if (reduce_all) {
if (keep_dim)
ctx->SetOutputDim(
"Out", framework::make_ddim(std::vector<int64_t>(x_rank, 1)));
else
ctx->SetOutputDim("Out", {1});
} else {
auto dims_vector = vectorize(x_dims);
if (keep_dim) {
for (size_t i = 0; i < dims.size(); ++i) {
dims_vector[dims[i]] = 1;
}
} else {
const int kDelFlag = -2;
for (size_t i = 0; i < dims.size(); ++i) {
dims_vector[dims[i]] = kDelFlag;
}
dims_vector.erase(
remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
}
auto out_dims = framework::make_ddim(dims_vector);
ctx->SetOutputDim("Out", out_dims);
if (dims[0] != 0) {
// Only pass LoD when not reducing on the first dim.
ctx->ShareLoD("X", /*->*/ "Out");
}
}
}
};
class ReduceGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
int broad_cats_times = 1;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) dims[i] = x_rank + dims[i];
reduced_dims_v[dims[i]] = 1;
broadcast_dim[dims[i]] = x_dims[dims[i]];
broad_cats_times *= x_dims[dims[i]];
PADDLE_ENFORCE_LT(
dims[i], x_rank,
"The dim should be in the range [-rank(input), rank(input)).");
}
sort(dims.begin(), dims.end());
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
auto reduced_dims = framework::make_ddim(reduced_dims_v);
auto x_reduce = EigenTensor<T, D>::From(*input1, reduced_dims);
auto x_reduce_grad = EigenTensor<T, D>::From(*input2, reduced_dims);
}
};
class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() final {
AddInput("X",
"(Tensor) The input tensor. Tensors with rank at most 6 are "
"supported.");
AddOutput("Out", "(Tensor) The result tensor.");
AddAttr<std::vector<int>>(
"dim",
"(list<int>, default {0}) The dimensions to reduce. "
"Must be in the range [-rank(input), rank(input)). "
"If `dim[i] < 0`, the dims[i] to reduce is `rank + dims[i]`. "
"Note that reducing on the first dim will make the LoD info lost.")
.SetDefault({0});
AddAttr<bool>("keep_dim",
"(bool, default false) "
"If true, retain the reduced dimension with length 1.")
.SetDefault(false);
AddAttr<bool>("reduce_all",
"(bool, default false) "
"If true, output a scalar reduced along all dimensions.")
.SetDefault(false);
AddComment(string::Sprintf(R"DOC(
%s Operator.
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
This operator computes the %s of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless keep_dim is true.
If reduce_all is true, just reduce along all dimensions and output a scalar.
Functor functor;
functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim,
broad_cats_times);
)DOC",
GetOpType(), GetName()));
}
protected:
virtual std::string GetName() const = 0;
virtual std::string GetOpType() const = 0;
};
} // namespace operators
} // namespace paddle
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(reduce_sum, SumFunctor, SumGradFunctor); \
__macro(reduce_mean, MeanFunctor, MeanGradFunctor); \
__macro(reduce_max, MaxFunctor, MaxOrMinGradFunctor); \
__macro(reduce_min, MinFunctor, MaxOrMinGradFunctor); \
__macro(reduce_prod, ProdFunctor, ProdGradFunctor);
namespace ops = paddle::operators;
#define REGISTER_REDUCE_OP(op_name) \
class __##op_name##Maker__ : public ops::ReduceOpMaker { \
protected: \
virtual std::string GetName() const { return #op_name; } \
virtual std::string GetOpType() const { return "Reduce " #op_name; } \
}; \
REGISTER_OPERATOR(op_name, ops::ReduceOp, __##op_name##Maker__, \
paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(op_name##_grad, ops::ReduceGradOp)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename DeviceContext, typename T, size_t D, size_t R_D,
typename Functor>
void ReduceFunctor(const DeviceContext& context, const framework::Tensor& input,
framework::Tensor* output, const std::vector<int>& dims,
bool keep_dim) {
auto x = EigenTensor<T, D>::From(input);
auto x_rank = static_cast<int>(x.dimensions().size());
auto reduce_dim = Eigen::array<int, R_D>();
std::vector<int> dims_ref = dims;
for (size_t i = 0; i < dims_ref.size(); ++i) {
if (dims_ref[i] < 0) dims_ref[i] = x_rank + dims_ref[i];
reduce_dim[i] = dims_ref[i];
}
// construct the squeezed output tensor
DDim out_dims = output->dims();
if (keep_dim && x_rank > 1) {
const int kDelFlag = -2;
auto dims_vector = framework::vectorize(out_dims);
for (size_t i = 0; i < dims_ref.size(); ++i) {
dims_vector[dims_ref[i]] = kDelFlag;
}
dims_vector.erase(remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
out_dims = framework::make_ddim(dims_vector);
}
auto& place = *context.eigen_device();
Functor functor;
if (D == 1) {
auto out = EigenScalar<T>::From(*output);
functor(place, &x, &out, reduce_dim);
} else {
auto out = EigenTensor<T, (D - R_D)>::From(*output, out_dims);
functor(place, &x, &out, reduce_dim);
}
}
template <typename DeviceContext, typename T, size_t D, typename Functor>
void ReduceGradFunctor(const DeviceContext& context,
const framework::Tensor& input0,
const framework::Tensor& input1,
const framework::Tensor& input2,
framework::Tensor* output,
const std::vector<int>& dims) {
auto x = EigenTensor<T, D>::From(input0);
auto x_grad = EigenTensor<T, D>::From(*output);
auto x_rank = static_cast<int>(x.dimensions().size());
auto x_dims = input0.dims();
auto reduced_dims_v = framework::vectorize(x_dims);
std::vector<int> dims_ref = dims;
Eigen::array<int, D> broadcast_dim;
for (size_t i = 0; i < D; ++i) broadcast_dim[i] = 1;
int broad_cats_times = 1;
for (size_t i = 0; i < dims_ref.size(); ++i) {
if (dims_ref[i] < 0) {
dims_ref[i] = x_rank + dims_ref[i];
}
reduced_dims_v[dims_ref[i]] = 1;
broadcast_dim[dims_ref[i]] = x_dims[dims_ref[i]];
broad_cats_times *= x_dims[dims_ref[i]];
}
auto reduced_dims = framework::make_ddim(reduced_dims_v);
auto x_reduce = EigenTensor<T, D>::From(input1, reduced_dims);
auto x_reduce_grad = EigenTensor<T, D>::From(input2, reduced_dims);
auto& place = *context.eigen_device();
Functor functor;
functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim,
broad_cats_times);
}
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_prod_op.h"
REGISTER_REDUCE_OP(reduce_prod);
REGISTER_OP_CPU_KERNEL(reduce_prod,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
float, ops::ProdFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
double, ops::ProdFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
int, ops::ProdFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::ProdFunctor>);
REGISTER_OP_CPU_KERNEL(reduce_prod_grad,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
float, ops::ProdGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
double, ops::ProdGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
int, ops::ProdGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::ProdGradFunctor>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_prod_op.h"
REGISTER_OP_CUDA_KERNEL(reduce_prod,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
float, ops::ProdFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
double, ops::ProdFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int, ops::ProdFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int64_t, ops::ProdFunctor>);
REGISTER_OP_CUDA_KERNEL(
reduce_prod_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
float, ops::ProdGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, double,
ops::ProdGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int,
ops::ProdGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int64_t,
ops::ProdGradFunctor>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/reduce_op.h"
namespace paddle {
namespace operators {
struct ProdFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->prod(dim);
}
};
struct ProdGradFunctor {
template <typename DeviceContext, typename X, typename Y, typename DX,
typename DY, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
const Dim& dim, int size) {
dx->device(place) = dy->broadcast(dim) * y->broadcast(dim) * x->inverse();
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_sum_op.h"
REGISTER_REDUCE_OP(reduce_sum);
REGISTER_OP_CPU_KERNEL(
reduce_sum, ops::ReduceKernel<paddle::platform::CPUDeviceContext, float,
ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, double,
ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::SumFunctor>);
REGISTER_OP_CPU_KERNEL(reduce_sum_grad,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
float, ops::SumGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
double, ops::SumGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
int, ops::SumGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::SumGradFunctor>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_sum_op.h"
REGISTER_OP_CUDA_KERNEL(reduce_sum,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
float, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
double, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int64_t, ops::SumFunctor>);
REGISTER_OP_CUDA_KERNEL(
reduce_sum_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
float, ops::SumGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, double,
ops::SumGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int,
ops::SumGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int64_t,
ops::SumGradFunctor>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/reduce_op.h"
namespace paddle {
namespace operators {
struct SumFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->sum(dim);
}
};
struct SumGradFunctor {
template <typename DeviceContext, typename X, typename Y, typename DX,
typename DY, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
const Dim& dim, int size) {
dx->device(place) = dy->broadcast(dim);
}
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册