未验证 提交 8947488c 编写于 作者: W Wang Xin 提交者: GitHub

static graph autogen code support for full_like op (#54698)

* static graph autogen code support for full_like op

* fix

* fix bug
上级 93f7a02a
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class FillAnyLikeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fill_any_like");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fill_any_like");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx);
const auto &data_type = ctx.Attr<int>("dtype");
if (data_type >= 0) {
kt.set_dtype(phi::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>(data_type)));
}
return kt;
}
phi::KernelKey GetKernelTypeForVar(
const std::string &var_name,
const phi::DenseTensor &tensor,
const phi::KernelKey &expected_kernel_type) const override {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
tensor.layout(),
expected_kernel_type.dtype());
}
};
class FillAnyLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input of fill-zeros-like op.");
AddOutput("Out", "The variable will be filled up with specified value.");
AddAttr<float>("value", "The filled value").SetDefault(0.0);
AddAttr<int>("dtype",
"Output tensor data type. default value is -1,"
"according to the input dtype.")
.SetDefault(-1);
AddComment(R"DOC(
FillAnyLike Operator.
Fill up a variable with Attr(value).
The output will have the same shape and dtype as the input.
)DOC");
}
};
class FillAnyLikeVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto var_data_type = static_cast<framework::proto::VarType::Type>(
PADDLE_GET_CONST(int, ctx->GetAttr("dtype")));
if (var_data_type < 0) {
ctx->SetOutputDataType("Out", ctx->GetInputDataType("X"));
} else {
ctx->SetOutputDataType("Out", var_data_type);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
fill_any_like,
ops::FillAnyLikeOp,
ops::FillAnyLikeOpMaker,
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::FillAnyLikeVarTypeInference)
......@@ -335,8 +335,8 @@ phi::KernelKey GetExpectedKernelType(
{% endif %}
{% elif kernel["data_type"]["candidates"] | length == 2 %}
{% set data_type_args = kernel["data_type"]["candidates"] %}
auto data_type = framework::proto::VarType::Type(ctx.Attr<int>("{{data_type_args[0]}}");
if (data_type == static_cast<proto::VarType::Type>(-1)) {
auto data_type = framework::proto::VarType::Type(ctx.Attr<int>("{{data_type_args[0]}}"));
if (data_type == static_cast<framework::proto::VarType::Type>(-1)) {
data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_args[1] | to_opmaker_name}});
}
{% endif %}
......
......@@ -1095,8 +1095,10 @@
x : X
outputs :
out : Out
attrs :
{value: value, dtype: dtype}
scalar :
value :
data_type : float
support_tensor : true
- op : fused_conv2d
extra :
......
......@@ -203,6 +203,16 @@
param : [x, axis, keepdim, reduce_all]
backward : frobenius_norm_grad
- op : full_like
args : (Tensor x, Scalar value = 0.0, DataType dtype = DataType::UNDEFINED)
output: Tensor(out)
infer_meta :
func : FillAnyLikeInferMeta
kernel :
func : full_like
param : [x, value, dtype]
data_type : dtype > x
- op : gaussian
args : (IntArray shape = {}, float mean = .0f, float std = 1.0f, int seed = 0, DataType dtype = DataType::FLOAT32)
output: Tensor(out)
......
......@@ -1111,6 +1111,15 @@ void ExpandInferMeta(const MetaTensor& x,
}
}
void FillAnyLikeInferMeta(const MetaTensor& x,
const Scalar& value,
DataType dtype,
MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype);
out->share_lod(x);
}
void FillDiagonalInferMeta(
const MetaTensor& x, float value, int offset, bool wrap, MetaTensor* out) {
PADDLE_ENFORCE_NE(
......
......@@ -178,6 +178,11 @@ void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
MetaTensor* out);
void FillAnyLikeInferMeta(const MetaTensor& x,
const Scalar& value,
DataType dtype,
MetaTensor* out);
void FillDiagonalInferMeta(
const MetaTensor& x, float value, int offset, bool wrap, MetaTensor* out);
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature FillAnyLikeOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("full_like", {"X"}, {"value", "dtype"}, {"Out"});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(fill_any_like, full_like);
PD_REGISTER_ARG_MAPPING_FN(fill_any_like, phi::FillAnyLikeOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册