未验证 提交 7cc0d171 编写于 作者: Z zyfncg 提交者: GitHub

generate static graph code for some op (#48036)

上级 992b30ba
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class ConjOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class ConjOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of conj op.");
AddOutput("Out", "(Tensor), The output tensor of conj op.");
AddComment(R"DOC(
Conj Operator.
This operator is used to perform elementwise conjugate for input $X$.
)DOC");
}
};
template <typename T>
class ConjGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("conj");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput("Out", this->InputGrad("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(conj,
ConjInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(conj,
ops::ConjOp,
ops::ConjOpMaker,
ops::ConjGradMaker<paddle::framework::OpDesc>,
ops::ConjGradMaker<paddle::imperative::OpBase>,
ConjInferShapeFunctor);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
class GridSampleOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.GetPlace());
}
};
class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) The input data of GridSampleOp, "
"This is a 4-D tensor with shape of [N, C, H, W] or"
" a 5-D tensot with shape of [N, C, D, H, W]");
AddInput(
"Grid",
"(Tensor) The input grid of GridSampleOp generated by AffineGridOp, "
"This is a 4-D tensor with shape of [N, H, W, 2] is the concatenation "
"of x and y coordinates with shape [N, H, W] in last dimension or "
"a 5-D tensor with shape of [N, D, H, W, 3] is the concatenation "
"of depth, x and y coordinates with shape [N, D, H, W] in last "
"dimension ");
AddOutput("Output",
"(Tensor) Output tensor with shape [N, C, H, W] or shape [N,C, "
"D, H ,W]");
AddAttr<bool>(
"align_corners",
"(bool, default true) If align_corners is true, it will project"
"-1 and 1 to the centers of the corner pixels. Otherwise, it will "
"project"
"-1 and 1 to the image edges.")
.SetDefault(true);
AddAttr<std::string>(
"mode",
"(bool, default true) The interpolation method which can be 'bilinear'"
" or 'nearest'.")
.SetDefault("bilinear");
AddAttr<std::string>(
"padding_mode",
"(bool, default true) The padding method used when source"
"index is out of input images. It can be 'zeros', 'reflection' and "
"'border'.")
.SetDefault("zeros");
AddComment(R"DOC(
This operation samples input X by using bilinear or nearest interpolation based on
flow field grid, which is usually generated by affine_grid. The grid of
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
(in width dimension) of input data x and grid_y is indexing the 3rd
dimension (in height dimension), finally results is the bilinear
interpolation value or nearest value of 4 nearest corner points.
For bilinear interpolation mode:
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
Step 2:
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
interpolate point value by 4 nearest points.
wn ------- y_n ------- en
| | |
| d_n |
| | |
x_w --d_w-- grid--d_e-- x_e
| | |
| d_s |
| | |
ws ------- y_s ------- wn
x_w = floor(x) // west side x coord
x_e = x_w + 1 // east side x coord
y_n = floor(y) // north side y coord
y_s = y_s + 1 // south side y coord
d_w = grid_x - x_w // distance to west side
d_e = x_e - grid_x // distance to east side
d_n = grid_y - y_n // distance to north side
d_s = y_s - grid_y // distance to south side
wn = X[:, :, y_n, x_w] // north-west point value
en = X[:, :, y_n, x_e] // north-east point value
ws = X[:, :, y_s, x_w] // south-east point value
es = X[:, :, y_s, x_w] // north-east point value
output = wn * d_e * d_s + en * d_w * d_s
+ ws * d_e * d_n + es * d_w * d_n
)DOC");
}
};
class GridSampleOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.GetPlace());
}
};
template <typename T>
class GridSampleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("grid_sampler_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Grid", this->Input("Grid"));
op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Grid"), this->InputGrad("Grid"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(grid_sampler,
GridSamplerInferShapeFunctor,
PD_INFER_META(phi::GridSampleBaseInferMeta));
REGISTER_OPERATOR(grid_sampler,
ops::GridSampleOp,
ops::GridSampleOpMaker,
ops::GridSampleGradMaker<paddle::framework::OpDesc>,
ops::GridSampleGradMaker<paddle::imperative::OpBase>,
GridSamplerInferShapeFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(grid_sampler_grad,
GridSamplerGradInferShapeFunctor,
PD_INFER_META(phi::GeneralBinaryGradInferMeta));
REGISTER_OPERATOR(grid_sampler_grad,
ops::GridSampleOpGrad,
GridSamplerGradInferShapeFunctor);
REGISTER_OP_VERSION(grid_sampler)
.AddCheckpoint(
R"ROC(
Upgrade grid_sampler add a new attribute [mode].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"mode", "In order to specify interpolation mode", "bilinear"));
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
using framework::OpKernelType;
class HistogramOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class HistogramOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) The input tensor of Histogram op,");
AddOutput("Out", "(Tensor) The output tensor of Histogram op,");
AddAttr<int64_t>("bins", "(int) number of histogram bins")
.SetDefault(100)
.EqualGreaterThan(1);
AddAttr<int>("min", "(int) lower end of the range (inclusive)")
.SetDefault(0);
AddAttr<int>("max", "(int) upper end of the range (inclusive)")
.SetDefault(0);
AddComment(R"DOC(
Histogram Operator.
Computes the histogram of a tensor. The elements are sorted
into equal width bins between min and max. If min and max are
both zero, the minimum and maximum values of the data are used.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(histogram,
HistogramInferShapeFunctor,
PD_INFER_META(phi::HistogramInferMeta));
REGISTER_OPERATOR(
histogram,
ops::HistogramOp,
ops::HistogramOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
HistogramInferShapeFunctor);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
class IndexSampleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input(Tensor), dtype support int32/int64/float/double");
AddInput("Index", "Index(Tensor), dtype support int32/int64");
AddOutput("Out", "Return the element of input at index");
AddComment(R"DOC(
IndexSample OP returns the element of the specified location of X,
and the location is specified by Index.
X tensor and Index tensor's shape must be 2-D,
dimension at 0 which usually is batch size must be equal.
The returned tensor has the same shape and dimensions as the Index tensor.
)DOC");
}
};
class IndexSampleOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class IndexSampleGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Index"),
true,
platform::errors::InvalidArgument("Input(Index) should be not null."));
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")),
true,
platform::errors::InvalidArgument(
"Input(Out@GRAD) should be not null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")),
true,
platform::errors::InvalidArgument(
"Output(X@GRAD) should be not null."));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
template <typename T>
class IndexSampleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("index_sample_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Index", this->Input("Index"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSampleGradNoNeedBufferVarInferer, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(index_sample,
IndexSampleInferShapeFunctor,
PD_INFER_META(phi::IndexSampleInferMeta));
REGISTER_OPERATOR(index_sample,
ops::IndexSampleOp,
ops::IndexSampleOpMaker,
ops::IndexSampleGradMaker<paddle::framework::OpDesc>,
ops::IndexSampleGradMaker<paddle::imperative::OpBase>,
IndexSampleInferShapeFunctor);
REGISTER_OPERATOR(index_sample_grad,
ops::IndexSampleGradOp,
ops::IndexSampleGradNoNeedBufferVarInferer);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/index_select_op.h"
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
class IndexSelectOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class IndexSelectGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Index"),
true,
platform::errors::InvalidArgument("Input(Index) should be not null."));
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")),
true,
platform::errors::InvalidArgument(
"Input(Out@GRAD) should be not null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")),
true,
platform::errors::InvalidArgument(
"Output(X@GRAD) should be not null."));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
class IndexSelectOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) the input tensor.");
AddInput("Index", "the 1-D tensor containing the indices to index.");
AddOutput("Out", "the output tensor.");
AddAttr<int>("dim", "the dimension in which we index.").SetDefault(0);
AddComment(R"DOC(
Returns a new tensor which indexes the input tensor
along dimension dim using the entries in index which
is a Tensor.
The returned tensor has the same number of dimensions
as the original tensor (input). The dim-th dimension
has the same size as the length of index; other dimensions
have the same size as in the original tensor.
)DOC");
}
};
template <typename T>
class IndexSelectGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("index_select_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Index", this->Input("Index"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSelectGradNoNeedBufferVarsInferer,
"X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(index_select,
IndexSelectInferShapeFunctor,
PD_INFER_META(phi::IndexSelectInferMeta));
REGISTER_OPERATOR(index_select,
ops::IndexSelectOp,
ops::IndexSelectOpMaker,
ops::IndexSelectGradMaker<paddle::framework::OpDesc>,
ops::IndexSelectGradMaker<paddle::imperative::OpBase>,
IndexSelectInferShapeFunctor);
REGISTER_OPERATOR(index_select_grad,
ops::IndexSelectGradOp,
ops::IndexSelectGradNoNeedBufferVarsInferer);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/matrix_inverse.h"
namespace paddle {
namespace operators {
class InverseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class InverseOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{
{"Input", /*->*/ "Output"}};
return m;
}
};
class InverseGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class InverseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"Input",
"(Tensor) A square matrix (2-D Tensor) or batches of square matrices"
" to inverse.");
AddOutput("Output", "(Tensor) The inverse of input matrix.");
AddComment(R"DOC(
Inverse Operator
Takes the inverse of the square matrix.
)DOC");
}
};
template <typename T>
class InverseGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad) const override {
grad->SetType(this->ForwardOpType() + "_grad");
grad->SetInput("Output", this->Output("Output"));
grad->SetInput(framework::GradVarName("Output"),
this->OutputGrad("Output"));
grad->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(inverse,
InverseInferShapeFunctor,
PD_INFER_META(phi::InverseInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(inverse_grad,
InverseGradInferShapeFunctor,
PD_INFER_META(phi::InverseGradInferMeta));
REGISTER_OPERATOR(inverse,
ops::InverseOp,
ops::InverseOpMaker,
ops::InverseOpInferVarType,
ops::InverseGradOpMaker<paddle::framework::OpDesc>,
ops::InverseGradOpMaker<paddle::imperative::OpBase>,
InverseInferShapeFunctor);
REGISTER_OPERATOR(inverse_grad,
ops::InverseGradOp,
InverseGradInferShapeFunctor);
......@@ -172,6 +172,12 @@
kernel :
func : cholesky_solve_grad
- backward_op : conj_grad
forward : conj (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
invoke : conj(out_grad)
- backward_op : cos_double_grad
forward : cos_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_out, Tensor grad_x_grad)
......@@ -451,6 +457,17 @@
kernel :
func : gelu_grad
- backward_op : grid_sample_grad
forward : grid_sample (Tensor x, Tensor grid, str mode, str padding_mode, bool align_corners) -> Tensor(out)
args : (Tensor x, Tensor grid, Tensor out_grad, str mode, str padding_mode, bool align_corners)
output : Tensor(x_grad), Tensor(grid_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, grid]
kernel :
func : grid_sample_grad
data_type : x
- backward_op : gumbel_softmax_grad
forward : gumbel_softmax (Tensor x, float temperature, bool hard, int axis) -> Tensor(out)
args : (Tensor out, Tensor out_grad, int axis)
......@@ -482,6 +499,39 @@
func : hard_sigmoid_grad
inplace : (out_grad -> x_grad)
- backward_op : index_sample_grad
forward : index_sample (Tensor x, Tensor index) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : index_sample_grad
data_type : out_grad
no_need_buffer : x
- backward_op : index_select_grad
forward : index_select(Tensor x, Tensor index, int axis) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad, int axis)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : index_select_grad
data_type : out_grad
no_need_buffer : x
- backward_op : inverse_grad
forward : inverse(Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
output : Tensor(x_grad)
infer_meta:
func : InverseGradInferMeta
kernel :
func : inverse_grad
- backward_op : leaky_relu_double_grad
forward : leaky_relu_grad (Tensor x, Tensor grad_out, float negative_slope) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_x_grad, float negative_slope)
......
......@@ -253,16 +253,6 @@
no_need_buffer : x
backward : concat_double_grad
- backward_op : conj_grad
forward : conj (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [out_grad]
kernel :
func : conj
- backward_op : conv2d_grad
forward : conv2d (Tensor input, Tensor filter, int[] strides, int[] paddings, str padding_algorithm, int[] dilations, int groups, str data_format) -> Tensor(out)
args : (Tensor input, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, str padding_algorithm, int[] dilations, int groups, str data_format)
......@@ -629,17 +619,6 @@
func : gather_nd_grad
no_need_buffer : x
- backward_op : grid_sample_grad
forward : grid_sample (Tensor x, Tensor grid, str mode, str padding_mode, bool align_corners) -> Tensor(out)
args : (Tensor x, Tensor grid, Tensor out_grad, str mode, str padding_mode, bool align_corners)
output : Tensor(x_grad), Tensor(grid_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, grid]
kernel :
func : grid_sample_grad
data_type : x
- backward_op : group_norm_grad
forward : group_norm (Tensor x, Tensor scale, Tensor bias, float epsilon, int groups, str data_layout) -> Tensor(y), Tensor(mean), Tensor(variance)
args : (Tensor x, Tensor scale, Tensor bias, Tensor y, Tensor mean, Tensor variance, Tensor y_grad, float epsilon, int groups, str data_layout)
......@@ -702,6 +681,12 @@
output : Tensor(x_grad)
invoke : imag_grad_impl(out_grad, x_grad)
- backward_op : increment_grad
forward : increment (Tensor x, float value) -> Tensor(out)
args : (Tensor out, float value)
output : Tensor(x_grad)
invoke : increment (out, -value)
- backward_op : index_add_grad
forward : index_add(Tensor x, Tensor index, Tensor add_value, int axis) -> Tensor(out)
args : (Tensor index, Tensor add_value, Tensor out_grad, int axis)
......@@ -713,30 +698,6 @@
data_type : out_grad
inplace : (out_grad -> x_grad)
- backward_op : index_sample_grad
forward : index_sample (Tensor x, Tensor index) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : index_sample_grad
data_type : out_grad
no_need_buffer : x
- backward_op : index_select_grad
forward : index_select(Tensor x, Tensor index, int axis) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad, int axis)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : index_select_grad
data_type : x
no_need_buffer : x
- backward_op : instance_norm_double_grad
forward : instance_norm_grad(Tensor x, Tensor fwd_scale, Tensor saved_mean, Tensor saved_variance, Tensor grad_y, float epsilon) -> Tensor(grad_x), Tensor(grad_scale), Tensor(grad_bias)
args : (Tensor x, Tensor fwd_scale, Tensor saved_mean, Tensor saved_variance, Tensor grad_y, Tensor grad_x_grad, Tensor grad_scale_grad, Tensor grad_bias_grad, float epsilon)
......@@ -760,15 +721,6 @@
optional : scale
backward : instance_norm_double_grad
- backward_op : inverse_grad
forward : inverse(Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
output : Tensor(x_grad)
infer_meta:
func : InverseGradInferMeta
kernel :
func : inverse_grad
- backward_op : kldiv_loss_grad
forward : kldiv_loss(Tensor x, Tensor label, str reduction) -> Tensor(out)
args : (Tensor x, Tensor label, Tensor out_grad, str reduction)
......
......@@ -424,15 +424,6 @@
func : concat
backward : concat_grad
- op : conj
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : conj
backward : conj_grad
- op : conv2d
args : (Tensor input, Tensor filter, int[] strides, int[] paddings, str padding_algorithm, int[] dilations, int groups, str data_format)
output : Tensor
......@@ -911,17 +902,6 @@
kernel :
func : greater_than
- op : grid_sample
args : (Tensor x, Tensor grid, str mode, str padding_mode, bool align_corners)
output : Tensor(out)
infer_meta :
func : GridSampleBaseInferMeta
param : [x, grid]
kernel:
func : grid_sample
data_type : x
backward : grid_sample_grad
- op : group_norm
args : (Tensor x, Tensor scale, Tensor bias, float epsilon, int groups, str data_layout)
output : Tensor(y), Tensor(mean), Tensor(variance)
......@@ -953,14 +933,6 @@
func : hard_tanh
backward : hardtanh_grad
- op : histogram
args : (Tensor input, int64_t bins, int min, int max)
output : Tensor(out)
infer_meta :
func : HistogramInferMeta
kernel :
func : histogram
- op : hsigmoid_loss
args : (Tensor x, Tensor w, Tensor label, Tensor path, Tensor code, Tensor bias, int num_classes, bool remote_prefetch, int trainer_id, int64_t[] height_sections, str[] epmap, str[] table_names, bool is_sparse)
output : Tensor(out), Tensor(pre_out), Tensor(w_out)
......@@ -991,7 +963,7 @@
backward : imag_grad
- op : increment
args : (Tensor x, float value)
args : (Tensor x, float value = 1.0)
output : Tensor(out)
infer_meta :
func : IncrementInferMeta
......@@ -1010,26 +982,6 @@
inplace : (x -> out)
backward : index_add_grad
- op : index_sample
args : (Tensor x, Tensor index)
output : Tensor
infer_meta :
func : IndexSampleInferMeta
kernel :
func : index_sample
data_type : x
backward : index_sample_grad
- op : index_select
args : (Tensor x, Tensor index, int axis)
output : Tensor(out)
infer_meta :
func : IndexSelectInferMeta
kernel :
func : index_select
data_type : x
backward : index_select_grad
- op : instance_norm
args : (Tensor x, Tensor scale, Tensor bias, float epsilon)
output : Tensor(y), Tensor(saved_mean), Tensor(saved_variance)
......@@ -1042,15 +994,6 @@
intermediate : saved_mean, saved_variance
backward : instance_norm_grad
- op : inverse
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : InverseInferMeta
kernel :
func : inverse
backward : inverse_grad
- op : is_empty
args : (Tensor x)
output : Tensor(out)
......
......@@ -188,6 +188,12 @@
extra :
attrs : ['str[] skip_eager_deletion_vars = {}']
- op : conj
inputs :
x : X
outputs :
out : Out
- op : conv2d
backward : conv2d_grad
extra :
......@@ -546,8 +552,12 @@
attrs : [bool use_mkldnn = false, str x_data_format = "", str y_data_format = "", str mkldnn_data_type = "float32",
bool use_quantizer = false, float Scale_x = 1.0f, float Scale_y = 1.0f, float Scale_out = 1.0f]
- op : grid_sampler
backward : grid_sampler_grad
- op : grid_sample(grid_sampler)
backward : grid_sample_grad (grid_sampler_grad)
inputs :
{x : X, grid : Grid}
outputs :
out : Output
extra :
attrs : [bool use_cudnn = true]
......@@ -587,11 +597,37 @@
attrs : [bool use_mkldnn = false, str x_data_format = "", str y_data_format = "", str mkldnn_data_type = "float32",
bool use_quantizer = false, float Scale_x = 1.0f, float Scale_y = 1.0f, float Scale_out = 1.0f]
- op : histogram
inputs :
input : X
outputs :
out : Out
- op : index_sample
inputs :
{x : X, index : Index}
outputs :
out : Out
- op : index_select
inputs :
{x : X, index : Index}
outputs :
out : Out
attrs :
axis : dim
- op : inplace_abn
backward : inplace_abn_grad
extra :
attrs : [bool use_mkldnn = false, bool fuse_with_relu = false]
- op : inverse
inputs :
x : Input
outputs :
out : Output
- op : layer_norm
backward : layer_norm_grad
extra :
......
......@@ -8,6 +8,14 @@
- delete_attr : dims
comment : The attr 'dims' is deleted.
- op : grid_sample
version :
- checkpoint : Upgrade grid_sampler add a new attribute [mode]
action :
- add_attr : mode
comment : In order to specify interpolation mode
default : std::string("bilinear")
- op : trace
version :
- checkpoint : Upgrade trace add a new attribute [axis2]
......
......@@ -152,6 +152,15 @@
func : cholesky_solve
backward : cholesky_solve_grad
- op : conj
args : (Tensor x)
output : Tensor (out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : conj
backward : conj_grad
- op : cos
args : (Tensor x)
output : Tensor
......@@ -411,6 +420,17 @@
func : gelu
backward : gelu_grad
- op : grid_sample
args : (Tensor x, Tensor grid, str mode = "bilinear", str padding_mode = "zeros", bool align_corners = true)
output : Tensor(out)
infer_meta :
func : GridSampleBaseInferMeta
param : [x, grid]
kernel:
func : grid_sample
data_type : x
backward : grid_sample_grad
- op : gumbel_softmax
args : (Tensor x, float temperature = 1.0, bool hard = false, int axis = -1)
output : Tensor
......@@ -440,6 +460,43 @@
func : hard_sigmoid
backward : hardsigmoid_grad
- op : histogram
args : (Tensor input, int64_t bins = 100, int min = 0, int max = 0)
output : Tensor(out)
infer_meta :
func : HistogramInferMeta
kernel :
func : histogram
- op : index_sample
args : (Tensor x, Tensor index)
output : Tensor
infer_meta :
func : IndexSampleInferMeta
kernel :
func : index_sample
data_type : x
backward : index_sample_grad
- op : index_select
args : (Tensor x, Tensor index, int axis = 0)
output : Tensor(out)
infer_meta :
func : IndexSelectInferMeta
kernel :
func : index_select
data_type : x
backward : index_select_grad
- op : inverse
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : InverseInferMeta
kernel :
func : inverse
backward : inverse_grad
- op : leaky_relu
args : (Tensor x, float negative_slope = 0.02f)
output : Tensor
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature GridSamplerOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("grid_sample",
{"X", "Grid"},
{"mode", "padding_mode", "align_corners"},
{"Output"});
}
KernelSignature GridSamplerGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("grid_sample_grad",
{"X", "Grid", "Output@GRAD"},
{"mode", "padding_mode", "align_corners"},
{"X@GRAD", "Grid@GRAD"});
}
} // namespace phi
// use Python API name as kernel name
PD_REGISTER_BASE_KERNEL_NAME(grid_sampler, grid_sample);
PD_REGISTER_BASE_KERNEL_NAME(grid_sampler_grad, grid_sample_grad);
PD_REGISTER_ARG_MAPPING_FN(grid_sampler, phi::GridSamplerOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(grid_sampler_grad,
phi::GridSamplerGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature HistogramOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("histogram", {"X"}, {"bins", "min", "max"}, {"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(histogram, phi::HistogramOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature IndexSampleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"index_sample_grad", {"X", "Index", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(index_sample_grad,
phi::IndexSampleGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature IndexSelectGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"index_select_grad", {"X", "Index", "Out@GRAD"}, {"dim"}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(index_select_grad,
phi::IndexSelectGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature InverseGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"inverse_grad", {"Output", "Output@GRAD"}, {}, {"Input@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(inverse_grad, phi::InverseGradOpArgumentMapping);
......@@ -88,7 +88,7 @@ class TestHistogramOpError(unittest.TestCase):
)
paddle.histogram(input=input_value, bins=-1, min=1, max=5)
with self.assertRaises(IndexError):
with self.assertRaises(ValueError):
self.run_network(net_func)
def test_min_max_error(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册