未验证 提交 192eb4d5 编写于 作者: Z zyfncg 提交者: GitHub

support generate static graph code for imag and real op (#49523)

上级 017af746
...@@ -187,7 +187,20 @@ def parse_kernel(op_name: str, kernel_config: Dict[str, Any]) -> Dict[str, Any]: ...@@ -187,7 +187,20 @@ def parse_kernel(op_name: str, kernel_config: Dict[str, Any]) -> Dict[str, Any]:
kernel['layout'] = parse_candidates(kernel_config["layout"]) kernel['layout'] = parse_candidates(kernel_config["layout"])
if 'data_type' in kernel_config: if 'data_type' in kernel_config:
kernel['data_type'] = parse_candidates(kernel_config["data_type"]) data_type_item = parse_candidates(kernel_config["data_type"])
params_num = len(data_type_item['candidates'])
data_type_item['to_complex_flag'] = [False] * params_num
for i in range(params_num):
complex_match_result = re.match(
r"complex\((?P<param_name>\w+)\)",
data_type_item['candidates'][i],
)
if complex_match_result:
data_type_item['candidates'][i] = complex_match_result.group(
'param_name'
)
data_type_item['to_complex_flag'][i] = True
kernel['data_type'] = data_type_item
kernel_funcs = re.compile(r'([a-zA-Z0-9_]+)\s*({[^}]+})?').findall( kernel_funcs = re.compile(r'([a-zA-Z0-9_]+)\s*({[^}]+})?').findall(
kernel_config['func'] kernel_config['func']
......
...@@ -262,6 +262,9 @@ phi::KernelKey GetExpectedKernelType( ...@@ -262,6 +262,9 @@ phi::KernelKey GetExpectedKernelType(
{% set inputs = op["inputs"] | map(attribute="name") | list %} {% set inputs = op["inputs"] | map(attribute="name") | list %}
{% if data_type_arg in inputs %} {% if data_type_arg in inputs %}
auto data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_arg | to_opmaker_name}}); auto data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_arg | to_opmaker_name}});
{% if kernel["data_type"]["to_complex_flag"][0] %}
data_type = framework::ToComplexType(data_type);
{% endif %}
{% else %}{# it is an attribute and probably named dtype#} {% else %}{# it is an attribute and probably named dtype#}
auto data_type = framework::proto::VarType::Type(ctx.Attr<int>("{{data_type_arg}}")); auto data_type = framework::proto::VarType::Type(ctx.Attr<int>("{{data_type_arg}}"));
{% endif %} {% endif %}
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class ImagOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class ImagOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of imag op.");
AddOutput("Out", "(Tensor), The output tensor of imag op.");
AddComment(R"DOC(
Imag Operator.
This operator is used to get a new tensor containing imaginary values
from a tensor with complex data type.
)DOC");
}
};
class ImagGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@Grad",
"ImagGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")),
"Output",
"X@Grad",
"ImagGrad");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), dout_dims);
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
auto complex_dtype = framework::ToComplexType(dtype);
return phi::KernelKey(complex_dtype, ctx.GetPlace());
}
};
template <typename T>
class ImagGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("imag_grad");
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
DECLARE_INPLACE_OP_INFERER(ImagOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ImagGradOpInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(imag,
ImagInferShapeFunctor,
PD_INFER_META(phi::RealAndImagInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(imag,
ops::ImagOp,
ops::ImagOpMaker,
ops::ImagGradOpMaker<paddle::framework::OpDesc>,
ops::ImagGradOpMaker<paddle::imperative::OpBase>,
ImagInferShapeFunctor);
REGISTER_OPERATOR(imag_grad, ops::ImagGradOp);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class RealOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class RealOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of real op.");
AddOutput("Out", "(Tensor), The output tensor of real op.");
AddComment(R"DOC(
Real Operator.
This operator is used to get a new tensor containing real values
from a tensor with complex data type.
)DOC");
}
};
class RealGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@Grad",
"RealGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")),
"Output",
"X@Grad",
"RealGrad");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), dout_dims);
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
auto complex_dtype = framework::ToComplexType(dtype);
return phi::KernelKey(complex_dtype, ctx.GetPlace());
}
};
template <typename T>
class RealGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("real_grad");
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
DECLARE_INPLACE_OP_INFERER(RealOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(RealGradOpInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(real,
RealInferShapeFunctor,
PD_INFER_META(phi::RealAndImagInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(real,
ops::RealOp,
ops::RealOpMaker,
ops::RealGradOpMaker<::paddle::framework::OpDesc>,
ops::RealGradOpMaker<::paddle::imperative::OpBase>,
RealInferShapeFunctor);
REGISTER_OPERATOR(real_grad, ops::RealGradOp);
...@@ -141,32 +141,6 @@ Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) { ...@@ -141,32 +141,6 @@ Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) {
////////////////// Backward(grad) api impls ////////////////////// ////////////////// Backward(grad) api impls //////////////////////
void imag_grad_impl(const Tensor& out_grad, Tensor* x_grad) {
phi::KernelKey kernel_key{ParseBackend(out_grad),
out_grad.layout(),
phi::dtype::ToComplex(out_grad.dtype())};
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"imag_grad", kernel_key);
const auto& kernel = kernel_result.kernel;
VLOG(6) << "imag_grad API kernel key: " << kernel_key;
VLOG(6) << "imag_grad API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto dense_out_grad = TensorToDenseTensor(out_grad);
auto kernel_out = SetKernelOutput(x_grad);
phi::MetaTensor meta_out(kernel_out);
phi::RealAndImagGradInferMeta(*dense_out_grad, &meta_out);
using kernel_signature = void (*)(
const phi::DeviceContext&, const phi::DenseTensor&, phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, *dense_out_grad, kernel_out);
}
void embedding_grad_impl(const Tensor& x, void embedding_grad_impl(const Tensor& x,
const Tensor& weight, const Tensor& weight,
const Tensor& out_grad, const Tensor& out_grad,
...@@ -290,31 +264,5 @@ void embedding_grad_impl(const Tensor& x, ...@@ -290,31 +264,5 @@ void embedding_grad_impl(const Tensor& x,
} }
} }
void real_grad_impl(const Tensor& out_grad, Tensor* x_grad) {
phi::KernelKey kernel_key{ParseBackend(out_grad),
out_grad.layout(),
phi::dtype::ToComplex(out_grad.dtype())};
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"real_grad", kernel_key);
const auto& kernel = kernel_result.kernel;
VLOG(6) << "real_grad API kernel key: " << kernel_key;
VLOG(6) << "real_grad API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto dense_out_grad = TensorToDenseTensor(out_grad);
auto kernel_out = SetKernelOutput(x_grad);
phi::MetaTensor meta_out(kernel_out);
phi::RealAndImagGradInferMeta(*dense_out_grad, &meta_out);
using kernel_signature = void (*)(
const phi::DeviceContext&, const phi::DenseTensor&, phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, *dense_out_grad, kernel_out);
}
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -563,6 +563,16 @@ ...@@ -563,6 +563,16 @@
func : hard_sigmoid_grad func : hard_sigmoid_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : imag_grad
forward : imag (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : RealAndImagGradInferMeta
kernel :
func : imag_grad
data_type : complex(out_grad)
- backward_op : index_sample_grad - backward_op : index_sample_grad
forward : index_sample (Tensor x, Tensor index) -> Tensor(out) forward : index_sample (Tensor x, Tensor index) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad) args : (Tensor x, Tensor index, Tensor out_grad)
...@@ -868,6 +878,16 @@ ...@@ -868,6 +878,16 @@
kernel : kernel :
func : qr_grad func : qr_grad
- backward_op : real_grad
forward : real (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : RealAndImagGradInferMeta
kernel :
func : real_grad
data_type : complex(out_grad)
- backward_op : reciprocal_grad - backward_op : reciprocal_grad
forward : reciprocal (Tensor x) -> Tensor(out) forward : reciprocal (Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad) args : (Tensor out, Tensor out_grad)
......
...@@ -486,6 +486,17 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -486,6 +486,17 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
) )
if kernel['data_type'] is not None: if kernel['data_type'] is not None:
def process_data_type_args(args_item):
args_item = args_item.strip()
complex_match_result = re.match(
r"complex\((?P<param_name>\w+)\)", args_item
)
if complex_match_result:
return f"phi::dtype::ToComplex(ParseDataType({complex_match_result.group('param_name')}))"
else:
return f"ParseDataType({args_item})"
if '>' in kernel['data_type']: if '>' in kernel['data_type']:
vars_list = kernel['data_type'].split('>') vars_list = kernel['data_type'].split('>')
assert ( assert (
...@@ -511,7 +522,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -511,7 +522,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
kernel_select_code = ( kernel_select_code = (
kernel_select_code kernel_select_code
+ f""" + f"""
kernel_data_type = ParseDataType({vars_list[0].strip()}); kernel_data_type = {process_data_type_args(vars_list[0])};
""" """
) )
......
...@@ -335,6 +335,7 @@ def source_include(header_file_path): ...@@ -335,6 +335,7 @@ def source_include(header_file_path):
#include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/infermeta/binary.h" #include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h" #include "paddle/phi/infermeta/multiary.h"
......
...@@ -280,6 +280,7 @@ def source_include(header_file_path): ...@@ -280,6 +280,7 @@ def source_include(header_file_path):
#include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/api/include/api.h" #include "paddle/phi/api/include/api.h"
#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/backward.h"
......
...@@ -617,12 +617,6 @@ ...@@ -617,12 +617,6 @@
kernel : kernel :
func : huber_loss_grad func : huber_loss_grad
- backward_op : imag_grad
forward : imag (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
invoke : imag_grad_impl(out_grad, x_grad)
- backward_op : index_add_grad - backward_op : index_add_grad
forward : index_add(Tensor x, Tensor index, Tensor add_value, int axis) -> Tensor(out) forward : index_add(Tensor x, Tensor index, Tensor add_value, int axis) -> Tensor(out)
args : (Tensor index, Tensor add_value, Tensor out_grad, int axis) args : (Tensor index, Tensor add_value, Tensor out_grad, int axis)
...@@ -1125,12 +1119,6 @@ ...@@ -1125,12 +1119,6 @@
data_type : x data_type : x
optional : boxes_num optional : boxes_num
- backward_op : real_grad
forward : real (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
invoke : real_grad_impl(out_grad, x_grad)
- backward_op : relu6_grad - backward_op : relu6_grad
forward : relu6 (Tensor x) -> Tensor(out) forward : relu6 (Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad, float threshold = 6) args : (Tensor out, Tensor out_grad, float threshold = 6)
......
...@@ -900,15 +900,6 @@ ...@@ -900,15 +900,6 @@
func : huber_loss func : huber_loss
backward : huber_loss_grad backward : huber_loss_grad
- op : imag
args : (Tensor x)
output : Tensor
infer_meta :
func : RealAndImagInferMeta
kernel :
func : imag
backward : imag_grad
- op : increment - op : increment
args : (Tensor x, float value = 1.0) args : (Tensor x, float value = 1.0)
output : Tensor(out) output : Tensor(out)
...@@ -1507,15 +1498,6 @@ ...@@ -1507,15 +1498,6 @@
data_type : dtype data_type : dtype
backend : place backend : place
- op : real
args : (Tensor x)
output : Tensor
infer_meta :
func : RealAndImagInferMeta
kernel :
func : real
backward : real_grad
- op : relu6 - op : relu6
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor
......
...@@ -646,6 +646,13 @@ ...@@ -646,6 +646,13 @@
outputs : outputs :
out : Out out : Out
- op : imag
backward : imag_grad
inputs :
x : X
outputs :
out : Out
- op : index_sample - op : index_sample
inputs : inputs :
{x : X, index : Index} {x : X, index : Index}
...@@ -997,6 +1004,13 @@ ...@@ -997,6 +1004,13 @@
extra : extra :
attrs : [float moving_rate = 0.9] attrs : [float moving_rate = 0.9]
- op : real
backward : real_grad
inputs :
x : X
outputs :
out : Out
- op : reciprocal - op : reciprocal
backward : reciprocal_grad backward : reciprocal_grad
inputs : inputs :
......
...@@ -517,6 +517,15 @@ ...@@ -517,6 +517,15 @@
kernel : kernel :
func : histogram func : histogram
- op : imag
args : (Tensor x)
output : Tensor (out)
infer_meta :
func : RealAndImagInferMeta
kernel :
func : imag
backward : imag_grad
- op : index_sample - op : index_sample
args : (Tensor x, Tensor index) args : (Tensor x, Tensor index)
output : Tensor output : Tensor
...@@ -839,6 +848,15 @@ ...@@ -839,6 +848,15 @@
func : qr func : qr
backward : qr_grad backward : qr_grad
- op : real
args : (Tensor x)
output : Tensor (out)
infer_meta :
func : RealAndImagInferMeta
kernel :
func : real
backward : real_grad
- op : reciprocal - op : reciprocal
args : (Tensor x) args : (Tensor x)
output : Tensor(out) output : Tensor(out)
......
...@@ -23,14 +23,18 @@ PD_REGISTER_KERNEL(real_grad, ...@@ -23,14 +23,18 @@ PD_REGISTER_KERNEL(real_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::RealGradKernel, phi::RealGradKernel,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
PD_REGISTER_KERNEL(imag_grad, PD_REGISTER_KERNEL(imag_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::ImagGradKernel, phi::ImagGradKernel,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
complex_grad, CPU, ALL_LAYOUT, phi::ComplexGradKernel, float, double) { complex_grad, CPU, ALL_LAYOUT, phi::ComplexGradKernel, float, double) {
......
...@@ -23,14 +23,18 @@ PD_REGISTER_KERNEL(imag_grad, ...@@ -23,14 +23,18 @@ PD_REGISTER_KERNEL(imag_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::ImagGradKernel, phi::ImagGradKernel,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
PD_REGISTER_KERNEL(real_grad, PD_REGISTER_KERNEL(real_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::RealGradKernel, phi::RealGradKernel,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
complex_grad, GPU, ALL_LAYOUT, phi::ComplexGradKernel, float, double) { complex_grad, GPU, ALL_LAYOUT, phi::ComplexGradKernel, float, double) {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature RealGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("real_grad", {"Out@GRAD"}, {}, {"X@GRAD"});
}
KernelSignature ImagGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("imag_grad", {"Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(real_grad, phi::RealGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(imag_grad, phi::ImagGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册