From 91a3d159fd32c4477259a959a0e025fdbdecbb35 Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Thu, 23 Feb 2023 14:53:49 +0800 Subject: [PATCH] Support 'complex promote' in yaml (#50611) * support 'complex promote' in yaml * change the compplex_promote * change 'kron' in math.py * change 'kron' comment in python * change kron comment in python * change kron comment in python --- .../fluid/operators/generator/generate_op.py | 5 + .../generator/templates/operator_utils.c.j2 | 55 ++++-- paddle/fluid/operators/kron_op.cc | 166 ------------------ paddle/phi/api/yaml/backward.yaml | 11 ++ paddle/phi/api/yaml/legacy_backward.yaml | 11 -- paddle/phi/api/yaml/legacy_ops.yaml | 9 - paddle/phi/api/yaml/op_compat.yaml | 8 + paddle/phi/api/yaml/ops.yaml | 9 + paddle/phi/ops/compat/kron_sig.cc | 26 --- python/paddle/tensor/math.py | 23 ++- 10 files changed, 92 insertions(+), 231 deletions(-) delete mode 100644 paddle/fluid/operators/kron_op.cc delete mode 100644 paddle/phi/ops/compat/kron_sig.cc diff --git a/paddle/fluid/operators/generator/generate_op.py b/paddle/fluid/operators/generator/generate_op.py index a3a85946022..4bf70d70586 100644 --- a/paddle/fluid/operators/generator/generate_op.py +++ b/paddle/fluid/operators/generator/generate_op.py @@ -293,6 +293,11 @@ def add_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict): if new_op_name != op_name: forward_op_item['op_name'] = op_name + # add complex promote infomation + if "complex_promote" in op_args: + forward_op_item["complex_promote"] = op_args["complex_promote"] + if has_backward: + backward_op_item["complex_promote"] = op_args["complex_promote"] scalar_configs = None int_array_configs = None if 'scalar' in op_args: diff --git a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 index ac260cd4f64..044f8085065 100644 --- a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 +++ b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 @@ -279,30 +279,52 @@ phi::KernelKey GetExpectedKernelType( data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_args[1] | to_opmaker_name}}); } {% endif %} +{% elif "complex_promote" in op and "forward" not in op%} + {% set inputs = op["complex_promote"]%} + auto data_type = + OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "{{inputs[0]}}", "{{inputs[1]}}"); {% endif %} return phi::KernelKey(data_type, ctx.GetPlace()); } -{% endmacro %} +{% endmacro -%} + +{% macro get_kernel_for_var(op) %} +{% set skip_args = none %} +{% if op["data_transform"] is not none%} + {% if "skip_transform" in op["data_transform"] %} + {% set skip_args = op["data_transform"]["skip_transform"] %} + {% elif "support_trans_dtype" in op["data_transform"] %} + {% set skip_args = op["data_transform"]["support_trans_dtype"] %} + {% endif %} +{% endif %} +{% set var_name = "var_name" -%} -{% macro get_kernel_for_var(op) %} {# only for data_transform #} -{% set skip_args = op["data_transform"]["skip_transform"] %} -{% set var_name = "var_name" %} -{% set skip_args_len = skip_args | length %} phi::KernelKey GetKernelTypeForVar( const std::string& {{var_name}}, const phi::DenseTensor& tensor, const phi::KernelKey& expected_kernel_type) const override { - +{%if skip_args is not none%}{# deal data_transform #} + {% set skip_args_len = skip_args | length %} if ( {%- for skip_arg in skip_args -%} var_name == "{{ skip_arg }}" {%- if skip_args_len != 1 and loop.index != skip_args_len %} || {% endif -%} {%- endfor -%} ){ + {% if "skip_transform" in op["data_transform"] %} return phi::KernelKey(phi::Backend::ALL_BACKEND, expected_kernel_type.layout(), expected_kernel_type.dtype()); + {% elif "support_trans_dtype" in op["data_transform"] %} + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); + {% endif %} + } +{% else %}{# deal complex_promote #} + if (framework::IsComplexType(expected_kernel_type.dtype())) { + // only promote inputs’s types when contains complex input + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } +{% endif %} else{ return phi::KernelKey( tensor.place(), tensor.layout(), expected_kernel_type.dtype()); @@ -317,20 +339,23 @@ class {{op["op_name"] | to_pascal_case}}Op : public framework::OperatorWithKerne using framework::OperatorWithKernel::OperatorWithKernel; {# ----------- get expected kernel type function -------------------------- #} {% set kernel = op["kernel"] %} - {% if kernel["data_type"] is not none %} + {% if kernel["data_type"] is not none or "complex_promote" in op or "data_transform" in op%} protected: - {% filter indent(2, True)%} + {% if kernel["data_type"] is not none or "complex_promote" in op %} + {% filter indent(2, True)%} {{get_expected_kernel(op)}} - {% endfilter %} - {%- if "data_transform" in op and op["data_transform"] is not none -%} - {%- if "skip_transform" in op["data_transform"] -%} - {% filter indent(2, True) %} + {% endfilter %} + {% endif %} + {% endif %} + {%- if "data_transform" in op and op["data_transform"] is not none -%} + {% filter indent(2, True) %} +{{get_kernel_for_var(op)}} + {% endfilter %} + {%- elif "complex_promote" in op and op["complex_promote"] is not none -%} + {% filter indent(2, True) %} {{get_kernel_for_var(op)}} {% endfilter %} {%- endif %} - {%- endif -%} -{# TODO(lizhiyu): add the 'support_trans_dtype' #} - {% endif %} }; DECLARE_INFER_SHAPE_FUNCTOR({{op["op_name"]}}, {{op["op_name"] | to_pascal_case}}InferShapeFunctor, diff --git a/paddle/fluid/operators/kron_op.cc b/paddle/fluid/operators/kron_op.cc deleted file mode 100644 index 6349ec65a96..00000000000 --- a/paddle/fluid/operators/kron_op.cc +++ /dev/null @@ -1,166 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/infermeta/binary.h" - -namespace paddle { -namespace operators { - -class KronOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto data_type = - OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); - return phi::KernelKey(data_type, ctx.GetPlace()); - } - - phi::KernelKey GetKernelTypeForVar( - const std::string& var_name, - const phi::DenseTensor& tensor, - const phi::KernelKey& expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.dtype())) { - // only promote inputs’s types when contains complex input - return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); - } else { - return phi::KernelKey( - tensor.place(), tensor.layout(), expected_kernel_type.dtype()); - } - } -}; - -class KronOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), the first operand of kron op"); - AddInput("Y", "(Tensor), the second operand of kron op"); - AddOutput("Out", "(Tensor), the output of kron op."); - AddComment(R"DOC( - Kron Operator. - - This operator computes the Kronecker product of two tensors, a - composite tensor made of blocks of the second tensor scaled by the - first. - - This operator assumes that the rank of the two tensors, $X$ and $Y$ - are the same, if necessary prepending the smallest with ones. If the - shape of $X$ is [$r_0$, $r_1$, ..., $r_N$] and the shape of $Y$ is - [$s_0$, $s_1$, ..., $s_N$], then the shape of the output tensor is - [$r_{0}s_{0}$, $r_{1}s_{1}$, ..., $r_{N}s_{N}$]. The elements are - products of elements from $X$ and $Y$. - - The equation is: - $$ - output[k_{0}, k_{1}, ..., k_{N}] = X[i_{0}, i_{1}, ..., i_{N}] * - Y[j_{0}, j_{1}, ..., j_{N}] - $$ - - where - $$ - k_{t} = i_{t} * s_{t} + j_{t}, t = 0, 1, ..., N - $$ - )DOC"); - } -}; - -class KronGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "kron_grad"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron_grad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - framework::GradVarName("Out"), - "kron_grad"); - - auto x_grad_name = framework::GradVarName("X"); - auto y_grad_name = framework::GradVarName("Y"); - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); - } - if (ctx->HasOutput(y_grad_name)) { - ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y")); - } - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto out_grad_name = framework::GradVarName("Out"); - return phi::KernelKey( - OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name), - ctx.GetPlace()); - } - - phi::KernelKey GetKernelTypeForVar( - const std::string& var_name, - const phi::DenseTensor& tensor, - const phi::KernelKey& expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.dtype())) { - // only promote inputs’s types when contains complex input - return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); - } else { - return phi::KernelKey( - tensor.place(), tensor.layout(), expected_kernel_type.dtype()); - } - } -}; - -template -class KronGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("kron_grad"); - - grad_op->SetInput("X", this->Input("X")); - grad_op->SetInput("Y", this->Input("Y")); - grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - - grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - grad_op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); - - grad_op->SetAttrMap(this->Attrs()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(kron, - KronInferShapeFunctor, - PD_INFER_META(phi::KronInferMeta)); -REGISTER_OPERATOR(kron, - ops::KronOp, - ops::KronOpMaker, - ops::KronGradOpMaker, - ops::KronGradOpMaker, - KronInferShapeFunctor); -REGISTER_OPERATOR(kron_grad, ops::KronGradOp); diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 9cfcfdea899..ee75d281b97 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -663,6 +663,17 @@ kernel : func : inverse_grad +- backward_op : kron_grad + forward : kron (Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] + kernel : + func : kron_grad + data_type : out_grad + - backward_op : kthvalue_grad forward : kthvalue(Tensor x, int k, int axis, bool keepdim) -> Tensor(out), Tensor(indices) args : (Tensor x, Tensor indices, Tensor out_grad, int k, int axis, bool keepdim) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index f13b07db9b2..e27c34ad3bc 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -636,17 +636,6 @@ func : kldiv_loss_grad no_need_buffer : x -- backward_op : kron_grad - forward : kron (Tensor x, Tensor y) -> Tensor(out) - args : (Tensor x, Tensor y, Tensor out_grad) - output : Tensor(x_grad), Tensor(y_grad) - infer_meta : - func : GeneralBinaryGradInferMeta - param : [x, y] - kernel : - func : kron_grad - data_type : out_grad - - backward_op : layer_norm_grad forward : layer_norm (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis) -> Tensor(out), Tensor(mean), Tensor(variance) args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, Tensor out_grad, float epsilon, int begin_norm_axis) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 2f59c858a03..1d4773764ea 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -921,15 +921,6 @@ data_type : x backward : kldiv_loss_grad -- op : kron - args : (Tensor x, Tensor y) - output : Tensor - infer_meta : - func : KronInferMeta - kernel : - func : kron - backward : kron_grad - - op : lamb_ args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, float weight_decay, float beta1, float beta2, float epsilon, bool multi_precision) output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 6a89f0e7d2c..2f2d56bfbee 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -828,6 +828,14 @@ outputs : out : Out +- op : kron + backward : kron_grad + inputs : + {x : X, y : Y} + outputs : + {out : Out} + complex_promote : [X, Y] + - op : kthvalue inputs : x : X diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index b783601e6d0..5928defccd2 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -653,6 +653,15 @@ func : isnan {dense -> dense}, isnan_sr {selected_rows -> selected_rows} +- op : kron + args : (Tensor x, Tensor y) + output : Tensor + infer_meta : + func : KronInferMeta + kernel : + func : kron + backward : kron_grad + - op : kthvalue args : (Tensor x, int k = 1, int axis = -1, bool keepdim = false) output : Tensor(out), Tensor(indices) diff --git a/paddle/phi/ops/compat/kron_sig.cc b/paddle/phi/ops/compat/kron_sig.cc deleted file mode 100644 index e2ba41dcadd..00000000000 --- a/paddle/phi/ops/compat/kron_sig.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature KronGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "kron_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(kron_grad, phi::KronGradOpArgumentMapping); diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 15e4c188244..35418ec7f48 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -3090,11 +3090,26 @@ def diagonal(x, offset=0, axis1=0, axis2=1, name=None): return out -@templatedoc(op_type="kron") def kron(x, y, name=None): - """ - - ${comment} + r""" + Compute the Kronecker product of two tensors, a + composite tensor made of blocks of the second tensor scaled by the + first. + Assume that the rank of the two tensors, $X$ and $Y$ + are the same, if necessary prepending the smallest with ones. If the + shape of $X$ is [$r_0$, $r_1$, ..., $r_N$] and the shape of $Y$ is + [$s_0$, $s_1$, ..., $s_N$], then the shape of the output tensor is + [$r_{0}s_{0}$, $r_{1}s_{1}$, ..., $r_{N}s_{N}$]. The elements are + products of elements from $X$ and $Y$. + The equation is: + $$ + output[k_{0}, k_{1}, ..., k_{N}] = X[i_{0}, i_{1}, ..., i_{N}] * + Y[j_{0}, j_{1}, ..., j_{N}] + $$ + where + $$ + k_{t} = i_{t} * s_{t} + j_{t}, t = 0, 1, ..., N + $$ Args: x (Tensor): the fist operand of kron op, data type: float16, float32, float64, int32 or int64. -- GitLab