未验证 提交 04b2c9fa 编写于 作者: H huzhiqiang 提交者: GitHub

[Op] Add one_hot op for host backend (#4093)

上级 6992a62f
......@@ -23,7 +23,9 @@ add_kernel(print_compute_host Host extra SRCS print_compute.cc DEPS ${lite_kerne
add_kernel(while_compute_host Host extra SRCS while_compute.cc DEPS ${lite_kernel_deps} program)
add_kernel(conditional_block_compute_host Host extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} program)
add_kernel(activation_grad_compute_host Host train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps})
add_kernel(one_hot_compute_host Host extra SRCS one_hot_compute.cc DEPS ${lite_kernel_deps})
if(LITE_BUILD_EXTRA)
if(LITE_BUILD_EXTRA AND LITE_WITH_x86)
lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host)
lite_cc_test(test_one_hot_compute_host SRCS one_hot_compute_test.cc DEPS one_hot_compute_host)
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/host/one_hot_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
template <typename T>
void OneHotKernelFunctor(const Tensor* in,
Tensor* out,
int depth,
bool allow_out_of_range = false) {
auto* p_in_data = in->data<T>();
auto numel = in->numel();
auto* p_out_data = out->mutable_data<T>();
memset(p_out_data, 0, out->numel() * sizeof(T));
if (allow_out_of_range) {
for (int i = 0; i < numel; ++i) {
if (p_in_data[i] >= 0 && p_in_data[i] < depth) {
p_out_data[i * depth + static_cast<int>(p_in_data[i])] = 1.0;
}
}
} else {
for (int i = 0; i < numel; ++i) {
CHECK_GE(p_in_data[i], 0) << "Illegal index value, Input(input) value "
"should be at least 0, but received input ("
<< p_in_data[i] << ") less than 0";
CHECK_LE(p_in_data[i], depth)
<< "Illegal index value, Input(input) value should be less than "
"Input(depth), but received input ("
<< p_in_data[i] << ") not less than depth (" << depth << ")";
p_out_data[i * depth + static_cast<int>(p_in_data[i])] = 1.0;
}
}
}
void OneHotCompute::Run() {
auto& param = this->template Param<param_t>();
switch (param.dtype) {
case static_cast<int>(lite::core::FluidType::INT64):
OneHotKernelFunctor<int64_t>(
param.X, param.Out, param.depth, param.allow_out_of_range);
break;
case static_cast<int>(lite::core::FluidType::INT32):
OneHotKernelFunctor<int32_t>(
param.X, param.Out, param.depth, param.allow_out_of_range);
break;
case static_cast<int>(lite::core::FluidType::FP32):
OneHotKernelFunctor<float>(
param.X, param.Out, param.depth, param.allow_out_of_range);
break;
default:
LOG(ERROR) << "Unsupported data type for one_hot op:" << param.dtype;
}
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
one_hot, kHost, kAny, kAny, paddle::lite::kernels::host::OneHotCompute, def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindInput("depth_tensor",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class OneHotCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
using param_t = operators::OneHotParam;
void Run() override;
virtual ~OneHotCompute() = default;
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <random>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/host/one_hot_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
/* note:
One Hot Operator. This operator creates the one-hot representations for input
index values. The following example will help to explain the function of this
operator:
X is a LoDTensor:
X.lod = [[0, 1, 4]]
X.shape = [4, 1]
X.data = [[1], [1], [3], [0]]
set depth = 4
Out is a LoDTensor:
Out.lod = [[0, 1, 4]]
Out.shape = [4, 4]
Out.data = [[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.]] */
TEST(one_hot, test) {
using T = float;
lite::Tensor x, out;
x.Resize({4, 1});
out.Resize({4, 4});
auto* x_data = x.mutable_data<T>();
x_data[0] = 1;
x_data[1] = 1;
x_data[2] = 3;
x_data[3] = 0;
auto* out_data = out.mutable_data<T>();
float out_ref[4][4] = {
{0, 1, 0, 0}, {0, 1, 0, 0}, {0, 0, 0, 1}, {1, 0, 0, 0}};
OneHotCompute one_hot;
operators::OneHotParam param;
param.X = &x;
param.Out = &out;
param.depth = 4;
// static_cast<int>(lite::core::FluidType::FP32) = 5;
param.dtype = 5;
one_hot.SetParam(param);
one_hot.PrepareForRun();
one_hot.Run();
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_data[i], out_ref[i], 1e-5);
}
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(one_hot, kHost, kAny, kAny, def);
......@@ -147,6 +147,7 @@ add_operator(layer_norm_op extra SRCS layer_norm_op.cc DEPS ${op_DEPS})
add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS})
add_operator(retinanet_detection_output_op extra SRCS retinanet_detection_output_op.cc DEPS ${op_DEPS})
add_operator(where_index_op extra SRCS where_index_op.cc DEPS ${op_DEPS})
add_operator(one_hot_op extra SRCS one_hot_op.cc DEPS ${op_DEPS})
# for content-dnn specific
add_operator(search_aligned_mat_mul_op extra SRCS search_aligned_mat_mul_op.cc DEPS ${op_DEPS})
add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS})
......@@ -175,7 +176,7 @@ add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc DEPS ${op_DEPS})
add_operator(__xpu__resnet_cbam_op extra SRCS __xpu__resnet_cbam_op.cc DEPS ${op_DEPS})
add_operator(__xpu__search_attention_op extra SRCS __xpu__search_attention_op.cc DEPS ${op_DEPS})
add_operator(__xpu__mmdnn_op extra SRCS __xpu__mmdnn_op.cc DEPS ${op_DEPS})
lite_cc_test(test_one_hot_op SRCS one_hot_op_test.cc DEPS one_hot_op memory scope ${op_deps} one_hot_compute_host)
if (NOT LITE_WITH_X86)
lite_cc_test(test_fc_op SRCS fc_op_test.cc
DEPS fc_op memory
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/one_hot_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool OneHotOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
return true;
}
bool OneHotOp::InferShapeImpl() const {
auto out_dims = param_.X->dims();
CHECK_GE(out_dims.size(), 2);
int depth = param_.depth_tensor ? param_.depth
: param_.depth_tensor->data<int32_t>()[0];
out_dims[out_dims.size() - 1] = depth;
param_.Out->Resize(out_dims);
param_.Out->set_lod(param_.X->lod());
return true;
}
bool OneHotOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto x = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front();
param_.X = scope->FindVar(x)->GetMutable<Tensor>();
param_.Out = scope->FindMutableTensor(out);
if (op_desc.HasInput("depth_tensor") &&
!op_desc.Input("depth_tensor").empty()) {
auto depth_tensor = op_desc.Input("depth_tensor").front();
param_.depth_tensor = scope->FindVar(depth_tensor)->GetMutable<Tensor>();
}
if (op_desc.HasAttr("depth")) {
param_.depth = op_desc.GetAttr<int>("depth");
}
if (op_desc.HasAttr("allow_out_of_range")) {
param_.allow_out_of_range = op_desc.GetAttr<bool>("allow_out_of_range");
}
param_.dtype = op_desc.GetAttr<int>("dtype");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(one_hot, paddle::lite::operators::OneHotOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
/* note:
One Hot Operator. This operator creates the one-hot representations for input
index values. The following example will help to explain the function of this
operator:
X is a LoDTensor:
X.lod = [[0, 1, 4]]
X.shape = [4, 1]
X.data = [[1], [1], [3], [0]]
set depth = 4
Out is a LoDTensor:
Out.lod = [[0, 1, 4]]
Out.shape = [4, 4]
Out.data = [[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.]] */
class OneHotOp : public OpLite {
public:
OneHotOp() {}
explicit OneHotOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "one_hot"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.X->dims());
ch->output_shape = ch->DimToStr(param_.Out->dims());
ch->macs = param_.X->numel() * 1.f;
}
#endif
private:
mutable OneHotParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/one_hot_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
TEST(one_hot_op_lite, TestHost) {
// prepare variables
Scope scope;
auto* x = scope.Var("X")->GetMutable<Tensor>();
auto* depth_tensor = scope.Var("depth_tensor")->GetMutable<Tensor>();
auto* output = scope.Var("Out")->GetMutable<Tensor>();
depth_tensor->dims();
output->dims();
// set data
x->Resize(DDim(std::vector<int64_t>({4, 1})));
auto* x_data = x->mutable_data<int32_t>();
x_data[0] = 1;
x_data[1] = 1;
x_data[2] = 3;
x_data[3] = 0;
// prepare op desc
cpp::OpDesc desc;
desc.SetType("one_hot");
desc.SetInput("X", {"X"});
desc.SetInput("depth_tensor", {"depth_tensor"});
desc.SetOutput("Out", {"Out"});
desc.SetAttr("depth", static_cast<int>(4));
desc.SetAttr("dtype", static_cast<int>(1));
desc.SetAttr("allow_out_of_range", static_cast<bool>(0));
OneHotOp one_hot("one_hot");
one_hot.SetValidPlaces({Place{TARGET(kHost), PRECISION(kAny)}});
one_hot.Attach(desc, &scope);
auto kernels = one_hot.CreateKernels({Place{TARGET(kHost), PRECISION(kAny)}});
ASSERT_FALSE(kernels.empty());
}
} // namespace operators
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(one_hot, kHost, kAny, kAny, def);
......@@ -1824,6 +1824,15 @@ struct PrintParam : ParamBase {
bool is_forward{true};
};
struct OneHotParam : ParamBase {
const lite::Tensor* X{};
const lite::Tensor* depth_tensor{nullptr};
lite::Tensor* Out{};
int depth;
int dtype;
bool allow_out_of_range;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册