未验证 提交 97aa938f 编写于 作者: HappyHeavyRain's avatar HappyHeavyRain 提交者: GitHub

Generate static graph code of some ops by yaml (#48698)

* generate static graph code of some ops by yaml, test = develop

* generate static graph code of some ops by yaml, test = develop
上级 61a1f688
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
class LU_UnpackOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddComment(R"DOC(Unpack L U and P to single matrix tensor,
unpack L and U matrix from LU, unpack permutation matrix Pmat from Pivtos .
)DOC");
AddInput("X", "(Tensor) The input LU tensor, shape of (*,m,n)");
AddInput("Pivots",
"(Tensor) The input Pivots tensor, shape of (*,min(m,n))");
AddOutput(
"Pmat",
"(Tensor) The output permutation matrix tensor, shape of (*, m, m)");
AddOutput("L", "(Tensor) The output lower triangular matrix tensor");
AddOutput("U", "(Tensor) The output upper triangular matrix tensor");
AddAttr<bool>("unpack_ludata", "Whether to unpack L and U")
.SetDefault(true);
AddAttr<bool>("unpack_pivots", "Whether to unpack permutation matrix")
.SetDefault(true);
}
};
class LU_UnpackOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class LU_UnpackOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto var_type = ctx->GetInputType("X", 0);
auto data_type = ctx->GetInputDataType("X", 0);
ctx->SetOutputType("L", var_type, framework::ALL_ELEMENTS);
ctx->SetOutputDataType("L", data_type, framework::ALL_ELEMENTS);
ctx->SetOutputType("U", var_type, framework::ALL_ELEMENTS);
ctx->SetOutputDataType("U", data_type, framework::ALL_ELEMENTS);
ctx->SetOutputType("Pmat", var_type, framework::ALL_ELEMENTS);
ctx->SetOutputDataType("Pmat", data_type, framework::ALL_ELEMENTS);
}
};
template <typename T>
class LU_UnpackOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("lu_unpack_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Pivots", this->Input("Pivots"));
retv->SetInput("L", this->Output("L"));
retv->SetInput("U", this->Output("U"));
retv->SetInput("Pmat", this->Output("Pmat"));
retv->SetInput(framework::GradVarName("L"), this->OutputGrad("L"));
retv->SetInput(framework::GradVarName("U"), this->OutputGrad("U"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
class LU_UnpackGradOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto var_type = ctx->GetInputType("X", 0);
auto data_type = ctx->GetInputDataType("X", 0);
ctx->SetOutputType(
framework::GradVarName("X"), var_type, framework::ALL_ELEMENTS);
ctx->SetOutputDataType(
framework::GradVarName("X"), data_type, framework::ALL_ELEMENTS);
}
};
class LU_UnpackGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(lu_unpack,
LUUnpackInferMetaFunctor,
PD_INFER_META(phi::LUUnpackInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(lu_unpack_grad,
LUUnpackGradInferMetaFunctor,
PD_INFER_META(phi::LUUnpackGradInferMeta));
REGISTER_OPERATOR(lu_unpack,
ops::LU_UnpackOp,
ops::LU_UnpackOpMaker,
ops::LU_UnpackOpVarTypeInference,
ops::LU_UnpackOpGradMaker<paddle::framework::OpDesc>,
ops::LU_UnpackOpGradMaker<paddle::imperative::OpBase>,
LUUnpackInferMetaFunctor);
REGISTER_OPERATOR(lu_unpack_grad,
ops::LU_UnpackGradOp,
ops::LU_UnpackGradOpVarTypeInference,
LUUnpackGradInferMetaFunctor);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class ModeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
phi::DataLayout layout_ = phi::DataLayout::kAnyLayout;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context(),
layout_,
library_);
}
};
class ModeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) The input of Mode op");
AddOutput("Out", "(Tensor) The output tensor of Topk op");
AddOutput("Indices", "(Tensor) The indices of Topk elements of input");
AddAttr<int>("axis",
"the axis to calculate mode values."
"if not set, will calculate on last axis.")
.SetDefault(-1);
AddAttr<bool>("keepdim", "Keep the dim that to reduce.").SetDefault(false);
AddComment(R"DOC(
This operator finds the mode of input Tensor. And outputs their values and indices as vectors.
)DOC");
}
};
class ModeOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"),
true,
platform::errors::InvalidArgument("Input(X) should be not null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Indices"),
true,
platform::errors::InvalidArgument("Input(Indices) should be not null"));
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")),
true,
platform::errors::InvalidArgument(
"Grad Input(Out) should be not null"));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")),
true,
platform::errors::InvalidArgument("Grad Output(X) should be not null"));
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
template <typename T>
class ModeGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("mode_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("X", this->Input("X"));
op->SetInput("Indices", this->Output("Indices"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(mode,
ModeInferShapeFunctor,
PD_INFER_META(phi::ModeInferMeta));
REGISTER_OPERATOR(mode,
ops::ModeOp,
ops::ModeOpMaker,
ops::ModeGradOpMaker<paddle::framework::OpDesc>,
ops::ModeGradOpMaker<paddle::imperative::OpBase>,
ModeInferShapeFunctor);
REGISTER_OPERATOR(mode_grad, ops::ModeOpGrad);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle {
namespace operators {
class NLLLossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class NLLLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>) A tensor whose last dimension "
"size is equal to the number of classes. It is expected to "
"contain log-probabilities of each class. "
"The X tensor's shape has to be either [batch_size, C] or"
"[batch_size, C, dim1, ..., dimK] in with K >= 1 in the case "
" K-dimensional loss.");
AddInput("Label",
"(Tensor, default Tensor<int64_t>) A tensor which represents the "
"the ground truth. It contains the class index in the range "
"[0, C-1] where C = number of classes. The Lable tensor's "
"shape has to be (batch_size), or "
"(batch_size, dim1, ..., dimK) "
"with K >= 1 in the case K-dimensional loss.");
AddInput("Weight",
"(Tensor, optional) A tensor should be a 1D tensor assigning "
"weight to each of the classes. It's shape must be [C], where "
"C is the class number.")
.AsDispensable();
AddOutput("Out",
"(Tensor, default Tensor<float>) A tensor that represents the "
"NLL loss.");
AddOutput("Total_weight",
"(Tensor, default Tensor<float>) A tensor saves the total"
"weight value in the forward process.");
AddAttr<int64_t>("ignore_index",
"(int64_t, default -100), Specifies a target value that is"
"ignored and does not contribute to the input gradient.")
.SetDefault(-100);
AddAttr<std::string>(
"reduction",
"(string, default mean), Specifies the reduction to apply"
"to the output. The options include \"none\", \"mean\","
"\"sum\".")
.SetDefault("mean");
AddComment(R"DOC(
NLL(Negative Log Likelihood) Loss Operator.
This operator computes the NLL loss according to the inputs.
The loss can be described as:
$Out[i] = -X[Label[i]]*Weight[Label[i]]$
It can also be used for higher dimension inputs, such as 2D images, by
providing an input of shape (batch_size, C, d1, d2, ..., dK), with
K >= 1, where K is the number of dimensions, and a Label of
appropriate shape. In the case of images, it computes NLL loss
per-pixel.
)DOC");
}
};
class NLLLossGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
template <typename T>
class NLLLossGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("nll_loss_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Label", this->Input("Label"));
op->SetInput("Total_weight", this->Output("Total_weight"));
if (this->HasInput("Weight")) {
op->SetInput("Weight", this->Input("Weight"));
}
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(nll_loss,
NllLossRawInferShapeFunctor,
PD_INFER_META(phi::NllLossRawInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(nll_loss_grad,
NllLossGradInferShapeFunctor,
PD_INFER_META(phi::NllLossGradInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(nll_loss,
ops::NLLLossOp,
ops::NLLLossOpMaker,
ops::NLLLossGradMaker<paddle::framework::OpDesc>,
ops::NLLLossGradMaker<paddle::imperative::OpBase>,
NllLossRawInferShapeFunctor);
REGISTER_OPERATOR(nll_loss_grad,
ops::NLLLossGradOp,
NllLossGradInferShapeFunctor);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
using DDim = framework::DDim;
class QrOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class QrOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of qr op.");
AddOutput("Q", "(Tensor), The output Q tensor of qr op.");
AddOutput("R", "(Tensor), The output R tensor of qr op.");
AddAttr<std::string>(
"mode",
"(string, default \"reduced\"). "
"If mode is \"reduced\", Qr op will return reduced Q and R matrices. "
"If mode is \"complete\", Qr op will return complete Q and R matrices. "
"If mode is \"r\", Qr op will only return reduced R matrix.")
.SetDefault("reduced");
AddComment(R"DOC(
Qr Operator.
This operator is used to perform QR operation for batched matrics $X$.
$$Q, R = qr(X)$$
)DOC");
}
};
class QrGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Q")),
"Input",
"Q@Grad",
"QrGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("R")),
"Input",
"R@Grad",
"QrGrad");
OP_INOUT_CHECK(ctx->HasInput("Q"), "Input", "Q", "QrGrad");
OP_INOUT_CHECK(ctx->HasInput("R"), "Input", "R", "QrGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")),
"Output",
"X@Grad",
"QrGrad");
auto x_dims = ctx->GetInputDim(("X"));
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};
template <typename T>
class QrGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("qr_grad");
retv->SetInput(framework::GradVarName("Q"), this->OutputGrad("Q"));
retv->SetInput(framework::GradVarName("R"), this->OutputGrad("R"));
retv->SetInput("Q", this->Output("Q"));
retv->SetInput("R", this->Output("R"));
retv->SetInput("X", this->Input("X"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(qr,
QrInferShapeFunctor,
PD_INFER_META(phi::QrInferMeta));
REGISTER_OPERATOR(qr,
ops::QrOp,
ops::QrOpMaker,
ops::QrGradMaker<paddle::framework::OpDesc>,
ops::QrGradMaker<paddle::imperative::OpBase>,
QrInferShapeFunctor);
REGISTER_OPERATOR(qr_grad, ops::QrGradOp);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class RenormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
using DDim = paddle::framework::DDim;
};
class RenormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of renorm op.");
AddOutput("Out", "(Tensor), The output tensor of renorm op.");
AddAttr<float>("p", "(float, norm's power");
AddAttr<int>("axis",
"int,the dimension to slice over to get the sub-tensors");
AddAttr<float>("max_norm", "(float, the norm upper-bound");
AddComment(R"DOC(
Renorm Operator.
This operator is used to scale tensor sliced by axis if its p-norm execeeds maxnorm
)DOC");
}
};
class RenormGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
template <typename T>
class RenormGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("renorm_grad");
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput("X", this->Input("X"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(renorm,
RenormInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(renorm_grad,
RenormGradInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(renorm,
ops::RenormOp,
ops::RenormOpMaker,
ops::RenormGradMaker<paddle::framework::OpDesc>,
ops::RenormGradMaker<paddle::imperative::OpBase>,
RenormInferShapeFunctor)
REGISTER_OPERATOR(renorm_grad, ops::RenormGradOp, RenormGradInferShapeFunctor);
......@@ -687,6 +687,15 @@
func : logsigmoid_grad
inplace : (out_grad -> x_grad)
- backward_op : lu_unpack_grad
forward : lu_unpack (Tensor x, Tensor y, bool unpack_ludata = true, bool unpack_pivots = true) -> Tensor(pmat), Tensor(l), Tensor(u)
args : (Tensor x, Tensor y, Tensor l, Tensor u, Tensor pmat, Tensor l_grad, Tensor u_grad, bool unpack_ludata, bool unpack_pivots)
output : Tensor(x_grad)
infer_meta :
func : LUUnpackGradInferMeta
kernel :
func : lu_unpack_grad
- backward_op : masked_select_grad
forward : masked_select (Tensor x, Tensor mask) -> Tensor(out)
args : (Tensor x, Tensor mask, Tensor out_grad)
......@@ -719,6 +728,16 @@
kernel :
func : maxout_grad
- backward_op : mode_grad
forward : mode(Tensor x, int axis = -1, bool keepdim = false) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int axis, bool keepdim)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : mode_grad
- backward_op : mv_grad
forward : mv (Tensor x, Tensor vec) -> Tensor(out)
args : (Tensor x, Tensor vec, Tensor out_grad)
......@@ -729,6 +748,17 @@
kernel :
func : mv_grad
- backward_op : nll_loss_grad
forward : nll_loss (Tensor input, Tensor label, Tensor weight, int64_t ignore_index = -100, str reduction = "mean") -> Tensor(out), Tensor(total_weight)
args : (Tensor input, Tensor label, Tensor weight, Tensor total_weight, Tensor out_grad, int64_t ignore_index, str reduction)
output : Tensor(input_grad)
infer_meta :
func : NllLossGradInferMeta
kernel :
func : nll_loss_grad
data_type : input
optional : weight
- backward_op : poisson_grad
forward : poisson (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
......@@ -739,6 +769,16 @@
kernel :
func : poisson_grad
- backward_op : qr_grad
forward : qr (Tensor x, str mode = "reduced") -> Tensor(q), Tensor(r)
args : (Tensor x, Tensor q, Tensor r, Tensor q_grad, Tensor r_grad, str mode)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : qr_grad
- backward_op : reciprocal_grad
forward : reciprocal (Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
......@@ -773,6 +813,16 @@
backward: relu_double_grad
inplace : (out_grad -> x_grad)
- backward_op : renorm_grad
forward : renorm (Tensor x, float p, int axis, float max_norm) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float p, int axis, float max_norm)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out_grad]
kernel :
func : renorm_grad
- backward_op : round_grad
forward : round(Tensor x) -> Tensor(out)
args : (Tensor out_grad)
......
......@@ -804,15 +804,6 @@
kernel :
func : lu_grad
- backward_op : lu_unpack_grad
forward : lu_unpack (Tensor x, Tensor y, bool unpack_ludata, bool unpack_pivots) -> Tensor(pmat), Tensor(l), Tensor(u)
args : (Tensor x, Tensor y, Tensor l, Tensor u, Tensor pmat, Tensor l_grad, Tensor u_grad, bool unpack_ludata, bool unpack_pivots)
output : Tensor(x_grad)
infer_meta :
func : LUUnpackGradInferMeta
kernel :
func : lu_unpack_grad
- backward_op : margin_cross_entropy_grad
forward : margin_cross_entropy (Tensor logits, Tensor label, bool return_softmax, int ring_id, int rank, int nranks, float margin1, float margin2, float margin3, float scale) -> Tensor(softmax), Tensor(loss)
args : (Tensor logits, Tensor label, Tensor softmax, Tensor loss_grad, bool return_softmax, int ring_id, int rank, int nranks, float margin1, float margin2, float margin3, float scale)
......@@ -964,16 +955,6 @@
func : mish_grad
inplace : (out_grad -> x_grad)
- backward_op : mode_grad
forward : mode(Tensor x, int axis, bool keepdim) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int axis, bool keepdim)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : mode_grad
- backward_op : multi_dot_grad
forward : multi_dot (Tensor[] x) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad)
......@@ -1041,17 +1022,6 @@
func : nearest_interp_grad
data_type : output_grad
- backward_op : nll_loss_grad
forward : nll_loss (Tensor input, Tensor label, Tensor weight, int64_t ignore_index, str reduction) -> Tensor(out), Tensor(total_weight)
args : (Tensor input, Tensor label, Tensor weight, Tensor total_weight, Tensor out_grad, int64_t ignore_index, str reduction)
output : Tensor(input_grad)
infer_meta :
func : NllLossGradInferMeta
kernel :
func : nll_loss_grad
data_type : input
optional : weight
- backward_op : norm_grad
forward : norm (Tensor x, int axis, float epsilon, bool is_test) -> Tensor(out), Tensor(norm)
args : (Tensor x, Tensor norm, Tensor out_grad, int axis, float epsilon, bool is_test)
......@@ -1246,16 +1216,6 @@
kernel :
func : put_along_axis_grad
- backward_op : qr_grad
forward : qr (Tensor x, str mode) -> Tensor(q), Tensor(r)
args : (Tensor x, Tensor q, Tensor r, Tensor q_grad, Tensor r_grad, str mode)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : qr_grad
- backward_op : real_grad
forward : real (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
......@@ -1273,16 +1233,6 @@
func : relu6_grad
inplace : (out_grad -> x_grad)
- backward_op : renorm_grad
forward : renorm (Tensor x, float p, int axis, float max_norm) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float p, int axis, float max_norm)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out_grad]
kernel :
func : renorm_grad
- backward_op : repeat_interleave_grad
forward : repeat_interleave(Tensor x, int repeats, int axis) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int repeats, int axis)
......
......@@ -1164,16 +1164,6 @@
func : lu
backward : lu_grad
- op : lu_unpack
args : (Tensor x, Tensor y, bool unpack_ludata, bool unpack_pivots)
output : Tensor(pmat), Tensor(l), Tensor(u)
infer_meta :
func : LUUnpackInferMeta
kernel :
func : lu_unpack
data_type : x
backward : lu_unpack_grad
- op : margin_cross_entropy
args : (Tensor logits, Tensor label, bool return_softmax, int ring_id, int rank, int nranks, float margin1, float margin2, float margin3, float scale)
output : Tensor(softmax), Tensor(loss)
......@@ -1339,15 +1329,6 @@
func : mish
backward : mish_grad
- op : mode
args : (Tensor x, int axis, bool keepdim)
output : Tensor(out), Tensor(indices)
infer_meta :
func : ModeInferMeta
kernel :
func : mode
backward : mode_grad
- op : momentum_
args : (Tensor param, Tensor grad, Tensor velocity, Tensor learning_rate, Tensor master_param, float mu, bool use_nesterov = false, str regularization_method = "", float regularization_coeff = 0.0, bool multi_precision = false, float rescale_grad = 1.0f)
output : Tensor(param_out), Tensor(velocity_out), Tensor(master_param_out)
......@@ -1416,17 +1397,6 @@
data_type : x
backward : nearest_interp_grad
- op : nll_loss
args : (Tensor input, Tensor label, Tensor weight, int64_t ignore_index, str reduction)
output : Tensor(out), Tensor(total_weight)
infer_meta :
func : NllLossRawInferMeta
kernel :
func : nll_loss
data_type : input
optional : weight
backward : nll_loss_grad
- op : nms
args : (Tensor x, float threshold)
output : Tensor(out)
......@@ -1615,15 +1585,6 @@
inplace : (arr -> out)
backward : put_along_axis_grad
- op : qr
args : (Tensor x, str mode)
output : Tensor(q), Tensor(r)
infer_meta :
func : QrInferMeta
kernel :
func : qr
backward : qr_grad
- op : randint
args : (int low, int high, IntArray shape, DataType dtype=DataType::INT64, Place place={})
output : Tensor(out)
......@@ -1676,16 +1637,6 @@
func : remainder
inplace : (x -> out)
- op : renorm
args : (Tensor x, float p, int axis, float max_norm)
output : Tensor
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : renorm
backward : renorm_grad
- op : repeat_interleave
args : (Tensor x, int repeats, int axis)
output : Tensor(out)
......
......@@ -761,6 +761,13 @@
extra :
attrs : [bool use_mkldnn = false, bool is_test = false]
- op : lu_unpack
backward : lu_unpack_grad
inputs :
{x : X, y : Pivots}
outputs :
{pmat : Pmat, l : L, u : U}
- op : masked_select
inputs :
{x : X, mask : Mask}
......@@ -809,6 +816,13 @@
extra :
attrs : [bool use_mkldnn = false]
- op : mode
backward : mode_grad
inputs :
x : X
outputs :
{out : Out, indices : Indices}
- op : multiply (elementwise_mul)
backward : multiply_grad (elementwise_mul_grad)
extra :
......@@ -832,6 +846,13 @@
extra :
attrs : [bool use_mkldnn = false]
- op : nll_loss
backward : nll_loss_grad
inputs :
{input : X, label : Label, weight : Weight}
outputs :
{out : Out, total_weight : Total_weight}
- op : pad2d
backward : pad2d_grad
extra :
......@@ -869,6 +890,13 @@
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false]
- op : qr
backward : qr_grad
inputs :
x : X
outputs :
{q : Q, r : R}
- op : quantize_linear
extra :
attrs : [float moving_rate = 0.9]
......@@ -946,6 +974,10 @@
- op : renorm
backward : renorm_grad
inputs :
x : X
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
......
......@@ -637,6 +637,16 @@
func : logsigmoid
backward : logsigmoid_grad
- op : lu_unpack
args : (Tensor x, Tensor y, bool unpack_ludata = true, bool unpack_pivots = true)
output : Tensor(pmat), Tensor(l), Tensor(u)
infer_meta :
func : LUUnpackInferMeta
kernel :
func : lu_unpack
data_type : x
backward : lu_unpack_grad
- op : masked_select
args : (Tensor x, Tensor mask)
output : Tensor (out)
......@@ -665,6 +675,15 @@
func : maxout
backward : maxout_grad
- op : mode
args : (Tensor x, int axis = -1, bool keepdim = false)
output : Tensor(out), Tensor(indices)
infer_meta :
func : ModeInferMeta
kernel :
func : mode
backward : mode_grad
- op : mv
args : (Tensor x, Tensor vec)
output : Tensor
......@@ -674,6 +693,17 @@
func : mv
backward : mv_grad
- op : nll_loss
args : (Tensor input, Tensor label, Tensor weight, int64_t ignore_index = -100, str reduction = "mean")
output : Tensor(out), Tensor(total_weight)
infer_meta :
func : NllLossRawInferMeta
kernel :
func : nll_loss
data_type : input
optional : weight
backward : nll_loss_grad
- op : npu_identity
args : (Tensor x, int format = -1)
output : Tensor
......@@ -692,6 +722,15 @@
func : poisson
backward : poisson_grad
- op : qr
args : (Tensor x, str mode = "reduced")
output : Tensor(q), Tensor(r)
infer_meta :
func : QrInferMeta
kernel :
func : qr
backward : qr_grad
- op : reciprocal
args : (Tensor x)
output : Tensor(out)
......@@ -712,6 +751,16 @@
inplace : (x -> out)
backward : relu_grad
- op : renorm
args : (Tensor x, float p, int axis, float max_norm)
output : Tensor
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : renorm
backward : renorm_grad
- op : round
args : (Tensor x)
output : Tensor(out)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature LUUnpackOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("lu_unpack",
{"X", "Pivots"},
{"unpack_ludata", "unpack_pivots"},
{"Pmat", "L", "U"});
}
KernelSignature LUUnpackGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("lu_unpack_grad",
{"X", "Pivots", "L", "U", "Pmat", "L@GRAD", "U@GRAD"},
{"unpack_ludata", "unpack_pivots"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(lu_unpack, phi::LUUnpackOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(lu_unpack_grad, phi::LUUnpackGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature ModeOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"mode", {"X"}, {"axis", "keepdim"}, {"Out", "Indices"});
}
KernelSignature ModeGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("mode_grad",
{"X", "Indices", "Out@GRAD"},
{"axis", "keepdim"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(mode, phi::ModeOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(mode_grad, phi::ModeGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature NllLossOpArgumentMapping(const ArgumentMappingContext& ctx) {
// TODO(xiongkun): can't remove the forward mapping, because the Weight is
// optional
return KernelSignature("nll_loss",
{"X", "Label", "Weight"},
{"ignore_index", "reduction"},
{"Out", "Total_weight"});
}
KernelSignature NllLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("nll_loss_grad",
{"X", "Label", "Weight", "Total_weight", "Out@GRAD"},
{"ignore_index", "reduction"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(nll_loss_grad, phi::NllLossGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(nll_loss, phi::NllLossOpArgumentMapping);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature QrOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("qr", {"X"}, {"mode"}, {"Q", "R"});
}
KernelSignature QrGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"qr_grad", {"X", "Q", "R", "Q@GRAD", "R@GRAD"}, {"mode"}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(qr, phi::QrOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(qr_grad, phi::QrGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature RenormOpArgumentMapping(const ArgumentMappingContext& ctx) {
VLOG(3) << "in renrom arguments mapping";
return KernelSignature("renorm", {"X"}, {"p", "axis", "max_norm"}, {"Out"});
}
KernelSignature RenormGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
VLOG(3) << "in renrom grad arguments mapping";
return KernelSignature(
"renorm_grad", {"X", "Out@GRAD"}, {"p", "axis", "max_norm"}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(renorm, phi::RenormOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(renorm_grad, phi::RenormGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册