未验证 提交 1804f834 编写于 作者: HappyHeavyRain's avatar HappyHeavyRain 提交者: GitHub

generate the static graph code of some ops (#49212)

* generate the static op of some ops

* add the VERSION of pixel_shuffle

* change the API doc of isclose

* change the API doc of isclose

* fix the isclose op comment
上级 71bde066
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/enforce.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class FrameOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(in_dtype, ctx.GetPlace());
}
};
class FrameOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of frame op.");
AddOutput("Out", "(Tensor), The output tensor of frame op.");
AddAttr<int>(
"frame_length",
"Length of the frame and `0 < frame_length <= x.shape[axis]`.");
AddAttr<int>("hop_length",
"Number of steps to advance between adjacent frames and "
"`0 < hop_length`.");
AddAttr<int>("axis",
"Specify the axis to operate on the input Tensors. Its value "
"should be 0(the first dimension) or -1(the last dimension).")
.SetDefault(-1);
AddComment(R"DOC(
Slice the N-dimensional (where N >= 1) input into (overlapping) frames.
)DOC");
}
};
class FrameOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(in_dtype, ctx.GetPlace());
}
};
template <typename T>
class FrameOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("frame_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(frame,
FrameInferShapeFunctor,
PD_INFER_META(phi::FrameInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(frame_grad,
FrameGradInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(frame,
ops::FrameOp,
ops::FrameOpMaker,
ops::FrameOpGradMaker<paddle::framework::OpDesc>,
ops::FrameOpGradMaker<paddle::imperative::OpBase>,
FrameInferShapeFunctor);
REGISTER_OPERATOR(frame_grad, ops::FrameOpGrad, FrameGradInferShapeFunctor);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
class GatherNdOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<phi::DenseTensor>("X");
const auto& x_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(
x_type,
x_type == framework::proto::VarType::BOOL
? x->place() // to be consistent with compare and logical ops
: ctx.device_context().GetPlace());
}
};
class GatherNdGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
class GatherNdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The source input of gather_nd op");
AddInput("Index", "The index input of gather_nd op");
AddOutput("Out", "The output of gather_nd op");
AddComment(R"DOC(
Gather_Nd Operator.
This function is actually a high-dimensional extension of gather
and supports for simultaneous indexing by multiple axes. Out is
obtained by gathering slices from X into a tensor with shape
Index.shape[:-1] + X.shape[Index.shape[-1]:].
Example:
Given:
X = [[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]]
X.shape = (2, 3, 4)
*Case 1:
Index = [[1]]
we get:
Out =
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]
*Case 2:
Index = [[0,2]]
we get:
Out = [8, 9, 10, 11]
*Case 3:
Index = [[1, 2, 3]]
we get:
Out = [23]
)DOC");
}
};
template <typename T>
class GatherNdGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("gather_nd_grad");
op->SetInput("Index", this->Input("Index"));
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherNdGradNoNeedBufferVarInferer, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(gather_nd,
GatherNdInferShapeFunctor,
PD_INFER_META(phi::GatherNdInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(gather_nd_grad,
GatherNdGradInferShapeFunctor,
PD_INFER_META(phi::GatherNdGradInferMeta));
REGISTER_OPERATOR(gather_nd,
ops::GatherNdOp,
ops::GatherNdOpMaker,
ops::GatherNdGradOpMaker<paddle::framework::OpDesc>,
ops::GatherNdGradOpMaker<paddle::imperative::OpBase>,
GatherNdInferShapeFunctor);
REGISTER_OPERATOR(gather_nd_grad,
ops::GatherNdGradOp,
ops::GatherNdGradNoNeedBufferVarInferer,
GatherNdGradInferShapeFunctor);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cmath>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
class IscloseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input",
"The input tensor, it's data type should be float32, float64.");
AddInput("Other",
"The input tensor, it's data type should be float32, float64.");
AddInput("Rtol", "The relative tolerance.").AsDispensable();
AddInput("Atol", "The absolute tolerance.").AsDispensable();
AddOutput("Out", "The output tensor, it's data type is bool.");
AddAttr<std::string>("rtol",
"The relative tolerance. Default: :math:`1e-5` .")
.SetDefault("1e-5");
AddAttr<std::string>("atol",
"The absolute tolerance. Default: :math:`1e-8` .")
.SetDefault("1e-8");
AddAttr<bool>("equal_nan",
"If :math:`True` , then two :math:`NaNs` will be "
"compared as equal. Default: :math:`False` .")
.SetDefault(false);
AddComment(R"DOC(
This operator checks if all :math:`x` and :math:`y` satisfy the condition:
.. math::
\left| x - y \right| \leq atol + rtol \times \left| y \right|
elementwise, for all elements of :math:`x` and :math:`y`. The behaviour of this
operator is analogous to :math:`numpy.isclose`, namely that it returns :math:`True` if
two tensors are elementwise equal within a tolerance.
)DOC");
}
};
class IscloseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
class IscloseOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
ctx->SetOutputDataType("Out", framework::proto::VarType::BOOL);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(isclose,
IscloseInferShapeFunctor,
PD_INFER_META(phi::ValueCompareInferMeta));
REGISTER_OPERATOR(
isclose,
ops::IscloseOp,
ops::IscloseOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::IscloseOpVarTypeInference,
IscloseInferShapeFunctor);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class OverlapAddOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(in_dtype, ctx.GetPlace());
}
};
class OverlapAddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of overlap_add op.");
AddOutput("Out", "(Tensor), The output tensor of overlap_add op.");
AddAttr<int>("hop_length",
"Number of steps to advance between adjacent frames and "
"`0 < hop_length <= frame_length`.");
AddAttr<int>("axis",
"Specify the axis to operate on the input Tensors. Its value "
"should be 0(the first dimension) or -1(the last dimension).")
.SetDefault(-1);
AddComment(R"DOC(
Reconstructs a tensor consisted of overlap added sequences from input frames.
)DOC");
}
};
class OverlapAddOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(in_dtype, ctx.GetPlace());
}
};
template <typename T>
class OverlapAddOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("overlap_add_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(overlap_add,
OverlapAddInferShapeFunctor,
PD_INFER_META(phi::OverlapAddInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(overlap_add_grad,
OverlapAddGradInferShapeFunctor,
PD_INFER_META(phi::OverlapAddGradInferMeta));
REGISTER_OPERATOR(overlap_add,
ops::OverlapAddOp,
ops::OverlapAddOpMaker,
ops::OverlapAddOpGradMaker<paddle::framework::OpDesc>,
ops::OverlapAddOpGradMaker<paddle::imperative::OpBase>,
OverlapAddInferShapeFunctor);
REGISTER_OPERATOR(overlap_add_grad,
ops::OverlapAddOpGrad,
OverlapAddGradInferShapeFunctor);
/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class PixelShuffleOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), "
"the input feature data of PixelShuffleOp, the layout is [N, C, "
"H, W] or [N, H, W, C].");
AddOutput("Out",
"(Tensor, default Tensor<float>), the output of "
"PixelShuffleOp. The layout is [N, C/factor^2, H*factor, "
"W*factor] or [N, H*factor, W*factor, C/factor^2].");
AddAttr<int>("upscale_factor",
"the factor to increase spatial resolution by.")
.SetDefault(1)
.AddCustomChecker([](const int& upscale_factor) {
PADDLE_ENFORCE_GE(upscale_factor,
1,
platform::errors::InvalidArgument(
"upscale_factor should be larger than 0."));
});
AddAttr<std::string>(
"data_format",
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\", Specify the data format of the input data.")
.SetDefault("NCHW");
AddComment(R"DOC(
Pixel Shuffle operator
This operator rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
to a tensor of shape :math:`(C, H \times r, W \times r)`.
This is useful for implementing efficient sub-pixel convolution
with a stride of :math:`1/r`.
Please refer to the paper:
`Real-Time Single Image and Video Super-Resolution Using an Efficient
Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158v2>`_
by Shi et. al (2016) for more details.
)DOC");
}
};
template <typename T>
class PixelShuffleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> op) const override {
op->SetType("pixel_shuffle_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
class PixelShuffleGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(pixel_shuffle,
PixelShuffleInferShapeFunctor,
PD_INFER_META(phi::PixelShuffleInferMeta));
REGISTER_OPERATOR(pixel_shuffle,
ops::PixelShuffleOp,
ops::PixelShuffleOpMaker,
ops::PixelShuffleGradMaker<paddle::framework::OpDesc>,
ops::PixelShuffleGradMaker<paddle::imperative::OpBase>,
PixelShuffleInferShapeFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(pixel_shuffle_grad,
PixelShuffleGradInferShapeFunctor,
PD_INFER_META(phi::PixelShuffleGradInferMeta));
REGISTER_OPERATOR(pixel_shuffle_grad,
ops::PixelShuffleGradOp,
PixelShuffleGradInferShapeFunctor);
REGISTER_OP_VERSION(pixel_shuffle)
.AddCheckpoint(
R"ROC(
Compatible upgrade of pixel_shuffle, add a new attribute [data_format])ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"data_format", "Specify the data format of the input data", true));
......@@ -491,6 +491,26 @@
data_type : out_grad
no_need_buffer : x
- backward_op : frame_grad
forward : frame(Tensor x, int frame_length, int hop_length, int axis=-1) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int frame_length, int hop_length, int axis)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : frame_grad
- backward_op : gather_nd_grad
forward : gather_nd (Tensor x, Tensor index) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : GatherNdGradInferMeta
kernel :
func : gather_nd_grad
no_need_buffer : x
- backward_op : gelu_grad
forward : gelu(Tensor x, bool approximate) -> Tensor(out)
args : (Tensor x, Tensor out_grad, bool approximate)
......@@ -799,6 +819,25 @@
data_type : input
optional : weight
- backward_op : overlap_add_grad
forward : overlap_add(Tensor x, int hop_length, int axis) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int hop_length, int axis)
output : Tensor(x_grad)
infer_meta :
func : OverlapAddGradInferMeta
kernel :
func : overlap_add_grad
data_type : x
- backward_op : pixel_shuffle_grad
forward : pixel_shuffle (Tensor x, int upscale_factor=1, str data_format="NCHW") -> Tensor(out)
args : (Tensor out_grad, int upscale_factor, str data_format)
output : Tensor(x_grad)
infer_meta :
func : PixelShuffleGradInferMeta
kernel :
func : pixel_shuffle_grad
- backward_op : poisson_grad
forward : poisson (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
......
......@@ -539,16 +539,6 @@
kernel :
func : fmin_grad
- backward_op : frame_grad
forward : frame(Tensor x, int frame_length, int hop_length, int axis) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int frame_length, int hop_length, int axis)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : frame_grad
- backward_op : frobenius_norm_grad
forward : frobenius_norm(Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] axis, bool keep_dim, bool reduce_all)
......@@ -571,17 +561,6 @@
func : gather_grad
no_need_buffer : x
- backward_op : gather_nd_grad
forward : gather_nd (Tensor x, Tensor index) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : gather_nd_grad
no_need_buffer : x
- backward_op : group_norm_grad
forward : group_norm (Tensor x, Tensor scale, Tensor bias, float epsilon, int groups, str data_layout) -> Tensor(y), Tensor(mean), Tensor(variance)
args : (Tensor x, Tensor scale, Tensor bias, Tensor y, Tensor mean, Tensor variance, Tensor y_grad, float epsilon, int groups, str data_layout)
......@@ -992,16 +971,6 @@
kernel :
func : norm_grad
- backward_op : overlap_add_grad
forward : overlap_add(Tensor x, int hop_length, int axis) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int hop_length, int axis)
output : Tensor(x_grad)
infer_meta :
func : OverlapAddGradInferMeta
kernel :
func : overlap_add_grad
data_type : x
- backward_op : p_norm_grad
forward : p_norm(Tensor x, float porder, int axis, float epsilon, bool keepdim, bool asvector=false) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, float porder, int axis, float epsilon, bool keepdim, bool asvector)
......@@ -1055,15 +1024,6 @@
no_need_buffer : x
backward : pad_double_grad
- backward_op : pixel_shuffle_grad
forward : pixel_shuffle (Tensor x, int upscale_factor, str data_format) -> Tensor(out)
args : (Tensor out_grad, int upscale_factor, str data_format)
output : Tensor(x_grad)
infer_meta :
func : PixelShuffleGradInferMeta
kernel :
func : pixel_shuffle_grad
- backward_op : pool2d_double_grad
forward : pool2d_grad(Tensor x, Tensor out, Tensor grad_out, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_x_grad, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm)
......
......@@ -743,15 +743,6 @@
func : fmin
backward : fmin_grad
- op : frame
args : (Tensor x, int frame_length, int hop_length, int axis)
output : Tensor(out)
infer_meta :
func : FrameInferMeta
kernel :
func : frame
backward : frame_grad
- op : frobenius_norm
args : (Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all)
output : Tensor(out)
......@@ -822,16 +813,6 @@
data_type: x
backward : gather_grad
- op : gather_nd
args : (Tensor x, Tensor index)
output : Tensor
infer_meta :
func : GatherNdInferMeta
kernel :
func : gather_nd
data_type : x
backward : gather_nd_grad
- op : gaussian
args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={})
output: Tensor(out)
......@@ -960,15 +941,6 @@
intermediate : saved_mean, saved_variance
backward : instance_norm_grad
- op : isclose
args : (Tensor x, Tensor y, Scalar rtol, Scalar atol, bool equal_nan)
output : Tensor(out)
infer_meta :
func : ValueCompareInferMeta
param: [x, y]
kernel :
func : isclose
- op : kldiv_loss
args : (Tensor x, Tensor label, str reduction)
output : Tensor(out)
......@@ -1414,15 +1386,6 @@
output : Tensor(out)
invoke : full_like(x, 1, dtype, place)
- op : overlap_add
args: (Tensor x, int hop_length, int axis)
output: Tensor
infer_meta:
func: OverlapAddInferMeta
kernel:
func: overlap_add
backward: overlap_add_grad
- op : p_norm
args : (Tensor x, float porder, int axis, float epsilon, bool keepdim, bool asvector=false)
output : Tensor(out)
......@@ -1450,15 +1413,6 @@
func : pad3d
backward : pad3d_grad
- op : pixel_shuffle
args : (Tensor x, int upscale_factor, str data_format)
output : Tensor
infer_meta :
func : PixelShuffleInferMeta
kernel :
func : pixel_shuffle
backward : pixel_shuffle_grad
- op : pool2d
args : (Tensor x, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm)
output : Tensor(out)
......
......@@ -547,6 +547,13 @@
outputs :
out : Y
- op : frame
backward : frame_grad
inputs :
x : X
outputs :
out : Out
- op : frobenius_norm
backward : frobenius_norm_grad
extra :
......@@ -561,6 +568,13 @@
extra :
attrs : [bool overwrite = true]
- op : gather_nd
backward : gather_nd_grad
inputs :
{x : X, index : Index}
outputs :
out : Out
- op : gather_tree
inputs :
{ids : Ids, parents : Parents}
......@@ -663,6 +677,19 @@
outputs :
out : Out
- op : isclose
inputs :
{x : Input, y : Other}
outputs :
out : Out
scalar :
rtol :
data_type : std::string
tensor_name : Rtol
atol :
data_type : std::string
tensor_name : Atol
- op : isfinite (isfinite_v2)
inputs :
x : X
......@@ -899,6 +926,13 @@
outputs :
{out : Out, total_weight : Total_weight}
- op : overlap_add
backward : overlap_add_grad
inputs :
x : X
outputs :
out : Out
- op : pad2d
backward : pad2d_grad
extra :
......@@ -914,6 +948,13 @@
extra :
attrs : [bool use_mkldnn = false]
- op : pixel_shuffle
backward : pixel_shuffle_grad
inputs :
x : X
outputs :
out : Out
- op : poisson
inputs :
x : X
......
......@@ -16,6 +16,15 @@
comment : In order to specify interpolation mode
default : std::string("bilinear")
- op : pixel_shuffle
version :
- checkpoint : Compatible upgrade of pixel_shuffle, add a new attribute [data_format]
action :
- add_attr :
name : data_format
comment : Specify the data format of the input data
default : "true"
- op : roll
version :
- checkpoint : Upgrade roll add 1 attribute [axis], delete 1 attribute[dims].
......
......@@ -431,6 +431,25 @@
func: fold
backward: fold_grad
- op : frame
args : (Tensor x, int frame_length, int hop_length, int axis=-1)
output : Tensor(out)
infer_meta :
func : FrameInferMeta
kernel :
func : frame
backward : frame_grad
- op : gather_nd
args : (Tensor x, Tensor index)
output : Tensor
infer_meta :
func : GatherNdInferMeta
kernel :
func : gather_nd
data_type : x
backward : gather_nd_grad
- op : gather_tree
args : (Tensor ids, Tensor parents)
output : Tensor(out)
......@@ -535,6 +554,16 @@
kernel :
func : is_empty
- op : isclose
args : (Tensor x, Tensor y, Scalar rtol="1e-5", Scalar atol="1e-8", bool equal_nan=false)
output : Tensor(out)
infer_meta :
func : ValueCompareInferMeta
param: [x, y]
kernel :
func : isclose
data_type : x
- op : isfinite
args : (Tensor x)
output : Tensor(out)
......@@ -761,6 +790,25 @@
kernel :
func : npu_identity
- op : overlap_add
args: (Tensor x, int hop_length, int axis=-1)
output: Tensor
infer_meta:
func: OverlapAddInferMeta
kernel:
func: overlap_add
data_type : x
backward: overlap_add_grad
- op : pixel_shuffle
args : (Tensor x, int upscale_factor=1, str data_format="NCHW")
output : Tensor
infer_meta :
func : PixelShuffleInferMeta
kernel :
func : pixel_shuffle
backward : pixel_shuffle_grad
- op : poisson
args : (Tensor x)
output : Tensor
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature FrameGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("frame_grad",
{"X", "Out@GRAD"},
{"frame_length", "hop_length", "axis"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(frame_grad, phi::FrameGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature GatherNdGradArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"gather_nd_grad", {"X", "Index", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(gather_nd_grad, phi::GatherNdGradArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature IscloseOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Rtol")) {
if (ctx.HasInput("Atol")) {
return KernelSignature("isclose",
{"Input", "Other"},
{"Rtol", "Atol", "equal_nan"},
{"Out"});
} else {
return KernelSignature("isclose",
{"Input", "Other"},
{"Rtol", "atol", "equal_nan"},
{"Out"});
}
} else {
if (ctx.HasInput("Atol")) {
return KernelSignature("isclose",
{"Input", "Other"},
{"rtol", "Atol", "equal_nan"},
{"Out"});
} else {
return KernelSignature("isclose",
{"Input", "Other"},
{"rtol", "atol", "equal_nan"},
{"Out"});
}
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(isclose, phi::IscloseOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature OverlapAddGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("overlap_add_grad",
{"X", "Out@GRAD"},
{"hop_length", "axis"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(overlap_add_grad,
phi::OverlapAddGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature PixelShuffleOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"pixel_shuffle", {"X"}, {"upscale_factor", "data_format"}, {"Out"});
}
KernelSignature PixelShuffleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("pixel_shuffle_grad",
{"Out@GRAD"},
{"upscale_factor", "data_format"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(pixel_shuffle, phi::PixelShuffleOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pixel_shuffle_grad,
phi::PixelShuffleGradOpArgumentMapping);
......@@ -936,20 +936,28 @@ def bitwise_not(x, out=None, name=None):
@templatedoc()
def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
"""
${comment}
r"""
Checks if all :math:`x` and :math:`y` satisfy the condition:
.. math::
\left| x - y \right| \leq atol + rtol \times \left| y \right|
elementwise, for all elements of :math:`x` and :math:`y`. The behaviour of this
operator is analogous to :math:`numpy.isclose`, namely that it returns :math:`True` if
two tensors are elementwise equal within a tolerance.
Args:
x(Tensor): ${input_comment}.
y(Tensor): ${other_comment}.
x(Tensor): The input tensor, it's data type should be float32, float64.
y(Tensor): The input tensor, it's data type should be float32, float64.
rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` .
atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` .
equal_nan(equalnantype, optional): ${equal_nan_comment}.
equal_nan(equalnantype, optional): If :math:`True` , then two :math:`NaNs` will be compared as equal. Default: :math:`False` .
name (str, optional): Name for the operation. For more information, please
refer to :ref:`api_guide_Name`. Default: None.
Returns:
Tensor: ${out_comment}.
Tensor: The output tensor, it's data type is bool.
Examples:
.. code-block:: python
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册