未验证 提交 2ad66a42 编写于 作者: R RedContritio 提交者: GitHub

support auto generate static for empty (#52524)

上级 535915aa
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/nullary.h"
namespace paddle {
namespace operators {
class EmptyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("ShapeTensor",
"(Tensor<int>), optional). The shape of the output."
"It has a higher priority than Attr(shape).")
.AsDispensable();
AddInput("ShapeTensorList",
"(vector<Tensor<int>>, optional). The shape of the output. "
"It has a higher priority than Attr(shape)."
"The shape of the element in vector must be [1].")
.AsDuplicable()
.AsDispensable();
AddAttr<std::vector<int64_t>>("shape",
"(vector<int64_t>) The shape of the output")
.SetDefault({});
AddAttr<int>("dtype", "The data type of output tensor, Default is float")
.SetDefault(framework::proto::VarType::FP32);
AddOutput("Out", "(Tensor) The output tensor.");
AddComment(R"DOC(empty operator
Returns a tensor filled with uninitialized data. The shape of the tensor is
defined by the variable argument shape.
The type of the tensor is specify by `dtype`.
)DOC");
}
};
class EmptyOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override {
if (var_name == "ShapeTensor" || var_name == "ShapeTensorList") {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
} else {
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
}
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& context) const override {
return phi::KernelKey(
framework::proto::VarType::Type(context.Attr<int>("dtype")),
context.GetPlace());
}
};
class EmptyOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* context) const override {
auto data_type = static_cast<framework::proto::VarType::Type>(
PADDLE_GET_CONST(int, context->GetAttr("dtype")));
context->SetOutputDataType("Out", data_type);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(empty,
EmptyInferShapeFunctor,
PD_INFER_META(phi::CreateInferMeta));
REGISTER_OP_WITHOUT_GRADIENT(empty,
ops::EmptyOp,
ops::EmptyOpMaker,
ops::EmptyOpVarTypeInference,
EmptyInferShapeFunctor);
......@@ -82,7 +82,6 @@ register_unity_group(
diag_v2_op.cc
dot_op.cc
edit_distance_op.cc
empty_op.cc
enqueue_op.cc
erf_op.cc
py_func_op.cc
......
......@@ -602,6 +602,15 @@
int trainer_id = 0, int slot = 0, 'int64_t[] height_sections = {}', 'str[] epmap = {}',
'str[] table_names = {}']
- op : empty
outputs :
out : Out
int_array:
shape :
data_type : int64_t
tensor_name : ShapeTensor
tensors_name : ShapeTensorList
- op : equal
inputs :
{x : X, y : Y}
......
......@@ -70,6 +70,17 @@
data_type : weight
backward : embedding_grad
- op : empty
args : (IntArray shape = {}, DataType dtype = DataType::FLOAT32)
output: Tensor(out)
infer_meta :
func : CreateInferMeta
param : [shape, dtype]
kernel :
func : empty
param : [shape, dtype]
data_type : dtype
- op : equal
args : (Tensor x, Tensor y, int axis = -1, bool force_cpu=false)
output : Tensor(out)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature EmptyOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("ShapeTensor")) {
return KernelSignature("empty", {}, {"ShapeTensor", "dtype"}, {"Out"});
} else if (ctx.InputSize("ShapeTensorList") > 0) {
return KernelSignature("empty", {}, {"ShapeTensorList", "dtype"}, {"Out"});
} else {
return KernelSignature("empty", {}, {"shape", "dtype"}, {"Out"});
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(empty, phi::EmptyOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册