diff --git a/CMakeLists.txt b/CMakeLists.txt old mode 100644 new mode 100755 index 77a94bea1efcdafaa67b4c078bfb0a756f7b1cec..786b1322b346631d1570a6ebd9c572302531db4e --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,6 +22,8 @@ if (WITH_PADDLE_MOBILE) return() endif(WITH_PADDLE_MOBILE) +# set(CMAKE_BUILD_TYPE DEBUG) + set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) set(CMAKE_CXX_STANDARD 11) diff --git a/lite/backends/fpga/KD/debugger.hpp b/lite/backends/fpga/KD/debugger.hpp index cab02208d73591dfa7631c77b1586d6b6041efbb..7cf67aba3ca3333f445526fdbb1929660d05afd3 100755 --- a/lite/backends/fpga/KD/debugger.hpp +++ b/lite/backends/fpga/KD/debugger.hpp @@ -41,19 +41,20 @@ class Debugger { private: std::unordered_map op_config; Debugger() { - // op_config["concat"] = true; - // op_config["pooling"] = true; - // op_config["conv"] = true; - // op_config["crop"] = true; - // op_config["feed"] = true; - // op_config["fetch"] = true; - // op_config["boxes"] = true; - // op_config["scores"] = true; - // op_config["nms"] = true; - // op_config["pb_boxes"] = true; - // op_config["pb_variances"] = true; - // // op_config["fc"] = true; - // op_config["softmax"] = true; + op_config["concat"] = true; + op_config["pooling"] = true; + op_config["conv"] = true; + op_config["crop"] = true; + op_config["feed"] = true; + op_config["mul"] = true; + op_config["fetch"] = true; + op_config["boxes"] = true; + op_config["scores"] = true; + op_config["nms"] = true; + op_config["pb_boxes"] = true; + op_config["pb_variances"] = true; + // op_config["fc"] = true; + op_config["softmax"] = true; } }; diff --git a/lite/kernels/arm/cast_compute.cc b/lite/kernels/arm/cast_compute.cc old mode 100644 new mode 100755 index bc274ea22485e84a1cc9145e62fc967f2847c5dd..b13f2dea9c154943432088cbd760881e3c7a0731 --- a/lite/kernels/arm/cast_compute.cc +++ b/lite/kernels/arm/cast_compute.cc @@ -56,6 +56,10 @@ void CastCompute::Run() { float* out_data = param.Out->mutable_data(); std::transform( x_data_begin, x_data_end, out_data, TransOp); + } else if (param.in_dtype == 3 && param.out_dtype == 5) { + const auto* x_data = param.X->data(); + auto* o_data = param.Out->mutable_data(); + memcpy(o_data, x_data, sizeof(float) * param.X->numel()); } else { LOG(FATAL) << "other has not been implemented"; } diff --git a/lite/kernels/arm/lookup_table_compute.cc b/lite/kernels/arm/lookup_table_compute.cc old mode 100644 new mode 100755 index ba58b378f4dda22fd78ce76b80bdbca8d8f284a3..fa7e2c0c3ae4580f5d19e82f7c48c74db3058847 --- a/lite/kernels/arm/lookup_table_compute.cc +++ b/lite/kernels/arm/lookup_table_compute.cc @@ -28,6 +28,7 @@ namespace arm { void LookupTableCompute::Run() { auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); // inputs auto w = param.W; auto ids = param.Ids; @@ -36,7 +37,7 @@ void LookupTableCompute::Run() { auto table_dim = w->dims(); int64_t ids_numel = ids->numel(); - auto ids_data = ids->data(); + auto ids_data = ids->data(); int64_t row_number = table_dim[0]; int64_t row_width = table_dim[1]; @@ -75,14 +76,3 @@ REGISTER_LITE_KERNEL(lookup_table, .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); - -REGISTER_LITE_KERNEL(lookup_table_v2, - kARM, - kFloat, - kNCHW, - paddle::lite::kernels::arm::LookupTableCompute, - def) - .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) - .Finalize(); diff --git a/lite/kernels/fpga/CMakeLists.txt b/lite/kernels/fpga/CMakeLists.txt index 7c47e72872ecae6216288c20fa1a6ae30fac65bd..89d2d26a1042bd092e91608d778dda7ad3430d41 100755 --- a/lite/kernels/fpga/CMakeLists.txt +++ b/lite/kernels/fpga/CMakeLists.txt @@ -7,7 +7,9 @@ set(fpga_deps fpga_target_wrapper kernel_fpga) # add_kernel(activation_compute_fpga FPGA basic SRCS activation_compute.cc DEPS ${fpga_deps}) # add_kernel(box_coder_compute_fpga FPGA basic SRCS box_coder_compute.cc DEPS ${fpga_deps}) -# add_kernel(concat_compute_fpga FPGA basic SRCS concat_compute.cc DEPS ${fpga_deps}) + +add_kernel(concat_compute_fpga FPGA basic SRCS concat_compute.cc DEPS ${fpga_deps}) + add_kernel(conv_compute_fpga FPGA basic SRCS conv_compute.cc DEPS ${fpga_deps}) # add_kernel(density_prior_box_compute_fpga FPGA basic SRCS density_prior_box_compute.cc DEPS ${fpga_deps}) add_kernel(dropout_compute_fpga FPGA basic SRCS dropout_compute.cc DEPS ${fpga_deps}) @@ -16,9 +18,12 @@ add_kernel(elementwise_compute_fpga FPGA basic SRCS elementwise_compute.cc DEPS add_kernel(fc_compute_fpga FPGA basic SRCS fc_compute.cc DEPS ${fpga_deps}) add_kernel(gru_compute_fpga FPGA extra SRCS gru_compute.cc DEPS ${fpga_deps}) + # add_kernel(mul_compute_fpga FPGA basic SRCS mul_compute.cc DEPS ${fpga_deps}) + add_kernel(multiclass_nms_compute_fpga FPGA basic SRCS multiclass_nms_compute.cc DEPS ${fpga_deps}) add_kernel(norm_compute_fpga FPGA basic SRCS norm_compute.cc DEPS ${fpga_deps}) + # add_kernel(im2sequence_compute_fpga FPGA basic SRCS im2sequence_compute.cc DEPS ${fpga_deps}) add_kernel(pooling_compute_fpga FPGA basic SRCS pooling_compute.cc DEPS ${fpga_deps}) add_kernel(prior_box_compute_fpga FPGA basic SRCS prior_box_compute.cc DEPS ${fpga_deps}) diff --git a/lite/kernels/fpga/feed_compute.cc b/lite/kernels/fpga/feed_compute.cc old mode 100755 new mode 100644 index 7670bf0007def88c27c12ea54c569a7fcf263693..79329e99a3e5e812dca487c17452f3f5d1e96449 --- a/lite/kernels/fpga/feed_compute.cc +++ b/lite/kernels/fpga/feed_compute.cc @@ -67,3 +67,13 @@ REGISTER_LITE_KERNEL( PRECISION(kFP16), DATALAYOUT(kNHWC))}) .Finalize(); + +REGISTER_LITE_KERNEL(feed, + kFPGA, + kFP16, + kNHWC, + paddle::lite::kernels::fpga::FeedCompute, + def_host) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + .Finalize(); diff --git a/lite/kernels/fpga/fetch_compute.cc b/lite/kernels/fpga/fetch_compute.cc index 9b5f3f60232bb8527f823395693cf3b3851bc04e..b1559770d7cd7d46cc4c3b744668113474565a70 100644 --- a/lite/kernels/fpga/fetch_compute.cc +++ b/lite/kernels/fpga/fetch_compute.cc @@ -43,8 +43,15 @@ void FetchCompute::PrepareForRun() { } void FetchCompute::Run() { - pe_.dispatch(); auto& param = this->Param(); + auto fetch_list = param.fetch_list; + if (fetch_list->size() <= static_cast(param.col)) { + fetch_list->resize(param.col + 1); + } + Tensor& out = param.fetch_list->at(param.col); + out.Resize(param.input->dims()); + + pe_.dispatch(); #ifdef FPGA_PRINT_TENSOR zynqmp::OutputParam& fetch_param = pe_.param(); @@ -67,10 +74,7 @@ REGISTER_LITE_KERNEL(fetch, {LiteType::GetTensorTy(TARGET(kFPGA), PRECISION(kAny), DATALAYOUT(kAny))}) - .BindOutput("Out", - {LiteType::GetTensorTy(TARGET(kHost), - PRECISION(kAny), - DATALAYOUT(kAny))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) .Finalize(); REGISTER_LITE_KERNEL(fetch, @@ -79,12 +83,6 @@ REGISTER_LITE_KERNEL(fetch, kNHWC, paddle::lite::kernels::fpga::FetchCompute, host_host) - .BindInput("X", - {LiteType::GetTensorTy(TARGET(kHost), - PRECISION(kAny), - DATALAYOUT(kAny))}) - .BindOutput("Out", - {LiteType::GetTensorTy(TARGET(kHost), - PRECISION(kAny), - DATALAYOUT(kAny))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) .Finalize(); diff --git a/lite/kernels/fpga/io_copy_compute.cc b/lite/kernels/fpga/io_copy_compute.cc index 51ed60c22d0c8f63c611662d3662863dac1f7e07..bd9e0b3e34d42fdaf26c75a138c8f62fc6d37963 100644 --- a/lite/kernels/fpga/io_copy_compute.cc +++ b/lite/kernels/fpga/io_copy_compute.cc @@ -122,7 +122,81 @@ class IoCopyFpgaToHostCompute // param.x->ZynqTensor()->saveToFile("io_x", true); // param.y->ZynqTensor()->saveToFile("io_y", true); } + std::string doc() const override { return "Copy IO from FPGA to HOST"; } +}; + +void hwc_to_chw(float* chw_data, + float* hwc_data, + int num, + int channel, + int height, + int width) { + int chw = channel * height * width; + int wc = width * channel; + int wh = width * height; + int index = 0; + for (int n = 0; n < num; n++) { + for (int h = 0; h < height; h++) { + for (int w = 0; w < width; w++) { + for (int c = 0; c < channel; c++) { + chw_data[n * chw + c * wh + h * width + w] = hwc_data[index]; + index++; + } + } + } + } +} +class IoCopyFpgaToHostCHWCompute + : public KernelLite { + public: + void Run() override { + auto& param = Param(); + CHECK(param.x->target() == TARGET(kHost) || + param.x->target() == TARGET(kFPGA)); + + Tensor hwc; + hwc.Resize(param.y->dims()); + float* hwc_data = hwc.mutable_data(); + + float* chw_data = param.y->mutable_data(); + param.y->ZynqTensor()->setDataType(zynqmp::FP32); + param.x->ZynqTensor()->syncToDevice(); + + if (param.x->ZynqTensor()->aligned() && + param.x->ZynqTensor()->shape().shouldAlign()) { + zynqmp::Tensor tempTensor; + tempTensor.mutableData(zynqmp::FP16, + param.x->ZynqTensor()->shape()); + tempTensor.copyFrom(param.x->ZynqTensor()); + tempTensor.setAligned(true); + tempTensor.unalignImage(); + hwc.ZynqTensor()->copyFrom(&tempTensor); + } else { + hwc.ZynqTensor()->copyFrom(param.x->ZynqTensor()); + } + + int num = 1; + int channel = 1; + int height = 1; + int width = 1; + + auto dims = param.y->ZynqTensor()->shape(); + + hwc_to_chw(chw_data, + hwc_data, + dims.num(), + dims.channel(), + dims.height(), + dims.width()); + + param.y->ZynqTensor()->copyScaleFrom(param.x->ZynqTensor()); + param.y->ZynqTensor()->flush(); + auto out_lod = param.y->mutable_lod(); + *out_lod = param.x->lod(); + // param.x->ZynqTensor()->saveToFile("io_x", true); + // param.y->ZynqTensor()->saveToFile("io_y", true); + } std::string doc() const override { return "Copy IO from FPGA to HOST"; } }; @@ -173,7 +247,7 @@ REGISTER_LITE_KERNEL(io_copy, PRECISION(kFP16), DATALAYOUT(kNHWC))}) .BindOutput("Out", - {LiteType::GetTensorTy(TARGET(kARM), + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNHWC))}) .Finalize(); @@ -182,7 +256,7 @@ REGISTER_LITE_KERNEL(io_copy, kFPGA, kAny, kAny, - paddle::lite::kernels::fpga::IoCopyFpgaToHostCompute, + paddle::lite::kernels::fpga::IoCopyFpgaToHostCHWCompute, device_to_host_22) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kFPGA), diff --git a/lite/kernels/fpga/multiclass_nms_compute.cc b/lite/kernels/fpga/multiclass_nms_compute.cc old mode 100644 new mode 100755 index cee5e16205370df7faabc6f37d57fe360e8a9e67..5740df1f66de69ed6dd813bcaa6dacf8fe8800d6 --- a/lite/kernels/fpga/multiclass_nms_compute.cc +++ b/lite/kernels/fpga/multiclass_nms_compute.cc @@ -379,11 +379,13 @@ void MulticlassNmsCompute::Run() { if (e > s) { Tensor out; + std::cout << "Slice:" << s << " -- " << e << std::endl; outs->Slice(out, s, e); MultiClassOutput( scores_slice, boxes_slice, all_indices[i], score_dims.size(), &out); outs->ZynqTensor()->copyFrom(out.ZynqTensor()); } + outs->Resize({static_cast(e - s), out_dim}); } } LoD lod; @@ -412,19 +414,19 @@ REGISTER_LITE_KERNEL(multiclass_nms, .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) .Finalize(); -REGISTER_LITE_KERNEL(multiclass_nms, - kFPGA, - kFP16, - kNHWC, - paddle::lite::kernels::fpga::MulticlassNmsCompute, - def2) - .BindInput("BBoxes", - {LiteType::GetTensorTy(TARGET(kFPGA), - PRECISION(kFP16), - DATALAYOUT(kNHWC))}) - .BindInput("Scores", - {LiteType::GetTensorTy(TARGET(kFPGA), - PRECISION(kFP16), - DATALAYOUT(kNHWC))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) - .Finalize(); +// REGISTER_LITE_KERNEL(multiclass_nms, +// kFPGA, +// kFP16, +// kNHWC, +// paddle::lite::kernels::fpga::MulticlassNmsCompute, +// def2) +// .BindInput("BBoxes", +// {LiteType::GetTensorTy(TARGET(kFPGA), +// PRECISION(kFP16), +// DATALAYOUT(kNHWC))}) +// .BindInput("Scores", +// {LiteType::GetTensorTy(TARGET(kFPGA), +// PRECISION(kFP16), +// DATALAYOUT(kNHWC))}) +// .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) +// .Finalize(); diff --git a/lite/kernels/host/CMakeLists.txt b/lite/kernels/host/CMakeLists.txt index 428cc213ce63b8d24193a44f23d61fea78f63d6a..c6f2721d80b6fd584ce96e817476372e37b17ed8 100755 --- a/lite/kernels/host/CMakeLists.txt +++ b/lite/kernels/host/CMakeLists.txt @@ -4,6 +4,7 @@ add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_ add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps}) add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op) add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(one_hot_compute_host Host basic SRCS one_hot_compute.cc DEPS ${lite_kernel_deps}) #lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any) #lite_cc_test(test_multiclass_nms_compute_host SRCS multiclass_nms_compute_test.cc DEPS multiclass_nms_compute_host any) diff --git a/lite/kernels/host/one_hot_compute.cc b/lite/kernels/host/one_hot_compute.cc new file mode 100755 index 0000000000000000000000000000000000000000..e0af6f5173f367bb9b2e06de10499ee36806379c --- /dev/null +++ b/lite/kernels/host/one_hot_compute.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "lite/backends/fpga/KD/debugger.hpp" +#include "lite/kernels/host/one_hot_compute.h" +#include "lite/utils/paddle_enforce.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +void OneHotCompute::Run() { + auto& param = Param(); + param.Out->mutable_data(); + int depth = param.depth; + if (param.depth_tensor) { + auto* depth_tensor = param.depth_tensor; + auto* depth_data = depth_tensor->data(); + depth = depth_data[0]; + auto in_dims = param.X->dims(); + DDim out_dims(in_dims); + out_dims[out_dims.size() - 1] = depth; + param.Out->Resize(out_dims); + } + + auto* p_in_data = param.X->data(); + auto numel = param.X->numel(); + auto* p_out_data = param.Out->mutable_data(); + + for (int i = 0; i < param.Out->numel(); ++i) { + p_out_data[i] = 0; + } + + if (param.allow_out_of_range) { + for (int i = 0; i < numel; ++i) { + if (p_in_data[i] >= 0 && p_in_data[i] < param.depth) { + *(p_out_data + i * param.depth + (int)(p_in_data[i])) = 1.0; // NOLINT + } + } + } else { + for (int i = 0; i < numel; ++i) { + PADDLE_ENFORCE_GE( + p_in_data[i], 0, "Illegal index value, should be at least 0."); + PADDLE_ENFORCE_LT(p_in_data[i], + param.depth, + "Illegal index value, should be less than depth (%d).", + param.depth); + *(p_out_data + i * param.depth + (int)(p_in_data[i])) = 1.0; // NOLINT + } + } +} +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(one_hot, + kHost, + kFloat, + kNCHW, + paddle::lite::kernels::host::OneHotCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + .Finalize(); diff --git a/lite/kernels/host/one_hot_compute.h b/lite/kernels/host/one_hot_compute.h new file mode 100755 index 0000000000000000000000000000000000000000..3a6c47fee31bc28f130c3de782c0c912c9f4b769 --- /dev/null +++ b/lite/kernels/host/one_hot_compute.h @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +class OneHotCompute + : public KernelLite { + public: + void Run() override; + + virtual ~OneHotCompute() = default; +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt old mode 100644 new mode 100755 index 7c4048c204b0889f9a9bd72a7e94da3777441d37..754270865aa9ca0a0f1100f2d2dfbc01dead7068 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -128,6 +128,8 @@ add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS}) add_operator(sequence_topk_avg_pooling_op basic SRCS sequence_topk_avg_pooling_op.cc DEPS ${op_DEPS}) add_operator(search_fc_op basic SRCS search_fc_op.cc DEPS ${op_DEPS}) +add_operator(one_hot basic SRCS one_hot_op.cc DEPS ${op_DEPS}) + if (NOT LITE_WITH_X86) lite_cc_test(test_fc_op SRCS fc_op_test.cc DEPS fc_op memory diff --git a/lite/operators/one_hot_op.cc b/lite/operators/one_hot_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..023cdc23aeb8329736b7438af2c52cbfa899c75c --- /dev/null +++ b/lite/operators/one_hot_op.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/one_hot_op.h" +#include "lite/core/op_registry.h" + +#include "lite/backends/fpga/KD/debugger.hpp" + +namespace paddle { +namespace lite { +namespace operators { + +bool OneHotOp::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Out); + return true; +} + +bool OneHotOp::InferShape() const { + CHECK_OR_FALSE(param_.Out); + // TODO(Superjomn) Enable data sharing. + auto out_dims = param_.X->dims(); + + out_dims[out_dims.size() - 1] = param_.depth; + param_.Out->Resize(out_dims); + return true; +} + +bool OneHotOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + param_.X = + scope->FindVar(opdesc.Input("X").front())->GetMutable(); + param_.Out = + scope->FindVar(opdesc.Output("Out").front())->GetMutable(); + + if (opdesc.HasInput("depth_tensor")) { + auto depth_tensor = opdesc.Input("depth_tensor").front(); + param_.depth_tensor = + scope->FindVar(depth_tensor)->GetMutable(); + } + + CHECK(param_.X); + CHECK(param_.Out); + param_.depth = opdesc.GetAttr("depth"); + param_.dtype = opdesc.GetAttr("dtype"); + + if (opdesc.HasAttr("allow_out_of_range")) { + param_.allow_out_of_range = opdesc.GetAttr("allow_out_of_range"); + } + + auto out_lod = param_.Out->mutable_lod(); + *out_lod = param_.X->lod(); + // param_.allow_out_of_range = opdesc.GetAttr("allow_out_of_range"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(one_hot, paddle::lite::operators::OneHotOp); diff --git a/lite/operators/one_hot_op.h b/lite/operators/one_hot_op.h new file mode 100755 index 0000000000000000000000000000000000000000..4a0613952520279699a0f4a56d002483de325241 --- /dev/null +++ b/lite/operators/one_hot_op.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class OneHotOp : public OpLite { + public: + OneHotOp() {} + explicit OneHotOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "one_hot"; } + + private: + mutable OneHotParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 4f0c707484f6a66148dabc80968665c1d38de445..df2e3c1217c1795af95477f549f4dec63b3c41ef 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1056,6 +1056,16 @@ struct SearchGrnnParam { lite::Tensor* layout_input{}; }; +// --------------------- attentions operators -------------- +struct OneHotParam { + lite::Tensor* X{}; + lite::Tensor* depth_tensor{nullptr}; + lite::Tensor* Out{}; + int depth{-1}; + int dtype{}; + bool allow_out_of_range{false}; +}; + } // namespace operators } // namespace lite } // namespace paddle