提交 b4e07620 编写于 作者: C chonwhite

added one_hot operator

上级 b02b1822
...@@ -22,6 +22,8 @@ if (WITH_PADDLE_MOBILE) ...@@ -22,6 +22,8 @@ if (WITH_PADDLE_MOBILE)
return() return()
endif(WITH_PADDLE_MOBILE) endif(WITH_PADDLE_MOBILE)
# set(CMAKE_BUILD_TYPE DEBUG)
set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD 11)
......
...@@ -41,19 +41,20 @@ class Debugger { ...@@ -41,19 +41,20 @@ class Debugger {
private: private:
std::unordered_map<std::string, bool> op_config; std::unordered_map<std::string, bool> op_config;
Debugger() { Debugger() {
// op_config["concat"] = true; op_config["concat"] = true;
// op_config["pooling"] = true; op_config["pooling"] = true;
// op_config["conv"] = true; op_config["conv"] = true;
// op_config["crop"] = true; op_config["crop"] = true;
// op_config["feed"] = true; op_config["feed"] = true;
// op_config["fetch"] = true; op_config["mul"] = true;
// op_config["boxes"] = true; op_config["fetch"] = true;
// op_config["scores"] = true; op_config["boxes"] = true;
// op_config["nms"] = true; op_config["scores"] = true;
// op_config["pb_boxes"] = true; op_config["nms"] = true;
// op_config["pb_variances"] = true; op_config["pb_boxes"] = true;
// // op_config["fc"] = true; op_config["pb_variances"] = true;
// op_config["softmax"] = true; // op_config["fc"] = true;
op_config["softmax"] = true;
} }
}; };
......
...@@ -56,6 +56,10 @@ void CastCompute::Run() { ...@@ -56,6 +56,10 @@ void CastCompute::Run() {
float* out_data = param.Out->mutable_data<float>(); float* out_data = param.Out->mutable_data<float>();
std::transform( std::transform(
x_data_begin, x_data_end, out_data, TransOp<unsigned char, float>); x_data_begin, x_data_end, out_data, TransOp<unsigned char, float>);
} else if (param.in_dtype == 3 && param.out_dtype == 5) {
const auto* x_data = param.X->data<float>();
auto* o_data = param.Out->mutable_data<float>();
memcpy(o_data, x_data, sizeof(float) * param.X->numel());
} else { } else {
LOG(FATAL) << "other has not been implemented"; LOG(FATAL) << "other has not been implemented";
} }
......
...@@ -28,6 +28,7 @@ namespace arm { ...@@ -28,6 +28,7 @@ namespace arm {
void LookupTableCompute::Run() { void LookupTableCompute::Run() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
// inputs // inputs
auto w = param.W; auto w = param.W;
auto ids = param.Ids; auto ids = param.Ids;
...@@ -36,7 +37,7 @@ void LookupTableCompute::Run() { ...@@ -36,7 +37,7 @@ void LookupTableCompute::Run() {
auto table_dim = w->dims(); auto table_dim = w->dims();
int64_t ids_numel = ids->numel(); int64_t ids_numel = ids->numel();
auto ids_data = ids->data<int64_t>(); auto ids_data = ids->data<float>();
int64_t row_number = table_dim[0]; int64_t row_number = table_dim[0];
int64_t row_width = table_dim[1]; int64_t row_width = table_dim[1];
...@@ -75,14 +76,3 @@ REGISTER_LITE_KERNEL(lookup_table, ...@@ -75,14 +76,3 @@ REGISTER_LITE_KERNEL(lookup_table,
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(lookup_table_v2,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::LookupTableCompute,
def)
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
...@@ -7,7 +7,9 @@ set(fpga_deps fpga_target_wrapper kernel_fpga) ...@@ -7,7 +7,9 @@ set(fpga_deps fpga_target_wrapper kernel_fpga)
# add_kernel(activation_compute_fpga FPGA basic SRCS activation_compute.cc DEPS ${fpga_deps}) # add_kernel(activation_compute_fpga FPGA basic SRCS activation_compute.cc DEPS ${fpga_deps})
# add_kernel(box_coder_compute_fpga FPGA basic SRCS box_coder_compute.cc DEPS ${fpga_deps}) # add_kernel(box_coder_compute_fpga FPGA basic SRCS box_coder_compute.cc DEPS ${fpga_deps})
# add_kernel(concat_compute_fpga FPGA basic SRCS concat_compute.cc DEPS ${fpga_deps})
add_kernel(concat_compute_fpga FPGA basic SRCS concat_compute.cc DEPS ${fpga_deps})
add_kernel(conv_compute_fpga FPGA basic SRCS conv_compute.cc DEPS ${fpga_deps}) add_kernel(conv_compute_fpga FPGA basic SRCS conv_compute.cc DEPS ${fpga_deps})
# add_kernel(density_prior_box_compute_fpga FPGA basic SRCS density_prior_box_compute.cc DEPS ${fpga_deps}) # add_kernel(density_prior_box_compute_fpga FPGA basic SRCS density_prior_box_compute.cc DEPS ${fpga_deps})
add_kernel(dropout_compute_fpga FPGA basic SRCS dropout_compute.cc DEPS ${fpga_deps}) add_kernel(dropout_compute_fpga FPGA basic SRCS dropout_compute.cc DEPS ${fpga_deps})
...@@ -16,9 +18,12 @@ add_kernel(elementwise_compute_fpga FPGA basic SRCS elementwise_compute.cc DEPS ...@@ -16,9 +18,12 @@ add_kernel(elementwise_compute_fpga FPGA basic SRCS elementwise_compute.cc DEPS
add_kernel(fc_compute_fpga FPGA basic SRCS fc_compute.cc DEPS ${fpga_deps}) add_kernel(fc_compute_fpga FPGA basic SRCS fc_compute.cc DEPS ${fpga_deps})
add_kernel(gru_compute_fpga FPGA extra SRCS gru_compute.cc DEPS ${fpga_deps}) add_kernel(gru_compute_fpga FPGA extra SRCS gru_compute.cc DEPS ${fpga_deps})
# add_kernel(mul_compute_fpga FPGA basic SRCS mul_compute.cc DEPS ${fpga_deps}) # add_kernel(mul_compute_fpga FPGA basic SRCS mul_compute.cc DEPS ${fpga_deps})
add_kernel(multiclass_nms_compute_fpga FPGA basic SRCS multiclass_nms_compute.cc DEPS ${fpga_deps}) add_kernel(multiclass_nms_compute_fpga FPGA basic SRCS multiclass_nms_compute.cc DEPS ${fpga_deps})
add_kernel(norm_compute_fpga FPGA basic SRCS norm_compute.cc DEPS ${fpga_deps}) add_kernel(norm_compute_fpga FPGA basic SRCS norm_compute.cc DEPS ${fpga_deps})
# add_kernel(im2sequence_compute_fpga FPGA basic SRCS im2sequence_compute.cc DEPS ${fpga_deps}) # add_kernel(im2sequence_compute_fpga FPGA basic SRCS im2sequence_compute.cc DEPS ${fpga_deps})
add_kernel(pooling_compute_fpga FPGA basic SRCS pooling_compute.cc DEPS ${fpga_deps}) add_kernel(pooling_compute_fpga FPGA basic SRCS pooling_compute.cc DEPS ${fpga_deps})
add_kernel(prior_box_compute_fpga FPGA basic SRCS prior_box_compute.cc DEPS ${fpga_deps}) add_kernel(prior_box_compute_fpga FPGA basic SRCS prior_box_compute.cc DEPS ${fpga_deps})
......
...@@ -67,3 +67,13 @@ REGISTER_LITE_KERNEL( ...@@ -67,3 +67,13 @@ REGISTER_LITE_KERNEL(
PRECISION(kFP16), PRECISION(kFP16),
DATALAYOUT(kNHWC))}) DATALAYOUT(kNHWC))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(feed,
kFPGA,
kFP16,
kNHWC,
paddle::lite::kernels::fpga::FeedCompute,
def_host)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
...@@ -43,8 +43,15 @@ void FetchCompute::PrepareForRun() { ...@@ -43,8 +43,15 @@ void FetchCompute::PrepareForRun() {
} }
void FetchCompute::Run() { void FetchCompute::Run() {
pe_.dispatch();
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
auto fetch_list = param.fetch_list;
if (fetch_list->size() <= static_cast<size_t>(param.col)) {
fetch_list->resize(param.col + 1);
}
Tensor& out = param.fetch_list->at(param.col);
out.Resize(param.input->dims());
pe_.dispatch();
#ifdef FPGA_PRINT_TENSOR #ifdef FPGA_PRINT_TENSOR
zynqmp::OutputParam& fetch_param = pe_.param(); zynqmp::OutputParam& fetch_param = pe_.param();
...@@ -67,10 +74,7 @@ REGISTER_LITE_KERNEL(fetch, ...@@ -67,10 +74,7 @@ REGISTER_LITE_KERNEL(fetch,
{LiteType::GetTensorTy(TARGET(kFPGA), {LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kAny), PRECISION(kAny),
DATALAYOUT(kAny))}) DATALAYOUT(kAny))})
.BindOutput("Out", .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(fetch, REGISTER_LITE_KERNEL(fetch,
...@@ -79,12 +83,6 @@ REGISTER_LITE_KERNEL(fetch, ...@@ -79,12 +83,6 @@ REGISTER_LITE_KERNEL(fetch,
kNHWC, kNHWC,
paddle::lite::kernels::fpga::FetchCompute, paddle::lite::kernels::fpga::FetchCompute,
host_host) host_host)
.BindInput("X", .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
{LiteType::GetTensorTy(TARGET(kHost), .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize(); .Finalize();
...@@ -122,7 +122,81 @@ class IoCopyFpgaToHostCompute ...@@ -122,7 +122,81 @@ class IoCopyFpgaToHostCompute
// param.x->ZynqTensor()->saveToFile("io_x", true); // param.x->ZynqTensor()->saveToFile("io_x", true);
// param.y->ZynqTensor()->saveToFile("io_y", true); // param.y->ZynqTensor()->saveToFile("io_y", true);
} }
std::string doc() const override { return "Copy IO from FPGA to HOST"; }
};
void hwc_to_chw(float* chw_data,
float* hwc_data,
int num,
int channel,
int height,
int width) {
int chw = channel * height * width;
int wc = width * channel;
int wh = width * height;
int index = 0;
for (int n = 0; n < num; n++) {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
for (int c = 0; c < channel; c++) {
chw_data[n * chw + c * wh + h * width + w] = hwc_data[index];
index++;
}
}
}
}
}
class IoCopyFpgaToHostCHWCompute
: public KernelLite<TARGET(kFPGA), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
void Run() override {
auto& param = Param<operators::IoCopyParam>();
CHECK(param.x->target() == TARGET(kHost) ||
param.x->target() == TARGET(kFPGA));
Tensor hwc;
hwc.Resize(param.y->dims());
float* hwc_data = hwc.mutable_data<float>();
float* chw_data = param.y->mutable_data<float>();
param.y->ZynqTensor()->setDataType(zynqmp::FP32);
param.x->ZynqTensor()->syncToDevice();
if (param.x->ZynqTensor()->aligned() &&
param.x->ZynqTensor()->shape().shouldAlign()) {
zynqmp::Tensor tempTensor;
tempTensor.mutableData<float16>(zynqmp::FP16,
param.x->ZynqTensor()->shape());
tempTensor.copyFrom(param.x->ZynqTensor());
tempTensor.setAligned(true);
tempTensor.unalignImage();
hwc.ZynqTensor()->copyFrom(&tempTensor);
} else {
hwc.ZynqTensor()->copyFrom(param.x->ZynqTensor());
}
int num = 1;
int channel = 1;
int height = 1;
int width = 1;
auto dims = param.y->ZynqTensor()->shape();
hwc_to_chw(chw_data,
hwc_data,
dims.num(),
dims.channel(),
dims.height(),
dims.width());
param.y->ZynqTensor()->copyScaleFrom(param.x->ZynqTensor());
param.y->ZynqTensor()->flush();
auto out_lod = param.y->mutable_lod();
*out_lod = param.x->lod();
// param.x->ZynqTensor()->saveToFile("io_x", true);
// param.y->ZynqTensor()->saveToFile("io_y", true);
}
std::string doc() const override { return "Copy IO from FPGA to HOST"; } std::string doc() const override { return "Copy IO from FPGA to HOST"; }
}; };
...@@ -173,7 +247,7 @@ REGISTER_LITE_KERNEL(io_copy, ...@@ -173,7 +247,7 @@ REGISTER_LITE_KERNEL(io_copy,
PRECISION(kFP16), PRECISION(kFP16),
DATALAYOUT(kNHWC))}) DATALAYOUT(kNHWC))})
.BindOutput("Out", .BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM), {LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat), PRECISION(kFloat),
DATALAYOUT(kNHWC))}) DATALAYOUT(kNHWC))})
.Finalize(); .Finalize();
...@@ -182,7 +256,7 @@ REGISTER_LITE_KERNEL(io_copy, ...@@ -182,7 +256,7 @@ REGISTER_LITE_KERNEL(io_copy,
kFPGA, kFPGA,
kAny, kAny,
kAny, kAny,
paddle::lite::kernels::fpga::IoCopyFpgaToHostCompute, paddle::lite::kernels::fpga::IoCopyFpgaToHostCHWCompute,
device_to_host_22) device_to_host_22)
.BindInput("Input", .BindInput("Input",
{LiteType::GetTensorTy(TARGET(kFPGA), {LiteType::GetTensorTy(TARGET(kFPGA),
......
...@@ -379,11 +379,13 @@ void MulticlassNmsCompute::Run() { ...@@ -379,11 +379,13 @@ void MulticlassNmsCompute::Run() {
if (e > s) { if (e > s) {
Tensor out; Tensor out;
std::cout << "Slice:" << s << " -- " << e << std::endl;
outs->Slice<float>(out, s, e); outs->Slice<float>(out, s, e);
MultiClassOutput<float>( MultiClassOutput<float>(
scores_slice, boxes_slice, all_indices[i], score_dims.size(), &out); scores_slice, boxes_slice, all_indices[i], score_dims.size(), &out);
outs->ZynqTensor()->copyFrom(out.ZynqTensor()); outs->ZynqTensor()->copyFrom(out.ZynqTensor());
} }
outs->Resize({static_cast<int64_t>(e - s), out_dim});
} }
} }
LoD lod; LoD lod;
...@@ -412,19 +414,19 @@ REGISTER_LITE_KERNEL(multiclass_nms, ...@@ -412,19 +414,19 @@ REGISTER_LITE_KERNEL(multiclass_nms,
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(multiclass_nms, // REGISTER_LITE_KERNEL(multiclass_nms,
kFPGA, // kFPGA,
kFP16, // kFP16,
kNHWC, // kNHWC,
paddle::lite::kernels::fpga::MulticlassNmsCompute, // paddle::lite::kernels::fpga::MulticlassNmsCompute,
def2) // def2)
.BindInput("BBoxes", // .BindInput("BBoxes",
{LiteType::GetTensorTy(TARGET(kFPGA), // {LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16), // PRECISION(kFP16),
DATALAYOUT(kNHWC))}) // DATALAYOUT(kNHWC))})
.BindInput("Scores", // .BindInput("Scores",
{LiteType::GetTensorTy(TARGET(kFPGA), // {LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16), // PRECISION(kFP16),
DATALAYOUT(kNHWC))}) // DATALAYOUT(kNHWC))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) // .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize(); // .Finalize();
...@@ -4,6 +4,7 @@ add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_ ...@@ -4,6 +4,7 @@ add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_
add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps}) add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps})
add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op) add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op)
add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps}) add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps})
add_kernel(one_hot_compute_host Host basic SRCS one_hot_compute.cc DEPS ${lite_kernel_deps})
#lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any) #lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any)
#lite_cc_test(test_multiclass_nms_compute_host SRCS multiclass_nms_compute_test.cc DEPS multiclass_nms_compute_host any) #lite_cc_test(test_multiclass_nms_compute_host SRCS multiclass_nms_compute_test.cc DEPS multiclass_nms_compute_host any)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include <utility>
#include <vector>
#include "lite/backends/fpga/KD/debugger.hpp"
#include "lite/kernels/host/one_hot_compute.h"
#include "lite/utils/paddle_enforce.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
void OneHotCompute::Run() {
auto& param = Param<operators::OneHotParam>();
param.Out->mutable_data<float>();
int depth = param.depth;
if (param.depth_tensor) {
auto* depth_tensor = param.depth_tensor;
auto* depth_data = depth_tensor->data<int32_t>();
depth = depth_data[0];
auto in_dims = param.X->dims();
DDim out_dims(in_dims);
out_dims[out_dims.size() - 1] = depth;
param.Out->Resize(out_dims);
}
auto* p_in_data = param.X->data<float>();
auto numel = param.X->numel();
auto* p_out_data = param.Out->mutable_data<float>();
for (int i = 0; i < param.Out->numel(); ++i) {
p_out_data[i] = 0;
}
if (param.allow_out_of_range) {
for (int i = 0; i < numel; ++i) {
if (p_in_data[i] >= 0 && p_in_data[i] < param.depth) {
*(p_out_data + i * param.depth + (int)(p_in_data[i])) = 1.0; // NOLINT
}
}
} else {
for (int i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE(
p_in_data[i], 0, "Illegal index value, should be at least 0.");
PADDLE_ENFORCE_LT(p_in_data[i],
param.depth,
"Illegal index value, should be less than depth (%d).",
param.depth);
*(p_out_data + i * param.depth + (int)(p_in_data[i])) = 1.0; // NOLINT
}
}
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(one_hot,
kHost,
kFloat,
kNCHW,
paddle::lite::kernels::host::OneHotCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class OneHotCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
void Run() override;
virtual ~OneHotCompute() = default;
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -128,6 +128,8 @@ add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS}) ...@@ -128,6 +128,8 @@ add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS})
add_operator(sequence_topk_avg_pooling_op basic SRCS sequence_topk_avg_pooling_op.cc DEPS ${op_DEPS}) add_operator(sequence_topk_avg_pooling_op basic SRCS sequence_topk_avg_pooling_op.cc DEPS ${op_DEPS})
add_operator(search_fc_op basic SRCS search_fc_op.cc DEPS ${op_DEPS}) add_operator(search_fc_op basic SRCS search_fc_op.cc DEPS ${op_DEPS})
add_operator(one_hot basic SRCS one_hot_op.cc DEPS ${op_DEPS})
if (NOT LITE_WITH_X86) if (NOT LITE_WITH_X86)
lite_cc_test(test_fc_op SRCS fc_op_test.cc lite_cc_test(test_fc_op SRCS fc_op_test.cc
DEPS fc_op memory DEPS fc_op memory
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/one_hot_op.h"
#include "lite/core/op_registry.h"
#include "lite/backends/fpga/KD/debugger.hpp"
namespace paddle {
namespace lite {
namespace operators {
bool OneHotOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
return true;
}
bool OneHotOp::InferShape() const {
CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing.
auto out_dims = param_.X->dims();
out_dims[out_dims.size() - 1] = param_.depth;
param_.Out->Resize(out_dims);
return true;
}
bool OneHotOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.Out =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
if (opdesc.HasInput("depth_tensor")) {
auto depth_tensor = opdesc.Input("depth_tensor").front();
param_.depth_tensor =
scope->FindVar(depth_tensor)->GetMutable<lite::Tensor>();
}
CHECK(param_.X);
CHECK(param_.Out);
param_.depth = opdesc.GetAttr<int>("depth");
param_.dtype = opdesc.GetAttr<int>("dtype");
if (opdesc.HasAttr("allow_out_of_range")) {
param_.allow_out_of_range = opdesc.GetAttr<bool>("allow_out_of_range");
}
auto out_lod = param_.Out->mutable_lod();
*out_lod = param_.X->lod();
// param_.allow_out_of_range = opdesc.GetAttr<bool>("allow_out_of_range");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(one_hot, paddle::lite::operators::OneHotOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class OneHotOp : public OpLite {
public:
OneHotOp() {}
explicit OneHotOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "one_hot"; }
private:
mutable OneHotParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -1056,6 +1056,16 @@ struct SearchGrnnParam { ...@@ -1056,6 +1056,16 @@ struct SearchGrnnParam {
lite::Tensor* layout_input{}; lite::Tensor* layout_input{};
}; };
// --------------------- attentions operators --------------
struct OneHotParam {
lite::Tensor* X{};
lite::Tensor* depth_tensor{nullptr};
lite::Tensor* Out{};
int depth{-1};
int dtype{};
bool allow_out_of_range{false};
};
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册