未验证 提交 be04f258 编写于 作者: L lzydev 提交者: GitHub

【fix bug】Fix bug in parse args with '{,}' (#52968)

* fix bug in parse args

* fix bug

* recover legacy_*.yaml

* change 'Out' to Output
上级 b9830634
......@@ -292,7 +292,9 @@ def ParseYamlArgs(string):
# attrs_list = [ [arg_name, arg_type, default_value, orig_position], ...]
attrs_list = []
args = [x.strip() for x in string.strip().split(",")]
patten = re.compile(r',(?![^{]*\})') # support int[] a={1,3}
args = re.split(patten, string.strip())
args = [x.strip() for x in args]
atype = r'((const )?\S+) '
aname = r'(.*)'
pattern = f'{atype}{aname}'
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
class DeformableConvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input",
"(Tensor) The input of deformable conv op. "
"The shape of input is "
"[N, channel_in, H, W]");
AddInput("Offset",
"(Tensor) The input offset. "
"The shape of the offset is "
"[N, deformable_groups * kernel_w * kernel_h * 2, H, W");
AddInput("Mask",
"(Tensor) The input mask. "
"The shape of the mask is "
"[N, deformable_groups * kernel_w * kernel_h, H, W].");
AddInput("Filter",
"(Tensor) The Input Filter "
"The shape of the wight is "
"[num_filters, channel_in, kernel_h, kernel_w.");
AddOutput("Output",
"(Tensor) The output. "
"The shape of the output tensor is "
"[N, num_filters, out_height, out_width]].");
AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of "
"convolution operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings",
"(vector<int> default:{0,0}), the "
"paddings(h_pad, w_pad) of "
"convolution operator. ")
.SetDefault({0, 0});
AddAttr<std::vector<int>>("dilations",
"(vector<int> default:{1, 1}), the "
"dilations(h_dilation, w_dilation) of "
"convolution operator.")
.SetDefault({1, 1});
AddAttr<int>(
"groups",
"(int default:1), the groups number of the convolution operator. "
"According to grouped convolution in Alex Krizhevsky's Deep CNN paper: "
"when group=2, the first half of the filters is only connected to the "
"first half of the input channels, while the second half of the "
"filters "
"is only connected to the second half of the input channels.")
.SetDefault(1);
AddAttr<int>("deformable_groups",
"(int default:1), the number of the deformable groups.")
.SetDefault(1);
AddAttr<int>("im2col_step",
"im2col maximum number of image per computation")
.SetDefault(64);
AddComment(R"DOC(
**Deformable Convolution Operator**
Compute 2-D deformable convolution on 4-D input.
Given input image x, output feature map y, the deformable convolution operation can be expressed as follow:
$$
y(p) = \\sum_{k=1}^{K}{w_k * x(p + p_k + \\Delta p_k) * \\Delta m_k}
$$
Where $$\\Delta p_k$$ and $$\Delta m_k$$ are the learnable offset and modulation scalar for the k-th location, respectively.
Refer to 'Deformable ConvNets v2: More Deformable, Better Results
'<https://arxiv.org/abs/1811.11168v2>
Example:
Input:
Input shape: $(N, C_{in}, H_{in}, W_{in})$
Filter shape: $(C_{out}, C_{in}, H_f, W_f)$
Offset shape: $(N, 2 * deformable_groups, * H_f * W_f, H_{out}, W_{out})$
Mask shape: $(N, deformable_groups * H_f * W_f, H_{out}, W_{out})$
Output:
Output shape: $(N, C_{out}, H_{out}, W_{out})$
where $H_{out}, W_{out}$ must be equal to $H_{in}, W_{in}$ respectively.
Where
$$
H_{out}= \frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]}+ 1 \\
W_{out}= \frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1
$$
)DOC");
}
};
class DeformableConvOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context().GetPlace());
}
};
template <typename T>
class DeformableConvGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("deformable_conv_grad");
op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", this->Input("Filter"));
op->SetInput("Offset", this->Input("Offset"));
op->SetInput("Mask", this->Input("Mask"));
op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("Filter"), this->InputGrad("Filter"));
op->SetOutput(framework::GradVarName("Offset"), this->InputGrad("Offset"));
op->SetOutput(framework::GradVarName("Mask"), this->InputGrad("Mask"));
op->SetAttrMap(this->Attrs());
}
};
class DeformableConvGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
auto offset_dims = ctx->GetInputDim("Offset");
auto mask_dims = ctx->GetInputDim("Mask");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Output")),
"Input",
"Output@Grad",
"deformable_conv_grad");
if (ctx->HasOutput(framework::GradVarName("Input"))) {
ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
}
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
}
if (ctx->HasOutput(framework::GradVarName("Offset"))) {
ctx->SetOutputDim(framework::GradVarName("Offset"), offset_dims);
}
if (ctx->HasOutput(framework::GradVarName("Mask"))) {
ctx->SetOutputDim(framework::GradVarName("Mask"), mask_dims);
}
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context().GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(deformable_conv,
DeformableConvInferShapeFunctor,
PD_INFER_META(phi::DeformableConvInferMeta));
REGISTER_OPERATOR(deformable_conv,
ops::DeformableConvOp,
ops::DeformableConvOpMaker,
ops::DeformableConvGradOpMaker<paddle::framework::OpDesc>,
ops::DeformableConvGradOpMaker<paddle::imperative::OpBase>,
DeformableConvInferShapeFunctor);
REGISTER_OPERATOR(deformable_conv_grad, ops::DeformableConvGradOp);
......@@ -167,8 +167,13 @@ def parse_candidates(s: str) -> Dict[str, Any]:
def parse_plain_list(s: str, sep=",") -> List[str]:
items = [item.strip() for item in s.strip().split(sep)]
return items
if sep == ",":
patten = re.compile(r',(?![^{]*\})') # support "int[] a={1,2}"
items = re.split(patten, s.strip())
items = [x.strip() for x in items]
return items
else:
return [item.strip() for item in s.strip().split(sep)]
def parse_kernel(op_name: str, kernel_config: Dict[str, Any]) -> Dict[str, Any]:
......
......@@ -144,7 +144,9 @@ class BaseAPI:
')'
), f"Args declaration should start with '(' and end with ')', please check the args of {api_name} in yaml."
args_str = args_str[1:-1]
args_list = args_str.split(',')
patten = re.compile(r',(?![^{]*\})') # support int[] a={1,3}
args_list = re.split(patten, args_str.strip())
args_list = [x.strip() for x in args_list]
input_types_map = {
'Tensor': 'const Tensor&',
'Tensor[]': 'const std::vector<Tensor>&',
......
......@@ -518,6 +518,13 @@
outputs :
out : Out
- op : deformable_conv
backward : deformable_conv_grad
inputs :
{x : Input, offset : Offset, filter : Filter, mask : Mask}
outputs :
out : Output
- op : depthwise_conv2d
backward : depthwise_conv2d_grad
extra :
......
......@@ -7,6 +7,16 @@
composite: assign_grad(out_grad, x_grad)
invoke : assign(out_grad)
- backward_op : deformable_conv_grad
forward : deformable_conv (Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides={1, 1}, int[] paddings={0, 0}, int[] dilations={1, 1}, int deformable_groups=1, int groups=1, int im2col_step=64) -> Tensor(out)
args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, Tensor out_grad, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step)
output : Tensor(x_grad), Tensor(offset_grad), Tensor(filter_grad), Tensor(mask_grad)
infer_meta :
func : DeformableConvGradInferMeta
kernel :
func : deformable_conv_grad
data_type : x
- backward_op : embedding_grad
forward : embedding (Tensor x, Tensor weight, int64_t padding_idx=-1) -> Tensor(out)
args : (Tensor x, Tensor weight, Tensor out_grad, int64_t padding_idx=-1)
......
......@@ -78,6 +78,16 @@
func : decode_jpeg
param : [x, mode]
- op : deformable_conv
args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides={1, 1}, int[] paddings={0, 0}, int[] dilations={1, 1}, int deformable_groups=1, int groups=1, int im2col_step=64)
output : Tensor(out)
infer_meta :
func : DeformableConvInferMeta
kernel :
func : deformable_conv
data_type : x
backward : deformable_conv_grad
- op : embedding
args : (Tensor x, Tensor weight, int64_t padding_idx=-1)
output : Tensor
......
......@@ -16,7 +16,7 @@
namespace phi {
KernelSignature DeformableConvOpArgumentMapping(
KernelSignature DeformableConvOpV1ArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("deformable_conv",
{"Input", "Offset", "Filter", "Mask"},
......@@ -29,7 +29,7 @@ KernelSignature DeformableConvOpArgumentMapping(
{"Output"});
}
KernelSignature DeformableConvGradOpArgumentMapping(
KernelSignature DeformableConvGradOpV1ArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"deformable_conv_grad",
......@@ -48,12 +48,7 @@ KernelSignature DeformableConvGradOpArgumentMapping(
PD_REGISTER_BASE_KERNEL_NAME(deformable_conv_v1, deformable_conv);
PD_REGISTER_BASE_KERNEL_NAME(deformable_conv_v1_grad, deformable_conv_grad);
PD_REGISTER_ARG_MAPPING_FN(deformable_conv,
phi::DeformableConvOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(deformable_conv_grad,
phi::DeformableConvGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(deformable_conv_v1,
phi::DeformableConvOpArgumentMapping);
phi::DeformableConvOpV1ArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(deformable_conv_v1_grad,
phi::DeformableConvGradOpArgumentMapping);
phi::DeformableConvGradOpV1ArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册