未验证 提交 ac9debee 编写于 作者: Z zyfncg 提交者: GitHub

Generate static graph code of stack, unbind, unique_consecutive op (#49726)

* generate static graph code of stack, unbind, unique_consecutive op

* fix bug
上级 43ec2271
......@@ -383,15 +383,21 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
# add optional tag for every input
for input in inputs:
input["optional"] = False
for output in outputs:
output["optional"] = False
if "optional" in op_entry:
optional_args = parse_plain_list(op_entry["optional"])
for name in optional_args:
assert (
name in input_names
), f"{op_name} has an optional input: '{name}' which is not an input."
name in input_names or name in output_names
), f"{op_name} has an optional tensor: '{name}' which is not in input or output."
for input in inputs:
if input["name"] in optional_args:
input["optional"] = True
for output in outputs:
if output["name"] in optional_args:
output["optional"] = True
# add intermediate tag for every output
for output in outputs:
......
......@@ -54,6 +54,10 @@ AddOutput({{name | to_opmaker_name}}, "({{typename}}), output {{i}} of {{op_name
.AsIntermediate()
{%- endif %}
{%- if output["optional"] %}
.AsDispensable()
{%- endif %}
{%- if "is_extra" in output and output["is_extra"] %}
.AsExtra()
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/multiary.h"
namespace plat = paddle::platform;
namespace ops = paddle::operators;
namespace paddle {
namespace operators {
class StackOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
};
class StackOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input of stack op.").AsDuplicable();
AddOutput("Y", "The output of stack op.");
AddAttr<int>("axis",
"The axis along which all of the Inputs(X) should be stacked.")
.SetDefault(0);
AddComment(R"DOC(
Stack Operator.
Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inputs(X) must be the same.
)DOC");
}
};
class StackOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
template <typename T>
class StackGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("stack_grad");
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(stack,
StackInferMetaFunctor,
PD_INFER_META(phi::StackInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(stack_grad,
StackGradInferMetaFunctor,
PD_INFER_META(phi::StackGradInferMeta));
REGISTER_OPERATOR(stack,
ops::StackOp,
ops::StackOpMaker,
ops::StackGradOpMaker<paddle::framework::OpDesc>,
ops::StackGradOpMaker<paddle::imperative::OpBase>,
StackInferMetaFunctor);
REGISTER_OPERATOR(stack_grad, ops::StackOpGrad, StackGradInferMetaFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/unbind_op.h"
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class UnbindOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"),
true,
platform::errors::NotFound("Input(X) of UnbindOp is not found."));
PADDLE_ENFORCE_GE(
ctx->Outputs("Out").size(),
1UL,
platform::errors::NotFound("Outputs(Out) of UnbindOp is not found."));
auto in_dims = ctx->GetInputDim("X");
auto outs_names = ctx->Outputs("Out");
int axis = ctx->Attrs().Get<int>("axis");
const size_t outs_number = outs_names.size();
auto out_dims = UnbindOutsDims(in_dims, axis);
std::vector<framework::DDim> outs_dims(outs_number, out_dims);
ctx->SetOutputsDim("Out", outs_dims);
for (size_t i = 0; i < outs_number; ++i) {
ctx->ShareLoD("X", "Out", 0, i);
}
}
};
class UnbindOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input tensor of the split operator.");
AddOutput("Out", "(Tensor) Output tensors of the unbind operator.")
.AsDuplicable();
AddComment(R"DOC(
Unbind operator
Remove a tensor dimension.
Example:
Input = [[1,2],
[3,4],
[5,6]]
axis = 0
Output[0] = [1,2]
Output[1] = [3,4]
Output[2] = [5,6]
)DOC");
AddAttr<int>("axis",
"(int, default 0) "
"dimension to remove.")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(unbind,
ops::UnbindOp,
ops::UnbindOpMaker,
ops::UnbindGradMaker<paddle::framework::OpDesc>,
ops::UnbindGradMaker<paddle::imperative::OpBase>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class UniqueConsecutiveOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
};
class UniqueConsecutiveOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor of unique_consecutive op.");
AddAttr<int>("dtype",
"(int, default 5(FP32)) "
"data type for output index")
.SetDefault(framework::proto::VarType::FP32);
AddOutput("Out", "A unique consecutive subsequence for input tensor.");
AddOutput("Index",
"The indices for where elements in the original input ended up "
"in the returned unique tensor.")
.AsDispensable();
AddOutput("Counts", "The counts for each unique element.").AsDispensable();
AddAttr<bool>(
"return_inverse",
"If True, also return the indices for where elements"
" in the original input ended up in the returned unique tensor.")
.SetDefault(false);
AddAttr<bool>("return_counts",
"If True, also return the counts for each unique element.")
.SetDefault(false);
AddAttr<std::vector<int>>(
"axis",
"The axis to apply unique. If None, the input will be flattened.")
.SetDefault({});
AddComment(R"DOC(
This function is different from paddle.unique() in the sense that this
function only eliminates consecutive duplicate values.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(unique_consecutive,
UniqueConsecutiveInferShapeFunctor,
PD_INFER_META(phi::UniqueConsecutiveInferMeta));
REGISTER_OP_WITHOUT_GRADIENT(unique_consecutive,
ops::UniqueConsecutiveOp,
ops::UniqueConsecutiveOpMaker,
UniqueConsecutiveInferShapeFunctor);
REGISTER_OP_VERSION(unique_consecutive)
.AddCheckpoint(
R"ROC(
Upgrade unique_consecutive, add 2 outputs [Indices, Counts] and 3 attribute
[return_inverse, return_counts, axis].
)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewOutput("Counts", "The counts for each unique element.")
.NewAttr("return_inverse",
"If True, also return the indices for where elements"
" in the original input ended up in the returned unique "
"tensor.",
false)
.NewAttr("return_counts",
"If True, also return the counts for each unique element.",
false)
.NewAttr("axis",
"The axis to apply unique. If None, the input will be "
"flattened.",
std::vector<int>{}));
......@@ -1300,6 +1300,19 @@
inplace : (out_grad -> x_grad)
backward: squeeze_double_grad
- backward_op : stack_grad
forward : stack (Tensor[] x, int axis) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad, int axis)
output : Tensor[](x_grad){x.size()}
infer_meta :
func : StackGradInferMeta
param: [out_grad, axis]
kernel :
func : stack_grad
param : [out_grad, axis]
data_type : out_grad
no_need_buffer : x
- backward_op : svd_grad
forward : svd (Tensor x, bool full_matrices = false) -> Tensor(u), Tensor(s), Tensor(vh)
args : (Tensor x, Tensor u, Tensor vh, Tensor s, Tensor u_grad, Tensor vh_grad, Tensor s_grad, bool full_matrices)
......@@ -1424,6 +1437,12 @@
kernel :
func : trunc_grad
- backward_op : unbind_grad
forward : unbind (Tensor input, int axis) -> Tensor[](out)
args : (Tensor[] out_grad, int axis)
output : Tensor(input_grad)
invoke : stack(out_grad, axis)
- backward_op : unfold_grad
forward : unfold (Tensor x, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations)
......
......@@ -1291,18 +1291,6 @@
kernel :
func : squared_l2_norm_grad
- backward_op : stack_grad
forward : stack (Tensor[] x, int axis) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad, int axis)
output : Tensor[](x_grad){x.size()}
infer_meta :
func : StackGradInferMeta
param: [out_grad, axis]
kernel :
func : stack_grad
param : [out_grad, axis]
no_need_buffer : x
- backward_op : strided_slice_grad
forward : strided_slice (Tensor x, int[] axes, IntArray starts, IntArray ends, IntArray strides) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int[] axes, IntArray starts, IntArray ends, IntArray strides)
......@@ -1468,12 +1456,6 @@
kernel :
func : triu_grad
- backward_op : unbind_grad
forward : unbind (Tensor input, int axis) -> Tensor[](out)
args : (Tensor[] out_grad, int axis)
output : Tensor(input_grad)
invoke : stack(out_grad, axis)
- backward_op : uniform_inplace_grad
forward : uniform_inplace(Tensor x, float min, float max, int seed, int diag_num, int diag_step, float diag_val) -> Tensor(out)
args : (Tensor out_grad, float min, float max, int seed, int diag_num, int diag_step, float diag_val)
......
......@@ -1701,15 +1701,6 @@
func : squared_l2_norm
backward : squared_l2_norm_grad
- op : stack
args : (Tensor[] x, int axis)
output : Tensor
infer_meta :
func : StackInferMeta
kernel :
func : stack
backward : stack_grad
- op : strided_slice
args : (Tensor x, int[] axes, IntArray starts, IntArray ends, IntArray strides)
output : Tensor
......@@ -1770,8 +1761,8 @@
backward : temporal_shift_grad
- op : tile
args : (Tensor x, IntArray repeat_times)
output : Tensor
args : (Tensor x, IntArray repeat_times = {})
output : Tensor(out)
infer_meta :
func : TileInferMeta
kernel :
......@@ -1863,15 +1854,6 @@
backend : place
data_type : dtype
- op : unbind
args : (Tensor input, int axis)
output : Tensor[] {axis<0 ? input.dims()[input.dims().size()+axis]:input.dims()[axis]}
infer_meta :
func : UnbindInferMeta
kernel :
func : unbind
backward : unbind_grad
- op : uniform
args : (IntArray shape, DataType dtype, Scalar min, Scalar max, int seed, Place place={})
output : Tensor(out)
......@@ -1905,15 +1887,6 @@
func : unique
data_type : x
- op : unique_consecutive
args : (Tensor x, bool return_inverse, bool return_counts, int[] axis, int dtype)
output : Tensor(out), Tensor(index), Tensor(counts)
infer_meta :
func : UniqueConsecutiveInferMeta
kernel :
func : unique_consecutive
data_type : x
- op : unpool
args: (Tensor x, Tensor indices, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format)
output: Tensor(out)
......
......@@ -854,13 +854,11 @@
out : Out
- op : logsigmoid
backward : logsigmoid_grad
inputs :
x : X
outputs :
out : Out
- op : logsigmoid
backward : logsigmoid_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
......@@ -1363,13 +1361,13 @@
- op : stack
backward : stack_grad
inputs :
x : X
outputs :
out : Y
extra :
attrs : [bool use_mkldnn = false]
- op : stack
backward : stack_grad
extra :
attrs : [bool use_mkldnn = false]
drop_empty_grad : [x_grad]
- op : subtract (elementwise_sub)
backward : subtract_grad (elementwise_sub_grad)
......@@ -1436,6 +1434,18 @@
outputs :
out : Out
- op : tile
backward : tile_grad, tile_double_grad
inputs :
x : X
outputs :
out : Out
int_array:
repeat_times :
data_type : int
tensor_name : RepeatTimes
tensors_name : repeat_times_tensor
- op : topk (top_k_v2)
backward : topk_grad (top_k_v2_grad)
inputs :
......@@ -1470,12 +1480,24 @@
outputs :
out : Out
- op : unbind
inputs :
input : X
outputs :
out : Out
- op : unfold
inputs :
x : X
outputs :
out : Y
- op : unique_consecutive
inputs :
x : X
outputs :
{out : Out, index : Index, counts : Counts}
- op : unsqueeze (unsqueeze2)
backward : unsqueeze_grad (unsqueeze2_grad), unsqueeze_double_grad(unsqueeze2_double_grad)
inputs :
......
......@@ -58,8 +58,7 @@
version :
- checkpoint : Compatible upgrade of pixel_shuffle, add a new attribute [data_format]
action :
- add_attr :
name : data_format
- add_attr : data_format
comment : Specify the data format of the input data
default : "true"
......@@ -84,11 +83,26 @@
- add_attr : axis1
comment : The added attribute 'axis1' is not yet registered.
default : std::vector<float>{0.0f}
- add_attr :
name : axis2
- add_attr : axis2
comment : The added attribute 'axis2' is not yet registered.
default : std::vector<float>{1.0f}
- delete_attr : dim1
comment : The attribute 'dim1' is not recommend according to the specification 2.0.
- delete_attr : dim2
comment : The attribute 'dim2' is not recommend according to the specification 2.0.
- op : unique_consecutive
version :
- checkpoint : Upgrade unique_consecutive, add 2 outputs [Indices, Counts] and 3 attribute [return_inverse, return_counts, axis].
action :
- add_output : Counts
comment : The counts for each unique element.
- add_attr : return_inverse
comment : If True, also return the indices for where elements in the original input ended up in the returned unique tensor.
default : "false"
- add_attr : return_counts
comment : If True, also return the counts for each unique element.
default : "false"
- add_attr : axis
comment : The axis to apply unique. If None, the input will be flattened.
default : std::vector<int>{}
......@@ -1157,6 +1157,15 @@
intermediate : xshape
backward : squeeze_grad
- op : stack
args : (Tensor[] x, int axis = 0)
output : Tensor (out)
infer_meta :
func : StackInferMeta
kernel :
func : stack
backward : stack_grad
- op : svd
args : (Tensor x, bool full_matrices = false)
output : Tensor(u), Tensor(s), Tensor(vh)
......@@ -1243,6 +1252,15 @@
func : trunc
backward : trunc_grad
- op : unbind
args : (Tensor input, int axis = 0)
output : Tensor[] {axis<0 ? input.dims()[input.dims().size()+axis]:input.dims()[axis]}
infer_meta :
func : UnbindInferMeta
kernel :
func : unbind
backward : unbind_grad
- op : unfold
args : (Tensor x, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations)
output : Tensor(out)
......@@ -1252,6 +1270,16 @@
func : unfold
backward : unfold_grad
- op : unique_consecutive
args : (Tensor x, bool return_inverse = false, bool return_counts = false, int[] axis = {}, int dtype = 5)
output : Tensor(out), Tensor(index), Tensor(counts)
infer_meta :
func : UniqueConsecutiveInferMeta
kernel :
func : unique_consecutive
data_type : x
optional : index, counts
- op : unsqueeze
args : (Tensor x, IntArray axis = {})
output : Tensor(out), Tensor(xshape)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature StackGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("stack_grad", {"Y@GRAD"}, {"axis"}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(stack_grad, phi::StackGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature UniqueConsecutiveOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("unique_consecutive",
{"X"},
{"return_inverse", "return_counts", "axis", "dtype"},
{"Out", "Index", "Counts"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(unique_consecutive,
phi::UniqueConsecutiveOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册