未验证 提交 1bc00955 编写于 作者: L lzydev 提交者: GitHub

Autogen segment_pool (#52538)

* autogen segment_pool

* delete legacy_dygraph about segment_pool
上级 0b89cb1d
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
class SegmentPoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
};
class SegmentPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) The input data of SegmentPoolOp");
AddInput("SegmentIds",
"(Tensor) 1-D tensor which have the same size with the fist "
"dimension of input X.");
AddOutput("Out", "(Tensor) The output of SegmentPoolOp.");
AddOutput("SummedIds",
"(Tensor) This tensor is used to counts of segment ids for the "
"backward of the mean pool.")
.AsIntermediate();
AddAttr<std::string>(
"pooltype",
"(string, default 'SUM') the pooling type of SegmentPoolOp.")
.SetDefault("SUM")
.InEnum({"SUM", "MEAN", "MIN", "MAX"});
AddComment(R"DOC(
Segment Pool Operator.
This operator will pool the elements of input `X` which with the same index
in `SegmentIds`.
For SUM operation, it computes a tensor such that $Out_i = \sum_{j} X_{j}$
where sum is over j such that `SegmentIds[j] == i`.
For MEAN operation, it computes a tensor such that
$Out_i = \frac{1}{n_i} \sum_{j} X_{j}$ where sum is over j such that
`SegmentIds[j] == i` and $n_i$ is the number of all index `SegmentIds[j] == i`.
For MIN operation, it computes a tensor such that $Out_i = \min_{j} X_{j}$
where min is over j such that `SegmentIds[j] == i`.
For MAX operation, it computes a tensor such that $Out_i = \max_{j} X_{j}$
where max is over j such that `SegmentIds[j] == i`.
)DOC");
}
};
class SegmentPoolGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
framework::GradVarName("Out"),
"SegmentPoolGrad");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SegmentPoolGrad");
auto og_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(og_dims.size(),
x_dims.size(),
platform::errors::InvalidArgument(
"The rank of output grad must equal to Input(X). But "
"received: input rank %u, input shape [%s].",
og_dims.size(),
og_dims));
for (int64_t i = 1; i < og_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(
og_dims[i],
x_dims[i],
platform::errors::InvalidArgument(
"The dimension mismatch between Input(OUT@GRAD) and "
"Input(X). Received Input(OUT@GRAD): input rank %u, "
"input shape [%s]; received Input(X): input rank %u, "
"input shape [%s].",
og_dims.size(),
og_dims,
x_dims.size(),
x_dims));
}
ctx->ShareDim("X", /*->*/ framework::GradVarName("X"));
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
template <typename T>
class SegmentPoolGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op_desc_ptr) const override {
op_desc_ptr->SetType("segment_pool_grad");
op_desc_ptr->SetInput("X", this->Input("X"));
op_desc_ptr->SetInput("SegmentIds", this->Input("SegmentIds"));
op_desc_ptr->SetInput("Out", this->Output("Out"));
if (PADDLE_GET_CONST(std::string, this->GetAttr("pooltype")) == "MEAN") {
op_desc_ptr->SetInput("SummedIds", this->Output("SummedIds"));
}
op_desc_ptr->SetInput(framework::GradVarName("Out"),
this->OutputGrad("Out"));
op_desc_ptr->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op_desc_ptr->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(segment_pool,
SegmentPoolInferShapeFunctor,
PD_INFER_META(phi::SegmentPoolInferMeta));
REGISTER_OPERATOR(segment_pool,
ops::SegmentPoolOp,
ops::SegmentPoolOpMaker,
ops::SegmentPoolGradOpMaker<paddle::framework::OpDesc>,
ops::SegmentPoolGradOpMaker<paddle::imperative::OpBase>,
SegmentPoolInferShapeFunctor);
REGISTER_OPERATOR(segment_pool_grad, ops::SegmentPoolGradOp);
...@@ -1404,6 +1404,18 @@ ...@@ -1404,6 +1404,18 @@
func : scatter_nd_add_grad func : scatter_nd_add_grad
no_need_buffer : updates no_need_buffer : updates
- backward_op : segment_pool_grad
forward : segment_pool (Tensor x, Tensor segment_ids, str pooltype="SUM") -> Tensor(out), Tensor(summed_ids)
args : (Tensor x, Tensor segment_ids, Tensor out, Tensor summed_ids, Tensor out_grad, str pooltype)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : segment_pool_grad
data_type : out_grad
optional : summed_ids
- backward_op : selu_grad - backward_op : selu_grad
forward : selu (Tensor x, float scale=1.0507009873554804934193349852946, float alpha=1.6732632423543772848170429916717) -> Tensor(out) forward : selu (Tensor x, float scale=1.0507009873554804934193349852946, float alpha=1.6732632423543772848170429916717) -> Tensor(out)
args : (Tensor out, Tensor out_grad, float scale, float alpha) args : (Tensor out, Tensor out_grad, float scale, float alpha)
......
...@@ -938,18 +938,6 @@ ...@@ -938,18 +938,6 @@
func : rrelu_grad func : rrelu_grad
data_type : x data_type : x
- backward_op : segment_pool_grad
forward : segment_pool (Tensor x, Tensor segment_ids, str pooltype) -> Tensor(out), Tensor(summed_ids)
args : (Tensor x, Tensor segment_ids, Tensor out, Tensor summed_ids, Tensor out_grad, str pooltype)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : segment_pool_grad
data_type : x
optional : summed_ids
- backward_op : slice_double_grad - backward_op : slice_double_grad
forward : slice_grad (Tensor input, Tensor grad_out, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) -> Tensor(grad_input) forward : slice_grad (Tensor input, Tensor grad_out, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) -> Tensor(grad_input)
args : (Tensor grad_input_grad, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) args : (Tensor grad_input_grad, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis)
......
...@@ -1226,16 +1226,6 @@ ...@@ -1226,16 +1226,6 @@
intermediate : noise intermediate : noise
backward : rrelu_grad backward : rrelu_grad
- op : segment_pool
args : (Tensor x, Tensor segment_ids, str pooltype)
output : Tensor(out), Tensor(summed_ids)
infer_meta :
func : SegmentPoolInferMeta
kernel :
func : segment_pool
data_type : x
backward : segment_pool_grad
- op : shape - op : shape
args : (Tensor input) args : (Tensor input)
output : Tensor(out) output : Tensor(out)
......
...@@ -1805,6 +1805,13 @@ ...@@ -1805,6 +1805,13 @@
extra : extra :
attrs : [bool deterministic = false, str rng_name = "", bool force_cpu = false] attrs : [bool deterministic = false, str rng_name = "", bool force_cpu = false]
- op : segment_pool
backward : segment_pool_grad
inputs :
{x : X, segment_ids : SegmentIds}
outputs :
{out : Out, summed_ids : SummedIds}
- op : selu - op : selu
backward : selu_grad backward : selu_grad
inputs : inputs :
......
...@@ -1485,6 +1485,17 @@ ...@@ -1485,6 +1485,17 @@
func : searchsorted func : searchsorted
data_type : sorted_sequence data_type : sorted_sequence
- op : segment_pool
args : (Tensor x, Tensor segment_ids, str pooltype="SUM")
output : Tensor(out), Tensor(summed_ids)
infer_meta :
func : SegmentPoolInferMeta
kernel :
func : segment_pool
data_type : x
intermediate : summed_ids
backward : segment_pool_grad
- op : selu - op : selu
args : (Tensor x, float scale=1.0507009873554804934193349852946, float alpha=1.6732632423543772848170429916717) args : (Tensor x, float scale=1.0507009873554804934193349852946, float alpha=1.6732632423543772848170429916717)
output : Tensor output : Tensor
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SegmentPoolGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("segment_pool_grad",
{
"X",
"SegmentIds",
"Out",
"SummedIds",
"Out@GRAD",
},
{"pooltype"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(segment_pool_grad,
phi::SegmentPoolGradOpArgumentMapping);
...@@ -51,7 +51,7 @@ def segment_sum(data, segment_ids, name=None): ...@@ -51,7 +51,7 @@ def segment_sum(data, segment_ids, name=None):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.segment_pool(data, segment_ids, "SUM")[0] return _C_ops.segment_pool(data, segment_ids, "SUM")
else: else:
check_variable_and_dtype( check_variable_and_dtype(
data, data,
...@@ -108,7 +108,7 @@ def segment_mean(data, segment_ids, name=None): ...@@ -108,7 +108,7 @@ def segment_mean(data, segment_ids, name=None):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.segment_pool(data, segment_ids, "MEAN")[0] return _C_ops.segment_pool(data, segment_ids, "MEAN")
else: else:
check_variable_and_dtype( check_variable_and_dtype(
...@@ -165,7 +165,7 @@ def segment_min(data, segment_ids, name=None): ...@@ -165,7 +165,7 @@ def segment_min(data, segment_ids, name=None):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.segment_pool(data, segment_ids, "MIN")[0] return _C_ops.segment_pool(data, segment_ids, "MIN")
else: else:
check_variable_and_dtype( check_variable_and_dtype(
data, data,
...@@ -221,7 +221,7 @@ def segment_max(data, segment_ids, name=None): ...@@ -221,7 +221,7 @@ def segment_max(data, segment_ids, name=None):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.segment_pool(data, segment_ids, "MAX")[0] return _C_ops.segment_pool(data, segment_ids, "MAX")
else: else:
check_variable_and_dtype( check_variable_and_dtype(
data, data,
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops
from paddle.fluid.data_feeder import check_variable_and_dtype from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper, _non_static_mode from paddle.fluid.layer_helper import LayerHelper
from paddle.utils import deprecated from paddle.utils import deprecated
__all__ = [] __all__ = []
...@@ -64,7 +64,7 @@ def segment_sum(data, segment_ids, name=None): ...@@ -64,7 +64,7 @@ def segment_sum(data, segment_ids, name=None):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.segment_pool(data, segment_ids, "SUM")[0] return _C_ops.segment_pool(data, segment_ids, "SUM")
else: else:
check_variable_and_dtype( check_variable_and_dtype(
data, "X", ("float32", "float64", "int32", "int64"), "segment_pool" data, "X", ("float32", "float64", "int32", "int64"), "segment_pool"
...@@ -130,12 +130,7 @@ def segment_mean(data, segment_ids, name=None): ...@@ -130,12 +130,7 @@ def segment_mean(data, segment_ids, name=None):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.segment_pool(data, segment_ids, "MEAN")[0] return _C_ops.segment_pool(data, segment_ids, "MEAN")
if _non_static_mode():
out, tmp = _legacy_C_ops.segment_pool(
data, segment_ids, 'pooltype', "MEAN"
)
return out
check_variable_and_dtype( check_variable_and_dtype(
data, "X", ("float32", "float64", "int32", "int64"), "segment_pool" data, "X", ("float32", "float64", "int32", "int64"), "segment_pool"
...@@ -200,13 +195,7 @@ def segment_min(data, segment_ids, name=None): ...@@ -200,13 +195,7 @@ def segment_min(data, segment_ids, name=None):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.segment_pool(data, segment_ids, "MIN")[0] return _C_ops.segment_pool(data, segment_ids, "MIN")
if _non_static_mode():
out, tmp = _legacy_C_ops.segment_pool(
data, segment_ids, 'pooltype', "MIN"
)
return out
check_variable_and_dtype( check_variable_and_dtype(
data, "X", ("float32", "float64", "int32", "int64"), "segment_pool" data, "X", ("float32", "float64", "int32", "int64"), "segment_pool"
...@@ -271,13 +260,7 @@ def segment_max(data, segment_ids, name=None): ...@@ -271,13 +260,7 @@ def segment_max(data, segment_ids, name=None):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, "MAX") out = _C_ops.segment_pool(data, segment_ids, "MAX")
return out
if _non_static_mode():
out, tmp = _legacy_C_ops.segment_pool(
data, segment_ids, 'pooltype', "MAX"
)
return out return out
check_variable_and_dtype( check_variable_and_dtype(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册