未验证 提交 4d7ddb49 编写于 作者: Z Zhenghai Zhang 提交者: GitHub

static graph autogen code for expand (#54628)

* static graph autogen code for expand

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug
上级 1c426e0b
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -128,23 +125,17 @@ Expand operator tiles the input by given times number. You should set times
number for each dimension by providing attribute 'expand_times'. The rank of X
should be in [1, 6]. Please note that size of 'expand_times' must be the same
with X's rank. Following is a using case:
Input(X) is a 3-D tensor with shape [2, 3, 1]:
[
[[1], [2], [3]],
[[4], [5], [6]]
]
Attr(expand_times): [1, 2, 2]
Output(Out) is a 3-D tensor with shape [2, 6, 2]:
[
[[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]],
[[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]]
]
)DOC");
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/expand_v2_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
#define MAX_RANK_SUPPORTED 6
namespace paddle {
namespace operators {
class ExpandV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override {
if (var_name == "expand_shapes_tensor" || var_name == "Shape") {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
};
class ExpandV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"X is the input to be expanded.");
AddInput("Shape",
"(Tensor<int>), optional). If provided, expand according to "
"this given Shape. It has a higher priority than "
"expand_shapes_tensor and the shape attribute.")
.AsDispensable();
AddInput("expand_shapes_tensor",
"(Tensor Tensor<int>), epxanded shape for X."
"It has a higher priority than shape attribute, but a lower "
"priority than the input Shape")
.AsDuplicable()
.AsDispensable();
AddOutput("Out",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"The rank of Output(Out) have the same with Input(X). "
"After expanding, size of each dimension of Output(Out) is equal "
"to size of the corresponding dimension of Input(X) multiplying "
"the corresponding value given by Attr(expand_times).");
AddAttr<std::vector<int>>("shape", "The expanded shape for each dimension.")
.SetDefault({});
AddComment(R"DOC(
Expand the input to the given shape. The rank of X
should be in [1, 6] and size of 'shape' must be in [1, 6] also.
Following is a using case:
Input(X) is a 3-D tensor with shape [2, 3, 1]:
[
[[1], [2], [3]],
[[4], [5], [6]]
]
Attr(shape): [2, 6, 2]
Output(Out) is a 3-D tensor with shape [2, 6, 2]:
[
[[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]],
[[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]]
]
)DOC");
}
};
class ExpandV2GradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExpandV2Grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
framework::GradVarName("Out"),
"ExpandV2Grad");
auto x_dims = ctx->GetInputDim("X");
std::vector<int> expand_shape = ctx->Attrs().Get<std::vector<int>>("shape");
if (expand_shape.size() == 0) {
expand_shape = std::vector<int>(x_dims.size(), -1);
}
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_dim_vec = phi::vectorize<int>(x_dims);
auto diff = expand_shape.size() - x_dim_vec.size();
x_dim_vec.insert(x_dim_vec.begin(), diff, -1);
for (size_t i = 0; i < expand_shape.size(); ++i) {
if (expand_shape[i] < 0 || x_dim_vec[i] == -1) {
continue;
} else {
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
expand_shape[i],
out_dims[i],
platform::errors::InvalidArgument(
"The size (%d) of the dimension %d of Input(Out@GRAD) should "
"be equal to the crroresponding dimension size of shape(%d).",
out_dims[i],
i,
expand_shape[i]));
}
}
}
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override {
if (var_name == "expand_shapes_tensor" || var_name == "Shape") {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
};
template <typename T>
class ExpandV2GradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("expand_v2_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetInput("expand_shapes_tensor", this->Input("expand_shapes_tensor"));
op->SetInput("Shape", this->Input("Shape"));
op->SetAttrMap(this->Attrs());
}
};
class ExpandV2CompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
auto x = this->GetSingleForwardInput("X");
auto out_grad = this->GetSingleOutputGrad("Out");
auto x_grad = this->GetSingleInputGrad("X");
auto x_grad_p = this->GetOutputPtr(&x_grad);
auto x_grad_name = this->GetOutputName(x_grad);
auto shape = this->Attr<std::vector<int>>("shape");
prim::expand_grad<prim::DescTensor>(
x, out_grad, paddle::experimental::IntArray(shape), x_grad_p);
VLOG(6) << "Runing expand_v2 composite func";
this->RecoverOutputName(x_grad, x_grad_name);
}
};
template <typename T>
class ExpandV2DoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("expand_v2");
op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
if (this->HasInput("expand_shapes_tensor")) {
op->SetInput("expand_shapes_tensor", this->Input("expand_shapes_tensor"));
}
if (this->HasInput("Shape")) {
op->SetInput("Shape", this->Input("Shape"));
}
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ExpandV2GradNoNeedBufVarsInferer, "X");
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(expand_v2,
ExpandInferShapeFunctor,
PD_INFER_META(phi::ExpandInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(expand_v2,
ops::ExpandV2Op,
ops::ExpandV2OpMaker,
ops::ExpandV2CompositeGradOpMaker,
ops::ExpandV2GradOpMaker<paddle::framework::OpDesc>,
ops::ExpandV2GradOpMaker<paddle::imperative::OpBase>,
ExpandInferShapeFunctor);
REGISTER_OPERATOR(expand_v2_grad,
ops::ExpandV2GradOp,
ops::ExpandV2DoubleGradOpMaker<paddle::framework::OpDesc>,
ops::ExpandV2DoubleGradOpMaker<paddle::imperative::OpBase>,
ops::ExpandV2GradNoNeedBufVarsInferer);
......@@ -682,6 +682,26 @@
func : expand_as_grad
no_need_buffer : x
- backward_op : expand_double_grad
forward : expand_grad (Tensor x, Tensor grad_out, IntArray shape) -> Tensor(grad_x)
args : (Tensor grad_x_grad, IntArray shape)
output : Tensor(grad_out_grad)
invoke : expand(grad_x_grad, shape)
- backward_op : expand_grad
forward : expand (Tensor x, IntArray shape) -> Tensor(out)
args : (Tensor x, Tensor out_grad, IntArray shape)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : expand_grad
data_type : out_grad
no_need_buffer : x
backward : expand_double_grad
composite: expand_grad(x, out_grad, shape, x_grad)
- backward_op : expm1_grad
forward : expm1 (Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
......
......@@ -255,25 +255,6 @@
invoke : embedding_grad_impl(x, weight, out_grad, padding_idx, sparse, weight_grad)
no_need_buffer : weight
- backward_op : expand_double_grad
forward : expand_grad (Tensor x, Tensor grad_out, IntArray shape) -> Tensor(grad_x)
args : (Tensor grad_x_grad, IntArray shape)
output : Tensor(grad_out_grad)
invoke : expand(grad_x_grad, shape)
- backward_op : expand_grad
forward : expand (Tensor x, IntArray shape) -> Tensor(out)
args : (Tensor x, Tensor out_grad, IntArray shape)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : expand_grad
no_need_buffer : x
backward : expand_double_grad
composite: expand_grad(x, out_grad, shape, x_grad)
- backward_op : exponential__grad
forward : exponential_ (Tensor x, float lam) -> Tensor(out)
args : (Tensor out_grad)
......
......@@ -302,15 +302,6 @@
kernel :
func : equal
- op : expand
args : (Tensor x, IntArray shape)
output : Tensor
infer_meta :
func : ExpandInferMeta
kernel :
func : expand
backward : expand_grad
- op : exponential_
args : (Tensor x, float lam)
output : Tensor(out)
......
......@@ -914,7 +914,7 @@
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : expand (expand_v2)
backward : expand_grad (expand_v2_grad)
backward : expand_grad (expand_v2_grad), expand_double_grad(expand_v2_double_grad)
inputs :
x : X
attrs :
......@@ -928,6 +928,7 @@
tensors_name : expand_shapes_tensor
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"]
manual_signature : [expand, expand_grad]
- op : expand_as (expand_as_v2)
backward : expand_as_grad (expand_as_v2_grad)
......
......@@ -762,6 +762,16 @@
inplace : (x -> out)
backward : exp_grad
- op : expand
args : (Tensor x, IntArray shape = {})
output : Tensor(out)
infer_meta :
func : ExpandInferMeta
kernel :
func : expand
data_type : x
backward : expand_grad
- op : expand_as
args : (Tensor x, Tensor y, int[] target_shape = {})
output : Tensor(out)
......
......@@ -44,7 +44,7 @@ class TestIntarrayInput(unittest.TestCase):
tensor_data = paddle.to_tensor(np_data)
shape = paddle.to_tensor([2, 3, 4])
net = paddle.jit.to_static(fn)
with self.assertRaises(ValueError):
with self.assertRaises(NotImplementedError):
_ = net(tensor_data, shape).numpy()
core._set_prim_all_enabled(False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册