未验证 提交 368a1bf9 编写于 作者: H HappyAngel 提交者: GitHub

Merge pull request #93 from PaddlePaddle/develop

update code
...@@ -307,6 +307,9 @@ function(add_kernel TARGET device level) ...@@ -307,6 +307,9 @@ function(add_kernel TARGET device level)
if ("${level}" STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA)) if ("${level}" STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA))
return() return()
endif() endif()
if ("${level}" STREQUAL "train" AND (NOT LITE_WITH_TRAIN))
return()
endif()
if ("${device}" STREQUAL "Host") if ("${device}" STREQUAL "Host")
...@@ -434,11 +437,13 @@ function(add_operator TARGET level) ...@@ -434,11 +437,13 @@ function(add_operator TARGET level)
ARGS) ARGS)
cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if ("${level}" STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA)) if ("${level}" STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA))
return() return()
endif() endif()
if ("${level}" STREQUAL "train" AND (NOT LITE_WITH_TRAIN))
return()
endif()
foreach(src ${args_SRCS}) foreach(src ${args_SRCS})
if(LITE_BUILD_TAILOR) if(LITE_BUILD_TAILOR)
......
...@@ -9,7 +9,7 @@ if (LITE_ON_TINY_PUBLISH) ...@@ -9,7 +9,7 @@ if (LITE_ON_TINY_PUBLISH)
set(CMAKE_C_FLAGS_RELEASE "-Os -DNDEBUG") set(CMAKE_C_FLAGS_RELEASE "-Os -DNDEBUG")
endif() endif()
set(light_lib_DEPS light_api paddle_api paddle_api_light optimizer) set(light_lib_DEPS light_api paddle_api paddle_api_light optimizer)
if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_CUDA OR LITE_WITH_X86 OR ARM_TARGET_OS STREQUAL "android" OR ARM_TARGET_OS STREQUAL "armlinux")) if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_CUDA OR LITE_WITH_X86 OR LITE_WITH_BM OR ARM_TARGET_OS STREQUAL "android" OR ARM_TARGET_OS STREQUAL "armlinux"))
#full api dynamic library #full api dynamic library
add_library(paddle_full_api_shared SHARED "") add_library(paddle_full_api_shared SHARED "")
target_sources(paddle_full_api_shared PUBLIC ${__lite_cc_files} paddle_api.cc light_api.cc cxx_api.cc cxx_api_impl.cc light_api_impl.cc) target_sources(paddle_full_api_shared PUBLIC ${__lite_cc_files} paddle_api.cc light_api.cc cxx_api.cc cxx_api_impl.cc light_api_impl.cc)
...@@ -262,7 +262,8 @@ if (NOT LITE_ON_TINY_PUBLISH) ...@@ -262,7 +262,8 @@ if (NOT LITE_ON_TINY_PUBLISH)
CV_DEPS paddle_cv_arm CV_DEPS paddle_cv_arm
NPU_DEPS ${npu_kernels} NPU_DEPS ${npu_kernels}
CL_DEPS ${opencl_kernels} CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels}) FPGA_DEPS ${fpga_kernels}
BM_DEPS ${bm_kernels})
# The final inference library for just MobileConfig. # The final inference library for just MobileConfig.
bundle_static_library(paddle_api_full paddle_api_full_bundled bundle_full_api) bundle_static_library(paddle_api_full paddle_api_full_bundled bundle_full_api)
target_link_libraries(paddle_api_full ${cuda_deps}) target_link_libraries(paddle_api_full ${cuda_deps})
......
...@@ -44,7 +44,10 @@ DEFINE_string(input_shape, ...@@ -44,7 +44,10 @@ DEFINE_string(input_shape,
"set input shapes according to the model, " "set input shapes according to the model, "
"separated by colon and comma, " "separated by colon and comma, "
"such as 1,3,244,244"); "such as 1,3,244,244");
DEFINE_string(input_img_path, "", "the path of input image"); DEFINE_string(input_img_path,
"",
"the path of input image, if not set "
"input_img_path, the input of model will be 1.0.");
DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times"); DEFINE_int32(repeats, 1, "repeats times");
DEFINE_int32(power_mode, DEFINE_int32(power_mode,
...@@ -57,16 +60,11 @@ DEFINE_int32(power_mode, ...@@ -57,16 +60,11 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num"); DEFINE_int32(threads, 1, "threads num");
DEFINE_string(result_filename, DEFINE_string(result_filename,
"result.txt", "result.txt",
"save benchmark " "save the inference time to the file.");
"result to the file");
DEFINE_bool(run_model_optimize, DEFINE_bool(run_model_optimize,
false, false,
"if set true, apply model_optimize_tool to " "if set true, apply model_optimize_tool to "
"model and use optimized model to test. "); "model and use optimized model to test. ");
DEFINE_bool(is_quantized_model,
false,
"if set true, "
"test the performance of the quantized model. ");
namespace paddle { namespace paddle {
namespace lite_api { namespace lite_api {
...@@ -87,10 +85,6 @@ void OutputOptModel(const std::string& save_optimized_model_dir) { ...@@ -87,10 +85,6 @@ void OutputOptModel(const std::string& save_optimized_model_dir) {
std::vector<Place> vaild_places = { std::vector<Place> vaild_places = {
Place{TARGET(kARM), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)},
}; };
if (FLAGS_is_quantized_model) {
vaild_places.insert(vaild_places.begin(),
Place{TARGET(kARM), PRECISION(kInt8)});
}
config.set_valid_places(vaild_places); config.set_valid_places(vaild_places);
auto predictor = lite_api::CreatePaddlePredictor(config); auto predictor = lite_api::CreatePaddlePredictor(config);
...@@ -181,8 +175,8 @@ void Run(const std::vector<int64_t>& input_shape, ...@@ -181,8 +175,8 @@ void Run(const std::vector<int64_t>& input_shape,
int main(int argc, char** argv) { int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_model_dir == "" || FLAGS_result_filename == "") { if (FLAGS_model_dir == "") {
LOG(INFO) << "please run ./benchmark_bin --help to obtain usage."; LOG(INFO) << "Please run ./benchmark_bin --help to obtain usage.";
exit(0); exit(0);
} }
......
...@@ -295,6 +295,8 @@ void Predictor::Build(const cpp::ProgramDesc &desc, ...@@ -295,6 +295,8 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
inner_places.emplace_back( inner_places.emplace_back(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
// Analysis whether the modle is quantized.
// For quantized model, add place(arm, int8) to inner_places
const std::vector<std::string> quant_dequant_op = { const std::vector<std::string> quant_dequant_op = {
"fake_quantize_abs_max", "fake_quantize_abs_max",
"fake_quantize_range_abs_max", "fake_quantize_range_abs_max",
...@@ -317,7 +319,8 @@ void Predictor::Build(const cpp::ProgramDesc &desc, ...@@ -317,7 +319,8 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
} }
} }
if (is_quantized_model) { if (is_quantized_model) {
inner_places.emplace_back(Place{TARGET(kARM), PRECISION(kInt8)}); inner_places.insert(inner_places.begin(),
Place{TARGET(kARM), PRECISION(kInt8)});
} }
Program program(desc, scope_, inner_places); Program program(desc, scope_, inner_places);
......
...@@ -36,7 +36,8 @@ void TestModel(const std::vector<Place>& valid_places) { ...@@ -36,7 +36,8 @@ void TestModel(const std::vector<Place>& valid_places) {
predictor.Build(FLAGS_model_dir, "", "", valid_places, passes); predictor.Build(FLAGS_model_dir, "", "", valid_places, passes);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224}))); input_tensor->Resize(DDim(
std::vector<DDim::value_type>({1, 3, FLAGS_im_height, FLAGS_im_width})));
auto* data = input_tensor->mutable_data<float>(); auto* data = input_tensor->mutable_data<float>();
auto item_size = input_tensor->dims().production(); auto item_size = input_tensor->dims().production();
if (FLAGS_input_img_txt_path.empty()) { if (FLAGS_input_img_txt_path.empty()) {
...@@ -67,15 +68,13 @@ void TestModel(const std::vector<Place>& valid_places) { ...@@ -67,15 +68,13 @@ void TestModel(const std::vector<Place>& valid_places) {
<< ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0
<< " ms in average."; << " ms in average.";
auto* out = predictor.GetOutput(0); auto out = predictor.GetOutputs();
ASSERT_EQ(out->dims().size(), 2);
ASSERT_EQ(out->dims()[0], 1);
ASSERT_EQ(out->dims()[1], 1000);
auto* out_data = out->data<float>();
FILE* fp = fopen("result.txt", "wb"); FILE* fp = fopen("result.txt", "wb");
for (int i = 0; i < out->numel(); i++) { for (int i = 0; i < out.size(); i++) {
fprintf(fp, "%f\n", out_data[i]); auto* out_data = out[i]->data<float>();
for (int j = 0; j < out[i]->numel(); j++) {
fprintf(fp, "%f\n", out_data[j]);
}
} }
fclose(fp); fclose(fp);
} }
......
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
// limitations under the License. // limitations under the License.
#include <gflags/gflags.h> #include <gflags/gflags.h>
#ifdef PADDLE_WITH_TESTING
#include <gtest/gtest.h> #include <gtest/gtest.h>
#endif
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/api/cxx_api.h" #include "lite/api/cxx_api.h"
......
...@@ -44,11 +44,9 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -44,11 +44,9 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fuser(graph.get()); fuser(graph.get());
} }
// delete quant_dequant_node // process quant_dequant_node
for (auto op_type : {"pool2d", "softmax", "elementwise_add"}) { fusion::DeleteQuantDequantOpFuser dqd_fuser;
fusion::DeleteQuantDequantOpFuser fuser(op_type); dqd_fuser(graph.get());
fuser(graph.get());
}
} }
} // namespace mir } // namespace mir
......
...@@ -50,7 +50,7 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -50,7 +50,7 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph,
auto* output_scale_node = matched.at("output_scale_node"); auto* output_scale_node = matched.at("output_scale_node");
auto* output_act_node = matched.at("output_act_node"); auto* output_act_node = matched.at("output_act_node");
// obtain values, save values and relink node // obtain scale, save attrs and relink node
int bit_length = quant_node->stmt()->op_info()->GetAttr<int>("bit_length"); int bit_length = quant_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1); int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_node->stmt()->op()->scope(); auto* scope = quant_node->stmt()->op()->scope();
...@@ -58,11 +58,22 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -58,11 +58,22 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph,
->GetMutable<lite::Tensor>(); ->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range; float scale_value = scale_tensor->data<float>()[0] / range;
auto in_act_name = input_act_node->arg()->name;
auto out_act_name = output_act_node->arg()->name;
auto outlinks = output_act_node->outlinks; auto outlinks = output_act_node->outlinks;
for (auto* quantized_node : outlinks) { for (auto* quantized_node : outlinks) {
auto* op_desc = quantized_node->stmt()->mutable_op_info(); // save input scale in quantized op by input argname + index
op_desc->SetAttr<int>("bit_length", bit_length); auto op_desc = *quantized_node->stmt()->mutable_op_info();
op_desc->SetAttr<float>("input_scale", scale_value); std::string argname;
int index;
op_desc.GetInputArgname(out_act_name, &argname);
op_desc.GetInputIndex(out_act_name, &index);
op_desc.SetAttr<float>(argname + std::to_string(index) + "_input_scale",
scale_value);
op_desc.SetAttr<float>("input_scale", scale_value); // save it for now
op_desc.SetAttr<int>("bit_length", bit_length);
op_desc.UpdateAllInputs(out_act_name, in_act_name);
quantized_node->stmt()->ResetOp(op_desc, graph->valid_places());
IR_NODE_LINK_TO(input_act_node, quantized_node) IR_NODE_LINK_TO(input_act_node, quantized_node)
} }
...@@ -125,19 +136,18 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -125,19 +136,18 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
auto* dequant_op = matched.at("dequant_op"); auto* dequant_op = matched.at("dequant_op");
auto* dequant_op_out = matched.at("dequant_op_out"); auto* dequant_op_out = matched.at("dequant_op_out");
// obtain input_scale and weight_scale // obtain weight_scale from max_range
auto* scope = quantized_op->stmt()->op()->scope(); auto* scope = quantized_op->stmt()->op()->scope();
auto& valid_places = quantized_op->stmt()->op()->valid_places(); auto& valid_places = quantized_op->stmt()->op()->valid_places();
int bit_length = quantized_op->stmt()->op_info()->GetAttr<int>("bit_length"); int bit_length = quantized_op->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1); int range = ((1 << (bit_length - 1)) - 1);
float input_scale =
quantized_op->stmt()->op_info()->GetAttr<float>("input_scale");
float max_range = dequant_op->stmt()->op_info()->GetAttr<float>("max_range"); float max_range = dequant_op->stmt()->op_info()->GetAttr<float>("max_range");
float whole_weight_scale = float whole_weight_scale =
static_cast<float>(range * range) / max_range / range; static_cast<float>(range * range) / max_range / range;
// max_range = range * range / max(abs(weight)) // As: max_range = range * range / max(abs(weight))
// weight_scale = range * range / (range * range / max(abs(weight))) / range // So: whole_weight_scale
// = max(abs(weight)) / range // = range * range / (range * range / max(abs(weight))) / range
// = max(abs(weight)) / range
// set op desc // set op desc
cpp::OpDesc op_desc = *quantized_op->stmt()->op_info(); cpp::OpDesc op_desc = *quantized_op->stmt()->op_info();
...@@ -153,7 +163,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -153,7 +163,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
// Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should // Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should
// be Cout. // be Cout.
weight_scale_size = quantized_weight_t->dims()[0]; weight_scale_size = quantized_weight_t->dims()[0];
} else if (quantized_op_type_ == "mul") { } else if (quantized_op_type_ == "mul" || quantized_op_type_ == "matmul") {
op_desc.SetInput("X", {quantized_op_input->arg()->name}); op_desc.SetInput("X", {quantized_op_input->arg()->name});
op_desc.SetOutput("Out", {dequant_op_out->arg()->name}); op_desc.SetOutput("Out", {dequant_op_out->arg()->name});
// Fc weight: Cin * Cout, the weight_scale_size should be Cout. // Fc weight: Cin * Cout, the weight_scale_size should be Cout.
...@@ -163,7 +173,6 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -163,7 +173,6 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
weight_scale.push_back(whole_weight_scale); weight_scale.push_back(whole_weight_scale);
} }
op_desc.SetAttr("enable_int8", true); op_desc.SetAttr("enable_int8", true);
op_desc.SetAttr("input_scale", input_scale);
op_desc.SetAttr("weight_scale", weight_scale); op_desc.SetAttr("weight_scale", weight_scale);
// change the weight from the float type to int8 type. // change the weight from the float type to int8 type.
...@@ -209,6 +218,7 @@ void ChannelWiseDequantOpFuser::BuildPattern() { ...@@ -209,6 +218,7 @@ void ChannelWiseDequantOpFuser::BuildPattern() {
->assert_is_op_output(quantized_op_type_) ->assert_is_op_output(quantized_op_type_)
->assert_is_op_input(dequant_op_type, "X") ->assert_is_op_input(dequant_op_type, "X")
->AsIntermediate(); ->AsIntermediate();
// The scale var_node of input activation is deleted in DeleteQuantOpFuser
auto* dequant_op_channel_scale = VarNode("dequant_op_channel_scale") auto* dequant_op_channel_scale = VarNode("dequant_op_channel_scale")
->assert_is_op_input(dequant_op_type) ->assert_is_op_input(dequant_op_type)
->AsIntermediate(); ->AsIntermediate();
...@@ -237,11 +247,9 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -237,11 +247,9 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
auto* dequant_op = matched.at("dequant_op"); auto* dequant_op = matched.at("dequant_op");
auto* dequant_op_out = matched.at("dequant_op_out"); auto* dequant_op_out = matched.at("dequant_op_out");
// obtain input_scale and weight_scale // obtain input weight_scale from fake_dequant op
auto* scope = quantized_op->stmt()->op()->scope(); auto* scope = quantized_op->stmt()->op()->scope();
auto& valid_places = quantized_op->stmt()->op()->valid_places(); auto& valid_places = quantized_op->stmt()->op()->valid_places();
float input_scale =
quantized_op->stmt()->op_info()->GetAttr<float>("input_scale");
std::vector<float> weight_scale; std::vector<float> weight_scale;
std::vector<int> quant_bits = std::vector<int> quant_bits =
...@@ -258,11 +266,15 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -258,11 +266,15 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
// set op desc // set op desc
cpp::OpDesc op_desc = *quantized_op->stmt()->op_info(); cpp::OpDesc op_desc = *quantized_op->stmt()->op_info();
op_desc.SetInput("Input", {quantized_op_input->arg()->name}); if (quantized_op_type_ == "conv2d" ||
op_desc.SetOutput("Output", {dequant_op_out->arg()->name}); quantized_op_type_ == "depthwise_conv2d") {
op_desc.SetInput("Input", {quantized_op_input->arg()->name});
op_desc.SetOutput("Output", {dequant_op_out->arg()->name});
} else if (quantized_op_type_ == "mul" || quantized_op_type_ == "matmul") {
op_desc.SetInput("X", {quantized_op_input->arg()->name});
op_desc.SetOutput("Out", {dequant_op_out->arg()->name});
}
op_desc.SetAttr("enable_int8", true); op_desc.SetAttr("enable_int8", true);
op_desc.SetAttr("input_scale", input_scale);
op_desc.SetAttr("weight_scale", weight_scale); op_desc.SetAttr("weight_scale", weight_scale);
// change the weight from the float type to int8 type. // change the weight from the float type to int8 type.
...@@ -297,167 +309,65 @@ cpp::OpDesc ChannelWiseDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { ...@@ -297,167 +309,65 @@ cpp::OpDesc ChannelWiseDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
void DeleteQuantDequantOpFuser::BuildPattern() { void DeleteQuantDequantOpFuser::BuildPattern() {
std::string quant_dequant_op_type = std::string quant_dequant_op_type =
"fake_quantize_dequantize_moving_average_abs_max"; "fake_quantize_dequantize_moving_average_abs_max";
if (quantized_op_type_ == "pool2d" || quantized_op_type_ == "softmax") { auto* input_scale_node =
auto* input_scale_node = VarNode("input_scale_node")
VarNode("input_scale_node") ->assert_is_op_input(quant_dequant_op_type, "InScale");
->assert_is_op_input(quant_dequant_op_type, "InScale"); auto* input_act_node =
auto* input_act_node = VarNode("input_act_node") VarNode("input_act_node")->assert_is_op_input(quant_dequant_op_type, "X");
->assert_is_op_input(quant_dequant_op_type, "X"); auto* quant_dequant_node = OpNode("quant_dequant_node", quant_dequant_op_type)
auto* quant_dequant_node = ->assert_is_op(quant_dequant_op_type);
OpNode("quant_dequant_node", quant_dequant_op_type) auto* output_scale_node =
->assert_is_op(quant_dequant_op_type); VarNode("output_scale_node")
auto* output_scale_node = ->assert_is_op_output(quant_dequant_op_type, "OutScale");
VarNode("output_scale_node") auto* output_act_node =
->assert_is_op_output(quant_dequant_op_type, "OutScale"); VarNode("output_act_node")
auto* output_act_node = ->assert_is_op_output(quant_dequant_op_type, "Out");
VarNode("output_act_node")
->assert_is_op_output(quant_dequant_op_type, "Out"); quant_dequant_node->LinksFrom({input_scale_node, input_act_node});
auto* quantized_node = OpNode("quantized_node", quantized_op_type_) output_scale_node->LinksFrom({quant_dequant_node});
->assert_is_op(quantized_op_type_); output_act_node->LinksFrom({quant_dequant_node});
quant_dequant_node->LinksFrom({input_scale_node, input_act_node});
output_scale_node->LinksFrom({quant_dequant_node});
output_act_node->LinksFrom({quant_dequant_node});
quantized_node->LinksFrom({output_act_node});
} else if (quantized_op_type_ == "elementwise_add") {
auto* input_scale_left_node =
VarNode("input_scale_left_node")
->assert_is_op_input(quant_dequant_op_type, "InScale");
auto* input_act_left_node =
VarNode("input_act_left_node")
->assert_is_op_input(quant_dequant_op_type, "X");
auto* quant_dequant_left_node =
OpNode("quant_dequant_left_node", quant_dequant_op_type)
->assert_is_op(quant_dequant_op_type);
auto* output_scale_left_node =
VarNode("output_scale_left_node")
->assert_is_op_output(quant_dequant_op_type, "OutScale");
auto* output_act_left_node =
VarNode("output_act_left_node")
->assert_is_op_output(quant_dequant_op_type, "Out")
->assert_is_op_input(quantized_op_type_, "X");
quant_dequant_left_node->LinksFrom(
{input_scale_left_node, input_act_left_node});
output_scale_left_node->LinksFrom({quant_dequant_left_node});
output_act_left_node->LinksFrom({quant_dequant_left_node});
auto* input_scale_right_node =
VarNode("input_scale_right_node")
->assert_is_op_input(quant_dequant_op_type, "InScale");
auto* input_act_right_node =
VarNode("input_act_right_node")
->assert_is_op_input(quant_dequant_op_type, "X");
auto* quant_dequant_right_node =
OpNode("quant_dequant_right_node", quant_dequant_op_type)
->assert_is_op(quant_dequant_op_type);
auto* output_scale_right_node =
VarNode("output_scale_right_node")
->assert_is_op_output(quant_dequant_op_type, "OutScale");
auto* output_act_right_node =
VarNode("output_act_right_node")
->assert_is_op_output(quant_dequant_op_type, "Out")
->assert_is_op_input(quantized_op_type_, "Y");
quant_dequant_right_node->LinksFrom(
{input_scale_right_node, input_act_right_node});
output_scale_right_node->LinksFrom({quant_dequant_right_node});
output_act_right_node->LinksFrom({quant_dequant_right_node});
auto* quantized_node = OpNode("quantized_node", quantized_op_type_)
->assert_is_op(quantized_op_type_);
quantized_node->LinksFrom({output_act_left_node, output_act_right_node});
} else {
LOG(FATAL) << "No support quantized_op_type:" << quantized_op_type_;
}
VLOG(4) << "DeleteQuantDequantOpFuser BuildPattern op_type:"
<< quantized_op_type_;
} }
void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) { const key2nodes_t& matched) {
if (quantized_op_type_ == "pool2d" || quantized_op_type_ == "softmax") { auto* input_scale_node = matched.at("input_scale_node");
auto* input_scale_node = matched.at("input_scale_node"); auto* input_act_node = matched.at("input_act_node");
auto* input_act_node = matched.at("input_act_node"); auto* quant_dequant_node = matched.at("quant_dequant_node");
auto* quant_dequant_node = matched.at("quant_dequant_node"); auto* output_scale_node = matched.at("output_scale_node");
auto* output_scale_node = matched.at("output_scale_node"); auto* output_act_node = matched.at("output_act_node");
auto* output_act_node = matched.at("output_act_node"); auto input_act_name = input_act_node->arg()->name;
auto* quantized_node = matched.at("quantized_node"); auto output_act_name = output_act_node->arg()->name;
// obtain values, save values and relink node // Get scale value from scale var node
int bit_length = int bit_length =
quant_dequant_node->stmt()->op_info()->GetAttr<int>("bit_length"); quant_dequant_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1); int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_dequant_node->stmt()->op()->scope(); auto* scope = quant_dequant_node->stmt()->op()->scope();
auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name) auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>(); ->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range; float scale_value = scale_tensor->data<float>()[0] / range;
auto* op_desc = quantized_node->stmt()->mutable_op_info(); auto quantized_nodes = output_act_node->outlinks;
op_desc->SetAttr<int>("bit_length", bit_length); for (auto* quantized_node : quantized_nodes) {
op_desc->SetAttr<float>("input_scale", scale_value); // Save quantization info in op_info attr
op_desc->SetInput("X", {input_act_node->arg()->name}); auto op_info = *quantized_node->stmt()->op_info();
IR_NODE_LINK_TO(input_act_node, quantized_node) std::string argname;
auto update_op_desc = *quantized_node->stmt()->mutable_op_info(); int index;
quantized_node->stmt()->ResetOp(update_op_desc, graph->valid_places()); op_info.GetInputArgname(output_act_name, &argname);
op_info.GetInputIndex(output_act_name, &index);
// delete nodes and edges op_info.SetAttr<float>(argname + std::to_string(index) + "_input_scale",
std::unordered_set<const Node*> nodes2rm = {input_scale_node, scale_value);
quant_dequant_node, op_info.SetAttr<float>("input_scale", scale_value); // Save it for now
output_scale_node, op_info.SetAttr<int>("bit_length", bit_length);
output_act_node};
GraphSafeRemoveNodes(graph, nodes2rm); op_info.UpdateAllInputs(output_act_name, input_act_name);
} else if (quantized_op_type_ == "elementwise_add") { quantized_node->stmt()->ResetOp(op_info, graph->valid_places());
auto* input_scale_left_node = matched.at("input_scale_left_node"); IR_NODE_LINK_TO(input_act_node, quantized_node);
auto* input_act_left_node = matched.at("input_act_left_node");
auto* quant_dequant_left_node = matched.at("quant_dequant_left_node");
auto* output_scale_left_node = matched.at("output_scale_left_node");
auto* output_act_left_node = matched.at("output_act_left_node");
auto* input_scale_right_node = matched.at("input_scale_right_node");
auto* input_act_right_node = matched.at("input_act_right_node");
auto* quant_dequant_right_node = matched.at("quant_dequant_right_node");
auto* output_scale_right_node = matched.at("output_scale_right_node");
auto* output_act_right_node = matched.at("output_act_right_node");
auto* quantized_node = matched.at("quantized_node");
// obtain values, save values and relink node
int bit_length =
quant_dequant_left_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_dequant_left_node->stmt()->op()->scope();
auto* left_scale_tensor =
scope->FindVar(output_scale_left_node->arg()->name)
->GetMutable<lite::Tensor>();
float left_scale_value = left_scale_tensor->data<float>()[0] / range;
auto* right_scale_tensor =
scope->FindVar(output_scale_right_node->arg()->name)
->GetMutable<lite::Tensor>();
float right_scale_value = right_scale_tensor->data<float>()[0] / range;
auto* op_desc = quantized_node->stmt()->mutable_op_info();
op_desc->SetAttr<int>("bit_length", bit_length);
op_desc->SetAttr<float>("x_input_scale", left_scale_value);
op_desc->SetAttr<float>("y_input_scale", right_scale_value);
op_desc->SetInput("X", {input_act_left_node->arg()->name});
op_desc->SetInput("Y", {input_act_right_node->arg()->name});
IR_NODE_LINK_TO(input_act_left_node, quantized_node)
IR_NODE_LINK_TO(input_act_right_node, quantized_node)
auto update_op_desc = *quantized_node->stmt()->mutable_op_info();
quantized_node->stmt()->ResetOp(update_op_desc, graph->valid_places());
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {input_scale_left_node,
quant_dequant_left_node,
output_scale_left_node,
output_act_left_node,
input_scale_right_node,
quant_dequant_right_node,
output_scale_right_node,
output_act_right_node};
GraphSafeRemoveNodes(graph, nodes2rm);
} else {
LOG(FATAL) << "No support quantized_op_type:" << quantized_op_type_;
} }
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {
input_scale_node, quant_dequant_node, output_scale_node, output_act_node};
GraphSafeRemoveNodes(graph, nodes2rm);
} }
cpp::OpDesc DeleteQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { cpp::OpDesc DeleteQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
......
...@@ -87,24 +87,16 @@ class ChannelWiseDequantOpFuser : public FuseBase { ...@@ -87,24 +87,16 @@ class ChannelWiseDequantOpFuser : public FuseBase {
}; };
/* The pattern like "fake_quantize_dequantize_moving_average_abs_max + /* The pattern like "fake_quantize_dequantize_moving_average_abs_max +
* pooled/elementwise_add" can be deteted by this fuser. The fuser * quantized_op" can be deteted by this fuser. The fuser modifies the input
* extract the input_scale form fake_quant_dequant_op and save into * scale for the quantized_op and deletes the fake_quant_dequant_op.
* the quantized_op. Besides, the fuser delete fake_quant_dequant_op in
* the graph.
*/ */
class DeleteQuantDequantOpFuser : public FuseBase { class DeleteQuantDequantOpFuser : public FuseBase {
public: public:
explicit DeleteQuantDequantOpFuser(const std::string& quantized_op_type)
: quantized_op_type_(quantized_op_type) {}
void BuildPattern() override; void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private: private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private:
std::string quantized_op_type_{};
}; };
} // namespace fusion } // namespace fusion
......
...@@ -225,6 +225,32 @@ class OpInfo : public cpp::OpDesc { ...@@ -225,6 +225,32 @@ class OpInfo : public cpp::OpDesc {
return false; return false;
} }
// For the input variable name, find the index of the corresponding
// input argname
bool GetInputIndex(const std::string &value_name, int *out) const {
for (auto &item : inputs_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = it - item.second.begin();
return true;
}
}
return false;
}
// For the output variable name, find the index of the corresponding
// output argname
bool GetOutputIndex(const std::string &value_name, int *out) const {
for (auto &item : outputs_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = it - item.second.begin();
return true;
}
}
return false;
}
void UpdateAllInputs(const std::string &from, const std::string &to) { void UpdateAllInputs(const std::string &from, const std::string &to) {
for (auto &item : inputs_) { for (auto &item : inputs_) {
for (auto &var : item.second) { for (auto &var : item.second) {
......
# Introduction
我们都知道,PaddleLite可以做移动端预测,事实上PaddleLite支持在移动端做模型训练。本文给出使用PaddleLite做训练的例子,这一例子对应的任务是“波士顿房价预测”,又称作“fit-a-line”。
你可以通过book库中的
[文档](https://paddlepaddle.org.cn/documentation/docs/zh/user_guides/simple_case/fit_a_line/README.cn.html)
[源码](https://github.com/PaddlePaddle/book/tree/develop/01.fit_a_line)
进一步了解“波士顿房价预测”这一任务的定义及其建模过程,
其使用线性回归(Linear Regression)
模型做建模。本文主要介绍如何将其迁移至Paddle-Lite进行训练。
注:这是一篇使用C++ API做模型训练的教程,其他API暂时不支持训练功能。
# Requirements
- 一部安卓手机,用于运行训练程序
- 装了Paddle (version: 1.7.0) 的python
# Quick start
## Step1 build paddle-lite
请按照[paddle-lite官方文档](https://paddle-lite.readthedocs.io/zh/latest/user_guides/source_compile.html#paddlelite) 的教程编译full_publish的paddle-lite lib。以Linux上编译为例,其具体的命令为:
```shell
## 配置环境
wget -c https://mms-res.cdn.bcebos.com/cmake-3.10.3-Linux-x86_64.tar.gz --no-check-certificate
tar xzf cmake-3.10.3-Linux-x86_64.tar.gz
export PATH=${PWD}'/cmake-3.10.3-Linux-x86_64/bin':$PATH
wget https://dl.google.com/android/repository/android-ndk-r17c-linux-x86_64.zip
unzip android-ndk-r17c-linux-x86_64.zip
export NDK_ROOT=/opt/android-ndk-r17c
## 编译
git clone https://github.com/PaddlePaddle/Paddle-Lite.git
cd Paddle-Lite
./lite/tools/build.sh \
--arm_os=android \
--arm_abi=armv7 \
--build_extra=ON \
--arm_lang=gcc \
--android_stl=c++_static \
--build_train=ON full_publish
```
产物:
```shell
Paddle-Lite/build.lite.android.armv7.gcc/inference_lite_lib.android.armv7/cxx/lib/libpaddle_full_api_shared.so
```
## Step2 编译lr_trainer
```shell
cd Paddle-Lite/lite/demo/cxx/train_demo/cplus_train/
sh run_build.sh /path/to/your/Paddle-Lite/build.lite.android.armv7.gcc/ /path/to/your/android-ndk-r17c
```
产物:
```shell
bin/
`-- demo_trainer
```
## Step3 download model and run it!
在你的笔记本电脑上,用usb连接到手机,开启开发者模式,在任意目录下执行:
```shell
local_path=/data/local/tmp/linear_regression
adb shell "mkdir "${local_path}
# download model and push to mobile
wget http://paddle-tar.bj.bcebos.com/paddle-lite/lite_lr_model.tar.gz
tar -zxvf lite_lr_model.tar.gz
adb push lite_lr_model/housing.data ${local_path}
adb push lite_lr_model/model_dir ${local_path}
# push lib and executable file to moblie
adb push libpaddle_full_api_shared.so ${local_path}
adb push demo_trainer ${local_path}
adb shell chmod +x ${local_path}/demo_trainer
# run it!
adb shell "export LD_LIBRARY_PATH="${local_path}" && export LIBRARY_PATH="${local_path}" && cd "${local_path}" && ./demo_trainer true"
```
期望结果:
```
sample 0: Loss: 564.317
sample 1: Loss: 463.9
sample 2: Loss: 1197.54
sample 3: Loss: 1093.83
sample 4: Loss: 1282.76
sample 5: Loss: 792.097
sample 6: Loss: 491.776
sample 7: Loss: 698.496
sample 8: Loss: 248.445
sample 9: Loss: 325.135
```
# 更多细节
上面提到的模型是直接下载得到的,如果你想自己生成,可以执行以下命令:
```shell
git clone https://github.com/PaddlePaddle/Paddle-Lite.git
cd Paddle-Lite/lite/demo/cxx/train_demo/
python train.py --save_model
```
产物:
```shell
model_dir/
|-- fc_0.b_0
|-- fc_0.w_0
|-- learning_rate_0
`-- __model__
md5sum fc_0.w_0: 2c7b3649b2a9cf7bcd19f8b256ce795d
```
如果你想生成自己的模型用于训练,可以参考`train.py`中保存模型的方式。
# 与Paddle训练结果做校对
## 前10个Loss值
为了验证paddle与lite的一致性,我们控制模型参数一致、数据一致、batch size = 1的情况下,训练10个batch, 记录了二者的loss值。
python + paddle 命令:
```shell
fluid train.py --num_steps=10 --batch_size=1
```
python + paddle 结果:
```shell
Train cost, Step 0, Cost 564.317017
Train cost, Step 1, Cost 463.900238
Train cost, Step 2, Cost 1197.537354
Train cost, Step 3, Cost 1093.833008
Train cost, Step 4, Cost 1282.760254
Train cost, Step 5, Cost 792.097351
Train cost, Step 6, Cost 491.775848
Train cost, Step 7, Cost 698.496033
Train cost, Step 8, Cost 248.444885
Train cost, Step 9, Cost 325.135132
```
c++ 与 paddle-lite命令:
```
./demo_trainer true
```
c++ 与 paddle-lite结果:
```
sample 0: Loss: 564.317
sample 1: Loss: 463.9
sample 2: Loss: 1197.54
sample 3: Loss: 1093.83
sample 4: Loss: 1282.76
sample 5: Loss: 792.097
sample 6: Loss: 491.776
sample 7: Loss: 698.496
sample 8: Loss: 248.445
sample 9: Loss: 325.135
```
## Loss 曲线
控制训练时的batch size为20,每个epoch对训练数据做全局shuffle,训练100个epoch后,paddle和lite的loss曲线对比如下。
![lr_loss](image/lr_loss.png)
如果想复现上述效果,paddle+python的运行命令为:
```
git clone https://github.com/PaddlePaddle/book.git
cd book/01.fit_a_line
python train.py
```
lite + c++的运行命令为:
```
./demo_trainer false
```
cmake_minimum_required(VERSION 2.8)
set (CMAKE_CXX_STANDARD 11)
# Project's name
if(NOT DEFINED LITE_ROOT)
message(FATAL_ERROR "please set LITE_ROOT with
-DLITE_ROOT=/path/to/your/build.lite.android.armv7.gcc/")
endif()
project(demo_trainer)
# Set the output folder where your program will be created
set(CMAKE_BINARY_DIR ${CMAKE_SOURCE_DIR}/bin)
set(EXECUTABLE_OUTPUT_PATH ${CMAKE_BINARY_DIR})
set(LIBRARY_OUTPUT_PATH ${CMAKE_BINARY_DIR})
# The following folder will be included
include_directories("include")
include_directories("${LITE_ROOT}/inference_lite_lib.android.armv7/cxx/include")
add_executable(demo_trainer ${PROJECT_SOURCE_DIR}/demo_trainer.cc ${PROJECT_SOURCE_DIR}/data_reader.cc)
TARGET_LINK_LIBRARIES(demo_trainer
"${LITE_ROOT}/inference_lite_lib.android.armv7/cxx/lib/libpaddle_full_api_shared.so")
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "include/data_reader.h"
#include <limits>
using std::string;
using std::vector;
int FEATURE_NUM = 13;
float rate = 0.8;
int get_samples(string line, vector<float>* feature, float* label) {
std::istringstream reader(line);
std::vector<float> numbers;
do {
// read as many numbers as possible.
for (float number; reader >> number;) {
numbers.push_back(number);
}
// consume and discard token from stream.
if (reader.fail()) {
reader.clear();
std::string token;
reader >> token;
}
} while (!reader.eof());
assert(numbers.size() == FEATURE_NUM + 1);
for (int i = 0; i < FEATURE_NUM; i++) {
feature->push_back(numbers[i]);
}
*label = numbers[FEATURE_NUM];
return 0;
}
int normalize(const vector<vector<float>>& origin_features,
vector<vector<float>>* features,
float rate) {
int inf = std::numeric_limits<int>::max();
vector<float> min_vec(FEATURE_NUM, static_cast<float>(inf));
vector<float> max_vec(FEATURE_NUM, -(static_cast<float>(inf)));
vector<float> sum_vec(FEATURE_NUM, 0);
vector<float> avg_vec(FEATURE_NUM, 0);
for (int i = 0; i < origin_features.size(); i++) {
for (int j = 0; j < FEATURE_NUM; j++) {
min_vec[j] = min(min_vec[j], origin_features[i][j]);
max_vec[j] = max(max_vec[j], origin_features[i][j]);
sum_vec[j] += origin_features[i][j];
}
}
for (int i = 0; i < FEATURE_NUM; i++) {
avg_vec[i] = sum_vec[i] / origin_features.size();
}
for (int i = 0; i < origin_features.size() * rate - 1; i++) {
vector<float> feat;
for (int j = 0; j < FEATURE_NUM; j++) {
feat.push_back((origin_features[i][j] - avg_vec[j]) /
(max_vec[j] - min_vec[j]));
}
features->push_back(feat);
}
}
int read_samples(const string fname,
vector<vector<float>>* features,
vector<float>* labels) {
fstream fin;
fin.open(fname);
if (!static_cast<bool>(fin)) {
return 1;
}
vector<vector<float>> origin_features;
vector<string> lines;
string line;
while (getline(fin, line)) {
lines.push_back(line);
}
fin.close();
for (int i = 0; i < lines.size(); i++) {
vector<float> feat;
float lbl = 0;
get_samples(lines[i], &feat, &lbl);
origin_features.push_back(feat);
if (i < lines.size() * rate - 1) {
labels->push_back(lbl);
}
}
cout << "finish read fata" << endl;
normalize(origin_features, features, rate);
assert(features->size() == labels->size());
return 0;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include <algorithm>
#include <iomanip>
#include <iostream>
#include <vector>
#include "include/data_reader.h"
#include "paddle_api.h" // NOLINT
using namespace paddle::lite_api; // NOLINT
class LRModel {
public:
void InitModel() {
// 1. Set CxxConfig
CxxConfig config;
config.set_model_dir("model_dir");
std::vector<Place> valid_places{Place{TARGET(kARM), PRECISION(kFloat)}};
config.set_valid_places(valid_places);
predictor_ = CreatePaddlePredictor<CxxConfig>(config);
}
float Predict(const vector<vector<float>>& features,
const vector<float>& labels) {
// Create Tensor
assert(features.size() == labels.size());
int batch_size = features.size();
std::unique_ptr<Tensor> input_tensor(std::move(predictor_->GetInput(0)));
input_tensor->Resize(shape_t({batch_size, FEATURE_NUM}));
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < FEATURE_NUM; j++) {
data[FEATURE_NUM * i + j] = features[i][j];
}
}
std::unique_ptr<Tensor> y_tensor(std::move(predictor_->GetInput(1)));
y_tensor->Resize(shape_t({batch_size, 1}));
auto* y_data = y_tensor->mutable_data<float>();
for (int i = 0; i < batch_size; i++) {
y_data[i] = labels[i];
}
predictor_->Run();
std::unique_ptr<const Tensor> output_tensor(
std::move(predictor_->GetOutput(0)));
return output_tensor->data<float>()[0];
}
private:
std::shared_ptr<PaddlePredictor> predictor_;
};
int shuffle(vector<vector<float>>* features, vector<float>* labels) {
assert(features->size() == labels->size());
vector<int> index;
for (int i = 0; i < features->size(); i++) {
index.push_back(i);
}
random_shuffle(index.begin(), index.end());
vector<vector<float>> tmp_features;
vector<float> tmp_labels;
for (int i = 0; i < features->size(); i++) {
tmp_features.push_back((*features)[index[i]]);
tmp_labels.push_back((*labels)[index[i]]);
}
for (int i = 0; i < features->size(); i++) {
for (int j = 0; j < FEATURE_NUM; j++) {
(*features)[i][j] = tmp_features[i][j];
}
(*labels)[i] = tmp_labels[i];
}
return 0;
}
int main(int argc, char* argv[]) {
if (argc < 2) {
cerr << "usage: ./demo_trainer is_small" << endl;
cerr << " if is_small is true, the batch size is set to 1, " << endl;
cerr << " and it will only runs for 10 steps." << endl;
return 1;
}
string is_small = argv[1];
vector<vector<float>> features;
vector<float> labels;
read_samples("housing.data", &features, &labels);
cout << "sample count: " << features.size() << " " << endl;
std::shared_ptr<LRModel> local_model(new LRModel());
local_model->InitModel();
if (is_small == "true") {
cout << "small mode" << endl;
for (int i; i < 10; i++) {
vector<vector<float>> batch_feature;
vector<float> batch_label;
batch_feature.push_back(features[i]);
batch_label.push_back(labels[i]);
auto loss = local_model->Predict(batch_feature, batch_label);
cout << "sample " << i << ": " << loss << endl;
}
} else if (is_small == "false") {
// shuffle
cout << "full model" << endl;
int epoch = 100;
int batch_size = 20;
int step = 0;
for (int i; i < epoch; i++) {
shuffle(&features, &labels);
for (int j = 0;
j < ceil(static_cast<float>(features.size()) / batch_size);
j++) {
int start_idx = j * batch_size;
int end_idx =
min((j + 1) * batch_size, static_cast<int>(features.size()));
auto batch_feature = vector<vector<float>>(features.begin() + start_idx,
features.begin() + end_idx);
auto batch_label =
vector<float>(labels.begin() + start_idx, labels.begin() + end_idx);
auto loss = local_model->Predict(batch_feature, batch_label);
if (step % 10 == 0) {
std::cout << "batch: " << i << ", step: " << step
<< ", Loss: " << loss << endl;
}
step += 1;
}
}
} else {
cerr << "wrong arg for is_small: " << is_small << endl;
}
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <assert.h>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
using std::string;
using std::vector;
using std::cerr;
using std::cout;
using std::endl;
using std::min;
using std::max;
using std::fstream;
extern int FEATURE_NUM;
int get_samples(string line, const vector<float>& feature, float* label);
int read_samples(const string fname,
vector<vector<float>>* features,
vector<float>* labels);
rm -rf build
mkdir build
cd build
LITE_ROOT=$1
NDK_ROOT=$2
cmake .. \
-DLITE_ROOT=${LITE_ROOT} \
-DNDK_ROOT=${NDK_ROOT} \
-DCMAKE_TOOLCHAIN_FILE=${NDK_ROOT}/build/cmake/android.toolchain.cmake \
-DANDROID_TOOLCHAIN=gcc \
-DANDROID_ABI="armeabi-v7a" \
-DANDROID_PLATFORM=android-23 \
-DANDROID=true \
-DANDROID_STL=c++_static
make
cd ..
# ./bin/demo_trainer
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
import argparse
import math
import numpy
import paddle
import paddle.fluid as fluid
def parse_args():
parser = argparse.ArgumentParser("fit_a_line")
parser.add_argument(
'--save_model',
action='store_true',
help="Whether to save main program")
parser.add_argument(
'--num_steps',
type=int,
default=1000000000000,
help="train steps")
parser.add_argument(
'--num_epochs', type=int, default=100, help="number of epochs.")
parser.add_argument(
'--batch_size', type=int, default=20, help="batch size.")
parser.add_argument(
'--shuffle',
action='store_true',
help="Whether to shuffle train data.")
args = parser.parse_args()
return args
# For training test cost
def train_test(executor, program, reader, feeder, fetch_list):
accumulated = 1 * [0]
count = 0
for data_test in reader():
outs = executor.run(
program=program, feed=feeder.feed(data_test), fetch_list=fetch_list)
accumulated = [x_c[0] + x_c[1][0] for x_c in zip(accumulated, outs)]
count += 1
return [x_d / count for x_d in accumulated]
def main():
if args.shuffle:
print("doing shuffle")
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.train(), buf_size=500),
batch_size=args.batch_size)
else:
train_reader = paddle.batch(
paddle.dataset.uci_housing.train(), batch_size=args.batch_size)
# feature vector of length 13
x = fluid.data(name='x', shape=[None, 13], dtype='float32')
y = fluid.data(name='y', shape=[None, 1], dtype='float32')
main_program = fluid.default_main_program()
startup_program = fluid.default_startup_program()
main_program.random_seed = 90
startup_program.random_seed = 90
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(cost)
test_program = main_program.clone(for_test=True)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_loss)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
num_epochs = args.num_epochs
# main train loop.
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe.run(startup_program)
if args.save_model:
fluid.io.save_persistables(exe, "model_dir")
# add feed and fetch op
feeded_var_names = ['x', 'y']
fetch_var_names = ['mean_0.tmp_0']
fluid.io.prepend_feed_ops(main_program, feeded_var_names)
fluid.io.append_fetch_ops(main_program, fetch_var_names)
with open("model_dir/__model__", "wb") as f:
f.write(main_program.desc.serialize_to_string())
with open("debug_main_program", "w") as f:
f.write(str(main_program))
print("train model saved to model_dir")
return
train_prompt = "Train cost"
step = 0
for pass_id in range(num_epochs):
for data_train in train_reader():
avg_loss_value, = exe.run(
main_program,
feed=feeder.feed(data_train),
fetch_list=[avg_loss])
print("%s, Step %d, Cost %f" %
(train_prompt, step, avg_loss_value[0]))
if step == args.num_steps - 1:
return
step += 1
if math.isnan(float(avg_loss_value[0])):
sys.exit("got NaN loss, training failed.")
if __name__ == '__main__':
args = parse_args()
main()
...@@ -106,13 +106,12 @@ add_kernel(lstm_arm ARM extra SRCS lstm_compute.cc DEPS ${lite_kernel_deps} math ...@@ -106,13 +106,12 @@ add_kernel(lstm_arm ARM extra SRCS lstm_compute.cc DEPS ${lite_kernel_deps} math
# 4. training kernels # 4. training kernels
add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_deps} math_arm)
if(LITE_WITH_TRAIN)
add_kernel(mean_grad_compute_arm ARM extra SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(mean_grad_compute_arm ARM train SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(activation_grad_compute_arm ARM basic SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(activation_grad_compute_arm ARM train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(elementwise_grad_compute_arm ARM basic SRCS elementwise_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(elementwise_grad_compute_arm ARM train SRCS elementwise_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(mul_grad_compute_arm ARM extra SRCS mul_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(mul_grad_compute_arm ARM train SRCS mul_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sgd_compute_arm ARM extra SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sgd_compute_arm ARM train SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm)
endif()
lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm)
lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm)
......
...@@ -30,6 +30,8 @@ lite_cc_library(subgraph_bridge_conv_transpose_op_bm SRCS conv_transpose_op.cc D ...@@ -30,6 +30,8 @@ lite_cc_library(subgraph_bridge_conv_transpose_op_bm SRCS conv_transpose_op.cc D
lite_cc_library(subgraph_bridge_reduce_full_op_bm SRCS reduce_full_op.cc DEPS ${bm_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_reduce_full_op_bm SRCS reduce_full_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_squeeze_op_bm SRCS squeeze_op.cc DEPS ${bm_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_squeeze_op_bm SRCS squeeze_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_cast_op_bm SRCS cast_op.cc DEPS ${bm_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_cast_op_bm SRCS cast_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_fill_constant_op_bm SRCS fill_constant_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_assign_value_op_bm SRCS assign_value_op.cc DEPS ${bm_subgraph_bridge_deps})
set(bm_subgraph_bridges set(bm_subgraph_bridges
subgraph_bridge_registry subgraph_bridge_registry
...@@ -58,4 +60,6 @@ set(bm_subgraph_bridges ...@@ -58,4 +60,6 @@ set(bm_subgraph_bridges
subgraph_bridge_reduce_full_op_bm subgraph_bridge_reduce_full_op_bm
subgraph_bridge_squeeze_op_bm subgraph_bridge_squeeze_op_bm
subgraph_bridge_cast_op_bm subgraph_bridge_cast_op_bm
subgraph_bridge_fill_constant_op_bm
subgraph_bridge_assign_value_op_bm
CACHE INTERNAL "bm_subgraph_bridges") CACHE INTERNAL "bm_subgraph_bridges")
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <bmcompiler_if.h> #include <bmcompiler_if.h>
#include <bmcompiler_if_lite.h>
#include <bmcompiler_op_code.h> #include <bmcompiler_op_code.h>
#include "lite/kernels/bm/bridges/graph.h" #include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/npu/bridges/registry.h" #include "lite/kernels/npu/bridges/registry.h"
...@@ -35,16 +36,14 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -35,16 +36,14 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto output_var_name = op_info->Output("Out").front(); auto output_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>(); auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims(); auto output_dims = output->dims();
const int64_t* x_shape_data = const_cast<const int64_t*>(&x_dims.data()[0]); bool x_is_const = !graph->HasNode(x_var_name);
const int64_t* output_shape_data =
const_cast<const int64_t*>(&output_dims.data()[0]);
std::vector<int32_t> i_x_shape_data(x_dims.size()); std::vector<int32_t> i_x_shape_data(x_dims.size());
std::vector<int32_t> i_output_shape_data(output_dims.size()); std::vector<int32_t> i_output_shape_data(output_dims.size());
for (size_t i = 0; i < x_dims.size(); i++) { for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]); i_x_shape_data[i] = x_dims[i];
} }
for (size_t i = 0; i < output_dims.size(); i++) { for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]); i_output_shape_data[i] = output_dims[i];
} }
float alpha = 0.f; float alpha = 0.f;
int active_type_id = 0; int active_type_id = 0;
...@@ -59,6 +58,15 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -59,6 +58,15 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
LOG(FATAL) << "[BM] unsupport act type"; LOG(FATAL) << "[BM] unsupport act type";
return FAILED; return FAILED;
} }
const float* x_data = const_cast<const float*>(x->mutable_data<float>());
if (x_is_const) {
bm_add_const_tensor(graph->GetCompilerHandle(),
static_cast<const char*>(x_var_name.c_str()),
const_cast<const int*>(&i_x_shape_data[0]),
x_dims.size(),
static_cast<bm_data_type_t>(DTYPE_FP32),
static_cast<const void*>(x_data));
}
if (op_type == "relu" || op_type == "leaky_relu") { if (op_type == "relu" || op_type == "leaky_relu") {
add_relu_layer(graph->GetCompilerHandle(), add_relu_layer(graph->GetCompilerHandle(),
const_cast<const int*>(&i_x_shape_data[0]), const_cast<const int*>(&i_x_shape_data[0]),
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <bmcompiler_defs.h>
#include <bmcompiler_if.h>
#include <bmcompiler_if_lite.h>
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace bm {
int AssignValueConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto scope = op->scope();
auto op_info = op->op_info();
auto output_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims();
std::vector<int32_t> i_output_shape_data(output_dims.size());
int buffer_size = 1;
for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_dims[i]);
buffer_size *= i_output_shape_data[i];
}
auto fp32_values = op_info->GetAttr<std::vector<float>>("fp32_values");
float* assign_data =
reinterpret_cast<float*>(malloc(buffer_size * sizeof(float)));
CHECK(assign_data != nullptr);
CHECK_EQ(buffer_size, fp32_values.size());
bm_add_const_tensor(graph->GetCompilerHandle(),
static_cast<const char*>(output_var_name.c_str()),
const_cast<const int*>(i_output_shape_data.data()),
output_dims.size(),
static_cast<bm_data_type_t>(DTYPE_FP32),
reinterpret_cast<const void*>(assign_data));
graph->AddNode(output_var_name);
return SUCCESS;
}
} // namespace bm
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(assign_value,
kBM,
paddle::lite::subgraph::bm::AssignValueConverter);
...@@ -39,6 +39,7 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -39,6 +39,7 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto filter_var_name = op_info->Input("Filter").front(); auto filter_var_name = op_info->Input("Filter").front();
auto filter = scope->FindVar(filter_var_name)->GetMutable<lite::Tensor>(); auto filter = scope->FindVar(filter_var_name)->GetMutable<lite::Tensor>();
auto filter_dims = filter->dims(); auto filter_dims = filter->dims();
CHECK_EQ(input_dims.size(), 4); CHECK_EQ(input_dims.size(), 4);
CHECK_EQ(output_dims.size(), 4); CHECK_EQ(output_dims.size(), 4);
CHECK_EQ(filter_dims.size(), 4); CHECK_EQ(filter_dims.size(), 4);
...@@ -90,6 +91,7 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -90,6 +91,7 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
dilations[1], dilations[1],
static_cast<int>(has_bias)); static_cast<int>(has_bias));
graph->AddNode(output_var_name); graph->AddNode(output_var_name);
LOG(INFO) << output_var_name << input_dims << " " << output_dims;
return SUCCESS; return SUCCESS;
} }
......
...@@ -65,6 +65,7 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -65,6 +65,7 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto output_dims = output->dims(); auto output_dims = output->dims();
const int64_t* output_shape_data = const int64_t* output_shape_data =
const_cast<const int64_t*>(&output_dims.data()[0]); const_cast<const int64_t*>(&output_dims.data()[0]);
LOG(INFO) << x_dims << " " << output_dims;
std::vector<int32_t> i_output_shape_data(output_dims.size()); std::vector<int32_t> i_output_shape_data(output_dims.size());
for (size_t i = 0; i < output_dims.size(); i++) { for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]); i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <bmcompiler_defs.h>
#include <bmcompiler_if.h>
#include <bmcompiler_if_lite.h>
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace bm {
int FillConstantConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto scope = op->scope();
auto op_info = op->op_info();
auto output_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims();
std::vector<int32_t> i_output_shape_data(output_dims.size());
int buffer_size = 1;
for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_dims[i]);
}
float* const_data =
reinterpret_cast<float*>(malloc(buffer_size * sizeof(float)));
CHECK(const_data != nullptr);
auto value = op_info->GetAttr<float>("value");
for (size_t i = 0; i < buffer_size; i++) {
const_data[i] = value;
}
bm_add_const_tensor(graph->GetCompilerHandle(),
static_cast<const char*>(output_var_name.c_str()),
const_cast<const int*>(i_output_shape_data.data()),
output_dims.size(),
static_cast<bm_data_type_t>(DTYPE_FP32),
reinterpret_cast<const void*>(const_data));
graph->AddNode(output_var_name);
return SUCCESS;
}
} // namespace bm
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(fill_constant,
kBM,
paddle::lite::subgraph::bm::FillConstantConverter);
...@@ -29,7 +29,6 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -29,7 +29,6 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto op_info = op->op_info(); auto op_info = op->op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
auto unique_op_name = lite::subgraph::bm::UniqueName(op_type); auto unique_op_name = lite::subgraph::bm::UniqueName(op_type);
// only support y is const
// input // input
auto x_var_name = op_info->Input("X").front(); auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>(); auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
...@@ -61,6 +60,12 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -61,6 +60,12 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto y_var_name = op_info->Input("Y").front(); auto y_var_name = op_info->Input("Y").front();
auto y = scope->FindVar(y_var_name)->GetMutable<lite::Tensor>(); auto y = scope->FindVar(y_var_name)->GetMutable<lite::Tensor>();
auto y_dims = y->dims(); auto y_dims = y->dims();
bool y_is_const = !graph->HasNode(y_var_name);
CHECK_EQ(y_dims.size(), 2);
int i_y_shape_data[2];
for (size_t i = 0; i < 2; i++) {
i_y_shape_data[i] = y_dims[i];
}
// output // output
auto output_var_name = op_info->Output("Out").front(); auto output_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>(); auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
...@@ -71,20 +76,39 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -71,20 +76,39 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
for (size_t i = 0; i < output_dims.size(); i++) { for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]); i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
} }
add_fc_layer(graph->GetCompilerHandle(), if (y_is_const) {
const_cast<const int*>(&i_x_reshape_shape_data[0]), add_fc_layer(graph->GetCompilerHandle(),
2, const_cast<const int*>(&i_x_reshape_shape_data[0]),
static_cast<const char*>(unique_op_reshape_name.c_str()), 2,
const_cast<const int*>(&i_output_shape_data[0]), static_cast<const char*>(unique_op_reshape_name.c_str()),
output_dims.size(), const_cast<const int*>(&i_output_shape_data[0]),
static_cast<const char*>(output_var_name.c_str()), output_dims.size(),
static_cast<const char*>(unique_op_name.c_str()), static_cast<const char*>(output_var_name.c_str()),
i_x_reshape_shape_data[1], static_cast<const char*>(unique_op_name.c_str()),
i_output_shape_data[1], i_x_reshape_shape_data[1],
static_cast<const float*>(y->mutable_data<float>()), i_output_shape_data[1],
nullptr, static_cast<const float*>(y->mutable_data<float>()),
0, nullptr,
0); 0,
0);
} else {
add_fc_weight_layer(
graph->GetCompilerHandle(),
const_cast<const int*>(&i_x_reshape_shape_data[0]),
2,
static_cast<const char*>(unique_op_reshape_name.c_str()),
const_cast<const int*>(&i_output_shape_data[0]),
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
static_cast<const char*>(unique_op_name.c_str()),
const_cast<const int*>(&i_y_shape_data[0]),
2,
static_cast<const char*>(y_var_name.c_str()),
i_x_reshape_shape_data[1],
nullptr,
0,
0);
}
graph->AddNode(output_var_name); graph->AddNode(output_var_name);
return SUCCESS; return SUCCESS;
} }
......
...@@ -51,3 +51,5 @@ USE_SUBGRAPH_BRIDGE(reduce_mean, kBM); ...@@ -51,3 +51,5 @@ USE_SUBGRAPH_BRIDGE(reduce_mean, kBM);
USE_SUBGRAPH_BRIDGE(squeeze, kBM); USE_SUBGRAPH_BRIDGE(squeeze, kBM);
USE_SUBGRAPH_BRIDGE(squeeze2, kBM); USE_SUBGRAPH_BRIDGE(squeeze2, kBM);
USE_SUBGRAPH_BRIDGE(cast, kBM); USE_SUBGRAPH_BRIDGE(cast, kBM);
USE_SUBGRAPH_BRIDGE(fill_constant, kBM);
USE_SUBGRAPH_BRIDGE(assign_value, kBM);
...@@ -35,7 +35,7 @@ int SubgraphEngine::BuildDeviceProgram() { ...@@ -35,7 +35,7 @@ int SubgraphEngine::BuildDeviceProgram() {
graph.CreateCompilerHandle(); graph.CreateCompilerHandle();
auto& ctx = this->ctx_->template As<BMContext>(); auto& ctx = this->ctx_->template As<BMContext>();
for (auto& inst : origin_program_) { for (auto& inst : origin_program_) {
auto op = inst.op(); auto op = const_cast<OpLite*>(inst.op());
CHECK(op); CHECK(op);
op->CheckShape(); op->CheckShape();
op->InferShape(); op->InferShape();
......
...@@ -5,6 +5,3 @@ add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kerne ...@@ -5,6 +5,3 @@ add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kerne
add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op) add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op)
add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps}) add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps})
add_kernel(crf_decoding_compute_host Host extra SRCS crf_decoding_compute.cc DEPS ${lite_kernel_deps}) add_kernel(crf_decoding_compute_host Host extra SRCS crf_decoding_compute.cc DEPS ${lite_kernel_deps})
#lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any)
#lite_cc_test(test_multiclass_nms_compute_host SRCS multiclass_nms_compute_test.cc DEPS multiclass_nms_compute_host any)
...@@ -369,6 +369,7 @@ void MulticlassNmsCompute::Run() { ...@@ -369,6 +369,7 @@ void MulticlassNmsCompute::Run() {
} }
} else { } else {
outs->Resize({static_cast<int64_t>(num_kept), out_dim}); outs->Resize({static_cast<int64_t>(num_kept), out_dim});
outs->mutable_data<float>();
int offset = 0; int offset = 0;
int* oindices = nullptr; int* oindices = nullptr;
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/host/multiclass_nms_compute.h"
#include <gtest/gtest.h>
#include <map>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
template <typename dtype>
static bool sort_score_pair_descend(const std::pair<float, dtype>& pair1,
const std::pair<float, dtype>& pair2) {
return pair1.first > pair2.first;
}
template <typename dtype>
void get_max_score_index(const dtype* scores,
int num,
float threshold,
int top_k,
std::vector<std::pair<dtype, int>>* score_index_vec) {
//! Generate index score pairs.
for (int i = 0; i < num; ++i) {
if (scores[i] > threshold) {
score_index_vec->push_back(std::make_pair(scores[i], i));
}
}
//! Sort the score pair according to the scores in descending order
std::stable_sort(score_index_vec->begin(),
score_index_vec->end(),
sort_score_pair_descend<int>);
//! Keep top_k scores if needed.
if (top_k > -1 && top_k < score_index_vec->size()) {
score_index_vec->resize(top_k);
}
}
template <typename dtype>
dtype bbox_size(const dtype* bbox, bool normalized = true) {
if (bbox[2] < bbox[0] || bbox[3] < bbox[1]) {
// If bbox is invalid (e.g. xmax < xmin or ymax < ymin), return 0.
return dtype(0.);
} else {
const dtype width = bbox[2] - bbox[0];
const dtype height = bbox[3] - bbox[1];
if (normalized) {
return width * height;
} else {
// If bbox is not within range [0, 1].
return (width + 1) * (height + 1);
}
}
}
template <typename dtype>
dtype jaccard_overlap(const dtype* bbox1, const dtype* bbox2) {
if (bbox2[0] > bbox1[2] || bbox2[2] < bbox1[0] || bbox2[1] > bbox1[3] ||
bbox2[3] < bbox1[1]) {
return dtype(0.);
} else {
const dtype inter_xmin = std::max(bbox1[0], bbox2[0]);
const dtype inter_ymin = std::max(bbox1[1], bbox2[1]);
const dtype inter_xmax = std::min(bbox1[2], bbox2[2]);
const dtype inter_ymax = std::min(bbox1[3], bbox2[3]);
const dtype inter_width = inter_xmax - inter_xmin;
const dtype inter_height = inter_ymax - inter_ymin;
const dtype inter_size = inter_width * inter_height;
const dtype bbox1_size = bbox_size(bbox1);
const dtype bbox2_size = bbox_size(bbox2);
return inter_size / (bbox1_size + bbox2_size - inter_size);
}
}
template <typename dtype>
void apply_nms_fast(const dtype* bboxes,
const dtype* scores,
int num,
float score_threshold,
float nms_threshold,
float eta,
int top_k,
std::vector<int>* indices) {
// Get top_k scores (with corresponding indices).
std::vector<std::pair<dtype, int>> score_index_vec;
get_max_score_index(scores, num, score_threshold, top_k, &score_index_vec);
// Do nms.
float adaptive_threshold = nms_threshold;
indices->clear();
while (score_index_vec.size() != 0) {
const int idx = score_index_vec.front().second;
bool keep = true;
for (int k = 0; k < indices->size(); ++k) {
if (keep) {
const int kept_idx = (*indices)[k];
float overlap =
jaccard_overlap(bboxes + idx * 4, bboxes + kept_idx * 4);
keep = overlap <= adaptive_threshold;
} else {
break;
}
}
if (keep) {
indices->push_back(idx);
}
score_index_vec.erase(score_index_vec.begin());
if (keep && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
}
template <typename dtype>
void multiclass_nms_compute_ref(const operators::MulticlassNmsParam& param,
int class_num,
const std::vector<int>& priors,
bool share_location,
std::vector<float>* result) {
int background_id = param.background_label;
int keep_topk = param.keep_top_k;
int nms_topk = param.nms_top_k;
float conf_thresh = param.score_threshold;
float nms_thresh = param.nms_threshold;
float nms_eta = param.nms_eta;
const dtype* bbox_data = param.bboxes->data<const dtype>();
const dtype* conf_data = param.scores->data<const dtype>();
dtype* out = param.out->mutable_data<dtype>();
(*result).clear();
int num_kept = 0;
std::vector<std::map<int, std::vector<int>>> all_indices;
int64_t conf_offset = 0;
int64_t bbox_offset = 0;
for (int i = 0; i < priors.size(); ++i) {
std::map<int, std::vector<int>> indices;
int num_det = 0;
int num_priors = priors[i];
int conf_idx = class_num * conf_offset;
int bbox_idx =
share_location ? bbox_offset * 4 : bbox_offset * 4 * class_num;
for (int c = 0; c < class_num; ++c) {
if (c == background_id) {
// Ignore background class
continue;
}
const dtype* cur_conf_data = conf_data + conf_idx + c * num_priors;
const dtype* cur_bbox_data = bbox_data + bbox_idx;
if (!share_location) {
cur_bbox_data += c * num_priors * 4;
}
apply_nms_fast(cur_bbox_data,
cur_conf_data,
num_priors,
conf_thresh,
nms_thresh,
nms_eta,
nms_topk,
&(indices[c]));
num_det += indices[c].size();
}
if (keep_topk > -1 && num_det > keep_topk) {
std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs;
for (auto it = indices.begin(); it != indices.end(); ++it) {
int label = it->first;
const std::vector<int>& label_indices = it->second;
for (int j = 0; j < label_indices.size(); ++j) {
int idx = label_indices[j];
float score = conf_data[conf_idx + label * num_priors + idx];
score_index_pairs.push_back(
std::make_pair(score, std::make_pair(label, idx)));
}
}
// Keep top k results per image.
std::stable_sort(score_index_pairs.begin(),
score_index_pairs.end(),
sort_score_pair_descend<std::pair<int, int>>);
score_index_pairs.resize(keep_topk);
// Store the new indices.
std::map<int, std::vector<int>> new_indices;
for (int j = 0; j < score_index_pairs.size(); ++j) {
int label = score_index_pairs[j].second.first;
int idx = score_index_pairs[j].second.second;
new_indices[label].push_back(idx);
}
all_indices.push_back(new_indices);
num_kept += keep_topk;
} else {
all_indices.push_back(indices);
num_kept += num_det;
}
conf_offset += num_priors;
bbox_offset += num_priors;
}
if (num_kept == 0) {
(*result).clear();
(*result).resize(1);
(*result)[0] = -1;
return;
} else {
(*result).resize(num_kept * 6);
}
int count = 0;
conf_offset = 0;
bbox_offset = 0;
for (int i = 0; i < priors.size(); ++i) {
int num_priors = priors[i];
int conf_idx = class_num * conf_offset;
int bbox_idx =
share_location ? bbox_offset * 4 : bbox_offset * 4 * class_num;
for (auto it = all_indices[i].begin(); it != all_indices[i].end(); ++it) {
int label = it->first;
std::vector<int>& indices = it->second;
const dtype* cur_conf_data = conf_data + conf_idx + label * num_priors;
const dtype* cur_bbox_data = bbox_data + bbox_idx;
if (!share_location) {
cur_bbox_data += label * num_priors * 4;
}
for (int j = 0; j < indices.size(); ++j) {
int idx = indices[j];
(*result)[count * 6] = label;
(*result)[count * 6 + 1] = cur_conf_data[idx];
for (int k = 0; k < 4; ++k) {
(*result)[count * 6 + 2 + k] = cur_bbox_data[idx * 4 + k];
}
++count;
}
}
conf_offset += num_priors;
bbox_offset += num_priors;
}
}
TEST(multiclass_nms_host, init) {
MulticlassNmsCompute multiclass_nms;
ASSERT_EQ(multiclass_nms.precision(), PRECISION(kFloat));
ASSERT_EQ(multiclass_nms.target(), TARGET(kHost));
}
TEST(multiclass_nms_host, retrive_op) {
auto multiclass_nms =
KernelRegistry::Global().Create<TARGET(kHost), PRECISION(kFloat)>(
"multiclass_nms");
ASSERT_FALSE(multiclass_nms.empty());
ASSERT_TRUE(multiclass_nms.front());
}
TEST(multiclass_nms_host, compute) {
MulticlassNmsCompute multiclass_nms;
operators::MulticlassNmsParam param;
lite::Tensor bbox, conf, out;
std::vector<float> out_ref;
for (std::vector<int> priors : {std::vector<int>({2, 2, 2})}) {
int N = priors.size();
for (bool share_location : {true}) {
for (int class_num : {1, 4, 10}) {
DDim* bbox_dim;
DDim* conf_dim;
int M = priors[0];
if (share_location) {
bbox_dim = new DDim({N, M, 4});
} else {
bbox_dim = new DDim({class_num, M, 4});
}
conf_dim = new DDim({N, class_num, M});
bbox.Resize(*bbox_dim);
conf.Resize(*conf_dim);
for (int background_id : {0}) {
for (int keep_topk : {1, 5, 10}) {
for (int nms_topk : {1, 5, 10}) {
for (float nms_eta : {1.0, 0.99, 0.9}) {
for (float nms_thresh : {0.5, 0.7}) {
for (float conf_thresh : {0.5, 0.7}) {
auto* conf_data = conf.mutable_data<float>();
auto* bbox_data = bbox.mutable_data<float>();
for (int i = 0; i < bbox_dim->production(); ++i) {
bbox_data[i] = i * 1. / bbox_dim->production();
}
for (int i = 0; i < conf_dim->production(); ++i) {
conf_data[i] = i * 1. / conf_dim->production();
}
param.bboxes = &bbox;
param.scores = &conf;
param.out = &out;
param.background_label = background_id;
param.keep_top_k = keep_topk;
param.nms_top_k = nms_topk;
param.score_threshold = conf_thresh;
param.nms_threshold = nms_thresh;
param.nms_eta = nms_eta;
multiclass_nms.SetParam(param);
multiclass_nms.Run();
auto* out_data = out.mutable_data<float>();
out_ref.clear();
multiclass_nms_compute_ref<float>(
param, class_num, priors, share_location, &out_ref);
EXPECT_EQ(out.dims().production(), out_ref.size());
if (out.dims().production() == out_ref.size()) {
auto* out_ref_data = out_ref.data();
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
}
}
}
}
}
}
}
delete bbox_dim;
delete conf_dim;
}
}
}
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(multiclass_nms, kHost, kFloat, kNCHW, def);
...@@ -25,6 +25,7 @@ lite_cc_library(subgraph_bridge_layer_norm_op_xpu SRCS layer_norm_op.cc DEPS ${x ...@@ -25,6 +25,7 @@ lite_cc_library(subgraph_bridge_layer_norm_op_xpu SRCS layer_norm_op.cc DEPS ${x
lite_cc_library(subgraph_bridge_dropout_op_xpu SRCS dropout_op.cc DEPS ${xpu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_dropout_op_xpu SRCS dropout_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_matmul_op_xpu SRCS matmul_op.cc DEPS ${xpu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_matmul_op_xpu SRCS matmul_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_cast_op_xpu SRCS cast_op.cc DEPS ${xpu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_cast_op_xpu SRCS cast_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_yolo_box_op_xpu SRCS yolo_box_op.cc DEPS ${xpu_subgraph_bridge_deps})
set(xpu_subgraph_bridges set(xpu_subgraph_bridges
subgraph_bridge_registry subgraph_bridge_registry
...@@ -48,6 +49,7 @@ set(xpu_subgraph_bridges ...@@ -48,6 +49,7 @@ set(xpu_subgraph_bridges
subgraph_bridge_dropout_op_xpu subgraph_bridge_dropout_op_xpu
subgraph_bridge_matmul_op_xpu subgraph_bridge_matmul_op_xpu
subgraph_bridge_cast_op_xpu subgraph_bridge_cast_op_xpu
subgraph_bridge_yolo_box_op_xpu
CACHE INTERNAL "xpu_subgraph_bridges") CACHE INTERNAL "xpu_subgraph_bridges")
message(STATUS "+++++ xpu_subgraph_bridges: ${xpu_subgraph_bridges}") message(STATUS "+++++ xpu_subgraph_bridges: ${xpu_subgraph_bridges}")
...@@ -37,3 +37,4 @@ USE_SUBGRAPH_BRIDGE(gelu, kXPU); ...@@ -37,3 +37,4 @@ USE_SUBGRAPH_BRIDGE(gelu, kXPU);
USE_SUBGRAPH_BRIDGE(dropout, kXPU); USE_SUBGRAPH_BRIDGE(dropout, kXPU);
USE_SUBGRAPH_BRIDGE(matmul, kXPU); USE_SUBGRAPH_BRIDGE(matmul, kXPU);
USE_SUBGRAPH_BRIDGE(cast, kXPU); USE_SUBGRAPH_BRIDGE(cast, kXPU);
USE_SUBGRAPH_BRIDGE(yolo_box, kXPU);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/graph.h"
#include "lite/kernels/xpu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace xpu {
int YoloBoxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindTensor(x_name);
auto img_size_name = op_info->Input("ImgSize").front();
auto img_size = scope->FindTensor(img_size_name);
auto boxes_name = op_info->Output("Boxes").front();
auto scores_name = op_info->Output("Scores").front();
auto anchors = op_info->GetAttr<std::vector<int>>("anchors");
auto class_num = op_info->GetAttr<int>("class_num");
auto conf_thresh = op_info->GetAttr<float>("conf_thresh");
auto downsample_ratio = op_info->GetAttr<int>("downsample_ratio");
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// ImgSize node
std::shared_ptr<Node> img_size_node = nullptr;
if (graph->Has(img_size_name)) {
img_size_node = graph->Get(img_size_name);
} else {
img_size_node = graph->Add(img_size_name, *img_size);
}
// Softmax node
auto yolo_box_data =
graph->builder_.CreateYoloBox(*x_node->data(),
*img_size_node->data(),
CvtShape<xtcl::Integer>(anchors),
class_num,
conf_thresh,
downsample_ratio);
graph->Add(boxes_name, graph->builder_.GetField(yolo_box_data, 0));
graph->Add(scores_name, graph->builder_.GetField(yolo_box_data, 1));
return SUCCESS;
}
} // namespace xpu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(yolo_box,
kXPU,
paddle::lite::subgraph::xpu::YoloBoxConverter);
...@@ -34,7 +34,7 @@ int SubgraphEngine::BuildDeviceProgram() { ...@@ -34,7 +34,7 @@ int SubgraphEngine::BuildDeviceProgram() {
subgraph::xpu::Graph graph; subgraph::xpu::Graph graph;
const auto& bridges = subgraph::Registry::Instance(); const auto& bridges = subgraph::Registry::Instance();
for (auto& inst : origin_program_) { for (auto& inst : origin_program_) {
auto op = inst.op(); auto op = const_cast<OpLite*>(inst.op());
CHECK(op); CHECK(op);
op->CheckShape(); op->CheckShape();
op->InferShape(); op->InferShape();
...@@ -43,10 +43,8 @@ int SubgraphEngine::BuildDeviceProgram() { ...@@ -43,10 +43,8 @@ int SubgraphEngine::BuildDeviceProgram() {
return subgraph::FAILED; return subgraph::FAILED;
} }
auto kernel = inst.kernel(); auto kernel = inst.kernel();
status |= status |= bridges.Select(op_type, TARGET(kXPU))(
bridges.Select(op_type, TARGET(kXPU))(reinterpret_cast<void*>(&graph), reinterpret_cast<void*>(&graph), op, const_cast<KernelBase*>(kernel));
const_cast<OpLite*>(op),
const_cast<KernelBase*>(kernel));
if (subgraph::CHECK_FAILED(status)) { if (subgraph::CHECK_FAILED(status)) {
return subgraph::FAILED; return subgraph::FAILED;
} }
......
...@@ -141,13 +141,12 @@ add_operator(lstm_op extra SRCS lstm_op.cc DEPS ${op_DEPS}) ...@@ -141,13 +141,12 @@ add_operator(lstm_op extra SRCS lstm_op.cc DEPS ${op_DEPS})
# 4. training op # 4. training op
add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS}) add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS})
if (LITE_WITH_TRAIN)
add_operator(mean_grad_op extra SRCS mean_grad_op.cc DEPS ${op_DEPS}) add_operator(mean_grad_op train SRCS mean_grad_op.cc DEPS ${op_DEPS})
add_operator(activation_grad_ops basic SRCS activation_grad_ops.cc DEPS ${op_DEPS}) add_operator(activation_grad_ops train SRCS activation_grad_ops.cc DEPS ${op_DEPS})
add_operator(elementwise_grad_op extra SRCS elementwise_grad_ops.cc DEPS ${op_DEPS}) add_operator(elementwise_grad_op train SRCS elementwise_grad_ops.cc DEPS ${op_DEPS})
add_operator(mul_grad_op basic SRCS mul_grad_op.cc DEPS ${op_DEPS}) add_operator(mul_grad_op train SRCS mul_grad_op.cc DEPS ${op_DEPS})
add_operator(sgd_op extra SRCS sgd_op.cc DEPS ${op_DEPS}) add_operator(sgd_op train SRCS sgd_op.cc DEPS ${op_DEPS})
endif()
if (NOT LITE_WITH_X86) if (NOT LITE_WITH_X86)
lite_cc_test(test_fc_op SRCS fc_op_test.cc lite_cc_test(test_fc_op SRCS fc_op_test.cc
......
...@@ -42,14 +42,8 @@ bool MulticlassNmsOpLite::CheckShape() const { ...@@ -42,14 +42,8 @@ bool MulticlassNmsOpLite::CheckShape() const {
} }
bool MulticlassNmsOpLite::InferShapeImpl() const { bool MulticlassNmsOpLite::InferShapeImpl() const {
auto box_dims = param_.bboxes->dims(); // InferShape is useless for multiclass_nms
auto score_dims = param_.scores->dims(); // out's dim is not sure before the end of calculation
auto score_size = score_dims.size();
if (score_size == 3) {
param_.out->Resize({box_dims[1], box_dims[2], 3});
} else {
param_.out->Resize({-1, box_dims[2] + 2});
}
return true; return true;
} }
......
...@@ -32,6 +32,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_ ...@@ -32,6 +32,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_
lite_cc_test(test_kernel_dropout_compute SRCS dropout_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_dropout_compute SRCS dropout_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_softmax_compute SRCS softmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_softmax_compute SRCS softmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_mul_compute SRCS mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_mul_compute SRCS mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_multiclass_nms_compute SRCS multiclass_nms_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_batch_norm_compute SRCS batch_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_batch_norm_compute SRCS batch_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_pool_compute SRCS pool_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_pool_compute SRCS pool_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_fill_constant_compute SRCS fill_constant_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_fill_constant_compute SRCS fill_constant_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <cmath>
#include <string>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
template <class T>
bool SortScorePairDescend(const std::pair<float, T>& pair1,
const std::pair<float, T>& pair2) {
return pair1.first > pair2.first;
}
template <class T>
static void GetMaxScoreIndex(const std::vector<T>& scores,
const T threshold,
int top_k,
std::vector<std::pair<T, int>>* sorted_indices) {
for (size_t i = 0; i < scores.size(); ++i) {
if (scores[i] > threshold) {
sorted_indices->push_back(std::make_pair(scores[i], i));
}
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices->begin(),
sorted_indices->end(),
SortScorePairDescend<int>);
// Keep top_k scores if needed.
if (top_k > -1 && top_k < static_cast<int>(sorted_indices->size())) {
sorted_indices->resize(top_k);
}
}
template <class T>
static T BBoxArea(const T* box, const bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <class T>
static T JaccardOverlap(const T* box1, const T* box2, const bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) {
return static_cast<T>(0.);
} else {
const T inter_xmin = std::max(box1[0], box2[0]);
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
T norm = normalized ? static_cast<T>(0.) : static_cast<T>(1.);
T inter_w = inter_xmax - inter_xmin + norm;
T inter_h = inter_ymax - inter_ymin + norm;
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
template <class T>
void SliceOneClass(const Tensor& items,
const int class_id,
Tensor* one_class_item) {
T* item_data = one_class_item->mutable_data<T>();
const T* items_data = items.data<T>();
const int64_t num_item = items.dims()[0];
const int64_t class_num = items.dims()[1];
if (items.dims().size() == 3) {
int64_t item_size = items.dims()[2];
for (int i = 0; i < num_item; ++i) {
std::memcpy(item_data + i * item_size,
items_data + i * class_num * item_size + class_id * item_size,
sizeof(T) * item_size);
}
} else {
for (int i = 0; i < num_item; ++i) {
item_data[i] = items_data[i * class_num + class_id];
}
}
}
template <typename T>
void NMSFast(const Tensor& bbox,
const Tensor& scores,
const T score_threshold,
const T nms_threshold,
const T eta,
const int64_t top_k,
std::vector<int>* selected_indices,
const bool normalized) {
// The total boxes for each instance.
int64_t num_boxes = bbox.dims()[0];
// 4: [xmin ymin xmax ymax]
// 8: [x1 y1 x2 y2 x3 y3 x4 y4]
// 16, 24, or 32: [x1 y1 x2 y2 ... xn yn], n = 8, 12 or 16
int64_t box_size = bbox.dims()[1];
std::vector<T> scores_data(num_boxes);
std::copy_n(scores.data<T>(), num_boxes, scores_data.begin());
std::vector<std::pair<T, int>> sorted_indices;
GetMaxScoreIndex(scores_data, score_threshold, top_k, &sorted_indices);
selected_indices->clear();
T adaptive_threshold = nms_threshold;
const T* bbox_data = bbox.data<T>();
while (sorted_indices.size() != 0) {
const int idx = sorted_indices.front().second;
bool keep = true;
for (size_t k = 0; k < selected_indices->size(); ++k) {
if (keep) {
const int kept_idx = (*selected_indices)[k];
T overlap = T(0.);
// 4: [xmin ymin xmax ymax]
if (box_size == 4) {
overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size,
normalized);
} else {
LOG(FATAL) << "not support";
}
keep = overlap <= adaptive_threshold;
} else {
break;
}
}
if (keep) {
selected_indices->push_back(idx);
}
sorted_indices.erase(sorted_indices.begin());
if (keep && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
}
template <typename T>
void MultiClassNMS(const Tensor& scores,
const Tensor& bboxes,
const int scores_size,
std::map<int, std::vector<int>>* indices,
int* num_nmsed_out,
int64_t background_label,
int64_t nms_top_k,
int64_t keep_top_k,
bool normalized,
T nms_threshold,
T nms_eta,
T score_threshold) {
int num_det = 0;
int64_t class_num = scores_size == 3 ? scores.dims()[0] : scores.dims()[1];
Tensor bbox_slice, score_slice;
for (int64_t c = 0; c < class_num; ++c) {
if (c == background_label) continue;
if (scores_size == 3) {
score_slice = scores.Slice<T>(c, c + 1);
bbox_slice = bboxes;
} else {
score_slice.Resize({scores.dims()[0], 1});
bbox_slice.Resize({scores.dims()[0], 4});
SliceOneClass<T>(scores, c, &score_slice);
SliceOneClass<T>(bboxes, c, &bbox_slice);
}
NMSFast(bbox_slice,
score_slice,
score_threshold,
nms_threshold,
nms_eta,
nms_top_k,
&((*indices)[c]),
normalized);
if (scores_size == 2) {
std::stable_sort((*indices)[c].begin(), (*indices)[c].end());
}
num_det += (*indices)[c].size();
}
*num_nmsed_out = num_det;
const T* scores_data = scores.data<T>();
if (keep_top_k > -1 && num_det > keep_top_k) {
const T* sdata;
std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs;
for (const auto& it : *indices) {
int label = it.first;
if (scores_size == 3) {
sdata = scores_data + label * scores.dims()[1];
} else {
score_slice.Resize({scores.dims()[0], 1});
SliceOneClass<T>(scores, label, &score_slice);
sdata = score_slice.data<T>();
}
const std::vector<int>& label_indices = it.second;
for (size_t j = 0; j < label_indices.size(); ++j) {
int idx = label_indices[j];
score_index_pairs.push_back(
std::make_pair(sdata[idx], std::make_pair(label, idx)));
}
}
// Keep top k results per image.
std::stable_sort(score_index_pairs.begin(),
score_index_pairs.end(),
SortScorePairDescend<std::pair<int, int>>);
score_index_pairs.resize(keep_top_k);
// Store the new indices.
std::map<int, std::vector<int>> new_indices;
for (size_t j = 0; j < score_index_pairs.size(); ++j) {
int label = score_index_pairs[j].second.first;
int idx = score_index_pairs[j].second.second;
new_indices[label].push_back(idx);
}
if (scores_size == 2) {
for (const auto& it : new_indices) {
int label = it.first;
std::stable_sort(new_indices[label].begin(), new_indices[label].end());
}
}
new_indices.swap(*indices);
*num_nmsed_out = keep_top_k;
}
}
template <typename T>
void MultiClassOutput(const Tensor& scores,
const Tensor& bboxes,
const std::map<int, std::vector<int>>& selected_indices,
const int scores_size,
Tensor* outs,
int* oindices = nullptr,
const int offset = 0) {
int64_t class_num = scores.dims()[1];
int64_t predict_dim = scores.dims()[1];
int64_t box_size = bboxes.dims()[1];
if (scores_size == 2) {
box_size = bboxes.dims()[2];
}
int64_t out_dim = box_size + 2;
auto* scores_data = scores.data<T>();
auto* bboxes_data = bboxes.data<T>();
auto* odata = outs->mutable_data<T>();
const T* sdata;
Tensor bbox;
bbox.Resize({scores.dims()[0], box_size});
int count = 0;
for (const auto& it : selected_indices) {
int label = it.first;
const std::vector<int>& indices = it.second;
if (scores_size == 2) {
SliceOneClass<T>(bboxes, label, &bbox);
} else {
sdata = scores_data + label * predict_dim;
}
for (size_t j = 0; j < indices.size(); ++j) {
int idx = indices[j];
odata[count * out_dim] = label; // label
const T* bdata;
if (scores_size == 3) {
bdata = bboxes_data + idx * box_size;
odata[count * out_dim + 1] = sdata[idx]; // score
if (oindices != nullptr) {
oindices[count] = offset + idx;
}
} else {
bdata = bbox.data<T>() + idx * box_size;
odata[count * out_dim + 1] = *(scores_data + idx * class_num + label);
if (oindices != nullptr) {
oindices[count] = offset + idx * class_num + label;
}
}
// xmin, ymin, xmax, ymax or multi-points coordinates
std::memcpy(odata + count * out_dim + 2, bdata, box_size * sizeof(T));
count++;
}
}
}
class MulticlassNmsComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string type_ = "multiclass_nms";
std::string bboxes_ = "bboxes";
std::string scores_ = "scores";
std::string out_ = "out";
DDim bboxes_dims_{};
DDim scores_dims_{};
int keep_top_k_{2};
float nms_threshold_{0.45f};
float nms_eta_{1.f};
int nms_top_k_{1};
int background_label_{-1};
float score_threshold_{0.01f};
bool normalized_{false};
public:
MulticlassNmsComputeTester(const Place& place,
const std::string& alias,
DDim bboxes_dims,
DDim scores_dims,
int keep_top_k = 2,
float nms_threshold = 0.45f,
float nms_eta = 1.f,
int nms_top_k = 1,
int background_label = 1,
float score_threshold = 0.01f,
bool normalized = false)
: TestCase(place, alias),
bboxes_dims_(bboxes_dims),
scores_dims_(scores_dims),
keep_top_k_(keep_top_k),
nms_threshold_(nms_threshold),
nms_eta_(nms_eta),
nms_top_k_(nms_top_k),
background_label_(background_label),
score_threshold_(score_threshold),
normalized_(normalized) {}
void RunBaseline(Scope* scope) override {
auto* boxes = scope->FindTensor(bboxes_);
auto* scores = scope->FindTensor(scores_);
auto* outs = scope->NewTensor(out_);
CHECK(outs);
outs->set_precision(PRECISION(kFloat));
auto score_size = scores_dims_.size();
std::vector<std::map<int, std::vector<int>>> all_indices;
std::vector<uint64_t> batch_starts = {0};
int64_t batch_size = scores_dims_[0];
int64_t box_dim = bboxes_dims_[2];
int64_t out_dim = box_dim + 2;
int num_nmsed_out = 0;
Tensor boxes_slice, scores_slice;
int n = score_size == 3 ? batch_size : boxes->lod().back().size() - 1;
for (int i = 0; i < n; ++i) {
if (score_size == 3) {
scores_slice = scores->Slice<float>(i, i + 1);
scores_slice.Resize({scores_dims_[1], scores_dims_[2]});
boxes_slice = boxes->Slice<float>(i, i + 1);
boxes_slice.Resize({scores_dims_[2], box_dim});
} else {
auto boxes_lod = boxes->lod().back();
scores_slice = scores->Slice<float>(boxes_lod[i], boxes_lod[i + 1]);
boxes_slice = boxes->Slice<float>(boxes_lod[i], boxes_lod[i + 1]);
}
std::map<int, std::vector<int>> indices;
MultiClassNMS<float>(scores_slice,
boxes_slice,
score_size,
&indices,
&num_nmsed_out,
background_label_,
nms_top_k_,
keep_top_k_,
normalized_,
nms_threshold_,
nms_eta_,
score_threshold_);
all_indices.push_back(indices);
batch_starts.push_back(batch_starts.back() + num_nmsed_out);
}
uint64_t num_kept = batch_starts.back();
if (num_kept == 0) {
outs->Resize({1, 1});
float* od = outs->mutable_data<float>();
od[0] = -1;
batch_starts = {0, 1};
} else {
outs->Resize({static_cast<int64_t>(num_kept), out_dim});
outs->mutable_data<float>();
int offset = 0;
int* oindices = nullptr;
for (int i = 0; i < n; ++i) {
if (score_size == 3) {
scores_slice = scores->Slice<float>(i, i + 1);
boxes_slice = boxes->Slice<float>(i, i + 1);
scores_slice.Resize({scores_dims_[1], scores_dims_[2]});
boxes_slice.Resize({scores_dims_[2], box_dim});
} else {
auto boxes_lod = boxes->lod().back();
scores_slice = scores->Slice<float>(boxes_lod[i], boxes_lod[i + 1]);
boxes_slice = boxes->Slice<float>(boxes_lod[i], boxes_lod[i + 1]);
}
int64_t s = static_cast<int64_t>(batch_starts[i]);
int64_t e = static_cast<int64_t>(batch_starts[i + 1]);
if (e > s) {
Tensor out = outs->Slice<float>(s, e);
MultiClassOutput<float>(scores_slice,
boxes_slice,
all_indices[i],
scores_dims_.size(),
&out,
oindices,
offset);
}
}
}
LoD lod;
lod.emplace_back(batch_starts);
outs->set_lod(lod);
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType(type_);
op_desc->SetInput("BBoxes", {bboxes_});
op_desc->SetInput("Scores", {scores_});
op_desc->SetOutput("Out", {out_});
op_desc->SetAttr("keep_top_k", keep_top_k_);
op_desc->SetAttr("nms_threshold", nms_threshold_);
op_desc->SetAttr("nms_eta", nms_eta_);
op_desc->SetAttr("nms_top_k", nms_top_k_);
op_desc->SetAttr("background_label", background_label_);
op_desc->SetAttr("score_threshold", score_threshold_);
op_desc->SetAttr("normalized", normalized_);
}
void PrepareData() override {
std::vector<float> bboxes(bboxes_dims_.production());
for (int i = 0; i < bboxes_dims_.production(); ++i) {
bboxes[i] = i * 1. / bboxes_dims_.production();
}
SetCommonTensor(bboxes_, bboxes_dims_, bboxes.data());
std::vector<float> scores(scores_dims_.production());
for (int i = 0; i < scores_dims_.production(); ++i) {
scores[i] = i * 1. / scores_dims_.production();
}
SetCommonTensor(scores_, scores_dims_, scores.data());
}
};
void TestMulticlassNms(Place place, float abs_error) {
int N = 3;
int M = 2500;
for (int class_num : {2, 4, 10}) {
std::vector<int64_t> bbox_shape{N, M, 4};
std::vector<int64_t> score_shape{N, class_num, M};
std::unique_ptr<arena::TestCase> tester(new MulticlassNmsComputeTester(
place, "def", DDim(bbox_shape), DDim(score_shape)));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
TEST(multiclass_nms, precision) {
float abs_error = 2e-5;
Place place;
#if defined(LITE_WITH_ARM)
place = TARGET(kHost);
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
#endif
TestMulticlassNms(place, abs_error);
}
} // namespace lite
} // namespace paddle
...@@ -228,14 +228,14 @@ class YoloBoxComputeTester : public arena::TestCase { ...@@ -228,14 +228,14 @@ class YoloBoxComputeTester : public arena::TestCase {
} }
}; };
void test_yolobox(Place place) { void TestYoloBox(Place place, float abs_error) {
for (int class_num : {1, 2, 3, 4}) { for (int class_num : {1, 4}) {
for (float conf_thresh : {0.01, 0.2, 0.7}) { for (float conf_thresh : {0.01, 0.2}) {
for (int downsample_ratio : {16, 32}) { for (int downsample_ratio : {16, 32}) {
std::vector<int> anchor({10, 13, 16, 30}); std::vector<int> anchor{10, 13, 16, 30, 33, 30};
std::unique_ptr<arena::TestCase> tester(new YoloBoxComputeTester( std::unique_ptr<arena::TestCase> tester(new YoloBoxComputeTester(
place, "def", anchor, class_num, conf_thresh, downsample_ratio)); place, "def", anchor, class_num, conf_thresh, downsample_ratio));
arena::Arena arena(std::move(tester), place, 2e-5); arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision(); arena.TestPrecision();
} }
} }
...@@ -243,13 +243,17 @@ void test_yolobox(Place place) { ...@@ -243,13 +243,17 @@ void test_yolobox(Place place) {
} }
TEST(YoloBox, precision) { TEST(YoloBox, precision) {
// #ifdef LITE_WITH_X86 float abs_error = 2e-5;
// Place place(TARGET(kX86)); Place place;
// #endif #if defined(LITE_WITH_ARM)
#ifdef LITE_WITH_ARM place = TARGET(kARM);
Place place(TARGET(kARM)); #elif defined(LITE_WITH_XPU)
test_yolobox(place); place = TARGET(kXPU);
#else
return;
#endif #endif
TestYoloBox(place, abs_error);
} }
} // namespace lite } // namespace lite
......
...@@ -14,6 +14,7 @@ readonly NUM_PROC=${LITE_BUILD_THREADS:-4} ...@@ -14,6 +14,7 @@ readonly NUM_PROC=${LITE_BUILD_THREADS:-4}
# global variables # global variables
BUILD_EXTRA=OFF BUILD_EXTRA=OFF
BUILD_TRAIN=OFF
BUILD_JAVA=ON BUILD_JAVA=ON
BUILD_PYTHON=OFF BUILD_PYTHON=OFF
BUILD_DIR=$(pwd) BUILD_DIR=$(pwd)
...@@ -226,6 +227,7 @@ function make_full_publish_so { ...@@ -226,6 +227,7 @@ function make_full_publish_so {
-DNPU_DDK_ROOT=$NPU_DDK_ROOT \ -DNPU_DDK_ROOT=$NPU_DDK_ROOT \
-DLITE_WITH_XPU=$BUILD_XPU \ -DLITE_WITH_XPU=$BUILD_XPU \
-DXPU_SDK_ROOT=$XPU_SDK_ROOT \ -DXPU_SDK_ROOT=$XPU_SDK_ROOT \
-DLITE_WITH_TRAIN=$BUILD_TRAIN \
-DARM_TARGET_OS=${os} -DARM_TARGET_ARCH_ABI=${abi} -DARM_TARGET_LANG=${lang} -DARM_TARGET_OS=${os} -DARM_TARGET_ARCH_ABI=${abi} -DARM_TARGET_LANG=${lang}
make publish_inference -j$NUM_PROC make publish_inference -j$NUM_PROC
...@@ -388,6 +390,7 @@ function print_usage { ...@@ -388,6 +390,7 @@ function print_usage {
echo -e "optional argument:" echo -e "optional argument:"
echo -e "--shutdown_log: (OFF|ON); controls whether to shutdown log, default is ON" echo -e "--shutdown_log: (OFF|ON); controls whether to shutdown log, default is ON"
echo -e "--build_extra: (OFF|ON); controls whether to publish extra operators and kernels for (sequence-related model such as OCR or NLP)" echo -e "--build_extra: (OFF|ON); controls whether to publish extra operators and kernels for (sequence-related model such as OCR or NLP)"
echo -e "--build_train: (OFF|ON); controls whether to publish training operators and kernels, build_train is only for full_publish library now"
echo -e "--build_python: (OFF|ON); controls whether to publish python api lib (ANDROID and IOS is not supported)" echo -e "--build_python: (OFF|ON); controls whether to publish python api lib (ANDROID and IOS is not supported)"
echo -e "--build_java: (OFF|ON); controls whether to publish java api lib (Only ANDROID is supported)" echo -e "--build_java: (OFF|ON); controls whether to publish java api lib (Only ANDROID is supported)"
echo -e "--build_dir: directory for building" echo -e "--build_dir: directory for building"
...@@ -436,6 +439,10 @@ function main { ...@@ -436,6 +439,10 @@ function main {
BUILD_EXTRA="${i#*=}" BUILD_EXTRA="${i#*=}"
shift shift
;; ;;
--build_train=*)
BUILD_TRAIN="${i#*=}"
shift
;;
--build_cv=*) --build_cv=*)
BUILD_CV="${i#*=}" BUILD_CV="${i#*=}"
shift shift
......
...@@ -5,7 +5,7 @@ set -ex ...@@ -5,7 +5,7 @@ set -ex
BM_SDK_ROOT="$(pwd)/third-party/bmlibs/bm_sc3_libs" # BM SDK BM_SDK_ROOT="$(pwd)/third-party/bmlibs/bm_sc3_libs" # BM SDK
TARGET_NAME="BM1682" # default target TARGET_NAME="BM1682" # default target
BUILD_EXTRA=OFF # ON(with sequence ops)/OFF BUILD_EXTRA=OFF # ON(with sequence ops)/OFF
WITH_TESTING=ON # ON/OFF WITH_TESTING=OFF # ON/OFF
function print_usage { function print_usage {
echo -e "\nUSAGE:" echo -e "\nUSAGE:"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册