未验证 提交 4f848aa9 编写于 作者: R RedContritio 提交者: GitHub

support auto generate for static op reduce_sum (#54304)

* remove reduce_sum_op.h

* support auto generate for static op reduce_sum

* remove reduce_sum_op in CMakeLists.txt
上级 d71baff6
......@@ -21,7 +21,6 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace framework {
class OpDesc;
} // namespace framework
namespace imperative {
class OpBase;
} // namespace imperative
} // namespace paddle
namespace paddle {
namespace operators {
// NOTE: Input(Out) is unnecessary in reduce_sum_grad, and Input(X) needs no
// buffer
template <typename T>
class ReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("reduce_sum_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
int in_dtype = ctx.Attr<int>("out_dtype");
if (in_dtype >= 0) {
return phi::KernelKey(
static_cast<framework::proto::VarType::Type>(in_dtype),
ctx.GetPlace());
}
return phi::KernelKey(framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
class ReduceSumCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
public:
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
void Apply() override {
// get inputs
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
// get attr
std::vector<int> axis = this->Attr<std::vector<int>>("dim");
bool keep_dim = this->Attr<bool>("keep_dim");
bool reduce_all = this->Attr<bool>("reduce_all");
// get output
paddle::Tensor x_grad_t = this->GetSingleInputGrad("X");
// get output ptr
paddle::Tensor* x_grad = this->GetOutputPtr(&x_grad_t);
// get output orginal name
std::string x_grad_name = this->GetOutputName(x_grad_t);
VLOG(6) << "Runing sum_grad composite func";
// call composite backward func
prim::sum_grad<prim::DescTensor>(
x, out_grad, axis, keep_dim, reduce_all, x_grad);
// recover output name
this->RecoverOutputName(x_grad_t, x_grad_name);
}
};
template <typename T>
class ReduceSumDoubleOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
op->SetAttrMap(this->Attrs());
op->SetType("reduce_sum");
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReduceSumGradNoNeedBufferVarInferer, "X");
class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference {
public:
void operator()(paddle::framework::InferVarTypeContext* ctx) const override {
auto data_type = static_cast<paddle::framework::proto::VarType::Type>(
PADDLE_GET_CONST(int, ctx->GetAttr("out_dtype")));
if (data_type >= 0) {
ctx->SetOutputDataType("Out", data_type);
} else {
auto x_type = ctx->GetInputDataType("X");
if (x_type == framework::proto::VarType::BOOL ||
x_type == framework::proto::VarType::INT32) {
ctx->SetOutputDataType("Out", framework::proto::VarType::INT64);
}
}
}
};
} // namespace operators
} // namespace paddle
class ReduceSumOpMaker : public ops::ReduceBaseOpMaker {
protected:
virtual std::string GetName() const { return "reduce_sum"; }
virtual std::string GetOpType() const { return "Reduce reduce_sum"; }
};
DECLARE_INFER_SHAPE_FUNCTOR(reduce_sum,
ReduceSumInferShapeFunctor,
PD_INFER_META(phi::SumRawInferMeta));
REGISTER_OPERATOR(reduce_sum,
ops::ReduceBaseOp,
ReduceSumOpMaker,
ops::ReduceSumVarTypeInference,
ops::ReduceSumOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceSumOpGradMaker<paddle::imperative::OpBase>,
ops::ReduceSumCompositeGradOpMaker,
ReduceSumInferShapeFunctor);
REGISTER_OPERATOR(reduce_sum_grad,
ops::ReduceGradOp,
ops::ReduceSumDoubleOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceSumDoubleOpGradMaker<paddle::imperative::OpBase>,
ops::ReduceSumGradNoNeedBufferVarInferer);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
namespace paddle {
namespace operators {
// use for loop to speed up Eigen broadcast. 4 timer faster then broadcast
template <typename DeviceContext,
typename T,
typename Functor,
bool kNoNeedBufferX = false>
class ReduceSumGradKernel : public framework::OpKernel<T> {
public:
void ComputeFromInput(const phi::DenseTensor* input2,
const framework::ExecutionContext& context) const {
auto dims = context.Attr<std::vector<int>>("dim");
auto* input0 = context.Input<phi::DenseTensor>("X");
auto* output =
context.Output<phi::DenseTensor>(framework::GradVarName("X"));
output->mutable_data<T>(context.GetPlace());
const auto* input2_d = input2->data<T>();
auto* output_d = output->data<T>();
// handle reduce_all
if (input2->dims().size() == 1 && input2->dims()[0] == 1) {
for (int64_t i = 0; i < phi::product(input0->dims()); ++i) {
output_d[i] = input2_d[0];
}
return;
}
// handle reduce by one dimension
int reduce_dim_index = dims[0];
if (reduce_dim_index < 0) {
reduce_dim_index += input0->dims().size();
}
auto& input_dim = input0->dims();
int64_t before_dim = 1;
for (int i = 0; i < reduce_dim_index; ++i) {
before_dim *= input_dim[i];
}
int64_t reduce_dim = input_dim[reduce_dim_index];
int64_t after_dim = 1;
for (int i = reduce_dim_index + 1; i < input_dim.size(); ++i) {
after_dim *= input_dim[i];
}
for (int64_t i = 0; i < before_dim; ++i) {
for (int64_t j = 0; j < reduce_dim; ++j) {
for (int64_t k = 0; k < after_dim; ++k) {
output_d[i * reduce_dim * after_dim + j * after_dim + k] =
input2_d[i * after_dim + k];
}
}
}
}
void Compute(const framework::ExecutionContext& context) const override {
auto dims = context.Attr<std::vector<int>>("dim");
if (context.GetPlace().GetType() == platform::CPUPlace().GetType() &&
dims.size() == 1) {
int in_dtype = context.Attr<int>("out_dtype");
if (in_dtype >= 0) {
phi::DenseTensor tmp_tensor;
auto* pre_input =
context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto in_kernel_type = phi::KernelKey(context.GetPlace(),
phi::DataLayout::ALL_LAYOUT,
pre_input->dtype());
auto out_kernel_type = phi::KernelKey(in_dtype, context.GetPlace());
framework::TransDataType(
in_kernel_type, out_kernel_type, *pre_input, &tmp_tensor);
ComputeFromInput(&tmp_tensor, context);
} else {
auto* input2 =
context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
ComputeFromInput(input2, context);
}
return;
}
// default use Eigen broadcast
ReduceGradKernel<DeviceContext, T, Functor, kNoNeedBufferX> kernel;
kernel.Compute(context);
}
};
struct SumFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->sum(dim);
}
};
struct SumGradFunctor {
template <typename DeviceContext,
typename X,
typename Y,
typename DX,
typename DY,
typename Dim>
void operator()(const DeviceContext& place,
X* x,
Y* y,
DX* dx,
DY* dy,
const Dim& dim,
int size) {
dx->device(place) = dy->broadcast(dim);
}
};
} // namespace operators
} // namespace paddle
......@@ -4,16 +4,9 @@
# Generally, the combination rules in this file do not need to be modified.
# If there are some redefined error in compiling with the source file which
# in combination rule, you can remove the source file from the following rules.
register_unity_group(cc reduce_all_op.cc reduce_any_op.cc reduce_prod_op.cc
reduce_sum_op.cc)
register_unity_group(
cu
reduce_all_op.cu
reduce_any_op.cu
reduce_prod_op.cu
reduce_prod_op.part.cu
reduce_sum_op.cu
reduce_sum_op.part.cu)
register_unity_group(cc reduce_all_op.cc reduce_any_op.cc reduce_prod_op.cc)
register_unity_group(cu reduce_all_op.cu reduce_any_op.cu reduce_prod_op.cu
reduce_prod_op.part.cu)
# The following groups are to make better use of `/MP` which MSVC's parallel
# compilation instruction when compiling in Unity Build.
register_unity_group(cu frobenius_norm_op.cu)
......
......@@ -2382,18 +2382,23 @@
bool use_quantizer = false, float Scale_x = 1.0f, float Scale_y = 1.0f, float Scale_out = 1.0f]
- op : sum (reduce_sum)
backward : (sum_grad) reduce_sum_grad
backward : sum_grad (reduce_sum_grad), sum_double_grad
inputs:
{x : X}
attrs:
{ axis : dim, keepdim : keep_dim, dtype : out_dtype}
outputs:
out : Out
attrs:
{ axis : dim, keepdim : keep_dim, dtype : out_dtype}
extra :
attrs : [bool use_mkldnn = false]
int_array:
axis :
data_type : int
extra :
attrs : [bool use_mkldnn = false]
support_tensor : true
get_expected_kernel_type :
sum : GetReduceExpectedKernelType
sum_grad : GetReduceGradExpectedKernelType
manual_signature : [sum]
- op : svd
backward : svd_grad
......
......@@ -132,6 +132,25 @@
data_type : out_grad
no_need_buffer : x
- backward_op : sum_double_grad
forward : sum_grad (Tensor x, Tensor grad_out, IntArray axis, bool keepdim, bool reduce_all=false) -> Tensor(grad_x)
args : (Tensor grad_x_grad, IntArray axis={}, bool keepdim=false, bool reduce_all=false)
output : Tensor(grad_out_grad)
invoke : sum(grad_x_grad, axis, keepdim, reduce_all)
- backward_op : sum_grad
forward : sum (Tensor x, IntArray axis={0}, bool keepdim=false, bool reduce_all=false, int in_dtype=-1, DataType out_dtype=DataType::UNDEFINED) -> Tensor(out)
args : (Tensor x, Tensor out_grad, IntArray axis, bool keepdim, bool reduce_all=false)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : sum_grad
composite : sum_grad(x, out_grad, axis, keepdim, reduce_all, x_grad)
no_need_buffer : x
backward : sum_double_grad
- backward_op : swish_grad
forward : swish (Tensor x, float beta = 1.0f) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......
......@@ -419,6 +419,18 @@
param : [x, axes, starts, ends, strides]
backward : strided_slice_grad
- op : sum
args : (Tensor x, IntArray axis={0}, bool keepdim=false, bool reduce_all=false, int in_dtype=-1, DataType out_dtype=DataType::UNDEFINED)
output : Tensor(out)
infer_meta :
func : SumRawInferMeta
param : [x, axis, keepdim, reduce_all, out_dtype]
kernel :
func : sum_raw
param : [x, axis, keepdim, reduce_all, out_dtype]
data_type : x
backward : sum_grad
- op : swish
args : (Tensor x, float beta = 1.0f)
output : Tensor(out)
......
......@@ -159,14 +159,6 @@ KernelSignature ReduceAllOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature ReduceSumGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("sum_grad",
{"X", "Out@GRAD"},
{"dim", "keep_dim", "reduce_all"},
{"X@GRAD"});
}
KernelSignature ReduceMeanGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("mean_grad",
......@@ -203,7 +195,6 @@ PD_REGISTER_BASE_KERNEL_NAME(reduce_prod, prod);
PD_REGISTER_BASE_KERNEL_NAME(reduce_all, all);
PD_REGISTER_BASE_KERNEL_NAME(reduce_any, any);
PD_REGISTER_BASE_KERNEL_NAME(reduce_sum_grad, sum_grad);
PD_REGISTER_BASE_KERNEL_NAME(reduce_mean_grad, mean_grad);
PD_REGISTER_BASE_KERNEL_NAME(reduce_prod_grad, prod_grad);
PD_REGISTER_BASE_KERNEL_NAME(reduce_min_grad, min_grad);
......@@ -218,8 +209,6 @@ PD_REGISTER_ARG_MAPPING_FN(reduce_amin, phi::ReduceAMinOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reduce_all, phi::ReduceAllOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reduce_any, phi::ReduceAnyOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reduce_sum_grad,
phi::ReduceSumGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reduce_mean_grad,
phi::ReduceMeanGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reduce_prod_grad,
......
......@@ -64,7 +64,7 @@ cc_test(
op_registry
variable_helper
mul_op
reduce_sum_op
generated_static_op
elementwise_add_op
memcpy)
cc_test(
......
......@@ -25,7 +25,6 @@ if(WITH_GPU
elementwise_mul_op
softmax_with_cross_entropy_op
reduce_mean_op
reduce_sum_op
activation_op
sum_op
elementwise_max_op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册