提交 02fb420f 编写于 作者: T tienfeek

FPGA support quantitative model

test=develop
上级 0699ee90
......@@ -223,14 +223,24 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING)
CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels})
lite_cc_test(test_ssd_fpga SRCS test_ssd_fpga.cc
DEPS ${lite_model_test_DEPS}
CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels})
lite_cc_test(test_inceptionv3_fpga SRCS inceptionv3_test_fpga.cc
DEPS ${lite_model_test_DEPS}
CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels})
lite_cc_test(test_inceptionv4 SRCS inceptionv4_test.cc
DEPS ${lite_model_test_DEPS}
CL_DEPS ${opencl_kernels}
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl
--model_dir=${LITE_MODEL_DIR}/inception_v4 SERIAL)
add_dependencies(test_inceptionv4 extern_lite_download_inception_v4_simple_tar_gz)
# lite_cc_test(test_ocr_attention SRCS ocr_attention_test.cc
# DEPS ${lite_model_test_DEPS})
lite_cc_test(test_ocr_attention_fpga SRCS ocr_attention_test_fpga.cc
DEPS ${lite_model_test_DEPS})
# lite_cc_test(model_run_test_image SRCS model_run_test_image.cc
# DEPS ${lite_model_test_DEPS}
......
......@@ -41,6 +41,7 @@ USE_MIR_PASS(lite_quant_dequant_fuse_pass);
USE_MIR_PASS(type_precision_cast_pass);
USE_MIR_PASS(type_layout_cast_pass);
USE_MIR_PASS(memory_optimize_pass);
USE_MIR_PASS(kernel_place_correct_pass)
USE_MIR_PASS(elementwise_mul_constant_eliminate_pass)
USE_MIR_PASS(npu_subgraph_pass);
USE_MIR_PASS(xpu_subgraph_pass);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <dirent.h>
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "lite/api/cxx_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/core/op_registry.h"
DEFINE_string(input_file, "", "input_file");
namespace paddle {
namespace lite {
std::vector<std::string> GetDirectoryFiles(const std::string& dir) {
std::vector<std::string> files;
std::shared_ptr<DIR> directory_ptr(opendir(dir.c_str()),
[](DIR* dir) { dir&& closedir(dir); });
struct dirent* dirent_ptr;
if (!directory_ptr) {
std::cout << "Error opening : " << std::strerror(errno) << dir << std::endl;
return files;
}
while ((dirent_ptr = readdir(directory_ptr.get())) != nullptr) {
files.push_back(std::string(dirent_ptr->d_name));
}
return files;
}
void readFromFile(int num, std::string path, float* data) {
std::ifstream file_stream(path);
// file_stream.open(path);
if (!file_stream.good()) {
std::cout << "file: " << path << " dones not exist!\n";
exit(-1);
return;
}
// float* data = mutableData<float>();
for (int i = 0; i < num; ++i) {
float value = 0;
file_stream >> value;
data[i] = value;
}
file_stream.close();
}
// #ifdef LITE_WITH_FPGA
TEST(ResNet50, test) {
lite::Predictor predictor;
std::vector<Place> valid_places({
Place{TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)},
Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)},
});
// predictor.Build(FLAGS_model_dir, "", "", valid_places);
predictor.Build("",
FLAGS_model_dir + "/model",
FLAGS_model_dir + "/params",
valid_places);
auto* input_tensor = predictor.GetInput(0);
int width = 300;
int height = 300;
// std::ifstream file_stream(FLAGS_input_file);
// if (!file_stream.good()) {
// std::cout << "file: " << FLAGS_input_file << " dones not exist!\n";
// exit(-1);
// return;
// }
// file_stream >> height;
// file_stream >> width;
input_tensor->Resize(
DDim(std::vector<DDim::value_type>({1, 3, height, width})));
auto* data = input_tensor->mutable_data<float>();
auto item_size = input_tensor->dims().production();
for (int i = 0; i < item_size; i++) {
data[i] = 1;
}
// readFromFile(item_size, "car.data", data);
int num = 3 * width * height;
// for (int i = 0; i < num; ++i) {
// float value = 0;
// file_stream >> value;
// data[i] = value;
// }
// file_stream.close();
for (int i = 0; i < 2; ++i) {
predictor.Run();
}
auto* out = predictor.GetOutput(0);
for (int i = 0; i < out->dims().production(); i++) {
std::cout << ":" << out->data<float>()[i] << std::endl;
}
std::string file = "output/" + FLAGS_input_file.substr(6);
std::cout << "file:::" << file << std::endl;
std::ofstream ofs;
ofs.open(file);
for (int i = 0; i < out->dims().production(); i++) {
float value = out->data<float>()[i];
ofs << value << std::endl;
}
ofs.close();
LOG(INFO) << "================== Speed Report ===================";
}
// #endif
} // namespace lite
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "lite/backends/fpga/KD/fpga_cv.hpp"
using paddle::zynqmp::float16;
void fpga_resize(float* input,
int input_width,
int input_height,
int input_channel,
uint8_t* output,
int output_width,
int output_height) {
paddle::zynqmp::InplaceArgs inplace_args = {0, 0, 0};
paddle::zynqmp::config_inplace(inplace_args);
paddle::zynqmp::ImageInputArgs input_args = {nullptr};
input_args.address = nullptr;
input_args.scale_address = nullptr;
float16* input_image_address =
reinterpret_cast<float16*>(paddle::zynqmp::fpga_malloc(
input_width * input_height * input_channel * sizeof(float16)));
int index = 0;
for (int i = 0; i < input_width * input_height * input_channel; i++) {
input_image_address[i] = float16(1.0 * input[i]);
}
paddle::zynqmp::ResizeArgs resize_args = {0};
resize_args.input_width = input_width;
resize_args.input_height = input_height;
resize_args.image_channel = input_channel;
resize_args.output_width = output_width;
resize_args.output_height = output_height;
float height_ratio = static_cast<float>(input_height) /
static_cast<float>(resize_args.output_height);
float width_ratio = static_cast<float>(input_width) /
static_cast<float>(resize_args.output_width);
resize_args.height_ratio = *reinterpret_cast<uint32_t*>(&height_ratio);
resize_args.width_ratio = *reinterpret_cast<uint32_t*>(&width_ratio);
int output_size =
resize_args.output_width * resize_args.output_height * input_channel;
float16* fpga_output = reinterpret_cast<float16*>(
paddle::zynqmp::fpga_malloc(output_size * sizeof(float16)));
resize_args.input_image_address = input_image_address;
resize_args.output_image_address = fpga_output;
memset(fpga_output, 0, output_size * sizeof(float16));
paddle::zynqmp::fpga_flush(
input_image_address,
input_width * input_height * input_channel * sizeof(float16));
paddle::zynqmp::fpga_flush(resize_args.output_image_address,
output_size * sizeof(float16));
int ret = paddle::zynqmp::compute_fpga_resize(resize_args);
if (ret == 0) {
paddle::zynqmp::fpga_invalidate(resize_args.output_image_address,
output_size * sizeof(float16));
}
for (int i = 0; i < output_size; i++) {
output[i] = fpga_output[i];
}
}
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdlib.h>
#include "lite/backends/fpga/KD/float16.hpp"
#include "lite/backends/fpga/KD/llapi/zynqmp_api.h"
#include "lite/backends/fpga/KD/pe.hpp"
void fpga_resize(float* input,
int input_width,
int input_height,
int input_channel,
uint8_t* output,
int output_width,
int output_height);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#define PADDLE_LITE_ZU5
#define FPGA_PRINT_MODE
#define PADDLE_LITE_PROFILE
......@@ -25,6 +25,7 @@ lite_cc_library(mir_passes
elimination/elementwise_mul_constant_eliminate_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
kernel_place_correct_pass.cc
type_target_cast_pass.cc
type_layout_cast_pass.cc
type_precision_cast_pass.cc
......
......@@ -27,10 +27,24 @@ namespace mir {
void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// delete quant node
std::vector<std::string> quant_op_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
"fake_quantize_abs_max",
"fake_quantize_range_abs_max",
"fake_quantize_moving_average_abs_max"};
/*
for (auto& op_type : {"conv2d", "mul", "depthwise_conv2d"}) {
for (int i = 5; i >= 1; --i){
fusion::DynamicQuantDequantOpFuser fuser("fake_quantize_abs_max", op_type,
i);
fuser(graph.get());
}
}
*/
for (auto& op_type : quant_op_types) {
fusion::DeleteQuantOpFuser fuser(op_type);
fuser(graph.get());
fusion::DeleteDynamicQuantOpFuser dfuser(op_type);
dfuser(graph.get());
}
// fuse quantized node and dequant node
......
......@@ -77,6 +77,55 @@ cpp::OpDesc DeleteQuantOpFuser::GenOpDesc(const key2nodes_t& matched) {
return op_desc;
}
void DeleteDynamicQuantOpFuser::BuildPattern() {
auto* input_act_node =
VarNode("input_act_node")->assert_is_op_input(quant_op_type_, "X");
auto* quant_node =
OpNode("quant_node", quant_op_type_)->assert_is_op(quant_op_type_);
auto* output_scale_node =
VarNode("output_scale_node")
->assert_is_op_output(quant_op_type_, "OutScale");
auto* output_act_node =
VarNode("output_act_node")->assert_is_op_output(quant_op_type_, "Out");
quant_node->LinksFrom({input_act_node});
output_scale_node->LinksFrom({quant_node});
output_act_node->LinksFrom({quant_node});
VLOG(4) << "DeleteQuantOpFuser BuildPattern quant_op_type:" << quant_op_type_;
}
void DeleteDynamicQuantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto* input_act_node = matched.at("input_act_node");
auto* quant_node = matched.at("quant_node");
auto* output_scale_node = matched.at("output_scale_node");
auto* output_act_node = matched.at("output_act_node");
// obtain values, save values and relink node
int bit_length = quant_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_node->stmt()->op()->scope();
auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
auto outlinks = output_act_node->outlinks;
for (auto* quantized_node : outlinks) {
auto* op_desc = quantized_node->stmt()->mutable_op_info();
op_desc->SetAttr<int>("bit_length", bit_length);
IR_NODE_LINK_TO(input_act_node, quantized_node)
}
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {
quant_node, output_scale_node, output_act_node};
GraphSafeRemoveNodes(graph, nodes2rm);
}
cpp::OpDesc DeleteDynamicQuantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
return op_desc;
}
void DequantOpFuser::BuildPattern() {
std::string weight_name = "";
if (quantized_op_type_ == "conv2d" ||
......@@ -130,8 +179,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
auto& valid_places = quantized_op->stmt()->op()->valid_places();
int bit_length = quantized_op->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
float input_scale =
quantized_op->stmt()->op_info()->GetAttr<float>("input_scale");
float input_scale = 0;
if (quantized_op->stmt()->op_info()->HasAttr("input_scale")) {
input_scale =
quantized_op->stmt()->op_info()->GetAttr<float>("input_scale");
}
float max_range = dequant_op->stmt()->op_info()->GetAttr<float>("max_range");
float whole_weight_scale =
static_cast<float>(range * range) / max_range / range;
......@@ -162,8 +214,12 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
for (int i = 0; i < weight_scale_size; i++) {
weight_scale.push_back(whole_weight_scale);
}
#ifndef LITE_WITH_FPGA
op_desc.SetAttr("enable_int8", true);
op_desc.SetAttr("input_scale", input_scale);
#endif
if (quantized_op->stmt()->op_info()->HasAttr("input_scale")) {
op_desc.SetAttr("input_scale", input_scale);
}
op_desc.SetAttr("weight_scale", weight_scale);
// change the weight from the float type to int8 type.
......@@ -171,12 +227,29 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
temp_tensor.CopyDataFrom(*quantized_weight_t);
float* temp_data = temp_tensor.mutable_data<float>();
size_t weight_num = quantized_weight_t->data_size();
#ifdef LITE_WITH_FPGA
float* quantized_weight_data = quantized_weight_t->mutable_data<float>();
for (size_t i = 0; i < weight_num; i++) {
quantized_weight_data[i] = temp_data[i] * whole_weight_scale;
}
quantized_weight_t->set_persistable(true);
quantized_weight_t->set_precision(PRECISION(kFloat));
#else
int8_t* quantized_weight_data = quantized_weight_t->mutable_data<int8_t>();
for (size_t i = 0; i < weight_num; i++) {
quantized_weight_data[i] = static_cast<int8_t>(temp_data[i]);
}
quantized_weight_t->set_persistable(true);
quantized_weight_t->set_precision(PRECISION(kInt8));
#endif
// int8_t* quantized_weight_data = quantized_weight_t->mutable_data<int8_t>();
// for (size_t i = 0; i < weight_num; i++) {
// quantized_weight_data[i] = static_cast<int8_t>(temp_data[i]);
// }
// quantized_weight_t->set_persistable(true);
// quantized_weight_t->set_precision(PRECISION(kInt8));
// new op and relink nodes
auto new_quantized_op = LiteOpRegistry::Global().Create(quantized_op_type_);
......@@ -464,6 +537,197 @@ cpp::OpDesc DeleteQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
return op_desc;
}
// ================dynamic quant fuse==============
// #define DYNAMIC_RANGE
void DynamicQuantDequantOpFuser::BuildPattern() {
const int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kQuantizedOpOutOffset = 2;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
std::string weight_name = "";
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") {
weight_name = "Filter";
} else {
weight_name = "Y";
}
auto* quant_op_input = VarNode("quant_op_input")
->assert_is_op_input(quant_type_, "X")
->AsInput();
#ifdef DYNAMIC_RANGE
auto* quant_op_in_scale = VarNode("quant_op_in_scale")
->assert_is_op_input(quant_type_, "InScale")
->AsIntermediate();
#endif
auto* quant_op = OpNode("quant_op", quant_type_)
->assert_is_op(quant_type_)
->AsIntermediate();
auto* quant_op_out_scale =
VarNode("quant_op_out_scale")
->assert_is_op_output(quant_type_, "OutScale")
->assert_is_op_input("fake_dequantize_max_abs", "Scale")
->AsIntermediate();
auto* quant_op_out = VarNode("quant_op_out")
->assert_is_op_output(quant_type_, "Out")
->assert_is_op_input(op_type_)
->AsIntermediate();
std::vector<PMNode*> nodes;
for (int i = 0; i < times_; i++) {
nodes.push_back(VarNode(string_format("quantized_op_weight%d", i))
->assert_is_op_input(op_type_, weight_name)
->AsInput());
nodes.push_back(OpNode(string_format("quantized_op%d", i), op_type_)
->assert_is_op(op_type_)
->AsIntermediate());
nodes.push_back(VarNode(string_format("quantized_op_out%d", i))
->assert_is_op_output(op_type_)
->assert_is_op_input("fake_dequantize_max_abs", "X")
->AsIntermediate());
nodes.push_back(
OpNode(string_format("dequant_op%d", i), "fake_dequantize_max_abs")
->assert_is_op("fake_dequantize_max_abs")
->AsIntermediate());
nodes.push_back(VarNode(string_format("dequant_op_out%d", i))
->assert_is_op_output("fake_dequantize_max_abs", "Out")
->AsOutput());
}
#ifdef DYNAMIC_RANGE
quant_op->LinksFrom({quant_op_input, quant_op_in_scale});
#endif
quant_op->LinksFrom({quant_op_input});
quant_op_out->LinksFrom({quant_op});
quant_op_out_scale->LinksFrom({quant_op});
for (int i = 0; i < times_; i++) {
nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom(
{quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]});
nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOffset]});
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale});
nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kDequantOpOffset]});
}
}
void DynamicQuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
const int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
auto* quant_op_input = matched.at("quant_op_input");
#ifdef DYNAMIC_RANGE
auto* quant_op_in_scale = matched.at("quant_op_in_scale");
#endif
auto* quant_op = matched.at("quant_op");
std::vector<Node*> nodes;
for (int i = 0; i < times_; i++) {
nodes.push_back(matched.at(string_format("quantized_op_weight%d", i)));
nodes.push_back(matched.at(string_format("quantized_op%d", i)));
nodes.push_back(matched.at(string_format("quantized_op_out%d", i)));
nodes.push_back(matched.at(string_format("dequant_op%d", i)));
nodes.push_back(matched.at(string_format("dequant_op_out%d", i)));
}
int bit_length = quant_op->stmt()->op_info()->GetAttr<int>("bit_length");
auto* scope = quant_op->stmt()->op()->scope();
auto& valid_places = quant_op->stmt()->op()->valid_places();
int range = ((1 << (bit_length - 1)) - 1);
#ifdef DYNAMIC_RANGE
auto input_scale_t = scope->FindVar(quant_op_in_scale->arg()->name)
->GetMutable<lite::Tensor>();
float input_scale = input_scale_t->data<float>()[0] / range;
VLOG(4) << "range: " << range << " input_scale: " << input_scale;
#endif
for (int i = 0; i < times_; i++) {
float max_range = nodes[i * kNumFields + kDequantOpOffset]
->stmt()
->op_info()
->GetAttr<float>("max_range");
// weight_scale = max(abs(weight))
float whole_weight_scale =
static_cast<float>(range * range) / max_range / range;
cpp::OpDesc op_desc =
*nodes[i * kNumFields + kQuantizedOpOffset]->stmt()->op_info();
auto quantized_weight_var_name =
nodes[i * kNumFields + kQuantizedWeightOffset]->arg()->name;
auto quantized_weight_t =
scope->FindVar(quantized_weight_var_name)->GetMutable<lite::Tensor>();
std::vector<float> weight_scale;
int weight_scale_size;
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") {
op_desc.SetInput("Input", {matched.at("quant_op_input")->arg()->name});
op_desc.SetOutput(
"Output", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name});
// Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should
// be Cout.
weight_scale_size = quantized_weight_t->dims()[0];
} else if (op_type_ == "mul") {
op_desc.SetInput("X", {matched.at("quant_op_input")->arg()->name});
op_desc.SetOutput(
"Out", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name});
// Fc weight: Cin * Cout, the weight_scale_size should be Cout.
weight_scale_size = quantized_weight_t->dims()[1];
}
for (int i = 0; i < weight_scale_size; i++) {
weight_scale.push_back(whole_weight_scale);
}
// op_desc.SetAttr("enable_int8", true);
// op_desc.SetAttr("input_scale", input_scale);
op_desc.SetAttr("weight_scale", weight_scale);
Tensor temp_tensor;
temp_tensor.CopyDataFrom(*quantized_weight_t);
float* temp_data = temp_tensor.mutable_data<float>();
size_t weight_num = quantized_weight_t->data_size();
quantized_weight_t->set_persistable(true);
std::cout << "DynamicQuantDequantOpFuser::InsertNewNode===================="
"========================================"
<< std::endl;
#ifdef LITE_WITH_FPGA
float* quantized_weight_data = quantized_weight_t->mutable_data<float>();
for (size_t i = 0; i < weight_num; i++) {
quantized_weight_data[i] = temp_data[i] * whole_weight_scale;
std::cout << whole_weight_scale << "," << temp_data[i] << ","
<< quantized_weight_data[i] << std::endl;
}
quantized_weight_t->set_precision(PRECISION(kFloat));
#else
int8_t* quantized_weight_data = quantized_weight_t->mutable_data<int8_t>();
for (size_t i = 0; i < weight_num; i++) {
quantized_weight_data[i] = static_cast<int8_t>(temp_data[i]);
}
quantized_weight_t->set_precision(PRECISION(kInt8));
#endif
auto quantized_op = LiteOpRegistry::Global().Create(op_type_);
quantized_op->Attach(op_desc, scope);
auto* new_op_node =
graph->GraphCreateInstructNode(quantized_op, valid_places);
IR_NODE_LINK_TO(quant_op_input, new_op_node);
IR_NODE_LINK_TO(nodes[i * kNumFields + kQuantizedWeightOffset],
new_op_node);
IR_NODE_LINK_TO(new_op_node, nodes[i * kNumFields + kDequantOpOutOffset]);
}
}
cpp::OpDesc DynamicQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
return op_desc;
}
} // namespace fusion
} // namespace mir
......
......@@ -52,6 +52,19 @@ class DeleteQuantOpFuser : public FuseBase {
private:
std::string quant_op_type_{};
};
class DeleteDynamicQuantOpFuser : public FuseBase {
public:
explicit DeleteDynamicQuantOpFuser(const std::string& quant_op_type)
: quant_op_type_(quant_op_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private:
std::string quant_op_type_{};
};
/* DequantOpFuser process conv2d/depthwise_conv2d/mul + fake_dequantize_max_abs.
*/
......@@ -106,6 +119,24 @@ class DeleteQuantDequantOpFuser : public FuseBase {
private:
std::string quantized_op_type_{};
};
// dynamic quantdequant op fuser
class DynamicQuantDequantOpFuser : public FuseBase {
public:
explicit DynamicQuantDequantOpFuser(const std::string& quantized_op_type,
const std::string& op_type,
int i)
: op_type_(op_type), quant_type_(quantized_op_type), times_(i) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private:
std::string op_type_{};
std::string quant_type_{};
int times_{1};
};
} // namespace fusion
} // namespace mir
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/kernel_place_correct_pass.h"
#include <memory>
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void KernelPlaceCorrectPass::Apply(const std::unique_ptr<SSAGraph> &graph) {
CorrectArgumentPlace(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(kernel_place_correct_pass,
paddle::lite::mir::KernelPlaceCorrectPass)
.BindTargets({TARGET(kFPGA)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "lite/core/mir/pass.h"
#include "lite/core/target_wrapper.h"
namespace paddle {
namespace lite {
namespace mir {
/*
* Correct the place of the variables in the SSAGrpah, it will inference the
* variables' place by the kernels outputs them.
*/
class KernelPlaceCorrectPass : public DebugPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
private:
void CorrectArgumentPlace(SSAGraph* graph) {
auto& valid_places = graph->valid_places();
auto valid_places_has_target = [&](TargetType t) -> bool {
for (auto& p : valid_places) {
if (p.target == t) {
return true;
}
}
return false;
};
std::map<std::string, bool> lite_with_targets{
{"kOpenCL", valid_places_has_target(TARGET(kOpenCL))},
{"kFPGA", valid_places_has_target(TARGET(kFPGA))}};
VLOG(4) << "lite_with_targets['kOpenCL']:" << lite_with_targets["kOpenCL"];
VLOG(4) << "lite_with_targets['kFPGA']:" << lite_with_targets["kFPGA"];
VLOG(3) << "param-type-registry:\n" << ParamTypeRegistry::Global();
for (auto& x : graph->StmtTopologicalOrder()) {
auto& inst = x->AsStmt();
// The IoCopyOp is a tool operator, it won't support the type inference.
// in fpga, we has io_copy+cali+layout tool ops, so we need type inference
// for
// tool operator
if ((!lite_with_targets["kFPGA"]) && (!lite_with_targets["kOpenCL"])) {
VLOG(3) << "inst.op_type() == 'io_copy', continue";
if (inst.op_type() == "io_copy") continue;
}
// deal with inputs
VLOG(4) << "checking op " << inst.op_info()->Repr();
auto get_argname = [&](
const std::string& node_name,
const std::map<std::string, std::vector<std::string>>& argname_map)
-> std::string {
for (auto& ele : argname_map) {
auto it =
std::find(ele.second.begin(), ele.second.end(), node_name);
if (it != ele.second.end()) return ele.first;
}
return "";
};
bool need_correct_place = true;
std::vector<TargetType> in_types;
std::vector<TargetType> out_types;
for (auto* x_in : x->inlinks) {
std::string node_name = x_in->AsArg().name;
std::string arg_name = get_argname(node_name, inst.op_info()->inputs());
CHECK(arg_name.size() > 0) << "can not found op arguments for node "
<< node_name;
VLOG(4) << "-- input arg_name:" << arg_name << " "
<< "-- node name:" << node_name;
auto type = inst.picked_kernel().GetInputDeclType(arg_name);
if (!x_in->AsArg().type) {
need_correct_place &= false;
} else {
if (in_types.empty()) {
in_types.push_back(x_in->AsArg().type->target());
} else {
if (in_types[0] != x_in->AsArg().type->target()) {
need_correct_place &= false;
}
}
}
}
for (auto* x_out : x->outlinks) {
std::string node_name = x_out->AsArg().name;
std::string arg_name =
get_argname(node_name, inst.op_info()->outputs());
CHECK(arg_name.size() > 0) << "can not found op arguments for node "
<< node_name << " in Inst "
<< inst.op_type();
VLOG(4) << "-- output arg_name " << arg_name;
auto type = inst.picked_kernel().GetOutputDeclType(arg_name);
if (!x_out->AsArg().type) {
need_correct_place &= false;
} else {
if (out_types.empty()) {
out_types.push_back(x_out->AsArg().type->target());
} else {
if (out_types[0] != x_out->AsArg().type->target()) {
need_correct_place &= false;
}
}
}
}
auto this_type = inst.picked_kernel().target();
bool io_target_same = (in_types[0] == out_types[0]);
need_correct_place &= (io_target_same && (in_types[0] != this_type));
if (need_correct_place) {
// update this kernel's valid place;
UpdateTarget(inst, in_types[0]);
}
}
}
// Update me's kUnk fields by other's fields.
void UpdateTarget(mir::Node::Stmt& inst, TargetType new_target) { // NOLINT
auto new_place = inst.place();
new_place.target = new_target;
std::vector<Place> places;
places.push_back(new_place);
inst.ResetKernels(places);
}
};
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -140,10 +140,12 @@ void SSAGraph::Build(const Program &program,
arg_node->AsArg(name, node_storage_.size() - 1);
arg_update_node_map_[name] = arg_node;
}
/*
if (var_types.count(name) && !arg_node->arg()->type) {
arg_node->arg()->type = LiteType::GetTensorTy(
TARGET(kUnk), var_types[name], DATALAYOUT(kUnk));
}
*/
if (is_weights(name)) arg_node->AsArg().is_weight = true;
CHECK(arg_node->IsRoleSet());
DirectedLink(arg_node, op_node);
......@@ -153,10 +155,12 @@ void SSAGraph::Build(const Program &program,
auto *arg_node = &node_storage_.back();
arg_node->AsArg(name, node_storage_.size() - 1);
arg_update_node_map_[name] = arg_node;
/*
if (var_types.count(name) && !arg_node->arg()->type) {
arg_node->arg()->type = LiteType::GetTensorTy(
TARGET(kUnk), var_types[name], DATALAYOUT(kUnk));
}
*/
if (is_weights(name)) arg_node->AsArg().is_weight = true;
CHECK(arg_node->IsRoleSet());
......
......@@ -101,7 +101,6 @@ void TypeTargetTransformPass::AddIoCopyInst(
auto io_copy_output_name =
string_format("%s/target_trans", in->AsArg().name.c_str());
// string_format("%s/target_trans/%d", in->AsArg().name.c_str(), node_id());
if (copied_nodes->count(in->AsArg().name)) {
// Remove the old link
RemoveDirectedLink(in, inst_node);
......@@ -116,12 +115,14 @@ void TypeTargetTransformPass::AddIoCopyInst(
} else {
// TODO(MyPandaShaoxiang) should set same place with input?
auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name);
// Set the place for io_copy_output_arg node, the target should be equal to
// to.target()
// The precision and layout should be equal to from.precision(),
// from.layout()
// Set the place for io_copy_output_arg node, the target should be equal to
// to.target()
// The precision and layout should be equal to from.precision(),
// from.layout()
#ifndef LITE_WITH_FPGA
io_copy_output_arg->AsArg().type =
LiteType::GetTensorTy(to.target(), from.precision(), from.layout());
#endif
auto* io_copy_inst = graph->NewInstructNode();
bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist;
......
......@@ -77,6 +77,7 @@ class Optimizer {
#endif
"static_kernel_pick_pass", // pick original kernel from graph
"variable_place_inference_pass", // inference arg/var's
"kernel_place_correct_pass",
// info(target/precision/layout/device)
// using kernel info
"argument_type_display_pass", // debug pass: show arg-type-node's
......@@ -108,7 +109,9 @@ class Optimizer {
"runtime_context_assign_pass",
"argument_type_display_pass",
#ifndef LITE_WITH_FPGA
"memory_optimize_pass",
#endif
"npu_subgraph_pass",
"xpu_subgraph_pass"}};
RunPasses(passes_local);
......
......@@ -139,6 +139,9 @@ void RuntimeProgram::Run() {
for (auto& inst : instructions_) {
#ifndef LITE_WITH_FPGA
if (inst.is_feed_fetch_op()) continue;
std::string op_type = inst.op()->op_info()->Type();
VLOG(4) << ">> Running kernel: " << inst.op()->op_info()->Repr()
<< " on Target " << TargetToStr(inst.kernel()->target());
#endif
inst.Run();
#ifdef LITE_WITH_PROFILE
......
文件模式从 100644 更改为 100755
......@@ -46,7 +46,7 @@ class Tensor {
*/
class PaddlePredictor {
public:
void Init();
void Init() {}
std::unique_ptr<Tensor> GetTensor(const std::string &id) const;
std::unique_ptr<Tensor> GetMutableTensor(const std::string &id);
......
......@@ -62,6 +62,10 @@ void CastCompute::Run() {
int32_t* out_data = param.Out->mutable_data<int32_t>();
std::transform(
x_data_begin, x_data_end, out_data, TransOp<int64_t, int32_t>);
} else if (param.in_dtype == 3 && param.out_dtype == 5) {
const auto* x_data = param.X->data<float>();
auto* o_data = param.Out->mutable_data<float>();
memcpy(o_data, x_data, sizeof(float) * param.X->numel());
} else {
LOG(FATAL) << "other has not been implemented";
}
......
......@@ -60,25 +60,10 @@ class FillConstantCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<ARMContext>();
if (param.dtype == static_cast<int32_t>(lite::core::FluidType::FP32)) {
auto data = param.Out->template mutable_data<float>();
for (int i = 0; i < param.Out->numel(); i++) {
data[i] = param.value;
}
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT32)) {
auto data = param.Out->template mutable_data<int32_t>();
for (int i = 0; i < param.Out->numel(); i++) {
data[i] = param.value;
}
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT8)) {
auto data = param.Out->template mutable_data<int8_t>();
for (int i = 0; i < param.Out->numel(); i++) {
data[i] = param.value;
}
} else {
LOG(FATAL) << "not supported dtype " << param.dtype;
// auto data = param.Out->template mutable_data<T>();
auto data = param.Out->template mutable_data<float>();
for (int i = 0; i < param.Out->numel(); i++) {
data[i] = param.value;
}
}
......@@ -94,32 +79,38 @@ class FillConstantBatchLikeCompute
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<ARMContext>();
if (param.input->lod().size() && param.input_dim_idx == 0) {
auto odims = param.out->dims();
odims[param.output_dim_idx] = param.input->lod().back().size() - 1;
param.out->Resize(odims);
// auto data = param.out->template mutable_data<T>();
auto data = param.out->template mutable_data<float>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
if (param.dtype == static_cast<int32_t>(lite::core::FluidType::FP32)) {
auto data = param.out->template mutable_data<float>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT32)) {
auto data = param.out->template mutable_data<int32_t>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT8)) {
auto data = param.out->template mutable_data<int8_t>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
} else {
LOG(FATAL) << "not supported dtype " << param.dtype;
}
// if (param.input->lod().size() && param.input_dim_idx == 0) {
// auto odims = param.out->dims();
// odims[param.output_dim_idx] = param.input->lod().back().size() - 1;
// param.out->Resize(odims);
// }
// if (param.dtype == static_cast<int32_t>(lite::core::FluidType::FP32)) {
// auto data = param.out->template mutable_data<float>();
// for (int i = 0; i < param.out->numel(); i++) {
// data[i] = param.value;
// }
// } else if (param.dtype ==
// static_cast<int32_t>(lite::core::FluidType::INT32)) {
// auto data = param.out->template mutable_data<int32_t>();
// for (int i = 0; i < param.out->numel(); i++) {
// data[i] = param.value;
// }
// } else if (param.dtype ==
// static_cast<int32_t>(lite::core::FluidType::INT8)) {
// auto data = param.out->template mutable_data<int8_t>();
// for (int i = 0; i < param.out->numel(); i++) {
// data[i] = param.value;
// }
// } else {
// LOG(FATAL) << "not supported dtype " << param.dtype;
// }
}
virtual ~FillConstantBatchLikeCompute() = default;
......@@ -142,8 +133,9 @@ REGISTER_LITE_KERNEL(fill_constant,
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("ShapeTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(fill_constant_batch_size_like,
kARM,
kAny,
......
......@@ -36,7 +36,7 @@ void LookupTableCompute::Run() {
auto table_dim = w->dims();
int64_t ids_numel = ids->numel();
auto ids_data = ids->data<int64_t>();
auto ids_data = ids->data<float>();
int64_t row_number = table_dim[0];
int64_t row_width = table_dim[1];
......@@ -75,7 +75,6 @@ REGISTER_LITE_KERNEL(lookup_table,
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(lookup_table_v2,
kARM,
kFloat,
......
......@@ -4,6 +4,7 @@ add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_
add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps})
add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op)
add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps})
add_kernel(one_hot_compute_host Host basic SRCS one_hot_compute.cc DEPS ${lite_kernel_deps})
#lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any)
#lite_cc_test(test_multiclass_nms_compute_host SRCS multiclass_nms_compute_test.cc DEPS multiclass_nms_compute_host any)
......@@ -426,8 +426,14 @@ REGISTER_LITE_KERNEL(multiclass_nms,
kNCHW,
paddle::lite::kernels::host::MulticlassNmsCompute,
def)
.BindInput("BBoxes", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Scores", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("BBoxes",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindInput("Scores",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Index",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include <utility>
#include <vector>
#include "lite/backends/fpga/KD/debugger.hpp"
#include "lite/kernels/host/one_hot_compute.h"
#include "lite/utils/paddle_enforce.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
void OneHotCompute::Run() {
auto& param = Param<operators::OneHotParam>();
param.Out->mutable_data<float>();
int depth = param.depth;
if (param.depth_tensor) {
auto* depth_tensor = param.depth_tensor;
auto* depth_data = depth_tensor->data<int32_t>();
depth = depth_data[0];
auto in_dims = param.X->dims();
DDim out_dims(in_dims);
out_dims[out_dims.size() - 1] = depth;
param.Out->Resize(out_dims);
}
auto* p_in_data = param.X->data<float>();
auto numel = param.X->numel();
auto* p_out_data = param.Out->mutable_data<float>();
for (int i = 0; i < param.Out->numel(); ++i) {
p_out_data[i] = 0;
}
if (param.allow_out_of_range) {
for (int i = 0; i < numel; ++i) {
if (p_in_data[i] >= 0 && p_in_data[i] < param.depth) {
*(p_out_data + i * param.depth + (int)(p_in_data[i])) = 1.0; // NOLINT
}
}
} else {
for (int i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE(
p_in_data[i], 0, "Illegal index value, should be at least 0.");
PADDLE_ENFORCE_LT(p_in_data[i],
param.depth,
"Illegal index value, should be less than depth (%d).",
param.depth);
*(p_out_data + i * param.depth + (int)(p_in_data[i])) = 1.0; // NOLINT
}
}
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(one_hot,
kHost,
kFloat,
kNCHW,
paddle::lite::kernels::host::OneHotCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class OneHotCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
void Run() override;
virtual ~OneHotCompute() = default;
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -46,17 +46,21 @@ REGISTER_LITE_KERNEL(reshape,
paddle::lite::kernels::host::ReshapeCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindInput("ShapeTensor",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindInput("Shape",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
REGISTER_LITE_KERNEL(reshape2,
......
......@@ -135,6 +135,8 @@ add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS})
add_operator(sequence_topk_avg_pooling_op basic SRCS sequence_topk_avg_pooling_op.cc DEPS ${op_DEPS})
add_operator(search_fc_op basic SRCS search_fc_op.cc DEPS ${op_DEPS})
add_operator(one_hot basic SRCS one_hot_op.cc DEPS ${op_DEPS})
if (NOT LITE_WITH_X86)
lite_cc_test(test_fc_op SRCS fc_op_test.cc
DEPS fc_op memory
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/one_hot_op.h"
#include "lite/core/op_registry.h"
#include "lite/backends/fpga/KD/debugger.hpp"
namespace paddle {
namespace lite {
namespace operators {
bool OneHotOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
return true;
}
bool OneHotOp::InferShape() const {
CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing.
auto out_dims = param_.X->dims();
out_dims[out_dims.size() - 1] = param_.depth;
param_.Out->Resize(out_dims);
return true;
}
bool OneHotOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.Out =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
if (opdesc.HasInput("depth_tensor")) {
auto depth_tensor = opdesc.Input("depth_tensor").front();
param_.depth_tensor =
scope->FindVar(depth_tensor)->GetMutable<lite::Tensor>();
}
CHECK(param_.X);
CHECK(param_.Out);
param_.depth = opdesc.GetAttr<int>("depth");
param_.dtype = opdesc.GetAttr<int>("dtype");
if (opdesc.HasAttr("allow_out_of_range")) {
param_.allow_out_of_range = opdesc.GetAttr<bool>("allow_out_of_range");
}
auto out_lod = param_.Out->mutable_lod();
*out_lod = param_.X->lod();
// param_.allow_out_of_range = opdesc.GetAttr<bool>("allow_out_of_range");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(one_hot, paddle::lite::operators::OneHotOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class OneHotOp : public OpLite {
public:
OneHotOp() {}
explicit OneHotOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "one_hot"; }
private:
mutable OneHotParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -1133,7 +1133,15 @@ struct GridSamplerParam {
lite::Tensor* out{};
lite::Tensor* grid{};
};
} // namespace operators
} // namespace lite
} // namespace paddle
/// --------------------- attentions operators --------------
struct OneHotParam {
lite::Tensor* X{};
lite::Tensor* depth_tensor{nullptr};
lite::Tensor* Out{};
int depth{-1};
int dtype{};
bool allow_out_of_range{false};
};
}; // namespace operators
}; // namespace lite
}; // namespace paddle
......@@ -2,12 +2,16 @@
build_dir=build_fpga
mkdir -p ${build_dir}
cd ${build_dir}
GEN_CODE_PATH_PREFIX=lite/gen_code
mkdir -p ./${GEN_CODE_PATH_PREFIX}
touch ./${GEN_CODE_PATH_PREFIX}/__generated_code__.cc
root_dir=$(pwd)
build_dir=${build_dir}
# in build directory
# 1. Prepare gen_code file
GEN_CODE_PATH_PREFIX=${build_dir}/lite/gen_code
mkdir -p ${GEN_CODE_PATH_PREFIX}
touch ${GEN_CODE_PATH_PREFIX}/__generated_code__.cc
cd ${build_dir}
cmake .. \
-DWITH_GPU=OFF \
-DWITH_MKL=OFF \
......@@ -19,8 +23,9 @@ cmake .. \
-DLITE_WITH_OPENMP=ON \
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \
-DWITH_TESTING=OFF \
-DARM_TARGET_OS=armlinux
make -j8
-DARM_TARGET_OS=armlinux \
-DLITE_BUILD_EXTRA=ON \
-DLITE_WITH_PROFILE=OFF
make -j42
cd -
......@@ -29,7 +29,6 @@ namespace zynqmp {
class ConvPE : public PE {
public:
bool init() {
std::cout << "Conv init" << std::endl;
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册