未验证 提交 d5387de2 编写于 作者: HappyHeavyRain's avatar HappyHeavyRain 提交者: GitHub

Generate static graph code for lerp by yaml (#48322)

* generate static graph code for lerp by yaml, test=develop

* modify the op_compat.yaml of lerp, test=develop

* generate static graph code for lerp by yaml, test=develop

* modify the op_compat.yaml of lerp, test=develop

* remove the 'attrs' of lerp, test=develop
Signed-off-by: HappyHeavyRain's avatarlizhiyu02 <1528794076@qq.com>
Signed-off-by: HappyHeavyRain's avatarlizhiyu02 <1528794076@qq.com>
上级 ed33b860
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle {
namespace operators {
class LerpOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class LerpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of lerp op.");
AddInput("Y", "(Tensor), The input tensor of lerp op.");
AddInput("Weight", "(Tensor, optional), The input tensor of lerp op.");
AddOutput("Out", "(Tensor), The output tensor of lerp op.");
AddComment(R"DOC(
Lerp Operator.
This operator is used to do a linear interpolation of input $X$ and $Y$ with $Weight$.
The equation is:
$$Out = X + Weight * (Y - X)$$
Both the input $X$ and $Y$ can carry the LoD (Level of Details) information,
or not. But the output only shares the LoD information with input $X$.
)DOC");
}
};
class LerpGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
if (ctx->HasOutput(framework::GradVarName("Y"))) {
ctx->SetOutputDim(framework::GradVarName("Y"), ctx->GetInputDim("Y"));
}
}
};
template <typename T>
class LerpOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> op) const override {
op->SetType("lerp_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput("Weight", this->Input("Weight"));
op->SetInput("Out", this->Output("Out"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
}
};
DECLARE_INPLACE_OP_INFERER(LerpInplaceInferer, {"X", "Out"});
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(lerp,
LerpInferShapeFunctor,
PD_INFER_META(phi::LerpInferMeta));
REGISTER_OPERATOR(
lerp,
paddle::operators::LerpOp,
paddle::operators::LerpOpMaker,
paddle::operators::LerpOpGradMaker<paddle::framework::OpDesc>,
paddle::operators::LerpOpGradMaker<paddle::imperative::OpBase>,
paddle::operators::LerpInplaceInferer,
LerpInferShapeFunctor);
REGISTER_OPERATOR(lerp_grad, paddle::operators::LerpGradOp);
...@@ -588,6 +588,16 @@ ...@@ -588,6 +588,16 @@
backward : leaky_relu_double_grad backward : leaky_relu_double_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : lerp_grad
forward : lerp (Tensor x, Tensor y, Tensor weight) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor weight, Tensor out, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : lerp_grad
- backward_op : lgamma_grad - backward_op : lgamma_grad
forward : lgamma(Tensor x) -> Tensor(out) forward : lgamma(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad) args : (Tensor x, Tensor out_grad)
......
...@@ -749,16 +749,6 @@ ...@@ -749,16 +749,6 @@
no_need_buffer : bias no_need_buffer : bias
optional : scale, bias optional : scale, bias
- backward_op : lerp_grad
forward : lerp (Tensor x, Tensor y, Tensor weight) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor weight, Tensor out, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : lerp_grad
- backward_op : linear_interp_grad - backward_op : linear_interp_grad
forward : linear_interp (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, str data_layout, int out_d, int out_h, int out_w, float[] scale, str interp_method, bool align_corners, int align_mode) -> Tensor(output) forward : linear_interp (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, str data_layout, int out_d, int out_h, int out_w, float[] scale, str interp_method, bool align_corners, int align_mode) -> Tensor(output)
args : (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, Tensor output_grad, str data_layout, int out_d, int out_h, int out_w, float[] scale, str interp_method, bool align_corners, int align_mode) args : (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, Tensor output_grad, str data_layout, int out_d, int out_h, int out_w, float[] scale, str interp_method, bool align_corners, int align_mode)
......
...@@ -1039,16 +1039,6 @@ ...@@ -1039,16 +1039,6 @@
backward : layer_norm_grad backward : layer_norm_grad
optional : scale, bias optional : scale, bias
- op : lerp
args : (Tensor x, Tensor y, Tensor weight)
output : Tensor(out)
infer_meta :
func : LerpInferMeta
kernel :
func : lerp
inplace : (x -> out)
backward : lerp_grad
- op : less_equal - op : less_equal
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor(out) output : Tensor(out)
......
...@@ -680,6 +680,13 @@ ...@@ -680,6 +680,13 @@
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
- op : lerp
backward : lerp_grad
inputs :
{x : X, y : Y, weight : Weight}
outputs :
out : Out
- op : lgamma - op : lgamma
inputs : inputs :
x : X x : X
......
...@@ -563,6 +563,16 @@ ...@@ -563,6 +563,16 @@
func : leaky_relu func : leaky_relu
backward : leaky_relu_grad backward : leaky_relu_grad
- op : lerp
args : (Tensor x, Tensor y, Tensor weight)
output : Tensor(out)
infer_meta :
func : LerpInferMeta
kernel :
func : lerp
inplace : (x -> out)
backward : lerp_grad
- op : lgamma - op : lgamma
args : (Tensor x) args : (Tensor x)
output : Tensor(out) output : Tensor(out)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature LerpOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("lerp", {"X", "Y", "Weight"}, {}, {"Out"});
}
KernelSignature LerpGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("lerp_grad",
{"X", "Y", "Weight", "Out", "Out@GRAD"},
{},
{"X@GRAD", "Y@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(lerp, phi::LerpOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(lerp_grad, phi::LerpGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册