未验证 提交 af96d1e8 编写于 作者: L lzydev 提交者: GitHub

Support auto-gen concat (#54970)

* support auto-gen concat

* fix bug in legacy_backward.yaml

* fix bug in get_expeceted_kernel_type
上级 802613cc
......@@ -130,18 +130,21 @@ def ReadFwdFile(filepath):
return contents if contents is not None else []
def ReadBwdFile(filepath):
def ReadBwdFile(filepath, bw_ops=None):
f = open(filepath, 'r')
contents = yaml.load(f, Loader=yaml.FullLoader)
# not all fused ops supoort dygraph
if filepath.endswith("fused_backward.yaml") is True:
new_apis = [
api
for api in contents
if "support_dygraph_mode" in api
and api["support_dygraph_mode"] is True
]
contents = new_apis
if bw_ops is None:
contents = yaml.load(f, Loader=yaml.FullLoader)
# not all fused ops supoort dygraph
if filepath.endswith("fused_backward.yaml") is True:
new_apis = [
api
for api in contents
if "support_dygraph_mode" in api
and api["support_dygraph_mode"] is True
]
contents = new_apis
else:
contents = bw_ops
ret = {}
if contents is not None:
......@@ -595,15 +598,16 @@ class FunctionGeneratorBase:
class GeneratorBase:
def __init__(self, api_yaml_path):
def __init__(self, api_yaml_path, fw_ops=None):
self.namespace = ""
self.api_yaml_path = api_yaml_path
self.forward_api_list = []
self.forward_api_list = fw_ops
def ParseForwardYamlContents(self):
api_yaml_path = self.api_yaml_path
self.forward_api_list = ReadFwdFile(api_yaml_path)
if self.forward_api_list is None:
self.forward_api_list = ReadFwdFile(api_yaml_path)
def InferNameSpace(self):
api_yaml_path = self.api_yaml_path
......
......@@ -16,6 +16,7 @@ import argparse
import os
import re
import yaml
from codegen_utils import (
AssertMessage,
FindForwardName,
......@@ -2552,14 +2553,17 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
class DygraphForwardAndNodesGenerator(GeneratorBase):
def __init__(self, api_yaml_path, backward_yaml_path):
def __init__(
self, api_yaml_path, backward_yaml_path, fw_ops=None, bw_ops=None
):
# Parent members:
# self.namespace
# self.api_yaml_path
# self.forward_api_list
GeneratorBase.__init__(self, api_yaml_path)
GeneratorBase.__init__(self, api_yaml_path, fw_ops)
self.backward_yaml_path = backward_yaml_path
self.bw_ops = bw_ops
self.grad_api_dict = {}
self.forward_declaration_str = ""
......@@ -2580,7 +2584,7 @@ class DygraphForwardAndNodesGenerator(GeneratorBase):
# string api is forward_only, no backward_yaml respectively
if backward_yaml_path is not None:
self.grad_api_dict = ReadBwdFile(backward_yaml_path)
self.grad_api_dict = ReadBwdFile(backward_yaml_path, self.bw_ops)
def GetBackwardAPIContents(self, forward_api_contents):
grad_api_dict = self.grad_api_dict
......@@ -2747,6 +2751,23 @@ if __name__ == "__main__":
forward_declaration_str = ""
forward_definition_str = ""
# merge legacy_ops.yaml and ops.yaml, legacy_backward.yaml and backward.yaml
all_ops = []
all_bw = []
for api_yaml_path in api_yaml_paths:
if api_yaml_path.endswith("legacy_ops.yaml") or api_yaml_path.endswith(
"/ops.yaml"
):
with open(api_yaml_path, 'r') as f:
all_ops += yaml.safe_load(f)
for bw_yaml_path in backward_yaml_paths:
if bw_yaml_path.endswith(
"legacy_backward.yaml"
) or bw_yaml_path.endswith("/backward.yaml"):
with open(bw_yaml_path, 'r') as f:
all_bw += yaml.safe_load(f)
for i in range(len(api_yaml_paths)):
api_yaml_path = api_yaml_paths[i]
......@@ -2756,9 +2777,16 @@ if __name__ == "__main__":
else:
backward_yaml_path = None
generator = DygraphForwardAndNodesGenerator(
api_yaml_path, backward_yaml_path
)
if api_yaml_path.endswith('/legacy_ops.yaml'):
continue
if api_yaml_path.endswith('/ops.yaml'):
generator = DygraphForwardAndNodesGenerator(
api_yaml_path, backward_yaml_path, all_ops, all_bw
)
else:
generator = DygraphForwardAndNodesGenerator(
api_yaml_path, backward_yaml_path
)
generator.run()
node_declaration_str += generator.node_declaration_str + "\n"
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/concat_op.h"
#include <paddle/fluid/platform/complex.h>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h"
namespace paddle {
namespace operators {
class ConcatOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto inputs = ctx.MultiInput<phi::DenseTensor>("X");
auto input_data_type = framework::proto::VarType::Type(0);
bool flag = 0;
for (auto *input : inputs) {
if (input->IsInitialized()) {
input_data_type = framework::TransToProtoVarType(input->dtype());
flag = 1;
break;
}
}
if (flag == 0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"All Inputs of Concat OP are Empty!"));
}
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
phi::KernelKey GetKernelTypeForVar(
const std::string &var_name,
const phi::DenseTensor &tensor,
const phi::KernelKey &expected_kernel_type) const override {
if (var_name == "AxisTensor") {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
};
class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input tensors of concat operator.").AsDuplicable();
AddOutput("Out", "Output tensor of concat operator.");
AddAttr<int>("axis",
"The axis along which the input tensors will be concatenated."
"The axis could also be negative numbers. Negative axis is "
"interpreted as counting from the end of the rank."
"i.e., axis + rank(X) th dimension.")
.SetDefault(0)
.SupportTensor();
AddInput("AxisTensor",
"(Tensor) The axis along which the input tensors will be "
"concatenated. "
"It has higher priority than Attr(axis). "
"The shape of AxisTensor must be [1].")
.AsDispensable();
AddComment(R"DOC(
Concat Operator.
Concatenate the input tensors along dimension axis.
Examples:
Input[0] = [[1,2],[3,4]]
Input[1] = [[5,6]]
axis = 0
Output = [[1,2],
[3,4],
[5,6]]
)DOC");
}
};
class ConcatOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
auto in_x = "X";
auto out_x_g_n = framework::GradVarName(in_x);
ctx->SetOutputsDim(out_x_g_n, ctx->GetInputsDim(in_x));
ctx->ShareAllLoD(in_x, out_x_g_n);
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
phi::KernelKey GetKernelTypeForVar(
const std::string &var_name,
const phi::DenseTensor &tensor,
const phi::KernelKey &expected_kernel_type) const override {
if (var_name == "AxisTensor") {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ConcatOpGradNoNeedBufferVarInferer, "X");
template <typename T>
class ConcatGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("concat_grad");
op->SetInput("X", this->Input("X"));
if (this->HasInput("AxisTensor")) {
op->SetInput("AxisTensor", this->Input("AxisTensor"));
}
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false));
op->SetAttrMap(this->Attrs());
}
};
class ConcatCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
std::vector<paddle::Tensor> input = this->GetMultiForwardInput("X");
paddle::optional<paddle::Tensor> tensor_axis =
this->GetOptionalSingleForwardInput("AxisTensor");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
std::vector<paddle::Tensor> input_grad = this->GetMultiInputGrad("X");
std::vector<paddle::Tensor *> input_grad_ptr;
for (auto i = 0; i < static_cast<int>(input_grad.size()); ++i) {
input_grad_ptr.push_back(&input_grad[i]);
}
int axis = static_cast<int>(this->Attr<int>("axis"));
std::vector<paddle::Tensor *> dx_ptr = this->GetOutputPtr(input_grad_ptr);
std::vector<std::string> dx_name = this->GetOutputName(input_grad);
VLOG(6) << "Runing concat_grad composite func";
if (tensor_axis.is_initialized()) {
PADDLE_THROW(platform::errors::Unimplemented(
"We don't support dynamic index from tensor for concat composite "
"grad for now. "));
} else {
prim::concat_grad<prim::DescTensor>(input, out_grad, axis, dx_ptr);
}
this->RecoverOutputName(input_grad, dx_name);
}
};
template <typename T>
class ConcatDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("concat");
grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
grad_op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(concat,
ConcatInferShapeFunctor,
PD_INFER_META(phi::ConcatInferMeta));
REGISTER_OPERATOR(concat,
ops::ConcatOp,
ops::ConcatOpMaker,
ops::ConcatGradOpMaker<paddle::framework::OpDesc>,
ops::ConcatGradOpMaker<paddle::imperative::OpBase>,
ops::ConcatCompositeGradOpMaker,
ConcatInferShapeFunctor);
REGISTER_OPERATOR(concat_grad,
ops::ConcatOpGrad,
ops::ConcatDoubleGradOpMaker<paddle::framework::OpDesc>,
ops::ConcatDoubleGradOpMaker<paddle::imperative::OpBase>,
ops::ConcatOpGradNoNeedBufferVarInferer);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/phi/kernels/concat_kernel.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h"
#include "paddle/phi/kernels/funcs/strided_memcpy.h"
namespace paddle {
namespace operators {
static inline int64_t ComputeAxis(int64_t axis, int64_t rank) {
PADDLE_ENFORCE_EQ(
axis >= -rank && axis < rank,
true,
platform::errors::InvalidArgument(
"The axis is expected to be in range of [%d, %d), but got %d",
-rank,
rank,
axis));
if (axis < 0) {
axis = axis + rank;
}
return axis > 0 ? axis : 0;
}
} // namespace operators
} // namespace paddle
......@@ -479,17 +479,22 @@ def parse_drop_empty_grad(op_fluid_list: list, bw_op_dict: dict):
bw_name.split('(')[0].strip()
for bw_name in op_comp_map['backward'].split(',')
]
for bw_name in bw_names:
# static_ops.yaml and ops.yaml use the common op_compat.yaml
if bw_name in bw_op_dict:
new_bw_names = [
bw_name for bw_name in bw_names if bw_name in bw_op_dict
]
if len(new_bw_names) != 0:
bws_has_out_grad = False
for bw_name in bw_names:
# static_ops.yaml and ops.yaml use the common op_compat.yaml
for out_grad in op_comp_map['drop_empty_grad']:
assert (
out_grad in bw_op_dict[bw_name]['output_dict']
), f'''
{bw_name} with {out_grad} is not existed in output_dict '''
bw_op_dict[bw_name]['output_dict'][out_grad][
'drop_empty_grad'
] = False
if out_grad in bw_op_dict[bw_name]['output_dict']:
bw_op_dict[bw_name]['output_dict'][out_grad][
'drop_empty_grad'
] = False
bws_has_out_grad = True
assert (
bws_has_out_grad
), f'''{bw_names} with {op_comp_map['drop_empty_grad']} is not existed in output_dict '''
def parse_get_expected_kerneltype(
......
......@@ -85,6 +85,27 @@ phi::KernelKey GetCheckFiniteAndUnscaleExpectedKernelType(
return phi::KernelKey(dtype, ctx.GetPlace());
}
phi::KernelKey GetConcatExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) {
(void)op_ptr;
auto inputs = ctx.MultiInput<phi::DenseTensor>("X");
auto input_data_type = framework::proto::VarType::Type(0);
bool flag = 0;
for (auto* input : inputs) {
if (input->IsInitialized()) {
input_data_type = framework::TransToProtoVarType(input->dtype());
flag = 1;
break;
}
}
if (flag == 0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"All Inputs of Concat OP are Empty!"));
}
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
phi::KernelKey GetReduceExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) {
......
......@@ -20,6 +20,10 @@ limitations under the License. */
namespace paddle {
namespace operators {
phi::KernelKey GetConcatExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);
phi::KernelKey GetReduceExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);
......
......@@ -289,7 +289,7 @@ def parse_invoke(op_name: str, invoke_config: str) -> Dict[str, Any]:
invoke_config = invoke_config.strip()
func, rest = invoke_config.split("(", 1)
func = func.strip()
args = rest.rstrip(")").strip()
args = rest[:-1].strip() # deal the last ')'
invocation = {"func": func, "args": args}
return invocation
......
......@@ -318,6 +318,26 @@
func : complex_grad
data_type : real
- backward_op : concat_double_grad
forward : concat_grad (Tensor[] x, Tensor grad_out, Scalar axis=0) -> Tensor[](grad_x)
args : (Tensor[] grad_x_grad, Scalar axis = 0)
output : Tensor(grad_out_grad)
invoke : concat(grad_x_grad, axis)
- backward_op : concat_grad
forward : concat (Tensor[] x, Scalar axis=0) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad, Scalar axis = 0)
output : Tensor[](x_grad){x.size()}
infer_meta :
func : UnchangedMultiInferMeta
param : [x]
kernel :
func : concat_grad
data_type : out_grad
composite : concat_grad(x, out_grad, axis, x_grad)
no_need_buffer : x
backward : concat_double_grad
- backward_op : conj_grad
forward : conj (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
......
......@@ -119,25 +119,6 @@
kernel :
func : channel_shuffle_grad
- backward_op : concat_double_grad
forward : concat_grad (Tensor[] x, Tensor grad_out, Scalar axis) -> Tensor[](grad_x)
args : (Tensor[] grad_x_grad, Scalar axis = 0)
output : Tensor(grad_out_grad)
invoke : concat(grad_x_grad, axis)
- backward_op : concat_grad
forward : concat (Tensor[] x, Scalar axis) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad, Scalar axis = 0)
output : Tensor[](x_grad){x.size()}
infer_meta :
func : UnchangedMultiInferMeta
param : [x]
kernel :
func : concat_grad
composite : concat_grad(x, out_grad, axis, x_grad)
no_need_buffer : x
backward : concat_double_grad
- backward_op : conv2d_transpose_double_grad
forward : conv2d_transpose_grad(Tensor x, Tensor filter, Tensor grad_out, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(grad_x), Tensor(grad_filter)
args : (Tensor x, Tensor filter, Tensor grad_out, Tensor grad_x_grad, Tensor grad_filter_grad, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
......
......@@ -143,16 +143,6 @@
func : channel_shuffle
backward : channel_shuffle_grad
- op : concat
args : (Tensor[] x, Scalar(int64_t) axis)
output : Tensor
infer_meta :
func : ConcatInferMeta
param : [x, axis]
kernel :
func : concat
backward : concat_grad
- op : conv2d_transpose
args : (Tensor x, Tensor filter, int[] strides={1, 1}, int[] paddings={0, 0}, int[] output_padding={}, IntArray output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW")
output : Tensor(out)
......
......@@ -509,7 +509,7 @@
out : Out
- op : concat
backward : concat_grad
backward : concat_grad, concat_double_grad
inputs:
x: X
outputs:
......@@ -520,8 +520,11 @@
axis :
data_type : int
tensor_name : AxisTensor
drop_empty_grad : [x_grad]
extra :
attrs : [bool use_mkldnn = false, bool use_quantizer = false, str mkldnn_data_type = "float32"]
get_expected_kernel_type :
concat : GetConcatExpectedKernelType
- op : conditional_block
backward : conditional_block_grad
......
......@@ -474,6 +474,17 @@
data_type : real
backward : complex_grad
- op : concat
args : (Tensor[] x, Scalar axis=0)
output : Tensor
infer_meta :
func : ConcatInferMeta
param : [x, axis]
kernel :
func : concat
data_type : x
backward : concat_grad
- op : conj
args : (Tensor x)
output : Tensor (out)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature ConcatOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("AxisTensor")) {
return KernelSignature("concat", {"X"}, {"AxisTensor"}, {"Out"});
}
return KernelSignature("concat", {"X"}, {"axis"}, {"Out"});
}
KernelSignature ConcatGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("AxisTensor")) {
return KernelSignature(
"concat_grad", {"X", "Out@GRAD"}, {"AxisTensor"}, {"X@GRAD"});
}
return KernelSignature(
"concat_grad", {"X", "Out@GRAD"}, {"axis"}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(concat, phi::ConcatOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(concat_grad, phi::ConcatGradOpArgumentMapping);
......@@ -173,7 +173,6 @@ class TestRegiterSupportTensorInOpMaker(unittest.TestCase):
self.support_tensor_attrs = {
'dropout': ['dropout_prob'],
'tile': ['repeat_times'],
'concat': ['axis'],
}
# Just add a op example to test not support tensor
self.not_support_tensor_attrs = {'svd': ['full_matrices']}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册