未验证 提交 b6af72eb 编写于 作者: R RedContritio 提交者: GitHub

support auto generate static for one_hot_v2 (#52134)

* support auto generate static for one_hot_v2

* format
上级 8c8c6d9d
...@@ -92,6 +92,8 @@ static PyObject * eager_api_{}(PyObject *self, PyObject *args, PyObject *kwargs) ...@@ -92,6 +92,8 @@ static PyObject * eager_api_{}(PyObject *self, PyObject *args, PyObject *kwargs)
PyThreadState *tstate = nullptr; PyThreadState *tstate = nullptr;
try {{ try {{
VLOG(6) << "Running Eager Final State API: {}"; VLOG(6) << "Running Eager Final State API: {}";
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
// Get EagerTensors from args // Get EagerTensors from args
{} {}
// Parse Attributes if needed // Parse Attributes if needed
......
...@@ -107,10 +107,10 @@ AddAttr<{{typename | to_op_attr_type}}>("{{name}}", "({{typename | to_op_attr_ty ...@@ -107,10 +107,10 @@ AddAttr<{{typename | to_op_attr_type}}>("{{name}}", "({{typename | to_op_attr_ty
{% if default_value == "DataType::UNDEFINED" %} {% if default_value == "DataType::UNDEFINED" %}
-1 -1
{%- else %} {%- else %}
static_cast<int>(framework::TransToProtoVarType(experimental::{{default_value}})) static_cast<int>(framework::TransToProtoVarType(phi::{{default_value}}))
{%- endif %} {%- endif %}
{%- elif typename == "DataLayout" %} {# does DataLayout need any processing?#} {%- elif typename == "DataLayout" %} {# does DataLayout need any processing?#}
static_cast<int>(experimental::{{default_value}}) static_cast<int>(phi::{{default_value}})
{%- elif typename == "Place" %}{# construct a Place to get the type #} {%- elif typename == "Place" %}{# construct a Place to get the type #}
static_cast<int>(phi::Place({{"phi::" if not default_value is initializer_list}}{{default_value}}).GetType()) static_cast<int>(phi::Place({{"phi::" if not default_value is initializer_list}}{{default_value}}).GetType())
{%- else %}{# pass through as-is #} {%- else %}{# pass through as-is #}
...@@ -385,7 +385,7 @@ phi::KernelKey GetKernelTypeForVar( ...@@ -385,7 +385,7 @@ phi::KernelKey GetKernelTypeForVar(
var_name == "{{ skip_arg }}" var_name == "{{ skip_arg }}"
{%- if skip_args_len != 1 and loop.index != skip_args_len %} || {% endif -%} {%- if skip_args_len != 1 and loop.index != skip_args_len %} || {% endif -%}
{%- endfor -%} {%- endfor -%}
){ ) {
{% if "skip_transform" in op["data_transform"] %} {% if "skip_transform" in op["data_transform"] %}
return phi::KernelKey(phi::Backend::ALL_BACKEND, return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(), expected_kernel_type.layout(),
...@@ -400,7 +400,7 @@ phi::KernelKey GetKernelTypeForVar( ...@@ -400,7 +400,7 @@ phi::KernelKey GetKernelTypeForVar(
return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
} }
{% endif %} {% endif %}
else{ else {
return phi::KernelKey( return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype()); tensor.place(), tensor.layout(), expected_kernel_type.dtype());
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class OneHotV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override {
if (var_name == "depth_tensor") {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
};
class OneHotV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(phi::DenseTensor, phi::DenseTensor<int>) Input variable with "
"rank at least 2. "
"The last dimension of X should be 1. Each value of X is an index "
"to indicate the position.");
AddInput("depth_tensor", "(Tensor, Tensor<int>), Length of one-hot vector")
.AsDispensable();
AddOutput("Out",
"(Tensor, Tensor<float>) Output tensor with same rank as X. "
"The tensor consists of one-hot representations of values in X.");
AddAttr<int>("depth",
"A positive integer to specify the length of one-hot vector.")
.SetDefault(-1);
AddAttr<int>("dtype",
"An integer to specify the data type of one-hot "
"vector. The default value is FP32.")
.SetDefault(paddle::framework::proto::VarType::FP32);
AddAttr<bool>("allow_out_of_range",
"If it is set true and the input data is out of range, "
"the output tensor will be filled zeros. The default value "
"is false.")
.SetDefault(false);
AddComment(R"DOC(
One Hot Operator. This operator creates the one-hot representations for input
index values. The following example will help to explain the function of this
operator:
X is a LoDTensor:
X.lod = [[0, 1, 4]]
X.shape = [4]
X.data = [1, 1, 3, 0]
set depth = 4
Out is a LoDTensor:
Out.lod = [[0, 1, 4]]
Out.shape = [4, 4]
Out.data = [[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.]]
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(one_hot_v2,
OneHotInferShapeFunctor,
PD_INFER_META(phi::OneHotRawInferMeta));
REGISTER_OPERATOR(
one_hot_v2,
ops::OneHotV2Op,
ops::OneHotV2OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
OneHotInferShapeFunctor);
...@@ -1328,6 +1328,16 @@ ...@@ -1328,6 +1328,16 @@
outputs : outputs :
size : Out size : Out
- op : one_hot (one_hot_v2)
inputs :
x : X
outputs :
out : Out
scalar :
depth :
data_type : int
tensor_name : depth_tensor
- op : overlap_add - op : overlap_add
backward : overlap_add_grad backward : overlap_add_grad
inputs : inputs :
......
...@@ -129,6 +129,15 @@ ...@@ -129,6 +129,15 @@
backend : x backend : x
force_backend : force_cpu force_backend : force_cpu
- op : one_hot
args : (Tensor x, Scalar(int) depth = -1, DataType dtype = DataType::FLOAT32, bool allow_out_of_range = false)
output : Tensor(out)
infer_meta :
func : OneHotRawInferMeta
kernel :
func : one_hot_raw
data_type : x
- op : reduce - op : reduce
args : (Tensor x, int ring_id = 0, int root_id = 0, int reduce_type = 0) args : (Tensor x, int ring_id = 0, int root_id = 0, int reduce_type = 0)
output : Tensor(out) output : Tensor(out)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature OneHotOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("depth_tensor")) {
return KernelSignature("one_hot_raw",
{"X"},
{"depth_tensor", "dtype", "allow_out_of_range"},
{"Out"});
} else {
return KernelSignature("one_hot_raw",
{"X"},
{"depth", "dtype", "allow_out_of_range"},
{"Out"});
}
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(one_hot_v2, one_hot);
PD_REGISTER_ARG_MAPPING_FN(one_hot_v2, phi::OneHotOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册