未验证 提交 a160c417 编写于 作者: F Feiyu Chan 提交者: GitHub

fix bugs in codegen for operators (#43594)

* add codegen for get_expected_kernel, add argument mapping for selected_rows kernels, fix other bugs in codegen for operators.
* move bernoulli, erf, mv, poisson, trunc, erf to api.yaml and corresponding backward api to backward.yaml
* generate EmptyGradOpMaker for ops without grad op
* add code to generate all possible kernel signatures for infrt
上级 e08d33e1
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/bernoulli_op.h"
#include <algorithm>
#include <string>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h"
namespace paddle {
namespace operators {
class BernoulliOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"A tensor with probabilities for generating the random binary "
"number");
AddOutput("Out", "A Tensor filled with random binary number");
AddComment(R"DOC(
This OP returns a Tensor filled with random binary(0 or 1) number from a Bernoulli distribution.
Out ~ Bernoulli(X)
)DOC");
}
};
class BernoulliOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
return UnaryOpUnchangedInferShape(ctx);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(
bernoulli, ops::BernoulliOp, ops::BernoulliOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class ErfOp : public framework::OperatorWithKernel {
public:
ErfOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class ErfGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"Input(%s) of ErfGradOp should not be null.", "DOut"));
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(%s) of ErfGradOp should not be null.", "X"));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::InvalidArgument(
"Output(%s) of ErfGradOp should not be null.", "DX"));
auto x_grad_name = framework::GradVarName("X");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class ErfOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor of erf operator.");
AddOutput("Out", "The output tensor of erf operator.");
AddComment(R"DOC(
Erf Operator.
The equation is:
$$
f(x) = \frac{2}{\sqrt{\pi}} \int_{0}^{x}e^{- \eta^{2}}d\eta
$$
The input `X` can carry the LoD (Level of Details) information,
or not. And the output shares the LoD information with input `X`.
)DOC");
}
};
template <typename T>
class ErfGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("erf_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(erf, ErfInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(erf, ops::ErfOp, ops::ErfOpMaker,
ops::ErfGradOpMaker<paddle::framework::OpDesc>,
ops::ErfGradOpMaker<paddle::imperative::OpBase>,
ErfInferShapeFunctor);
REGISTER_OPERATOR(erf_grad, ops::ErfGradOp);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
class MVOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The matrix input of mv op");
AddInput("Vec", "The vector input of mv op");
AddOutput("Out", "The output of mv op");
AddComment(R"DOC(
MV Operator.
This operator is used to perform matrix vector multiplication
of the input tensors `X` and `Vec`.
)DOC");
}
};
class MVOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
};
template <typename T>
class MVOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("mv_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Vec", this->Input("Vec"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("Vec"), this->InputGrad("Vec"));
}
};
class MVOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "mv");
OP_INOUT_CHECK(context->HasInput("Vec"), "Input", "Vec", "mv");
OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "mv");
auto x_dims = context->GetInputDim("X");
auto vec_dims = context->GetInputDim("Vec");
auto x_grad_name = framework::GradVarName("X");
auto vec_grad_name = framework::GradVarName("Vec");
if (context->HasOutput(x_grad_name)) {
context->SetOutputDim(x_grad_name, x_dims);
}
if (context->HasOutput(vec_grad_name)) {
context->SetOutputDim(vec_grad_name, vec_dims);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(mv, MvInferShapeFunctor,
PD_INFER_META(phi::MvInferMeta));
REGISTER_OPERATOR(mv, ops::MVOp, ops::MVOpMaker,
ops::MVOpGradMaker<paddle::framework::OpDesc>,
ops::MVOpGradMaker<paddle::imperative::OpBase>,
MvInferShapeFunctor);
REGISTER_OPERATOR(mv_grad, ops::MVOpGrad);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class PoissonOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class PoissonOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) The input tensor of poisson op");
AddOutput("Out",
"The output tensor of poisson op, it has the same shape and "
"dtype with input. Each element corresponds to input tensor");
AddComment(R"DOC(
This operator generate random value that obey poisson distribution.
)DOC");
}
};
class PoissonOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
class PoissonGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out_Grad", "PoissonGradOp");
auto dout_dim = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), dout_dim);
}
};
template <typename T>
class PoissonGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("poisson_grad");
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(poisson, PoissonInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(poisson, ops::PoissonOp, ops::PoissonOpMaker,
ops::PoissonOpInferVarType,
ops::PoissonGradOpMaker<paddle::framework::OpDesc>,
ops::PoissonGradOpMaker<paddle::imperative::OpBase>,
PoissonInferShapeFunctor);
REGISTER_OPERATOR(poisson_grad, ops::PoissonGradOp);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class TruncOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class TruncOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of trunc op.");
AddOutput("Out", "(Tensor), The output tensor of trunc op.");
AddComment(R"DOC(
Trunc Operator.
Returns a new tensor with the truncated integer values of input.
$$out = trunc(x)$$
)DOC");
}
};
class TruncGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "TruncGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "TruncGrad");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), dout_dims);
}
};
template <typename T>
class TruncGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("trunc_grad");
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(trunc, TruncInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(trunc, ops::TruncOp, ops::TruncOpMaker,
ops::TruncGradOpMaker<paddle::framework::OpDesc>,
ops::TruncGradOpMaker<paddle::imperative::OpBase>,
TruncInferShapeFunctor);
REGISTER_OPERATOR(trunc_grad, ops::TruncGradOp);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature ErfGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("erf_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(erf_grad, phi::ErfGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MvGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"mv_grad", {"X", "Vec", "Out@GRAD"}, {}, {"X@GRAD", "Vec@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(mv_grad, phi::MvGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature PoissonGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("poisson_grad", {"Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(poisson_grad, phi::PoissonGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature TruncOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("trunc", {"X"}, {}, {"Out"});
}
KernelSignature TruncGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("trunc_grad", {"Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(trunc, phi::TruncOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(trunc_grad, phi::TruncGradOpArgumentMapping);
# erf
# bernoulli
- api : bernoulli
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : bernoulli
- api : erf
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : erf
backward : erf_grad
- api : mv
args : (Tensor x, Tensor vec)
output : Tensor
infer_meta :
func : MvInferMeta
kernel :
func : mv
backward : mv_grad
# poisson
- api : poisson
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : poisson
backward : poisson_grad
- api : trunc
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : trunc
backward : trunc_grad
- backward_api : erf_grad
forward : erf (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : erf_grad
data_type : out_grad
- backward_api : mv_grad
forward : mv (Tensor x, Tensor vec) -> Tensor(out)
args : (Tensor x, Tensor vec, Tensor out_grad)
output : Tensor(x_grad), Tensor(vec_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, vec]
kernel :
func : mv_grad
- backward_api : poisson_grad
forward : poisson (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out_grad]
kernel :
func : poisson_grad
- backward_api : trunc_grad
forward : trunc (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out_grad]
kernel :
func : trunc_grad
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import List, Dict from typing import List, Dict
import itertools
import re import re
from jinja2.filters import do_xmlattr from jinja2.filters import do_xmlattr
...@@ -25,6 +26,10 @@ from type_mapping import (dense_input_types_map, dense_optional_input_types_map, ...@@ -25,6 +26,10 @@ from type_mapping import (dense_input_types_map, dense_optional_input_types_map,
phi_attr_types_map) phi_attr_types_map)
def quote(s):
return '"{}"'.format(s)
# ------------------------------ attr ------------------------------------- # ------------------------------ attr -------------------------------------
def to_phi_attr_type(s): def to_phi_attr_type(s):
return phi_attr_types_map[s] return phi_attr_types_map[s]
...@@ -74,15 +79,14 @@ def to_sr_output_type(s): ...@@ -74,15 +79,14 @@ def to_sr_output_type(s):
# -------------- transform argument names from yaml to opmaker ------------ # -------------- transform argument names from yaml to opmaker ------------
def to_opmaker_name(s): def to_opmaker_name(s):
if s.endswith("_grad"): if s.endswith("_grad"):
return 'GradVarName("{}")'.format( return 'GradVarName("{}")'.format(to_pascal_case(s[:-5]))
to_pascal_case(s.removesuffix("_grad")))
else: else:
return '"{}"'.format(to_pascal_case(s)) return '"{}"'.format(to_pascal_case(s))
def to_opmaker_name_cstr(s): def to_opmaker_name_cstr(s):
if s.endswith("_grad"): if s.endswith("_grad"):
return '"{}@GRAD"'.format(to_pascal_case(s.removesuffix("_grad"))) return '"{}@GRAD"'.format(to_pascal_case(s[:-5]))
else: else:
return '"{}"'.format(to_pascal_case(s)) return '"{}"'.format(to_pascal_case(s))
...@@ -105,3 +109,48 @@ def to_input_name(s): ...@@ -105,3 +109,48 @@ def to_input_name(s):
match = re.match(r"(d\d*)(\w+)", s) match = re.match(r"(d\d*)(\w+)", s)
assert (match.group(1) != ""), "it should be a grad style name." assert (match.group(1) != ""), "it should be a grad style name."
return match.group(2) return match.group(2)
def cartesian_prod_attrs(attrs):
items = []
for attr in attrs:
type_name = attr["typename"]
name = attr["name"]
if type_name == "Scalar":
items.append((name, "{}Tensor".format(name)))
elif type_name == "IntArray":
items.append(
(name, "{}Tensor".format(name), "{}TensorList".format(name)))
else:
items.append((name, ))
_combinations = itertools.product(*items)
combinations = []
for x in _combinations:
combinations.append('{' + ", ".join(quote(t) for t in x) + '}')
return combinations
def cartesian_prod_mapping(api):
kernels = api["kernel"]["func"]
inputs = [
x["name"] for x in api["inputs"] if x["name"] in api["kernel"]["param"]
]
inputs = [to_opmaker_name_cstr(input) for input in inputs]
attrs = cartesian_prod_attrs(api["attrs"])
outputs = [
to_opmaker_name_cstr(output["name"]) for output in api["outputs"]
]
def vec(items):
return "{" + ', '.join(items) + "}"
inputs = [vec(inputs)]
outputs = [vec(outputs)]
kernels = [quote(x) for x in kernels]
mappings = itertools.product(kernels, inputs, attrs, outputs)
outs = []
for spec in mappings:
outs.append("return KernelSignature({});".format(", ".join(spec)))
return "\n".join(outs)
...@@ -22,7 +22,7 @@ from jinja2 import Environment, FileSystemLoader, StrictUndefined ...@@ -22,7 +22,7 @@ from jinja2 import Environment, FileSystemLoader, StrictUndefined
from filters import to_op_attr_type, to_opmaker_name, to_opmaker_name_cstr, to_pascal_case from filters import to_op_attr_type, to_opmaker_name, to_opmaker_name_cstr, to_pascal_case
from tests import is_base_api, is_vec, is_scalar, is_initializer_list, supports_inplace, supports_no_need_buffer from tests import is_base_api, is_vec, is_scalar, is_initializer_list, supports_inplace, supports_no_need_buffer
from filters import to_input_name from filters import to_input_name, cartesian_prod_mapping
from parse_utils import to_named_dict from parse_utils import to_named_dict
file_loader = FileSystemLoader(Path(__file__).parent / "templates") file_loader = FileSystemLoader(Path(__file__).parent / "templates")
...@@ -37,6 +37,7 @@ env.filters["to_opmaker_name"] = to_opmaker_name ...@@ -37,6 +37,7 @@ env.filters["to_opmaker_name"] = to_opmaker_name
env.filters["to_pascal_case"] = to_pascal_case env.filters["to_pascal_case"] = to_pascal_case
env.filters["to_input_name"] = to_input_name env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.tests["base_api"] = is_base_api env.tests["base_api"] = is_base_api
env.tests["vec"] = is_vec env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar env.tests["scalar"] = is_scalar
...@@ -45,14 +46,23 @@ env.tests["supports_inplace"] = supports_inplace ...@@ -45,14 +46,23 @@ env.tests["supports_inplace"] = supports_inplace
env.tests["supports_no_need_buffer"] = supports_no_need_buffer env.tests["supports_no_need_buffer"] = supports_no_need_buffer
def restruct_io(api):
api["input_dict"] = to_named_dict(api["inputs"])
api["attr_dict"] = to_named_dict(api["attrs"])
api["output_dict"] = to_named_dict(api["outputs"])
return api
def main(api_yaml_path, backward_yaml_path, output_op_path, def main(api_yaml_path, backward_yaml_path, output_op_path,
output_arg_map_path): output_arg_map_path):
with open(api_yaml_path, "rt") as f: with open(api_yaml_path, "rt") as f:
apis = yaml.safe_load(f) apis = yaml.safe_load(f)
apis = [restruct_io(api) for api in apis]
forward_api_dict = to_named_dict(apis) forward_api_dict = to_named_dict(apis)
with open(backward_yaml_path, "rt") as f: with open(backward_yaml_path, "rt") as f:
backward_apis = yaml.safe_load(f) backward_apis = yaml.safe_load(f)
backward_apis = [restruct_io(api) for api in backward_apis]
backward_api_dict = to_named_dict(backward_apis) backward_api_dict = to_named_dict(backward_apis)
# fill backward field for an api if another api claims it as forward # fill backward field for an api if another api claims it as forward
......
...@@ -255,15 +255,6 @@ ...@@ -255,15 +255,6 @@
func : bce_loss func : bce_loss
backward : bce_loss_grad backward : bce_loss_grad
# bernoulli
- api : bernoulli
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : bernoulli
# bitwise_and # bitwise_and
- api : bitwise_and - api : bitwise_and
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
...@@ -677,16 +668,6 @@ ...@@ -677,16 +668,6 @@
kernel : kernel :
func : equal_all func : equal_all
# erf
- api : erf
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : erf
backward : erf_grad
# erfinv # erfinv
- api : erfinv - api : erfinv
args : (Tensor x) args : (Tensor x)
...@@ -1544,15 +1525,6 @@ ...@@ -1544,15 +1525,6 @@
func : multiply func : multiply
backward : multiply_grad backward : multiply_grad
- api : mv
args : (Tensor x, Tensor vec)
output : Tensor
infer_meta :
func : MvInferMeta
kernel :
func : mv
backward : mv_grad
- api : nll_loss - api : nll_loss
args : (Tensor input, Tensor label, Tensor weight, int64_t ignore_index, str reduction) args : (Tensor input, Tensor label, Tensor weight, int64_t ignore_index, str reduction)
output : Tensor(out), Tensor(total_weight) output : Tensor(out), Tensor(total_weight)
...@@ -1633,16 +1605,6 @@ ...@@ -1633,16 +1605,6 @@
func : pixel_shuffle func : pixel_shuffle
backward : pixel_shuffle_grad backward : pixel_shuffle_grad
# poisson
- api : poisson
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : poisson
backward : poisson_grad
- api : pool2d - api : pool2d
args : (Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) args : (Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm)
output : Tensor(out) output : Tensor(out)
...@@ -2239,15 +2201,6 @@ ...@@ -2239,15 +2201,6 @@
func : tril_triu func : tril_triu
backward : tril_triu_grad backward : tril_triu_grad
- api : trunc
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : trunc
backward : trunc_grad
# python API: paddle.nn.initializer.TruncatedNormal # python API: paddle.nn.initializer.TruncatedNormal
- api : truncated_gaussian_random - api : truncated_gaussian_random
args : (int[] shape, float mean, float std, int seed, DataType dtype=DataType::FLOAT32, Place place={}) args : (int[] shape, float mean, float std, int seed, DataType dtype=DataType::FLOAT32, Place place={})
......
...@@ -664,17 +664,6 @@ ...@@ -664,17 +664,6 @@
output : Tensor(weight_grad) output : Tensor(weight_grad)
invoke : embedding_grad_impl(x, weight, out_grad, padding_idx, sparse, weight_grad) invoke : embedding_grad_impl(x, weight, out_grad, padding_idx, sparse, weight_grad)
- backward_api : erf_grad
forward : erf (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : erf_grad
data_type : out_grad
- backward_api : erfinv_grad - backward_api : erfinv_grad
forward : erfinv (Tensor x) -> Tensor(out) forward : erfinv (Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad) args : (Tensor out, Tensor out_grad)
...@@ -1431,16 +1420,6 @@ ...@@ -1431,16 +1420,6 @@
func : multiply_triple_grad func : multiply_triple_grad
optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_grad_out_grad optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_grad_out_grad
- backward_api : mv_grad
forward : mv (Tensor x, Tensor vec) -> Tensor(out)
args : (Tensor x, Tensor vec, Tensor out_grad)
output : Tensor(x_grad), Tensor(vec_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, vec]
kernel :
func : mv_grad
- backward_api : nll_loss_grad - backward_api : nll_loss_grad
forward : nll_loss (Tensor input, Tensor label, Tensor weight, int64_t ignore_index, str reduction) -> Tensor(out), Tensor(total_weight) forward : nll_loss (Tensor input, Tensor label, Tensor weight, int64_t ignore_index, str reduction) -> Tensor(out), Tensor(total_weight)
args : (Tensor input, Tensor label, Tensor weight, Tensor total_weight, Tensor out_grad, int64_t ignore_index, str reduction) args : (Tensor input, Tensor label, Tensor weight, Tensor total_weight, Tensor out_grad, int64_t ignore_index, str reduction)
...@@ -1524,16 +1503,6 @@ ...@@ -1524,16 +1503,6 @@
kernel : kernel :
func : pixel_shuffle_grad func : pixel_shuffle_grad
- backward_api : poisson_grad
forward : poisson (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out_grad]
kernel :
func : poisson_grad
- backward_api : pool2d_double_grad - backward_api : pool2d_double_grad
forward : pool2d_grad(Tensor x, Tensor out, Tensor grad_out, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) -> Tensor(grad_x) forward : pool2d_grad(Tensor x, Tensor out, Tensor grad_out, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) -> Tensor(grad_x)
args : (Tensor grad_x_grad, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) args : (Tensor grad_x_grad, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm)
...@@ -2259,16 +2228,6 @@ ...@@ -2259,16 +2228,6 @@
kernel : kernel :
func : tril_triu_grad func : tril_triu_grad
- backward_api : trunc_grad
forward : trunc (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out_grad]
kernel :
func : trunc_grad
- backward_api : unbind_grad - backward_api : unbind_grad
forward : unbind (Tensor input, int axis) -> Tensor[](out) forward : unbind (Tensor input, int axis) -> Tensor[](out)
args : (Tensor[] out_grad, int axis) args : (Tensor[] out_grad, int axis)
......
{% from "operator_utils.c.j2" import name_map, register_name_map %} {% from "operator_utils.c.j2" import name_map, register_name_map %}
// this file is generated by python/paddle/utils/code_gen/generate_op.py, do not edit. // this file is generated by python/paddle/utils/code_gen/generate_op.py, do not edit.
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
namespace phi { namespace phi {
using paddle::framework::GradVarName;
{% for api in apis %} {% for api in apis %}
{% if api is base_api %} {% if api is base_api %}
{{name_map(api)}} {{name_map(api)}}
......
...@@ -102,8 +102,24 @@ KernelSignature {{api["name"] | to_pascal_case }}OpArgumentMapping(const Argumen ...@@ -102,8 +102,24 @@ KernelSignature {{api["name"] | to_pascal_case }}OpArgumentMapping(const Argumen
{% endfilter %} {% endfilter %}
{% endfor %} {% endfor %}
{{get_output_list(api["outputs"], kernel_args)}}; {{get_output_list(api["outputs"], kernel_args)}};
return KernelSignature("{{api["name"]}}", std::move(inputs), std::move(attrs), std::move(outputs)); {% if api["kernel"]["func"] | length == 1 %}
KernelSignature sig("{{api["name"]}}", std::move(inputs), std::move(attrs), std::move(outputs));
return sig;
{% else %}{# it has kernel for selected rows #}
const char* kernel_name = ctx.IsSelectedRowsInput({{kernel_args[0] | to_opmaker_name_cstr}}) ? "{{api["kernel"]["func"][1]}}" : "{{api["kernel"]["func"][0]}}";
KernelSignature sig (kernel_name, std::move(inputs), std::move(attrs), std::move(outputs));
return sig;
{%endif%}
} }
/*
******************************************************************
NOTE: The following codes are for 'get_compat_kernel_signature.py'
All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArgumentMapping:
{{api | cartesian_prod_mapping}}
******************************************************************
*/
{% endmacro %} {% endmacro %}
...@@ -151,14 +167,47 @@ paddle::small_vector<const char*> outputs { ...@@ -151,14 +167,47 @@ paddle::small_vector<const char*> outputs {
} }
{%- endmacro %} {%- endmacro %}
{% macro get_expected_kernel(api) %}
{% set kernel = api["kernel"] %}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
{%if kernel["data_type"] is not none %}{# data type ---------------------------------#}
{% if kernel["data_type"]["candidates"] | length == 1 %}
{% set data_type_arg = kernel["data_type"]["candidates"][0] %}
{% set inputs = api["inputs"] | map(attribute="name") | list %}
{% if data_type_arg in inputs %}
auto data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_arg | to_opmaker_name}});
{% else %}{# it is an attribute and probably named dtype#}
auto data_type = framework::proto::VarType::Type(ctx.Attr<int>("{{data_type_arg}}"));
{% endif %}
{% elif kernel["data_type"]["candidates"] | length == 2 %}
{% set data_type_args = kernel["data_type"]["candidates"] %}
auto data_type = framework::proto::VarType::Type(ctx.Attr<int>("{{data_type_args[0]}}");
if (data_type == static_cast<proto::VarType::Type>(-1)) {
data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_args[1] | to_opmaker_name}});
}
{% endif %}
{% endif %}
platform::Place place = ctx.GetPlace();
return framework::OpKernelType(data_type, place);
}
{% endmacro %}
{# --------------------------------------- operator ---------------------------------------------- #} {# --------------------------------------- operator ---------------------------------------------- #}
{% macro operator(api) %} {% macro operator(api) %}
class {{api["name"] | to_pascal_case}}Op : public framework::OperatorWithKernel { class {{api["name"] | to_pascal_case}}Op : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
{# ----------- get expected kernel type function -------------------------- #}
{% set kernel = api["kernel"] %}
{% if kernel["data_type"] is not none %}
protected:
{% filter indent(2, True)%}
{{get_expected_kernel(api)}}
{% endfilter %}
{% endif %}
}; };
{# infershape functor #}
DECLARE_INFER_SHAPE_FUNCTOR({{api["name"]}}, {{api["name"] | to_pascal_case}}InferShapeFunctor, DECLARE_INFER_SHAPE_FUNCTOR({{api["name"]}}, {{api["name"] | to_pascal_case}}InferShapeFunctor,
PD_INFER_META(phi::{{api["infer_meta"]["func"]}})); PD_INFER_META(phi::{{api["infer_meta"]["func"]}}));
{# inplace inferer #} {# inplace inferer #}
...@@ -189,6 +238,9 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op, ...@@ -189,6 +238,9 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
{% set backward_name = api["backward"] %} {% set backward_name = api["backward"] %}
ops::{{backward_name | to_pascal_case}}OpMaker<paddle::framework::OpDesc>, ops::{{backward_name | to_pascal_case}}OpMaker<paddle::framework::OpDesc>,
ops::{{backward_name | to_pascal_case}}OpMaker<paddle::imperative::OpBase>, ops::{{backward_name | to_pascal_case}}OpMaker<paddle::imperative::OpBase>,
{% else %}
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
{% endif %} {% endif %}
{% if api is supports_inplace %}{# inplace#} {% if api is supports_inplace %}{# inplace#}
ops::{{name | to_pascal_case}}InplaceInferer, ops::{{name | to_pascal_case}}InplaceInferer,
...@@ -219,7 +271,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -219,7 +271,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
grad_op->SetType("{{name}}"); grad_op->SetType("{{name}}");
{% for input in api["inputs"] %} {% for input in api["inputs"] %}
grad_op->SetInput("{{input["name"] | to_pascal_case}}", this->{{extract_input_from_forward( grad_op->SetInput({{input["name"] | to_opmaker_name}}, this->{{extract_input_from_forward(
input["name"], input["name"],
forward_input_names, forward_input_names,
forward_output_names, forward_output_names,
...@@ -228,7 +280,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -228,7 +280,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
{% endfor %} {% endfor %}
{% for output in api["outputs"] %} {% for output in api["outputs"] %}
grad_op->SetOutput("{{output["name"] | to_pascal_case}}", this->{{extract_output_from_forward( grad_op->SetOutput({{output["name"] | to_opmaker_name}}, this->{{extract_output_from_forward(
output["name"], output["name"],
forward_input_names, forward_input_names,
forward_output_names, forward_output_names,
...@@ -266,7 +318,7 @@ Input("{{name_in_forward_orig | to_pascal_case}}") ...@@ -266,7 +318,7 @@ Input("{{name_in_forward_orig | to_pascal_case}}")
{% set name_in_forward_orig = output_orig_names[output_names.index(name)]%} {% set name_in_forward_orig = output_orig_names[output_names.index(name)]%}
Output("{{name | to_pascal_case}}") Output("{{name | to_pascal_case}}")
{%- elif name.endswith("_grad") %}{# output grad#} {%- elif name.endswith("_grad") %}{# output grad#}
{% set name_in_forward = name.removesuffix("_grad") %} {% set name_in_forward = name[:-5] %}
{% if name_in_forward in output_names %} {% if name_in_forward in output_names %}
{% set name_in_forward_orig = output_orig_names[output_names.index(name_in_forward)] %} {% set name_in_forward_orig = output_orig_names[output_names.index(name_in_forward)] %}
OutputGrad("{{name_in_forward_orig | to_pascal_case}}") OutputGrad("{{name_in_forward_orig | to_pascal_case}}")
...@@ -276,10 +328,10 @@ OutputGrad("{{name_in_forward_orig | to_pascal_case}}") ...@@ -276,10 +328,10 @@ OutputGrad("{{name_in_forward_orig | to_pascal_case}}")
{% macro extract_output_from_forward(name, input_names, output_names, {% macro extract_output_from_forward(name, input_names, output_names,
input_orig_names, output_orig_names) %}{# inline #} input_orig_names, output_orig_names) %}{# inline #}
{% if name.removesuffix("_grad") in input_names %} {% if name[:-5] in input_names %}
{% set name_in_forward = name.removesuffix("_grad") %} {% set name_in_forward = name[:-5] %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%} {% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%}
InputGrad("{{name.removesuffix("_grad") | to_pascal_case}}") InputGrad("{{name[:-5] | to_pascal_case}}")
{%- elif (name | to_input_name) in input_names %} {%- elif (name | to_input_name) in input_names %}
{% set name_in_forward = name | to_input_name %} {% set name_in_forward = name | to_input_name %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%} {% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册