未验证 提交 ec877d1f 编写于 作者: R RedContritio 提交者: GitHub

support auto generate for dirichlet (#51601)

* support auto generate for dirichlet

* use uppercase in args

* use op_compat for name mapping
上级 d04c9cda
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class DirichletOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Alpha", "(Tensor), The dirichlet Alpha parameter");
AddOutput("Out", "(Tensor), The output tensor of sample");
AddComment(R"DOC(Sample random data from dirichlet distribution.)DOC");
}
};
class DirichletOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(dirichlet,
DirichletInferShapeFunctor,
PD_INFER_META(phi::DirichletInferMeta));
REGISTER_OP_WITHOUT_GRADIENT(dirichlet,
paddle::operators::DirichletOp,
paddle::operators::DirichletOpMaker,
DirichletInferShapeFunctor);
......@@ -448,14 +448,6 @@
func : depthwise_conv2d_transpose
backward : depthwise_conv2d_transpose_grad
- op : dirichlet
args: (Tensor alpha)
output: Tensor(out)
infer_meta:
func: DirichletInferMeta
kernel:
func: dirichlet
- op : distribute_fpn_proposals
args : (Tensor fpn_rois, Tensor rois_num, int min_level, int max_level, int refer_level, int refer_scale, bool pixel_offset)
output : Tensor[](multi_fpn_rois){max_level - min_level + 1}, Tensor[](multi_level_rois_num){max_level - min_level + 1}, Tensor(restore_index)
......
......@@ -458,6 +458,12 @@
outputs :
out : Out
- op : dirichlet
inputs :
alpha : Alpha
outputs :
out : Out
- op : dist
inputs :
{x : X, y : Y}
......
......@@ -384,6 +384,14 @@
func : digamma
backward : digamma_grad
- op : dirichlet
args: (Tensor alpha)
output: Tensor(out)
infer_meta:
func: DirichletInferMeta
kernel:
func: dirichlet
- op : dist
args : (Tensor x, Tensor y, float p = 2.0)
output : Tensor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册