未验证 提交 ffeac6d5 编写于 作者: C cyberslack_lee 提交者: GitHub

support auto generation for gather (#54084)

上级 b547c4ac
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
class GatherOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context().GetPlace());
}
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override {
if (var_name == "Axis") {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
};
class GatherGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context().GetPlace());
}
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override {
if (var_name == "Axis") {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
};
class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The source input of gather op");
AddInput("Index", "The index input of gather op");
AddInput("Axis",
"The Tensor which contains the axis that we do gather operation.")
.AsDispensable();
AddOutput("Out", "The output of gather op");
AddAttr<int>(
"axis",
"The Tensor which contains the axis that we do gather operation.")
.SetDefault(0);
AddComment(R"DOC(
Gather Operator.
$Out = X[Index]$
Out is obtained by gathering entries of the outer-most dimension
of X indexed by Index and concatenate them together.
Example:
X = [[1, 2],
[3, 4],
[5, 6]]
Index = [[1, 2]]
Then:
Out = [[3, 4],
[5, 6]]
)DOC");
}
};
template <typename T>
class GatherGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("gather_grad");
op->SetInput("Index", this->Input("Index"));
op->SetInput("Axis", this->Input("Axis"));
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
class GatherCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
public:
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
protected:
void Apply() override {
paddle::Tensor index = this->GetSingleForwardInput("Index");
paddle::optional<paddle::Tensor> tensor_axis =
this->GetOptionalSingleForwardInput("Axis");
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor dout = this->GetSingleOutputGrad("Out");
paddle::Tensor dx = this->GetSingleInputGrad("X");
auto* dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(*dx_ptr);
int axis = static_cast<int>(this->Attr<int>("axis"));
VLOG(3) << "Runing gather_grad composite func";
if (tensor_axis.is_initialized()) {
PADDLE_THROW(platform::errors::Unimplemented(
"We don't support dynamic index from tensor for gather composite "
"grad for now. "));
} else {
prim::gather_grad<prim::DescTensor>(x, index, dout, axis, dx_ptr);
}
this->RecoverOutputName(dx, dx_name);
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherGradNoNeedBufferVarInferer, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(gather,
GatherInferShapeFunctor,
PD_INFER_META(phi::GatherInferMeta));
REGISTER_OPERATOR(gather,
ops::GatherOp,
ops::GatherOpMaker,
ops::GatherGradOpMaker<paddle::framework::OpDesc>,
ops::GatherGradOpMaker<paddle::imperative::OpBase>,
ops::GatherCompositeGradOpMaker,
GatherInferShapeFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(gather_grad,
GatherGradInferShapeFunctor,
PD_INFER_META(phi::GeneralUnaryGradInferMeta));
REGISTER_OPERATOR(gather_grad,
ops::GatherGradOp,
ops::GatherGradNoNeedBufferVarInferer,
GatherGradInferShapeFunctor);
REGISTER_OP_VERSION(gather).AddCheckpoint(
R"ROC(upgrad gather, add a new input [Axis])ROC",
paddle::framework::compatible::OpVersionDesc().NewInput(
"Axis", "Specify the axis of gather operation."));
......@@ -104,7 +104,6 @@ register_unity_group(
flatten_op.cc
fsp_op.cc
gather_nd_op.cc
gather_op.cc
gather_tree_op.cc
gaussian_random_batch_size_like_op.cc
mkldnn/gaussian_random_mkldnn_op.cc
......
......@@ -840,6 +840,19 @@
kernel :
func : frame_grad
- backward_op : gather_grad
forward : gather(Tensor x, Tensor index, Scalar axis=0) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad, Scalar axis=0)
output : Tensor(x_grad)
infer_meta :
func : GeneralUnaryGradInferMeta
param: [x]
kernel :
data_type: out_grad
func : gather_grad
composite : gather_grad(x, index, out_grad, axis, x_grad)
no_need_buffer : x
- backward_op : gather_nd_grad
forward : gather_nd (Tensor x, Tensor index) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad)
......
......@@ -302,19 +302,6 @@
kernel :
func : frobenius_norm_grad
- backward_op : gather_grad
forward : gather(Tensor x, Tensor index, Scalar axis=0) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad, Scalar axis=0)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
data_type: x
func : gather_grad
composite : gather_grad(x, index, out_grad, axis, x_grad)
no_need_buffer : x
- backward_op : hardswish_grad
forward : hardswish (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......
......@@ -422,16 +422,6 @@
optional : skip_update, master_params
inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out)
- op : gather
args : (Tensor x, Tensor index, Scalar(int) axis=0)
output : Tensor(out)
infer_meta :
func : GatherInferMeta
kernel :
func : gather
data_type: x
backward : gather_grad
- op : gaussian
args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={})
output: Tensor(out)
......
......@@ -1132,6 +1132,14 @@
- op : gather
backward : gather_grad
inputs :
{x : X, index : Index}
outputs :
out : Out
scalar :
axis :
data_type : int
tensor_name : Axis
- op : gather_nd
backward : gather_nd_grad
......
......@@ -197,6 +197,13 @@
- delete_attr : dims
comment : The attr 'dims' is deleted.
- op : gather
version :
- checkpoint : Upgrade gather, add a new input [Axis]
action :
- add_input : Axis
comment : Specify the axis of gather operation.
- op : gaussian_random
version :
- checkpoint : Upgrade gaussian_random add new inputs [ShapeTensor] and [ShapeTensorList]
......
......@@ -940,6 +940,16 @@
data_type : dtype
backend : place
- op : gather
args : (Tensor x, Tensor index, Scalar axis=0)
output : Tensor(out)
infer_meta :
func : GatherInferMeta
kernel :
func : gather
data_type: x
backward : gather_grad
- op : gather_nd
args : (Tensor x, Tensor index)
output : Tensor
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature GatherOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Axis")) {
return KernelSignature("gather", {"X", "Index"}, {"Axis"}, {"Out"});
} else {
return KernelSignature("gather", {"X", "Index"}, {"axis"}, {"Out"});
}
}
KernelSignature GatherGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Axis")) {
return KernelSignature(
"gather_grad", {"X", "Index", "Out@GRAD"}, {"Axis"}, {"X@GRAD"});
} else {
return KernelSignature(
"gather_grad", {"X", "Index", "Out@GRAD"}, {"axis"}, {"X@GRAD"});
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(gather, phi::GatherOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(gather_grad, phi::GatherGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册