未验证 提交 91a3d159 编写于 作者: HappyHeavyRain's avatar HappyHeavyRain 提交者: GitHub

Support 'complex promote' in yaml (#50611)

* support 'complex promote' in yaml

* change the compplex_promote

* change 'kron' in math.py

* change 'kron' comment in python

* change kron comment in python

* change kron comment in python
上级 ff4ec23a
......@@ -293,6 +293,11 @@ def add_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
if new_op_name != op_name:
forward_op_item['op_name'] = op_name
# add complex promote infomation
if "complex_promote" in op_args:
forward_op_item["complex_promote"] = op_args["complex_promote"]
if has_backward:
backward_op_item["complex_promote"] = op_args["complex_promote"]
scalar_configs = None
int_array_configs = None
if 'scalar' in op_args:
......
......@@ -279,30 +279,52 @@ phi::KernelKey GetExpectedKernelType(
data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_args[1] | to_opmaker_name}});
}
{% endif %}
{% elif "complex_promote" in op and "forward" not in op%}
{% set inputs = op["complex_promote"]%}
auto data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "{{inputs[0]}}", "{{inputs[1]}}");
{% endif %}
return phi::KernelKey(data_type, ctx.GetPlace());
}
{% endmacro %}
{% endmacro -%}
{% macro get_kernel_for_var(op) %}
{% set skip_args = none %}
{% if op["data_transform"] is not none%}
{% if "skip_transform" in op["data_transform"] %}
{% set skip_args = op["data_transform"]["skip_transform"] %}
{% elif "support_trans_dtype" in op["data_transform"] %}
{% set skip_args = op["data_transform"]["support_trans_dtype"] %}
{% endif %}
{% endif %}
{% set var_name = "var_name" -%}
{% macro get_kernel_for_var(op) %} {# only for data_transform #}
{% set skip_args = op["data_transform"]["skip_transform"] %}
{% set var_name = "var_name" %}
{% set skip_args_len = skip_args | length %}
phi::KernelKey GetKernelTypeForVar(
const std::string& {{var_name}},
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override {
{%if skip_args is not none%}{# deal data_transform #}
{% set skip_args_len = skip_args | length %}
if (
{%- for skip_arg in skip_args -%}
var_name == "{{ skip_arg }}"
{%- if skip_args_len != 1 and loop.index != skip_args_len %} || {% endif -%}
{%- endfor -%}
){
{% if "skip_transform" in op["data_transform"] %}
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
{% elif "support_trans_dtype" in op["data_transform"] %}
return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
{% endif %}
}
{% else %}{# deal complex_promote #}
if (framework::IsComplexType(expected_kernel_type.dtype())) {
// only promote inputs’s types when contains complex input
return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
}
{% endif %}
else{
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
......@@ -317,20 +339,23 @@ class {{op["op_name"] | to_pascal_case}}Op : public framework::OperatorWithKerne
using framework::OperatorWithKernel::OperatorWithKernel;
{# ----------- get expected kernel type function -------------------------- #}
{% set kernel = op["kernel"] %}
{% if kernel["data_type"] is not none %}
{% if kernel["data_type"] is not none or "complex_promote" in op or "data_transform" in op%}
protected:
{% filter indent(2, True)%}
{% if kernel["data_type"] is not none or "complex_promote" in op %}
{% filter indent(2, True)%}
{{get_expected_kernel(op)}}
{% endfilter %}
{%- if "data_transform" in op and op["data_transform"] is not none -%}
{%- if "skip_transform" in op["data_transform"] -%}
{% filter indent(2, True) %}
{% endfilter %}
{% endif %}
{% endif %}
{%- if "data_transform" in op and op["data_transform"] is not none -%}
{% filter indent(2, True) %}
{{get_kernel_for_var(op)}}
{% endfilter %}
{%- elif "complex_promote" in op and op["complex_promote"] is not none -%}
{% filter indent(2, True) %}
{{get_kernel_for_var(op)}}
{% endfilter %}
{%- endif %}
{%- endif -%}
{# TODO(lizhiyu): add the 'support_trans_dtype' #}
{% endif %}
};
DECLARE_INFER_SHAPE_FUNCTOR({{op["op_name"]}}, {{op["op_name"] | to_pascal_case}}InferShapeFunctor,
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
class KronOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
return phi::KernelKey(data_type, ctx.GetPlace());
}
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override {
if (framework::IsComplexType(expected_kernel_type.dtype())) {
// only promote inputs’s types when contains complex input
return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
} else {
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
}
};
class KronOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), the first operand of kron op");
AddInput("Y", "(Tensor), the second operand of kron op");
AddOutput("Out", "(Tensor), the output of kron op.");
AddComment(R"DOC(
Kron Operator.
This operator computes the Kronecker product of two tensors, a
composite tensor made of blocks of the second tensor scaled by the
first.
This operator assumes that the rank of the two tensors, $X$ and $Y$
are the same, if necessary prepending the smallest with ones. If the
shape of $X$ is [$r_0$, $r_1$, ..., $r_N$] and the shape of $Y$ is
[$s_0$, $s_1$, ..., $s_N$], then the shape of the output tensor is
[$r_{0}s_{0}$, $r_{1}s_{1}$, ..., $r_{N}s_{N}$]. The elements are
products of elements from $X$ and $Y$.
The equation is:
$$
output[k_{0}, k_{1}, ..., k_{N}] = X[i_{0}, i_{1}, ..., i_{N}] *
Y[j_{0}, j_{1}, ..., j_{N}]
$$
where
$$
k_{t} = i_{t} * s_{t} + j_{t}, t = 0, 1, ..., N
$$
)DOC");
}
};
class KronGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "kron_grad");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron_grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
framework::GradVarName("Out"),
"kron_grad");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y"));
}
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name),
ctx.GetPlace());
}
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override {
if (framework::IsComplexType(expected_kernel_type.dtype())) {
// only promote inputs’s types when contains complex input
return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
} else {
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
}
};
template <typename T>
class KronGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("kron_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput("Y", this->Input("Y"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
grad_op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(kron,
KronInferShapeFunctor,
PD_INFER_META(phi::KronInferMeta));
REGISTER_OPERATOR(kron,
ops::KronOp,
ops::KronOpMaker,
ops::KronGradOpMaker<paddle::framework::OpDesc>,
ops::KronGradOpMaker<paddle::imperative::OpBase>,
KronInferShapeFunctor);
REGISTER_OPERATOR(kron_grad, ops::KronGradOp);
......@@ -663,6 +663,17 @@
kernel :
func : inverse_grad
- backward_op : kron_grad
forward : kron (Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : kron_grad
data_type : out_grad
- backward_op : kthvalue_grad
forward : kthvalue(Tensor x, int k, int axis, bool keepdim) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int k, int axis, bool keepdim)
......
......@@ -636,17 +636,6 @@
func : kldiv_loss_grad
no_need_buffer : x
- backward_op : kron_grad
forward : kron (Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : kron_grad
data_type : out_grad
- backward_op : layer_norm_grad
forward : layer_norm (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis) -> Tensor(out), Tensor(mean), Tensor(variance)
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, Tensor out_grad, float epsilon, int begin_norm_axis)
......
......@@ -921,15 +921,6 @@
data_type : x
backward : kldiv_loss_grad
- op : kron
args : (Tensor x, Tensor y)
output : Tensor
infer_meta :
func : KronInferMeta
kernel :
func : kron
backward : kron_grad
- op : lamb_
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, float weight_decay, float beta1, float beta2, float epsilon, bool multi_precision)
output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs)
......
......@@ -828,6 +828,14 @@
outputs :
out : Out
- op : kron
backward : kron_grad
inputs :
{x : X, y : Y}
outputs :
{out : Out}
complex_promote : [X, Y]
- op : kthvalue
inputs :
x : X
......
......@@ -653,6 +653,15 @@
func : isnan {dense -> dense},
isnan_sr {selected_rows -> selected_rows}
- op : kron
args : (Tensor x, Tensor y)
output : Tensor
infer_meta :
func : KronInferMeta
kernel :
func : kron
backward : kron_grad
- op : kthvalue
args : (Tensor x, int k = 1, int axis = -1, bool keepdim = false)
output : Tensor(out), Tensor(indices)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature KronGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"kron_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(kron_grad, phi::KronGradOpArgumentMapping);
......@@ -3090,11 +3090,26 @@ def diagonal(x, offset=0, axis1=0, axis2=1, name=None):
return out
@templatedoc(op_type="kron")
def kron(x, y, name=None):
"""
${comment}
r"""
Compute the Kronecker product of two tensors, a
composite tensor made of blocks of the second tensor scaled by the
first.
Assume that the rank of the two tensors, $X$ and $Y$
are the same, if necessary prepending the smallest with ones. If the
shape of $X$ is [$r_0$, $r_1$, ..., $r_N$] and the shape of $Y$ is
[$s_0$, $s_1$, ..., $s_N$], then the shape of the output tensor is
[$r_{0}s_{0}$, $r_{1}s_{1}$, ..., $r_{N}s_{N}$]. The elements are
products of elements from $X$ and $Y$.
The equation is:
$$
output[k_{0}, k_{1}, ..., k_{N}] = X[i_{0}, i_{1}, ..., i_{N}] *
Y[j_{0}, j_{1}, ..., j_{N}]
$$
where
$$
k_{t} = i_{t} * s_{t} + j_{t}, t = 0, 1, ..., N
$$
Args:
x (Tensor): the fist operand of kron op, data type: float16, float32, float64, int32 or int64.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册