未验证 提交 015db48e 编写于 作者: HappyHeavyRain's avatar HappyHeavyRain 提交者: GitHub

Generate static graph code of some ops by yaml (#48977)

* generate static graph code of some ops by yaml

* fix the code-style of yaml

* fix the framework_ci for triangular_solve

* change the 'data_type' of scatter

* add the 'out: Out' of scatter_nd_add
上级 aaee07a3
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle {
namespace operators {
class ScatterNdAddOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "Updates"),
platform::errors::InvalidArgument(
"Ref and Updates must have same type"));
return framework::OpKernelType(
framework::TransToProtoVarType(
ctx.Input<phi::DenseTensor>("X")->type()),
ctx.device_context());
}
};
class ScatterNdAddGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
class ScatterNdAddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The source input of scatter_nd_add op");
AddInput("Index",
"The index input of scatter_nd_add op where X will be updated");
AddInput("Updates", "The updated value of scatter_nd_add op");
AddOutput("Out", "The output of scatter_nd_add op");
AddComment(R"DOC(
Scatter_nd_add Operator.
Output is obtained by applying sparse addition to a single value or slice in a Variable.
Given:
* Case 1:
ref = [0, 1, 2, 3, 4, 5]
index = [[1], [2], [3], [1]]
updates = [9, 10, 11, 12]
we get:
output = [0, 22, 12, 14, 4, 5]
* Case 2:
ref = [[65, 17], [-14, -25]]
index = [[], []]
updates = [[[-1, -2], [1, 2]],
[[3, 4], [-3, -4]]]
ref.shape = (2, 2)
index.shape = (2, 0)
updates.shape = (2, 2, 2)
we get:
output = [[67, 19], [-16, -27]]
)DOC");
}
};
template <typename T>
class ScatterNdAddGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("scatter_nd_add_grad");
op->SetInput("Index", this->Input("Index"));
op->SetInput("Updates", this->Input("Updates"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Updates"),
this->InputGrad("Updates"));
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ScatterNdAddGradNoNeedBufferVarsInferer,
"Updates");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(scatter_nd_add,
ScatterNdAddInferShapeFunctor,
PD_INFER_META(phi::ScatterNdAddInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(scatter_nd_add_grad,
ScatterNdAddGradInferShapeFunctor,
PD_INFER_META(phi::ScatterNdAddGradInferMeta));
REGISTER_OPERATOR(scatter_nd_add,
ops::ScatterNdAddOp,
ops::ScatterNdAddOpMaker,
ops::ScatterNdAddGradMaker<paddle::framework::OpDesc>,
ops::ScatterNdAddGradMaker<paddle::imperative::OpBase>,
ScatterNdAddInferShapeFunctor);
REGISTER_OPERATOR(scatter_nd_add_grad,
ops::ScatterNdAddGradOp,
ops::ScatterNdAddGradNoNeedBufferVarsInferer,
ScatterNdAddGradInferShapeFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle {
namespace operators {
class ScatterOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class ScatterGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The source input of scatter op");
AddInput("Ids", "The index input of scatter op where X will be updated");
AddInput("Updates", "The updated value of scatter op");
AddOutput("Out", "The output of scatter op");
AddAttr<bool>("overwrite",
"(bool, default: True) "
"The mode that updating the output when has same index,"
"If True, use the overwrite mode to update the output"
"of the same index, if False, use the accumulate mode to"
"update the output of the same index,Default value is True."
"You can set overwrite=False to implement scatter_add.")
.SetDefault(true);
AddComment(R"DOC(
Scatter Operator.
This operator obtains output by updating the input on selected indices on the first axis:
$$
Out = X \\
Out[Ids] = Updates
$$
)DOC");
}
};
template <typename T>
class ScatterGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("scatter_grad");
op->SetInput("Ids", this->Input("Ids"));
op->SetInput("Updates", this->Input("Updates"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Updates"),
this->InputGrad("Updates"));
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ScatterGradNoNeedBufferVarsInferer,
"Updates");
DECLARE_INPLACE_OP_INFERER(ScatterInplaceInferer, {"X", "Out"});
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(scatter,
ScatterInferShapeFunctor,
PD_INFER_META(phi::ScatterInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(scatter_grad,
ScatterGradInferShapeFunctor,
PD_INFER_META(phi::ScatterGradInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(scatter,
ops::ScatterOp,
ops::ScatterOpMaker,
ops::ScatterGradMaker<paddle::framework::OpDesc>,
ops::ScatterGradMaker<paddle::imperative::OpBase>,
ops::ScatterInplaceInferer,
ScatterInferShapeFunctor);
REGISTER_OPERATOR(scatter_grad,
ops::ScatterGradOp,
ops::ScatterGradNoNeedBufferVarsInferer,
ScatterGradInferShapeFunctor);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class SeluOp : public framework::OperatorWithKernel {
public:
SeluOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class SeluOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
class SeluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor of selu operator.");
AddOutput("Out", "The output tensor of selu operator.");
AddAttr<float>("scale",
"(float) the default value is 1.0507~. For more "
"information about this value, please refer to:"
"https://arxiv.org/abs/1706.02515.")
.SetDefault(1.0507009873554804934193349852946);
AddAttr<float>("alpha",
"(float) the default value is 1.6732~. For more "
"information about this value, please refer to:"
"https://arxiv.org/abs/1706.02515.")
.SetDefault(1.6732632423543772848170429916717);
AddComment(R"DOC(
Selu Operator.
The equation is:
$$
f(x) =\lambda*
\begin{cases}
\quad \quad x, \quad \quad \quad \text{if} \ x > 0 \\
\alpha * e^x - \alpha, \qquad \text{if} \ x <= 0
\end{cases}
$$
The input `X` can carry the LoD (Level of Details) information,
or not. And the output shares the LoD information with input `X`.
)DOC");
}
};
template <typename T>
class SeluGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("selu_grad");
grad_op->SetInput("Out", this->Output("Out"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
class SeluGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@GRAD",
"selu_grad");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "selu_grad");
auto x_grad_name = framework::GradVarName("X");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("Out"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(selu,
SeluInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(selu,
ops::SeluOp,
ops::SeluOpMaker,
ops::SeluOpInferVarType,
ops::SeluGradMaker<paddle::framework::OpDesc>,
ops::SeluGradMaker<paddle::imperative::OpBase>,
SeluInferShapeFunctor);
REGISTER_OPERATOR(selu_grad, ops::SeluGradOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class ShardIndexOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class ShardIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(phi::DenseTensor, phi::DenseTensor<int|int64>) Input variable. "
"Each value "
"of X is an index.");
AddOutput(
"Out",
"(Tensor, Tensor<int|int64>) Output tensor with same shape as X. "
"The tensor consists of sharding representations of values in X.");
AddAttr<int>("index_num",
"A positive integer to specify the range of the input X.");
AddAttr<int>("nshards",
"A positive integer to specify the number of shards.");
AddAttr<int>("shard_id", "The current shard id");
AddAttr<int>("ignore_value", "An integer value out of sharded range")
.SetDefault(-1);
AddComment(R"DOC(
This layer creates the sharded index for input. This layers is used in
model- and data- parallel mixed training generally, in which the index
data (usually the label) should be recaculated in each trainer according
to
.. math::
assert index_num % nshards == 0
shard_size = index_num / nshards
y = x % shard_size if x / shard_size == shard_id else ignore_value
We take the distributed one-hot representation to show what this layer is
used for. The distributed one-hot representation is separated into multiple
shards, and each shard is filling zeros except the one with the index
inside. In order to create these sharded representation in each trainer,
the original index should be recalculated (i.e. sharded) before.
Examples:
X is a Tensor of integer values:
X.shape = [4, 1]
X.data = [[1], [6], [12], [19]]
suppose index_num = 20 and nshards = 2, then we get shard_size = 10
if shard_id == 0, we get the Out:
Out.shape = [4, 1]
Out.data = [[1], [6], [-1], [-1]]
if shard_id == 1, we get the Out:
Out.shape = [4, 1]
Out.data = [[-1], [-1], [2], [9]]
the default `ignore_value` -1 is used in this example.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(shard_index,
ShardIndexInferShapeFunctor,
PD_INFER_META(phi::ShardIndexInferMeta));
REGISTER_OPERATOR(
shard_index,
ops::ShardIndexOp,
ops::ShardIndexOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ShardIndexInferShapeFunctor);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle {
namespace operators {
class ViterbiDecodeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
class ViterbiDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"Input",
"The unary emission tensor. The shape of Input must be (batch_size,"
"sequence_length, num_tags). ");
AddInput("Transition",
"The transition matrix. The shape of Transition must be ( "
"num_tags, num_tags). ");
AddInput("Length",
"The input length tensor storing real length of each sequence for "
"correctness. The shape of Length MUST be (batch_size).");
AddOutput("Scores",
"The scores tensor containing the score for the Viterbi "
"sequence. The shape of Scores MUST be (batch_size).");
AddOutput("Path",
"The paths tensor containing the highest scoring tag indices. "
"The shape of Scores MUST be (batch_size, sequence_length).");
AddAttr<bool>("include_bos_eos_tag",
"If set to True, the last row and the last column of "
"transitions will be considered as start tag.")
.SetDefault(true);
AddComment(R"DOC(
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace platform = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(viterbi_decode,
ViterbiDecodeInferShapeFunctor,
PD_INFER_META(phi::ViterbiDecodeInferMeta));
REGISTER_OP_WITHOUT_GRADIENT(viterbi_decode,
ops::ViterbiDecodeOp,
ops::ViterbiDecodeOpMaker,
ViterbiDecodeInferShapeFunctor);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
class WhereOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class WhereGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Condition"), "Input", "Condition", "Where");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Where");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Where");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
framework::GradVarName("Out"),
"Where");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
class WhereOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Condition",
"(Tensor) A bool tensor whose rank is at least 1. When Condition "
"is True, yield x, otherwise yield y");
AddInput("X",
"(Tensor), The first input tensor of where op. When the "
"corresponding position of the condition is true, the output "
"takes the element of X.");
AddInput("Y",
"(Tensor), The second input tensor of where op. When the "
"corresponding position of condition is false, the output takes "
"the element of Y.");
AddOutput("Out", "(Tensor), The output tensor of where op.");
AddComment(R"DOC(
Where Operator.
Return a tensor of elements selected from either $X$ or $Y$, depending on condition.
The equation is:
$$
Out_i =
\begin{cases}
\X_i, \quad \text{if} \ cond_i is True \\
\Y_i, \quad \text{if} \ cond_i is False \\
\end{cases}
$$
)DOC");
}
};
template <typename T>
class WhereOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad) const override {
grad->SetType("where_grad");
grad->SetInput("Condition", this->Input("Condition"));
grad->SetInput("X", this->Input("X"));
grad->SetInput("Y", this->Input("Y"));
grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(WhereGradNoNeedBufferVarsInferer, "X", "Y");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(where,
WhereInferShapeFunctor,
PD_INFER_META(phi::WhereInferMeta));
REGISTER_OPERATOR(where,
ops::WhereOp,
ops::WhereOpMaker,
ops::WhereOpGradMaker<paddle::framework::OpDesc>,
ops::WhereOpGradMaker<paddle::imperative::OpBase>,
WhereInferShapeFunctor);
REGISTER_OPERATOR(where_grad,
ops::WhereGradOp,
ops::WhereGradNoNeedBufferVarsInferer);
......@@ -887,6 +887,39 @@
backward : rsqrt_double_grad
inplace : (out_grad -> x_grad)
- backward_op : scatter_grad
forward : scatter (Tensor x, Tensor index, Tensor updates, bool overwrite=true) -> Tensor(out)
args : (Tensor index, Tensor updates, Tensor out_grad, bool overwrite)
output : Tensor(x_grad), Tensor(updates_grad)
infer_meta :
func : ScatterGradInferMeta
param : [index, updates, out_grad, overwrite]
kernel :
func : scatter_grad
no_need_buffer : updates
- backward_op : scatter_nd_add_grad
forward : scatter_nd_add (Tensor x, Tensor index, Tensor updates) -> Tensor(out)
args : (Tensor index, Tensor updates, Tensor out_grad)
output : Tensor(x_grad), Tensor(updates_grad)
infer_meta :
func : ScatterNdAddGradInferMeta
param : [index, updates, out_grad]
kernel :
func : scatter_nd_add_grad
no_need_buffer : updates
- backward_op : selu_grad
forward : selu (Tensor x, float scale=1.0507009873554804934193349852946, float alpha=1.6732632423543772848170429916717) -> Tensor(out)
args : (Tensor out, Tensor out_grad, float scale, float alpha)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out]
kernel :
func : selu_grad
data_type : out
- backward_op : send_uv_grad
forward : send_uv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD") -> Tensor(out)
args: (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out_grad, str message_op = "ADD")
......@@ -1216,3 +1249,14 @@
func : unfold_grad
data_type : out_grad
no_need_buffer : x
- backward_op : where_grad
forward : where (Tensor condition, Tensor x, Tensor y) -> Tensor(out)
args : (Tensor condition, Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : where_grad
no_need_buffer : x, y
......@@ -1312,28 +1312,6 @@
output : Tensor(x_grad)
invoke : scale(out_grad, scale, 0.0, bias_after_scale)
- backward_op : scatter_grad
forward : scatter (Tensor x, Tensor index, Tensor updates, bool overwrite) -> Tensor(out)
args : (Tensor index, Tensor updates, Tensor out_grad, bool overwrite)
output : Tensor(x_grad), Tensor(updates_grad)
infer_meta :
func : ScatterGradInferMeta
param : [index, updates, out_grad, overwrite]
kernel :
func : scatter_grad
no_need_buffer : updates
- backward_op : scatter_nd_add_grad
forward : scatter_nd_add (Tensor x, Tensor index, Tensor updates) -> Tensor(out)
args : (Tensor index, Tensor updates, Tensor out_grad)
output : Tensor(x_grad), Tensor(updates_grad)
infer_meta :
func : ScatterNdAddGradInferMeta
param : [index, updates, out_grad]
kernel :
func : scatter_nd_add_grad
no_need_buffer : updates
- backward_op : segment_pool_grad
forward : segment_pool (Tensor x, Tensor segment_ids, str pooltype) -> Tensor(out), Tensor(summed_ids)
args : (Tensor x, Tensor segment_ids, Tensor out, Tensor summed_ids, Tensor out_grad, str pooltype)
......@@ -1346,16 +1324,6 @@
data_type : x
optional : summed_ids
- backward_op : selu_grad
forward : selu (Tensor x, float scale, float alpha) -> Tensor(out)
args : (Tensor out, Tensor out_grad, float scale, float alpha)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out]
kernel :
func : selu_grad
- backward_op : send_u_recv_grad
forward : send_u_recv (Tensor x, Tensor src_index, Tensor dst_index, str reduce_op = "SUM", IntArray out_size = {0}) -> Tensor(out), Tensor(dst_count)
args : (Tensor x, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str reduce_op = "SUM")
......@@ -1719,17 +1687,6 @@
optional : logits_length
no_need_buffer : logits
- backward_op : where_grad
forward : where (Tensor condition, Tensor x, Tensor y) -> Tensor(out)
args : (Tensor condition, Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : where_grad
no_need_buffer : x, y
- backward_op : yolo_loss_grad
forward : yolo_loss(Tensor x, Tensor gt_box, Tensor gt_label, Tensor gt_score, int[] anchors, int[] anchor_mask, int class_num, float ignore_thresh, int downsample_ratio, bool use_label_smooth=true, float scale_x_y=1.0) -> Tensor(loss), Tensor(objectness_mask), Tensor(gt_match_mask)
args : (Tensor x, Tensor gt_box, Tensor gt_label, Tensor gt_score, Tensor objectness_mask, Tensor gt_match_mask, Tensor loss_grad, int[] anchors, int[] anchor_mask, int class_num, float ignore_thresh, int downsample_ratio, bool use_label_smooth=true, float scale_x_y=1.0)
......
......@@ -1708,27 +1708,6 @@
inplace : (x -> out)
backward : scale_grad
- op : scatter
args : (Tensor x, Tensor index, Tensor updates, bool overwrite)
output : Tensor(out)
infer_meta :
func : ScatterInferMeta
dtype : x
kernel :
func : scatter
inplace : (x -> out)
backward : scatter_grad
- op : scatter_nd_add
args : (Tensor x, Tensor index, Tensor updates)
output : Tensor
infer_meta :
func : ScatterNdAddInferMeta
dtype : x
kernel :
func : scatter_nd_add
backward : scatter_nd_add_grad
- op : segment_pool
args : (Tensor x, Tensor segment_ids, str pooltype)
output : Tensor(out), Tensor(summed_ids)
......@@ -1739,16 +1718,6 @@
data_type : x
backward : segment_pool_grad
- op : selu
args : (Tensor x, float scale, float alpha)
output : Tensor
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : selu
backward : selu_grad
- op : send_u_recv
args : (Tensor x, Tensor src_index, Tensor dst_index, str reduce_op = "SUM", IntArray out_size = {0})
output : Tensor(out), Tensor(dst_count)
......@@ -1797,14 +1766,6 @@
data_transform:
skip_transform : input
- op : shard_index
args : (Tensor input, int index_num, int nshards, int shard_id, int ignore_value)
output : Tensor(out)
infer_meta :
func : ShardIndexInferMeta
kernel :
func : shard_index
- op : sigmoid_cross_entropy_with_logits
args : (Tensor x, Tensor label, bool normalize, int ignore_index)
output : Tensor
......@@ -1993,6 +1954,7 @@
func : TriangularSolveInferMeta
kernel :
func : triangular_solve
data_type : x
backward : triangular_solve_grad
- op : tril
......@@ -2164,15 +2126,6 @@
data_type : x
inplace : (x -> out), (prev_loss_scaling -> loss_scaling), (in_good_steps -> out_good_steps), (in_bad_steps -> out_bad_steps)
- op : viterbi_decode
args : (Tensor potentials, Tensor transition_params, Tensor lengths, bool include_bos_eos_tag)
output : Tensor(scores), Tensor(path)
infer_meta :
func : ViterbiDecodeInferMeta
kernel :
func : viterbi_decode
data_type : potentials
- op : warpctc
args : (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank, bool norm_by_times)
output : Tensor(loss), Tensor(warpctcgrad)
......@@ -2185,15 +2138,6 @@
intermediate: warpctcgrad
backward : warpctc_grad
- op : where
args : (Tensor condition, Tensor x, Tensor y)
output : Tensor
infer_meta :
func : WhereInferMeta
kernel :
func : where
backward : where_grad
- op : yolo_box
args : (Tensor x, Tensor img_size, int[] anchors, int class_num, float conf_thresh, int downsample_ratio, bool clip_bbox, float scale_x_y=1.0, bool iou_aware=false, float iou_aware_factor=0.5)
output : Tensor(boxes), Tensor(scores)
......
......@@ -1045,6 +1045,20 @@
extra :
attrs : [bool use_mkldnn = false]
- op : scatter
backward : scatter_grad
inputs :
{x : X, index : Ids, updates : Updates}
outputs :
out : Out
- op : scatter_nd_add
backward : scatter_nd_add_grad
inputs :
{x : X, index : Index, updates : Updates}
outputs :
out : Out
- op : searchsorted
inputs :
{sorted_sequence : SortedSequence, values : Values}
......@@ -1055,6 +1069,13 @@
extra :
attrs : [bool deterministic = false, str rng_name = "", bool force_cpu = false]
- op : selu
backward : selu_grad
inputs :
x : X
outputs :
out : Out
- op : send_uv (graph_send_uv)
backward : send_uv_grad (graph_send_uv_grad)
......@@ -1067,6 +1088,12 @@
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"]
- op : shard_index
inputs :
input : X
outputs :
out : Out
- op : share_buffer
inputs :
x : X
......@@ -1297,6 +1324,19 @@
outputs :
out : Y
- op : viterbi_decode
inputs :
{potentials : Input, transition_params : Transition, lengths : Length}
outputs :
{scores : Scores, path : Path}
- op : where
backward : where_grad
inputs :
{condition : Condition, x : X, y : Y}
outputs :
out : Out
- op : while
backward : while_grad
extra :
......
......@@ -821,6 +821,27 @@
inplace : (x -> out)
backward : rsqrt_grad
- op : scatter
args : (Tensor x, Tensor index, Tensor updates, bool overwrite=true)
output : Tensor(out)
infer_meta :
func : ScatterInferMeta
kernel :
func : scatter
data_type : x
inplace : (x -> out)
backward : scatter_grad
- op : scatter_nd_add
args : (Tensor x, Tensor index, Tensor updates)
output : Tensor
infer_meta :
func : ScatterNdAddInferMeta
kernel :
func : scatter_nd_add
data_type : x
backward : scatter_nd_add_grad
- op : searchsorted
args : (Tensor sorted_sequence, Tensor values, bool out_int32 = false, bool right = false)
output : Tensor(out)
......@@ -830,6 +851,16 @@
func : searchsorted
data_type : sorted_sequence
- op : selu
args : (Tensor x, float scale=1.0507009873554804934193349852946, float alpha=1.6732632423543772848170429916717)
output : Tensor
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : selu
backward : selu_grad
- op : send_uv
args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD")
output : Tensor(out)
......@@ -840,6 +871,14 @@
data_type : x
backward : send_uv_grad
- op : shard_index
args : (Tensor input, int index_num, int nshards, int shard_id, int ignore_value=-1)
output : Tensor(out)
infer_meta :
func : ShardIndexInferMeta
kernel :
func : shard_index
- op : sigmoid
args : (Tensor x)
output : Tensor
......@@ -1031,3 +1070,21 @@
kernel :
func : unfold
backward : unfold_grad
- op : viterbi_decode
args : (Tensor potentials, Tensor transition_params, Tensor lengths, bool include_bos_eos_tag = true)
output : Tensor(scores), Tensor(path)
infer_meta :
func : ViterbiDecodeInferMeta
kernel :
func : viterbi_decode
data_type : potentials
- op : where
args : (Tensor condition, Tensor x, Tensor y)
output : Tensor
infer_meta :
func : WhereInferMeta
kernel :
func : where
backward : where_grad
......@@ -21,24 +21,6 @@ KernelSignature GatherNdGradArgumentMapping(const ArgumentMappingContext& ctx) {
"gather_nd_grad", {"X", "Index", "Out@GRAD"}, {}, {"X@GRAD"});
}
KernelSignature ScatterGradArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("scatter_grad",
{"Ids", "Updates", "Out@GRAD"},
{"overwrite"},
{"X@GRAD", "Updates@GRAD"});
}
KernelSignature ScatterNdAddGradArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("scatter_nd_add_grad",
{"Index", "Updates", "Out@GRAD"},
{},
{"X@GRAD", "Updates@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(gather_nd_grad, phi::GatherNdGradArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(scatter_grad, phi::ScatterGradArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(scatter_nd_add_grad,
phi::ScatterNdAddGradArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SeluGradGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"selu_grad", {"Out", "Out@GRAD"}, {"scale", "alpha"}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(selu_grad, phi::SeluGradGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature WhereGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("where_grad",
{"Condition", "X", "Y", "Out@GRAD"},
{},
{"X@GRAD", "Y@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(where_grad, phi::WhereGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册