From af96d1e8b015b17e4223b5ee66a1ed2750951941 Mon Sep 17 00:00:00 2001 From: lzydev Date: Mon, 3 Jul 2023 15:33:39 +0800 Subject: [PATCH] Support auto-gen concat (#54970) * support auto-gen concat * fix bug in legacy_backward.yaml * fix bug in get_expeceted_kernel_type --- .../generator/codegen_utils.py | 32 +-- .../generator/eager_gen.py | 40 +++- paddle/fluid/operators/concat_op.cc | 224 ------------------ paddle/fluid/operators/concat_op.h | 47 ---- .../fluid/operators/generator/generate_op.py | 25 +- .../generator/get_expected_kernel_func.cc | 21 ++ .../generator/get_expected_kernel_func.h | 4 + .../fluid/operators/generator/parse_utils.py | 2 +- paddle/phi/api/yaml/backward.yaml | 20 ++ paddle/phi/api/yaml/legacy_backward.yaml | 19 -- paddle/phi/api/yaml/legacy_ops.yaml | 10 - paddle/phi/api/yaml/op_compat.yaml | 5 +- paddle/phi/api/yaml/ops.yaml | 11 + paddle/phi/ops/compat/concat_sig.cc | 38 --- test/legacy_test/test_attribute_var.py | 1 - 15 files changed, 128 insertions(+), 371 deletions(-) delete mode 100644 paddle/fluid/operators/concat_op.cc delete mode 100644 paddle/fluid/operators/concat_op.h delete mode 100644 paddle/phi/ops/compat/concat_sig.cc diff --git a/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py index 0ec006555e4..fd19005cec3 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py @@ -130,18 +130,21 @@ def ReadFwdFile(filepath): return contents if contents is not None else [] -def ReadBwdFile(filepath): +def ReadBwdFile(filepath, bw_ops=None): f = open(filepath, 'r') - contents = yaml.load(f, Loader=yaml.FullLoader) - # not all fused ops supoort dygraph - if filepath.endswith("fused_backward.yaml") is True: - new_apis = [ - api - for api in contents - if "support_dygraph_mode" in api - and api["support_dygraph_mode"] is True - ] - contents = new_apis + if bw_ops is None: + contents = yaml.load(f, Loader=yaml.FullLoader) + # not all fused ops supoort dygraph + if filepath.endswith("fused_backward.yaml") is True: + new_apis = [ + api + for api in contents + if "support_dygraph_mode" in api + and api["support_dygraph_mode"] is True + ] + contents = new_apis + else: + contents = bw_ops ret = {} if contents is not None: @@ -595,15 +598,16 @@ class FunctionGeneratorBase: class GeneratorBase: - def __init__(self, api_yaml_path): + def __init__(self, api_yaml_path, fw_ops=None): self.namespace = "" self.api_yaml_path = api_yaml_path - self.forward_api_list = [] + self.forward_api_list = fw_ops def ParseForwardYamlContents(self): api_yaml_path = self.api_yaml_path - self.forward_api_list = ReadFwdFile(api_yaml_path) + if self.forward_api_list is None: + self.forward_api_list = ReadFwdFile(api_yaml_path) def InferNameSpace(self): api_yaml_path = self.api_yaml_path diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index b3d8db0e100..a90f73c8209 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -16,6 +16,7 @@ import argparse import os import re +import yaml from codegen_utils import ( AssertMessage, FindForwardName, @@ -2552,14 +2553,17 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): class DygraphForwardAndNodesGenerator(GeneratorBase): - def __init__(self, api_yaml_path, backward_yaml_path): + def __init__( + self, api_yaml_path, backward_yaml_path, fw_ops=None, bw_ops=None + ): # Parent members: # self.namespace # self.api_yaml_path # self.forward_api_list - GeneratorBase.__init__(self, api_yaml_path) + GeneratorBase.__init__(self, api_yaml_path, fw_ops) self.backward_yaml_path = backward_yaml_path + self.bw_ops = bw_ops self.grad_api_dict = {} self.forward_declaration_str = "" @@ -2580,7 +2584,7 @@ class DygraphForwardAndNodesGenerator(GeneratorBase): # string api is forward_only, no backward_yaml respectively if backward_yaml_path is not None: - self.grad_api_dict = ReadBwdFile(backward_yaml_path) + self.grad_api_dict = ReadBwdFile(backward_yaml_path, self.bw_ops) def GetBackwardAPIContents(self, forward_api_contents): grad_api_dict = self.grad_api_dict @@ -2747,6 +2751,23 @@ if __name__ == "__main__": forward_declaration_str = "" forward_definition_str = "" + # merge legacy_ops.yaml and ops.yaml, legacy_backward.yaml and backward.yaml + all_ops = [] + all_bw = [] + for api_yaml_path in api_yaml_paths: + if api_yaml_path.endswith("legacy_ops.yaml") or api_yaml_path.endswith( + "/ops.yaml" + ): + with open(api_yaml_path, 'r') as f: + all_ops += yaml.safe_load(f) + + for bw_yaml_path in backward_yaml_paths: + if bw_yaml_path.endswith( + "legacy_backward.yaml" + ) or bw_yaml_path.endswith("/backward.yaml"): + with open(bw_yaml_path, 'r') as f: + all_bw += yaml.safe_load(f) + for i in range(len(api_yaml_paths)): api_yaml_path = api_yaml_paths[i] @@ -2756,9 +2777,16 @@ if __name__ == "__main__": else: backward_yaml_path = None - generator = DygraphForwardAndNodesGenerator( - api_yaml_path, backward_yaml_path - ) + if api_yaml_path.endswith('/legacy_ops.yaml'): + continue + if api_yaml_path.endswith('/ops.yaml'): + generator = DygraphForwardAndNodesGenerator( + api_yaml_path, backward_yaml_path, all_ops, all_bw + ) + else: + generator = DygraphForwardAndNodesGenerator( + api_yaml_path, backward_yaml_path + ) generator.run() node_declaration_str += generator.node_declaration_str + "\n" diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc deleted file mode 100644 index ec3f0c76df2..00000000000 --- a/paddle/fluid/operators/concat_op.cc +++ /dev/null @@ -1,224 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/concat_op.h" - -#include - -#include -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" -#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" -#include "paddle/fluid/prim/utils/static/desc_tensor.h" -#include "paddle/phi/infermeta/multiary.h" -#include "paddle/phi/kernels/funcs/concat_funcs.h" - -namespace paddle { -namespace operators { - -class ConcatOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - auto inputs = ctx.MultiInput("X"); - auto input_data_type = framework::proto::VarType::Type(0); - bool flag = 0; - for (auto *input : inputs) { - if (input->IsInitialized()) { - input_data_type = framework::TransToProtoVarType(input->dtype()); - flag = 1; - break; - } - } - if (flag == 0) { - PADDLE_THROW(platform::errors::InvalidArgument( - "All Inputs of Concat OP are Empty!")); - } - return phi::KernelKey(input_data_type, ctx.GetPlace()); - } - - phi::KernelKey GetKernelTypeForVar( - const std::string &var_name, - const phi::DenseTensor &tensor, - const phi::KernelKey &expected_kernel_type) const override { - if (var_name == "AxisTensor") { - return phi::KernelKey(phi::Backend::ALL_BACKEND, - expected_kernel_type.layout(), - expected_kernel_type.dtype()); - } - return phi::KernelKey( - tensor.place(), tensor.layout(), expected_kernel_type.dtype()); - } -}; - -class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "Input tensors of concat operator.").AsDuplicable(); - AddOutput("Out", "Output tensor of concat operator."); - AddAttr("axis", - "The axis along which the input tensors will be concatenated." - "The axis could also be negative numbers. Negative axis is " - "interpreted as counting from the end of the rank." - "i.e., axis + rank(X) th dimension.") - .SetDefault(0) - .SupportTensor(); - AddInput("AxisTensor", - "(Tensor) The axis along which the input tensors will be " - "concatenated. " - "It has higher priority than Attr(axis). " - "The shape of AxisTensor must be [1].") - .AsDispensable(); - AddComment(R"DOC( -Concat Operator. - -Concatenate the input tensors along dimension axis. -Examples: - Input[0] = [[1,2],[3,4]] - Input[1] = [[5,6]] - axis = 0 - Output = [[1,2], - [3,4], - [5,6]] - -)DOC"); - } -}; - -class ConcatOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - auto in_x = "X"; - auto out_x_g_n = framework::GradVarName(in_x); - ctx->SetOutputsDim(out_x_g_n, ctx->GetInputsDim(in_x)); - - ctx->ShareAllLoD(in_x, out_x_g_n); - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - auto input_data_type = OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")); - return phi::KernelKey(input_data_type, ctx.GetPlace()); - } - - phi::KernelKey GetKernelTypeForVar( - const std::string &var_name, - const phi::DenseTensor &tensor, - const phi::KernelKey &expected_kernel_type) const override { - if (var_name == "AxisTensor") { - return phi::KernelKey(phi::Backend::ALL_BACKEND, - expected_kernel_type.layout(), - expected_kernel_type.dtype()); - } - return phi::KernelKey( - tensor.place(), tensor.layout(), expected_kernel_type.dtype()); - } -}; - -DECLARE_NO_NEED_BUFFER_VARS_INFERER(ConcatOpGradNoNeedBufferVarInferer, "X"); - -template -class ConcatGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("concat_grad"); - op->SetInput("X", this->Input("X")); - if (this->HasInput("AxisTensor")) { - op->SetInput("AxisTensor", this->Input("AxisTensor")); - } - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false)); - op->SetAttrMap(this->Attrs()); - } -}; - -class ConcatCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { - using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; - - public: - void Apply() override { - std::vector input = this->GetMultiForwardInput("X"); - paddle::optional tensor_axis = - this->GetOptionalSingleForwardInput("AxisTensor"); - paddle::Tensor out_grad = this->GetSingleOutputGrad("Out"); - std::vector input_grad = this->GetMultiInputGrad("X"); - - std::vector input_grad_ptr; - for (auto i = 0; i < static_cast(input_grad.size()); ++i) { - input_grad_ptr.push_back(&input_grad[i]); - } - int axis = static_cast(this->Attr("axis")); - std::vector dx_ptr = this->GetOutputPtr(input_grad_ptr); - std::vector dx_name = this->GetOutputName(input_grad); - - VLOG(6) << "Runing concat_grad composite func"; - if (tensor_axis.is_initialized()) { - PADDLE_THROW(platform::errors::Unimplemented( - "We don't support dynamic index from tensor for concat composite " - "grad for now. ")); - } else { - prim::concat_grad(input, out_grad, axis, dx_ptr); - } - this->RecoverOutputName(input_grad, dx_name); - } -}; - -template -class ConcatDoubleGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("concat"); - grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X"))); - grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out"))); - grad_op->SetAttrMap(this->Attrs()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(concat, - ConcatInferShapeFunctor, - PD_INFER_META(phi::ConcatInferMeta)); - -REGISTER_OPERATOR(concat, - ops::ConcatOp, - ops::ConcatOpMaker, - ops::ConcatGradOpMaker, - ops::ConcatGradOpMaker, - ops::ConcatCompositeGradOpMaker, - ConcatInferShapeFunctor); -REGISTER_OPERATOR(concat_grad, - ops::ConcatOpGrad, - ops::ConcatDoubleGradOpMaker, - ops::ConcatDoubleGradOpMaker, - ops::ConcatOpGradNoNeedBufferVarInferer); diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h deleted file mode 100644 index 58d978ea9c7..00000000000 --- a/paddle/fluid/operators/concat_op.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/concat_and_split.h" -#include "paddle/fluid/operators/utils.h" -#include "paddle/phi/kernels/concat_kernel.h" -#include "paddle/phi/kernels/funcs/concat_funcs.h" -#include "paddle/phi/kernels/funcs/strided_memcpy.h" - -namespace paddle { -namespace operators { - -static inline int64_t ComputeAxis(int64_t axis, int64_t rank) { - PADDLE_ENFORCE_EQ( - axis >= -rank && axis < rank, - true, - platform::errors::InvalidArgument( - "The axis is expected to be in range of [%d, %d), but got %d", - -rank, - rank, - axis)); - if (axis < 0) { - axis = axis + rank; - } - return axis > 0 ? axis : 0; -} - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/generator/generate_op.py b/paddle/fluid/operators/generator/generate_op.py index fa544f38763..777a9d19ba9 100644 --- a/paddle/fluid/operators/generator/generate_op.py +++ b/paddle/fluid/operators/generator/generate_op.py @@ -479,17 +479,22 @@ def parse_drop_empty_grad(op_fluid_list: list, bw_op_dict: dict): bw_name.split('(')[0].strip() for bw_name in op_comp_map['backward'].split(',') ] - for bw_name in bw_names: - # static_ops.yaml and ops.yaml use the common op_compat.yaml - if bw_name in bw_op_dict: + new_bw_names = [ + bw_name for bw_name in bw_names if bw_name in bw_op_dict + ] + if len(new_bw_names) != 0: + bws_has_out_grad = False + for bw_name in bw_names: + # static_ops.yaml and ops.yaml use the common op_compat.yaml for out_grad in op_comp_map['drop_empty_grad']: - assert ( - out_grad in bw_op_dict[bw_name]['output_dict'] - ), f''' - {bw_name} with {out_grad} is not existed in output_dict ''' - bw_op_dict[bw_name]['output_dict'][out_grad][ - 'drop_empty_grad' - ] = False + if out_grad in bw_op_dict[bw_name]['output_dict']: + bw_op_dict[bw_name]['output_dict'][out_grad][ + 'drop_empty_grad' + ] = False + bws_has_out_grad = True + assert ( + bws_has_out_grad + ), f'''{bw_names} with {op_comp_map['drop_empty_grad']} is not existed in output_dict ''' def parse_get_expected_kerneltype( diff --git a/paddle/fluid/operators/generator/get_expected_kernel_func.cc b/paddle/fluid/operators/generator/get_expected_kernel_func.cc index 155493b632e..aef9434ef0a 100644 --- a/paddle/fluid/operators/generator/get_expected_kernel_func.cc +++ b/paddle/fluid/operators/generator/get_expected_kernel_func.cc @@ -85,6 +85,27 @@ phi::KernelKey GetCheckFiniteAndUnscaleExpectedKernelType( return phi::KernelKey(dtype, ctx.GetPlace()); } +phi::KernelKey GetConcatExpectedKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel* op_ptr) { + (void)op_ptr; + auto inputs = ctx.MultiInput("X"); + auto input_data_type = framework::proto::VarType::Type(0); + bool flag = 0; + for (auto* input : inputs) { + if (input->IsInitialized()) { + input_data_type = framework::TransToProtoVarType(input->dtype()); + flag = 1; + break; + } + } + if (flag == 0) { + PADDLE_THROW(platform::errors::InvalidArgument( + "All Inputs of Concat OP are Empty!")); + } + return phi::KernelKey(input_data_type, ctx.GetPlace()); +} + phi::KernelKey GetReduceExpectedKernelType( const framework::ExecutionContext& ctx, const framework::OperatorWithKernel* op_ptr) { diff --git a/paddle/fluid/operators/generator/get_expected_kernel_func.h b/paddle/fluid/operators/generator/get_expected_kernel_func.h index c1f288c61ca..a8a50413d8f 100644 --- a/paddle/fluid/operators/generator/get_expected_kernel_func.h +++ b/paddle/fluid/operators/generator/get_expected_kernel_func.h @@ -20,6 +20,10 @@ limitations under the License. */ namespace paddle { namespace operators { +phi::KernelKey GetConcatExpectedKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel* op_ptr); + phi::KernelKey GetReduceExpectedKernelType( const framework::ExecutionContext& ctx, const framework::OperatorWithKernel* op_ptr); diff --git a/paddle/fluid/operators/generator/parse_utils.py b/paddle/fluid/operators/generator/parse_utils.py index 8f803fa4e70..ed061f9967a 100644 --- a/paddle/fluid/operators/generator/parse_utils.py +++ b/paddle/fluid/operators/generator/parse_utils.py @@ -289,7 +289,7 @@ def parse_invoke(op_name: str, invoke_config: str) -> Dict[str, Any]: invoke_config = invoke_config.strip() func, rest = invoke_config.split("(", 1) func = func.strip() - args = rest.rstrip(")").strip() + args = rest[:-1].strip() # deal the last ')' invocation = {"func": func, "args": args} return invocation diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 46969cda6be..781ac69ee2e 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -318,6 +318,26 @@ func : complex_grad data_type : real +- backward_op : concat_double_grad + forward : concat_grad (Tensor[] x, Tensor grad_out, Scalar axis=0) -> Tensor[](grad_x) + args : (Tensor[] grad_x_grad, Scalar axis = 0) + output : Tensor(grad_out_grad) + invoke : concat(grad_x_grad, axis) + +- backward_op : concat_grad + forward : concat (Tensor[] x, Scalar axis=0) -> Tensor(out) + args : (Tensor[] x, Tensor out_grad, Scalar axis = 0) + output : Tensor[](x_grad){x.size()} + infer_meta : + func : UnchangedMultiInferMeta + param : [x] + kernel : + func : concat_grad + data_type : out_grad + composite : concat_grad(x, out_grad, axis, x_grad) + no_need_buffer : x + backward : concat_double_grad + - backward_op : conj_grad forward : conj (Tensor x) -> Tensor(out) args : (Tensor out_grad) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index b3f50900c36..ff50cda0d47 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -119,25 +119,6 @@ kernel : func : channel_shuffle_grad -- backward_op : concat_double_grad - forward : concat_grad (Tensor[] x, Tensor grad_out, Scalar axis) -> Tensor[](grad_x) - args : (Tensor[] grad_x_grad, Scalar axis = 0) - output : Tensor(grad_out_grad) - invoke : concat(grad_x_grad, axis) - -- backward_op : concat_grad - forward : concat (Tensor[] x, Scalar axis) -> Tensor(out) - args : (Tensor[] x, Tensor out_grad, Scalar axis = 0) - output : Tensor[](x_grad){x.size()} - infer_meta : - func : UnchangedMultiInferMeta - param : [x] - kernel : - func : concat_grad - composite : concat_grad(x, out_grad, axis, x_grad) - no_need_buffer : x - backward : concat_double_grad - - backward_op : conv2d_transpose_double_grad forward : conv2d_transpose_grad(Tensor x, Tensor filter, Tensor grad_out, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(grad_x), Tensor(grad_filter) args : (Tensor x, Tensor filter, Tensor grad_out, Tensor grad_x_grad, Tensor grad_filter_grad, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 7d51086456a..bc898e84635 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -143,16 +143,6 @@ func : channel_shuffle backward : channel_shuffle_grad -- op : concat - args : (Tensor[] x, Scalar(int64_t) axis) - output : Tensor - infer_meta : - func : ConcatInferMeta - param : [x, axis] - kernel : - func : concat - backward : concat_grad - - op : conv2d_transpose args : (Tensor x, Tensor filter, int[] strides={1, 1}, int[] paddings={0, 0}, int[] output_padding={}, IntArray output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW") output : Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 7e1059ea02f..1424453c841 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -509,7 +509,7 @@ out : Out - op : concat - backward : concat_grad + backward : concat_grad, concat_double_grad inputs: x: X outputs: @@ -520,8 +520,11 @@ axis : data_type : int tensor_name : AxisTensor + drop_empty_grad : [x_grad] extra : attrs : [bool use_mkldnn = false, bool use_quantizer = false, str mkldnn_data_type = "float32"] + get_expected_kernel_type : + concat : GetConcatExpectedKernelType - op : conditional_block backward : conditional_block_grad diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 2c5d7bc2565..09501b51924 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -474,6 +474,17 @@ data_type : real backward : complex_grad +- op : concat + args : (Tensor[] x, Scalar axis=0) + output : Tensor + infer_meta : + func : ConcatInferMeta + param : [x, axis] + kernel : + func : concat + data_type : x + backward : concat_grad + - op : conj args : (Tensor x) output : Tensor (out) diff --git a/paddle/phi/ops/compat/concat_sig.cc b/paddle/phi/ops/compat/concat_sig.cc deleted file mode 100644 index c5e6a4b0200..00000000000 --- a/paddle/phi/ops/compat/concat_sig.cc +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature ConcatOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.HasInput("AxisTensor")) { - return KernelSignature("concat", {"X"}, {"AxisTensor"}, {"Out"}); - } - return KernelSignature("concat", {"X"}, {"axis"}, {"Out"}); -} - -KernelSignature ConcatGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.HasInput("AxisTensor")) { - return KernelSignature( - "concat_grad", {"X", "Out@GRAD"}, {"AxisTensor"}, {"X@GRAD"}); - } - return KernelSignature( - "concat_grad", {"X", "Out@GRAD"}, {"axis"}, {"X@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(concat, phi::ConcatOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(concat_grad, phi::ConcatGradOpArgumentMapping); diff --git a/test/legacy_test/test_attribute_var.py b/test/legacy_test/test_attribute_var.py index 082b1970a46..b74141a10eb 100644 --- a/test/legacy_test/test_attribute_var.py +++ b/test/legacy_test/test_attribute_var.py @@ -173,7 +173,6 @@ class TestRegiterSupportTensorInOpMaker(unittest.TestCase): self.support_tensor_attrs = { 'dropout': ['dropout_prob'], 'tile': ['repeat_times'], - 'concat': ['axis'], } # Just add a op example to test not support tensor self.not_support_tensor_attrs = {'svd': ['full_matrices']} -- GitLab