提交 b0de9835 编写于 作者: S Shixiaowei02

Merge branch 'incubate/lite' of http://10.87.145.36/inference/paddlelite into shixiaowei02/calib

...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <vector> #include <vector>
// #include "paddle/fluid/lite/utils/logging.h" // #include "paddle/fluid/lite/utils/logging.h"
// #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK // #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include <glog/logging.h> #include <glog/logging.h> // NOLINT
// #endif // #endif
namespace paddle { namespace paddle {
......
...@@ -104,7 +104,7 @@ file(WRITE ${offline_lib_registry_file} "") # clean ...@@ -104,7 +104,7 @@ file(WRITE ${offline_lib_registry_file} "") # clean
# LIGHT_DEPS: LITE_WITH_LIGHT_WEIGHT_FRAMEWORK # LIGHT_DEPS: LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
# HVY_DEPS: NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK # HVY_DEPS: NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
function(lite_cc_library TARGET) function(lite_cc_library TARGET)
set(options "") set(options STATIC static SHARED shared)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS LIGHT_DEPS set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS LIGHT_DEPS
HVY_DEPS ARGS) HVY_DEPS ARGS)
...@@ -120,8 +120,11 @@ function(lite_cc_library TARGET) ...@@ -120,8 +120,11 @@ function(lite_cc_library TARGET)
LIGHT_DEPS ${args_LIGHT_DEPS} LIGHT_DEPS ${args_LIGHT_DEPS}
HVY_DEPS ${args_HVY_DEPS} HVY_DEPS ${args_HVY_DEPS}
) )
if (${args_SHARED} OR ${args_shared})
cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS} SHARED)
else()
cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS}) cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS})
endif()
# collect targets need to compile for lite # collect targets need to compile for lite
add_dependencies(lite_compile_deps ${TARGET}) add_dependencies(lite_compile_deps ${TARGET})
......
...@@ -100,14 +100,12 @@ lite_cc_test(test_apis_lite SRCS apis_test.cc ...@@ -100,14 +100,12 @@ lite_cc_test(test_apis_lite SRCS apis_test.cc
ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model
--optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
lite_cc_library(cxx_api_impl_lite SRCS cxx_api_impl.cc DEPS cxx_api_lite) lite_cc_library(paddle_api_lite SRCS paddle_api.cc DEPS op_params_lite)
lite_cc_library(light_api_impl_lite SRCS light_api_impl.cc DEPS light_api_lite)
lite_cc_library(paddle_api_full SRCS paddle_api.cc DEPS cxx_api_impl_lite light_api_impl_lite) lite_cc_library(paddle_api_full SRCS cxx_api_impl.cc DEPS cxx_api_lite paddle_api_lite light_api_lite)
lite_cc_library(paddle_api_light SRCS paddle_api.cc DEPS light_api_impl_lite) lite_cc_library(paddle_api_light SRCS light_api_impl.cc DEPS light_api_lite paddle_api_lite)
lite_cc_test(test_paddle_api_lite SRCS paddle_api_test.cc DEPS paddle_api_full paddle_api_light
lite_cc_test(test_paddle_api_lite SRCS paddle_api_test.cc DEPS cxx_api_lite light_api_lite paddle_api_full
${ops_lite} ${ops_lite}
ARM_DEPS ${arm_kernels} ARM_DEPS ${arm_kernels}
X86_DEPS ${x86_kernels} X86_DEPS ${x86_kernels}
...@@ -120,3 +118,13 @@ endif() ...@@ -120,3 +118,13 @@ endif()
#X86_DEPS operator #X86_DEPS operator
#DEPS light_api_lite model_parser_lite target_wrapper_host mir_passes #DEPS light_api_lite model_parser_lite target_wrapper_host mir_passes
#ARM_DEPS ${arm_kernels}) #ARM_DEPS ${arm_kernels})
lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin_int8.cc
DEPS
cxx_api_lite
model_parser_lite
target_wrapper_host
mir_passes
${ops_lite} ${host_kernels}
ARM_DEPS ${arm_kernels})
lite_cc_binary(model_optimize_tool SRCS model_optimize_tool.cc DEPS paddle_api_full)
...@@ -29,16 +29,18 @@ double time_diff(Time t1, Time t2) { ...@@ -29,16 +29,18 @@ double time_diff(Time t1, Time t2) {
return counter.count() / 1000.0; return counter.count() / 1000.0;
} }
void Run(const char* model_dir, int repeat, int thread_num) { void Run(const char* model_dir, int repeat) {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
DeviceInfo::Init(); DeviceInfo::Init();
DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, thread_num);
#endif #endif
lite::Predictor predictor; lite::Predictor predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, std::vector<Place> valid_places({
Place{TARGET(kARM), PRECISION(kFloat)}}); Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kInt8)},
});
predictor.Build(model_dir, Place{TARGET(kARM), PRECISION(kFloat)}, predictor.Build(model_dir, Place{TARGET(kARM), PRECISION(kInt8)},
valid_places); valid_places);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
...@@ -48,8 +50,6 @@ void Run(const char* model_dir, int repeat, int thread_num) { ...@@ -48,8 +50,6 @@ void Run(const char* model_dir, int repeat, int thread_num) {
data[i] = 1; data[i] = 1;
} }
for (int i = 0; i < 10; i++) predictor.Run();
auto time1 = time(); auto time1 = time();
for (int i = 0; i < repeat; i++) predictor.Run(); for (int i = 0; i < repeat; i++) predictor.Run();
auto time2 = time(); auto time2 = time();
...@@ -68,8 +68,8 @@ void Run(const char* model_dir, int repeat, int thread_num) { ...@@ -68,8 +68,8 @@ void Run(const char* model_dir, int repeat, int thread_num) {
} // namespace paddle } // namespace paddle
int main(int argc, char** argv) { int main(int argc, char** argv) {
CHECK_EQ(argc, 4) << "usage: ./cmd <model_dir> <repeat> <thread_num>"; CHECK_EQ(argc, 3) << "usage: ./cmd <model_dir> <repeat>";
paddle::lite::Run(argv[1], std::stoi(argv[2]), std::stoi(argv[3])); paddle::lite::Run(argv[1], std::stoi(argv[2]));
return 0; return 0;
} }
...@@ -93,13 +93,18 @@ USE_LITE_OP(fake_dequantize_max_abs); ...@@ -93,13 +93,18 @@ USE_LITE_OP(fake_dequantize_max_abs);
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
USE_LITE_OP(calib);
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(fc, kARM, kInt8, kNCHW, int8out);
USE_LITE_KERNEL(fc, kARM, kInt8, kNCHW, fp32out);
USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, int8_out);
USE_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, fp32_out);
USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(relu, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(relu, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, def);
...@@ -107,6 +112,9 @@ USE_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW, def); ...@@ -107,6 +112,9 @@ USE_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, int8_to_fp32);
// USE_LITE_KERNEL(feed, kARM, kAny, kAny, def); // USE_LITE_KERNEL(feed, kARM, kAny, kAny, def);
// USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def); // USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def);
#endif // LITE_WITH_ARM #endif // LITE_WITH_ARM
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h"
#include <chrono> // NOLINT
#include "paddle/fluid/lite/api/paddle_use_kernels.h"
#include "paddle/fluid/lite/api/paddle_use_ops.h"
#include "paddle/fluid/lite/api/paddle_use_passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
using Time = decltype(std::chrono::high_resolution_clock::now());
Time time() { return std::chrono::high_resolution_clock::now(); }
double time_diff(Time t1, Time t2) {
typedef std::chrono::microseconds ms;
auto diff = t2 - t1;
ms counter = std::chrono::duration_cast<ms>(diff);
return counter.count() / 1000.0;
}
void Run(const char* model_dir, int repeat) {
#ifdef LITE_WITH_ARM
DeviceInfo::Init();
#endif
lite::Predictor predictor;
std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kInt8)},
});
predictor.Build(model_dir, Place{TARGET(kARM), PRECISION(kInt8)},
valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < input_tensor->dims().production(); i++) {
data[i] = 1;
}
auto time1 = time();
for (int i = 0; i < repeat; i++) predictor.Run();
auto time2 = time();
std::cout << " predict cost: " << time_diff(time1, time2) / repeat << "ms"
<< std::endl;
auto* out = predictor.GetOutput(0);
LOG(INFO) << out << " memory size " << out->data_size();
LOG(INFO) << "out " << out->data<float>()[0];
LOG(INFO) << "out " << out->data<float>()[1];
LOG(INFO) << "dims " << out->dims();
LOG(INFO) << "out data size: " << out->data_size();
}
} // namespace lite
} // namespace paddle
int main(int argc, char** argv) {
CHECK_EQ(argc, 3) << "usage: ./cmd <model_dir> <repeat>";
paddle::lite::Run(argv[1], std::stoi(argv[2]));
return 0;
}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/lite/api/paddle_api.h"
#include "paddle/fluid/lite/api/paddle_use_kernels.h"
#include "paddle/fluid/lite/api/paddle_use_ops.h"
#include "paddle/fluid/lite/api/paddle_use_passes.h"
#include "paddle/fluid/lite/utils/string.h"
DEFINE_string(model_dir, "", "path of the model");
DEFINE_string(optimize_out, "", "path of the output optimized model");
DEFINE_string(valid_targets, "ARM",
"The targets this model optimized for, should be one of (arm, "
"opencl, x86), splitted by space");
DEFINE_bool(int8_mode, false, "Support Int8 quantitative mode");
namespace paddle {
namespace lite_api {
void Main() {
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
std::vector<Place> valid_places;
auto target_reprs = lite::Split(FLAGS_valid_targets, " ");
for (auto& target_repr : target_reprs) {
if (target_repr == "arm") {
valid_places.emplace_back(TARGET(kARM));
} else if (target_repr == "opencl") {
valid_places.emplace_back(TARGET(kOpenCL));
} else if (target_repr == "x86") {
valid_places.emplace_back(TARGET(kX86));
} else {
LOG(FATAL) << lite::string_format(
"Wrong target '%s' found, please check the command flag "
"'valid_targets'",
target_repr.c_str());
}
}
CHECK(!valid_places.empty())
<< "At least one target should be set, should set the "
"command argument 'valid_targets'";
if (FLAGS_int8_mode) {
LOG(WARNING) << "Int8 mode is only support by ARM target";
valid_places.push_back(Place{TARGET(kARM), PRECISION(kInt8)});
config.set_preferred_place(Place{TARGET(kARM), PRECISION(kInt8)});
}
config.set_valid_places(valid_places);
auto predictor = lite_api::CreatePaddlePredictor(config);
predictor->SaveOptimizedModel(FLAGS_optimize_out);
}
} // namespace lite_api
} // namespace paddle
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, false);
paddle::lite_api::Main();
return 0;
}
...@@ -56,6 +56,7 @@ TEST(CxxApi, run) { ...@@ -56,6 +56,7 @@ TEST(CxxApi, run) {
predictor->SaveOptimizedModel(FLAGS_model_dir + ".opt2"); predictor->SaveOptimizedModel(FLAGS_model_dir + ".opt2");
} }
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
TEST(LightApi, run) { TEST(LightApi, run) {
lite_api::MobileConfig config; lite_api::MobileConfig config;
config.set_model_dir(FLAGS_model_dir + ".opt2"); config.set_model_dir(FLAGS_model_dir + ".opt2");
...@@ -79,6 +80,7 @@ TEST(LightApi, run) { ...@@ -79,6 +80,7 @@ TEST(LightApi, run) {
EXPECT_NEAR(out[0], 50.2132, 1e-3); EXPECT_NEAR(out[0], 50.2132, 1e-3);
EXPECT_NEAR(out[1], -28.8729, 1e-3); EXPECT_NEAR(out[1], -28.8729, 1e-3);
} }
#endif
} // namespace lite_api } // namespace lite_api
} // namespace paddle } // namespace paddle
...@@ -83,7 +83,7 @@ struct Place { ...@@ -83,7 +83,7 @@ struct Place {
int16_t device{0}; // device ID int16_t device{0}; // device ID
Place() = default; Place() = default;
Place(TargetType target, PrecisionType precision, Place(TargetType target, PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW), int16_t device = 0) DataLayoutType layout = DATALAYOUT(kNCHW), int16_t device = 0)
: target(target), precision(precision), layout(layout), device(device) {} : target(target), precision(precision), layout(layout), device(device) {}
......
...@@ -38,6 +38,13 @@ USE_LITE_KERNEL(relu, kARM, kFloat, kNCHW, def); ...@@ -38,6 +38,13 @@ USE_LITE_KERNEL(relu, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(transpose, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(transpose, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(transpose2, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(transpose2, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, int8_to_fp32);
USE_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, int8_out);
USE_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, fp32_out);
USE_LITE_KERNEL(fc, kARM, kInt8, kNCHW, int8out);
USE_LITE_KERNEL(fc, kARM, kInt8, kNCHW, fp32out);
#endif #endif
#ifdef LITE_WITH_X86 #ifdef LITE_WITH_X86
......
...@@ -38,3 +38,7 @@ USE_LITE_OP(batch_norm) ...@@ -38,3 +38,7 @@ USE_LITE_OP(batch_norm)
USE_LITE_OP(fusion_elementwise_sub_activation) USE_LITE_OP(fusion_elementwise_sub_activation)
USE_LITE_OP(transpose) USE_LITE_OP(transpose)
USE_LITE_OP(transpose2) USE_LITE_OP(transpose2)
USE_LITE_OP(fake_quantize_moving_average_abs_max);
USE_LITE_OP(fake_dequantize_max_abs);
USE_LITE_OP(calib);
...@@ -31,3 +31,5 @@ USE_MIR_PASS(identity_scale_eliminate_pass); ...@@ -31,3 +31,5 @@ USE_MIR_PASS(identity_scale_eliminate_pass);
USE_MIR_PASS(lite_conv_elementwise_add_activation_fuse_pass); USE_MIR_PASS(lite_conv_elementwise_add_activation_fuse_pass);
USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass); USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass);
USE_MIR_PASS(lite_quant_dequant_fuse_pass); USE_MIR_PASS(lite_quant_dequant_fuse_pass);
USE_MIR_PASS(precision_cast_transform_pass);
USE_MIR_PASS(trans_weight_pass);
...@@ -31,7 +31,7 @@ cc_library(types_lite SRCS types.cc) ...@@ -31,7 +31,7 @@ cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite) cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite)
lite_cc_library(program_lite SRCS program.cc lite_cc_library(program_lite SRCS program.cc
DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite ${ops_lite}
HVY_DEPS framework_proto HVY_DEPS framework_proto
PROFILE_DEPS basic_profiler_lite) PROFILE_DEPS basic_profiler_lite)
cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite) cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite)
......
...@@ -18,10 +18,12 @@ cc_library(mir_passes ...@@ -18,10 +18,12 @@ cc_library(mir_passes
static_kernel_pick_pass.cc static_kernel_pick_pass.cc
variable_place_inference_pass.cc variable_place_inference_pass.cc
type_target_transform_pass.cc type_target_transform_pass.cc
precision_cast_transform_pass.cc
io_copy_kernel_pick_pass.cc io_copy_kernel_pick_pass.cc
graph_visualize_pass.cc graph_visualize_pass.cc
generate_program_pass.cc generate_program_pass.cc
argument_type_display_pass.cc argument_type_display_pass.cc
trans_weigths_pass.cc
demo_pass.cc demo_pass.cc
runtime_context_assign_pass.cc runtime_context_assign_pass.cc
DEPS mir_pass types_lite context_lite ${mir_fusers}) DEPS mir_pass types_lite context_lite ${mir_fusers})
......
...@@ -60,7 +60,7 @@ void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -60,7 +60,7 @@ void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
} }
cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) { cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc; cpp::OpDesc op_desc = *matched.at("mul")->stmt()->op_info();
op_desc.SetType("fc"); op_desc.SetType("fc");
op_desc.SetInput("Input", {matched.at("x")->arg()->name}); op_desc.SetInput("Input", {matched.at("x")->arg()->name});
op_desc.SetInput("W", {matched.at("W")->arg()->name}); op_desc.SetInput("W", {matched.at("W")->arg()->name});
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/precision_cast_transform_pass.h"
#include <list>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void PrecisionCastPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// Start from inputs of the graph, those should have place set.
std::list<Node*> nodes;
for (auto& node : graph->mutable_nodes()) {
nodes.push_back(&node);
}
for (auto& node : nodes) {
if (!node->IsStmt()) continue;
auto inlinks = node->inlinks;
for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in);
}
}
VLOG(3) << "\n" << Visualize(graph.get());
}
void PrecisionCastPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
Node* in) {
// If this input is out of date.
if (inst_node->inlinks.end() ==
std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in))
return;
CHECK(inst_node->IsStmt());
auto& inst = inst_node->AsStmt();
CHECK(in->IsRoleSet());
CHECK(in->IsArg());
auto in_arg_name = in->AsArg().name;
std::string tmp;
CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp));
auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp);
CHECK(in->AsArg().type);
LOG(INFO) << inst.picked_kernel().name();
// if (!in->AsArg().is_weight && !PrecisionCompatibleTo(*in->AsArg().type,
// *decl_arg_type)) {
if (!PrecisionCompatibleTo(*in->AsArg().type, *decl_arg_type)) {
LOG(INFO) << "found Target unmatched tensor: " << in->AsArg().name
<< " for kernel " << inst.op()->DebugString() << " "
<< *in->AsArg().type << " -> " << *decl_arg_type;
// Add an Cast instruction to make the input compatible with other dist.
AddCastInst(*in->AsArg().type, *decl_arg_type, in, graph, inst_node,
graph->valid_places());
}
}
void PrecisionCastPass::AddCastInst(const Type& from, const Type& to, Node* in,
SSAGraph* graph, Node* inst_node,
const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set";
// var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new Cast Statement Node.
CHECK(in->IsArg());
auto node_id = [&] { return graph->nodes().size(); };
auto cast_op_output_name =
in->AsArg().name + "/trans/" + std::to_string(node_id());
auto* cast_op_output_arg = graph->NewArgumentNode(cast_op_output_name);
auto* cast_inst = graph->NewInstructNode();
// create Op and kernels.
auto cast_op = LiteOpRegistry::Global().Create("calib");
CHECK(cast_op) << "create op [" << cast_op << "] failed";
// Create the new var manually.
inst_node->AsStmt().op()->scope()->Var(cast_op_output_name);
// Create Calib Instruction.
cpp::OpDesc op_desc;
op_desc.SetType("calib");
op_desc.SetInput("Input", {in->AsArg().name});
op_desc.SetOutput("Out", {cast_op_output_name});
CHECK(inst_node->AsStmt().op_info()->HasAttr("input_scale"));
op_desc.SetAttr("scale",
inst_node->AsStmt().op_info()->GetAttr<float>("input_scale"));
cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope());
auto kernels = cast_op->CreateKernels(valid_places);
std::vector<std::unique_ptr<KernelBase>> selected_kernels;
bool is_found = false;
for (auto& kernel : kernels) {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (in_arg_ty->precision() == from.precision() &&
out_arg_ty->precision() == to.precision()) {
is_found = true;
selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel
cast_inst->AsStmt("calib", std::move(selected_kernels), cast_op);
break;
}
}
CHECK(is_found) << "Can't find a Cast kernel for Cast op: " << from << ":"
<< in->AsArg().name << "->" << to << ":"
<< inst_node->AsStmt().op_info()->Type();
// Remove the old link
RemoveDirectedLink(in, inst_node);
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink(in, cast_inst);
DirectedLink(cast_inst, cast_op_output_arg);
DirectedLink(cast_op_output_arg, inst_node);
// reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), in->AsArg().name,
cast_op_output_name);
// recreate the op
auto original_selected_kernel =
std::move(inst_node->AsStmt().kernels().front());
auto updated_op_info = *inst_node->AsStmt().mutable_op_info();
inst_node->AsStmt().ResetOp(updated_op_info, graph->valid_places());
inst_node->AsStmt().kernels().clear();
inst_node->AsStmt().kernels().emplace_back(
std::move(original_selected_kernel));
for (auto& kernel : inst_node->AsStmt().kernels()) {
LOG(INFO) << "kernel info: " << kernel->name();
inst_node->AsStmt().op()->AttachKernel(kernel.get());
}
graph->CheckValid();
}
void PrecisionCastPass::SetValidPlaces(const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty());
valid_places_ = valid_places;
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(precision_cast_transform_pass,
paddle::lite::mir::PrecisionCastPass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace mir {
static void UpdateInputTo(cpp::OpDesc* desc, const std::string& from,
const std::string& to) {
for (auto& item : *desc->mutable_inputs()) {
for (auto& input : item.second) {
if (input == from) {
input = to;
}
}
}
}
/*
* The pass complement the necessary instruction to make data
* transferring or transformation between different places.
*/
class PrecisionCastPass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in);
void AddCastInst(const Type& from, const Type& to, Node* in, SSAGraph* graph,
Node* inst_node, const std::vector<Place>& valid_places);
void SetValidPlaces(const std::vector<Place>& valid_places);
const std::vector<Place>& valid_places() const { return valid_places_; }
private:
std::vector<Place> valid_places_;
};
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -33,9 +33,12 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -33,9 +33,12 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
<< "kernel_pick_factors should be specified first"; << "kernel_pick_factors should be specified first";
CHECK(graph) << "graph not valid"; CHECK(graph) << "graph not valid";
// sort kernels by the factors. // sort kernels by the factors.
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->mutable_nodes()) {
if (!node.IsStmt()) continue; if (!node.IsStmt()) continue;
auto& instruct = node.AsStmt(); auto& instruct = node.AsStmt();
// Get candidate kernels
std::vector<std::pair<size_t, std::unique_ptr<KernelBase>>> scored; std::vector<std::pair<size_t, std::unique_ptr<KernelBase>>> scored;
CHECK(!instruct.kernels().empty()) << "No kernels found for " CHECK(!instruct.kernels().empty()) << "No kernels found for "
<< instruct.op_type(); << instruct.op_type();
...@@ -43,15 +46,56 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -43,15 +46,56 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
size_t score = KernelGrade(*kernel); size_t score = KernelGrade(*kernel);
scored.emplace_back(score, std::move(kernel)); scored.emplace_back(score, std::move(kernel));
} }
std::sort(scored.begin(), scored.end(), KernelScoreCmp); std::sort(scored.begin(), scored.end(), KernelScoreCmp);
instruct.kernels().clear();
if (!instruct.op_info()->HasAttr("enable_int8")) {
// Move kernel back // Move kernel back
// Just keep a single best kernel. // Just keep a single best kernel.
// TODO(Superjomn) reconsider this. // TODO(Superjomn) reconsider this.
instruct.kernels().clear();
instruct.kernels().emplace_back(std::move(scored.front().second)); instruct.kernels().emplace_back(std::move(scored.front().second));
VLOG(2) << "pick " << instruct.kernels().front()->name(); VLOG(2) << "pick " << instruct.kernels().front()->name();
} else {
bool out_type_int8 = true;
// Only if all ops linked to this op output has enable_int8 attr,
// then the op output type is int8, or fp32.
for (auto* out_n : node.outlinks) {
CHECK(out_n->IsArg());
for (auto* tmp_op : out_n->outlinks) {
CHECK(tmp_op->IsStmt());
if (!tmp_op->AsStmt().op_info()->HasAttr("enable_int8")) {
out_type_int8 = false;
break;
}
}
if (!out_type_int8) break;
}
// According to the out type, we pick the kernel.
auto output_arguments = instruct.op_info()->OutputArgumentNames();
for (auto& candidate : scored) {
bool all_output_type_match = true;
auto expect_output_type =
out_type_int8 ? PRECISION(kInt8) : PRECISION(kFloat);
for (auto& arg_name : output_arguments) {
const Type* out_arg_ty =
candidate.second->GetOutputDeclType(arg_name);
if (out_arg_ty->precision() != expect_output_type) {
all_output_type_match = false;
}
}
if (all_output_type_match) {
instruct.kernels().emplace_back(std::move(candidate.second));
VLOG(2) << "pick " << instruct.kernels().front()->name();
break;
}
}
CHECK(!instruct.kernels().empty()) << "No kernels found for "
<< instruct.op_type();
}
} }
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/trans_weigths_pass.h"
#include <list>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void TransWeightPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// Start from inputs of the graph, those should have place set.
std::list<Node*> nodes;
for (auto& node : graph->mutable_nodes()) {
nodes.push_back(&node);
}
for (auto& node : nodes) {
if (!node->IsStmt()) continue;
auto& instruct = node->AsStmt();
if (!instruct.op_info()->HasAttr("enable_int8")) {
continue;
}
std::vector<std::string> output_arg_names =
instruct.op_info()->output_argnames();
CHECK(output_arg_names.size() == 1)
<< "Currently, the op that supports int8 supports only one output";
// After static kernel select pass, there is only one kernel here.
const Type* out_arg_ty =
instruct.kernels()[0]->GetOutputDeclType(output_arg_names[0]);
auto out_precision = out_arg_ty->precision();
bool out_type_int8 = out_precision == PRECISION(kInt8) ? true : false;
float in_scale, out_scale;
in_scale = instruct.op_info()->GetAttr<float>("input_scale");
// Get next input op's input_scale
if (out_type_int8) {
LOG(INFO) << "output_type_int8";
auto out_node = node->outlinks.front();
CHECK(out_node->IsArg());
auto one_adj_op_node = out_node->outlinks.front();
CHECK(one_adj_op_node->IsStmt());
auto& one_adj_instruct = one_adj_op_node->AsStmt();
CHECK(one_adj_instruct.op_info()->HasAttr("enable_int8"));
CHECK(one_adj_instruct.op_info()->HasAttr("input_scale"));
out_scale = one_adj_instruct.op_info()->GetAttr<float>("input_scale");
instruct.mutable_op_info()->SetAttr("output_scale", out_scale);
} else {
LOG(INFO) << "output_type_fp32";
}
std::string op_type = instruct.op_info()->Type();
std::vector<float> weight_scale;
auto* scope = instruct.op()->scope();
if (op_type == "depthwise_conv2d" || op_type == "conv2d") {
std::string weight_var_name = instruct.op_info()->Input("Filter").front();
auto conv_weight_t =
scope->FindVar(weight_var_name)->GetMutable<lite::Tensor>();
// till now, all the weight should be float32 type
float* conv_weight_d = conv_weight_t->mutable_data<float>();
int64_t axis_size = conv_weight_t->dims()[0];
int64_t inner_size = conv_weight_t->data_size() / axis_size;
weight_scale =
GetWeightScale(conv_weight_d, axis_size, inner_size, 127.0);
Tensor temp_tensor;
temp_tensor.Resize(conv_weight_t->dims());
int8_t* temp_data = temp_tensor.mutable_data<int8_t>();
FP32ToInt8(conv_weight_d, temp_data, weight_scale.data(), axis_size, 1,
inner_size);
conv_weight_t->CopyDataFrom(temp_tensor);
} else if (op_type == "fc" || op_type == "mul") {
std::string weight_arg_name = "W";
if (op_type == "mul") weight_arg_name = "Y";
std::string weight_var_name =
instruct.op_info()->Input(weight_arg_name).front();
auto fc_weight_t =
scope->FindVar(weight_var_name)->GetMutable<lite::Tensor>();
// till now, all the weight should be float32 type
float* fc_weight_d = fc_weight_t->mutable_data<float>();
CHECK_EQ(fc_weight_t->dims().size(), 2UL);
int64_t h = fc_weight_t->dims()[0];
int64_t w = fc_weight_t->data_size() / h;
Tensor trans_w_t, int8_temp_t;
trans_w_t.CopyDataFrom(*fc_weight_t);
float* trans_w_data = trans_w_t.mutable_data<float>();
int8_temp_t.Resize(fc_weight_t->dims());
int8_t* int8_temp_data = int8_temp_t.mutable_data<int8_t>();
// trans weight for calc the weight scale.
for (int i = 0; i < h; i++) {
for (int j = 0; j < w; j++) {
trans_w_data[i * w + j] = fc_weight_d[j * h + i];
}
}
weight_scale = GetWeightScale(trans_w_data, w, h, 127.0);
int8_t* fc_weight_int8_d = fc_weight_t->mutable_data<int8_t>();
FP32ToInt8(trans_w_data, int8_temp_data, weight_scale.data(), w, 1, h);
// Retrans back
for (int i = 0; i < w; i++) {
for (int j = 0; j < h; j++) {
fc_weight_int8_d[i * h + j] = int8_temp_data[j * w + i];
}
}
}
// Convert fp32 bias to int8 bias
std::vector<std::string> input_arg_names =
instruct.op_info()->InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") !=
input_arg_names.end() &&
instruct.op_info()->Input("Bias").size() > 0) {
std::string bias_var_name = instruct.op_info()->Input("Bias").front();
auto bias_weight_t =
scope->FindVar(bias_var_name)->GetMutable<lite::Tensor>();
float* bias_weight_d = bias_weight_t->mutable_data<float>();
Tensor temp_bias;
temp_bias.Resize(bias_weight_t->dims());
int* temp_bias_data = temp_bias.mutable_data<int>();
TransFP32BiasToInt32(bias_weight_d, temp_bias_data, temp_bias.data_size(),
in_scale, weight_scale);
bias_weight_t->CopyDataFrom(temp_bias);
}
instruct.mutable_op_info()->SetAttr("weight_scale", weight_scale);
auto original_selected_kernel = std::move(instruct.kernels().front());
auto updated_op_info = *instruct.mutable_op_info();
instruct.ResetOp(updated_op_info, graph->valid_places());
instruct.kernels().clear();
instruct.kernels().emplace_back(std::move(original_selected_kernel));
for (auto& kernel : instruct.kernels()) {
LOG(INFO) << "kernel info: " << kernel->name();
instruct.op()->AttachKernel(kernel.get());
}
}
}
void TransWeightPass::SetValidPlaces(const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty());
valid_places_ = valid_places;
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(trans_weight_pass, paddle::lite::mir::TransWeightPass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/lite/arm/math/saturate.h"
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace mir {
/*
* IoComplementPass complement the necessary instruction to make data
* transferring or transformation between different places.
*/
class TransWeightPass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
std::vector<float> GetWeightScale(float* in_data, int64_t axis_size,
int64_t inner_size, float scale_factor) {
std::vector<float> scale_out(axis_size);
auto calc_abs_max = [&](float* in, size_t data_size) -> float {
float max_data = 0.0;
for (size_t i = 0; i < data_size; i++) {
if (max_data < std::abs(in[i])) max_data = std::abs(in[i]);
}
return max_data;
};
for (int c = 0; c < axis_size; c++) {
float* part_in = in_data + c * inner_size;
scale_out[c] = calc_abs_max(part_in, inner_size) / scale_factor;
}
return scale_out;
}
void FP32ToInt8(const float* din, int8_t* dout, const float* scale,
int axis_size, int64_t outer_size, int64_t inner_size) {
int loop_size = axis_size * outer_size;
for (int i = 0; i < loop_size; ++i) {
float inv_scale = 1.f / scale[i % axis_size];
for (int j = 0; j < inner_size; ++j) {
dout[j] = static_cast<int8_t>(std::roundf(din[j] * inv_scale));
}
dout += inner_size;
din += inner_size;
}
}
void TransFP32BiasToInt32(const float* din, int* dout, size_t data_size,
float in_scale, std::vector<float> weight_scale) {
CHECK(data_size == weight_scale.size())
<< "Bias data size should be equal toe the weight scale data size.";
for (size_t i = 0; i < data_size; i++) {
dout[i] =
static_cast<int>(std::roundf(din[i] / in_scale / weight_scale[i]));
}
}
void SetValidPlaces(const std::vector<Place>& valid_places);
const std::vector<Place>& valid_places() const { return valid_places_; }
private:
std::vector<Place> valid_places_;
};
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -49,8 +49,8 @@ class Optimizer { ...@@ -49,8 +49,8 @@ class Optimizer {
InitTargetTypeTransformPass(); InitTargetTypeTransformPass();
if (passes.empty()) { if (passes.empty()) {
RunPasses(std::vector<std::string>{{ RunPasses(std::vector<std::string>{
"lite_quant_dequant_fuse_pass", // {"lite_quant_dequant_fuse_pass", //
"lite_conv_bn_fuse_pass", // "lite_conv_bn_fuse_pass", //
// This pass is disabled to force some opencl kernels selected for final // This pass is disabled to force some opencl kernels selected for final
// running, otherwise, they will be fused to ARM fusion kernels, and the OpenCL // running, otherwise, they will be fused to ARM fusion kernels, and the OpenCL
...@@ -75,8 +75,11 @@ class Optimizer { ...@@ -75,8 +75,11 @@ class Optimizer {
"argument_type_display_pass", // "argument_type_display_pass", //
"io_copy_kernel_pick_pass", // "io_copy_kernel_pick_pass", //
"variable_place_inference_pass", // "variable_place_inference_pass", //
"precision_cast_transform_pass", //
"argument_type_display_pass", //
"trans_weight_pass", //
"runtime_context_assign_pass", // "runtime_context_assign_pass", //
}}); "graph_visualze"}});
} else { } else {
RunPasses(passes); RunPasses(passes);
} }
...@@ -134,7 +137,7 @@ class Optimizer { ...@@ -134,7 +137,7 @@ class Optimizer {
for (auto& x : passes) { for (auto& x : passes) {
LOG(INFO) << "== Running pass " << x; LOG(INFO) << "== Running pass " << x;
auto* pass = mir::PassManager::Global().LookUp(x); auto* pass = mir::PassManager::Global().LookUp(x);
CHECK(pass); CHECK(pass) << "Can not find pass: " << x;
pass->Apply(graph_); pass->Apply(graph_);
} }
} }
......
...@@ -26,3 +26,5 @@ if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) ...@@ -26,3 +26,5 @@ if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
add_dependencies(__generated_code__ test_gen_code_lite) add_dependencies(__generated_code__ test_gen_code_lite)
add_dependencies(__generated_code__ extern_lite_download_lite_naive_model_tar_gz) add_dependencies(__generated_code__ extern_lite_download_lite_naive_model_tar_gz)
endif() endif()
lite_cc_binary(paddle_code_generator SRCS paddle_code_generator.cc DEPS model_parser_lite gen_code_lite)
...@@ -111,6 +111,15 @@ void Module::AddOpDescHelper(const std::string &op_id, ...@@ -111,6 +111,15 @@ void Module::AddOpDescHelper(const std::string &op_id,
return std::to_string(desc.GetAttr<bool>(name)); return std::to_string(desc.GetAttr<bool>(name));
case AttrType::STRING: case AttrType::STRING:
return "\"" + desc.GetAttr<std::string>(name) + "\""; return "\"" + desc.GetAttr<std::string>(name) + "\"";
case AttrType::FLOATS: {
auto vals = desc.GetAttr<std::vector<float>>(name);
return "{" + Join(vals, ",") + "}";
}
case AttrType::INTS: {
auto vals = desc.GetAttr<std::vector<int>>(name);
return "{" + Join(vals, ",") + "}";
}
case AttrType::STRINGS: { case AttrType::STRINGS: {
std::vector<std::string> tmp; std::vector<std::string> tmp;
auto vals = desc.GetAttr<std::vector<std::string>>(name); auto vals = desc.GetAttr<std::vector<std::string>>(name);
...@@ -137,8 +146,12 @@ void Module::AddOpDescHelper(const std::string &op_id, ...@@ -137,8 +146,12 @@ void Module::AddOpDescHelper(const std::string &op_id,
return "bool"; return "bool";
case AttrType::STRING: case AttrType::STRING:
return "std::string"; return "std::string";
case AttrType::FLOATS:
return "std::vector<float>";
case AttrType::STRINGS: case AttrType::STRINGS:
return "std::vector<std::string>"; return "std::vector<std::string>";
case AttrType::INTS:
return "std::vector<int>";
default: default:
LOG(FATAL) << "Unsupported attribute type: " << static_cast<int>(type); LOG(FATAL) << "Unsupported attribute type: " << static_cast<int>(type);
} }
...@@ -160,6 +173,8 @@ void Module::AddOp(const cpp::OpDesc &op) { ...@@ -160,6 +173,8 @@ void Module::AddOp(const cpp::OpDesc &op) {
auto op_name = OpUniqueName(); auto op_name = OpUniqueName();
AddOpDescHelper(op_name, op); AddOpDescHelper(op_name, op);
LOG(INFO) << "add op " << op_name;
Line(string_format("// Create Op: %s", op.Type().c_str())); Line(string_format("// Create Op: %s", op.Type().c_str()));
Line(string_format("auto %s = lite::LiteOpRegistry::Global().Create(\"%s\");", Line(string_format("auto %s = lite::LiteOpRegistry::Global().Create(\"%s\");",
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include "paddle/fluid/lite/gen_code/gen_code.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
DEFINE_string(optimized_model, "", "");
DEFINE_string(generated_code_file, "__generated_code__.cc", "");
namespace paddle {
namespace lite {
namespace gencode {
void GenCode(const std::string& model_dir, const std::string& out_file) {
lite::Scope scope;
framework::proto::ProgramDesc desc;
LoadModel(model_dir, &scope, &desc);
ProgramCodeGenerator codegen(desc, scope);
std::ofstream file(out_file);
file << codegen.GenCode();
file.close();
}
} // namespace gencode
} // namespace lite
} // namespace paddle
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, false);
paddle::lite::gencode::GenCode(FLAGS_optimized_model,
FLAGS_generated_code_file);
return 0;
}
...@@ -31,7 +31,7 @@ lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm) ...@@ -31,7 +31,7 @@ lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm)
lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm) lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm)
lite_cc_test(test_concat_compute_arm SRCS concat_compute_test.cc DEPS concat_compute_arm) lite_cc_test(test_concat_compute_arm SRCS concat_compute_test.cc DEPS concat_compute_arm)
lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_compute_arm) lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_compute_arm)
lite_cc_test(test_calib_compute_arm SRCS calib_compute_test.cc DEPS calib_compute_arm) # lite_cc_test(test_calib_compute_arm SRCS calib_compute_test.cc DEPS calib_compute_arm)
lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS transpose_compute_arm) lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS transpose_compute_arm)
set(arm_kernels set(arm_kernels
...@@ -48,6 +48,7 @@ set(arm_kernels ...@@ -48,6 +48,7 @@ set(arm_kernels
concat_compute_arm concat_compute_arm
dropout_compute_arm dropout_compute_arm
transpose_compute_arm transpose_compute_arm
calib_compute_arm
) )
set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels") set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels")
...@@ -23,26 +23,24 @@ namespace lite { ...@@ -23,26 +23,24 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
void CalibCompute::Run() { void CalibComputeFp32ToInt8::Run() {
auto& param = this->Param<operators::CalibParam>(); auto& param = this->Param<operators::CalibParam>();
std::vector<float> scale = {param.in_scale}; std::vector<float> scale = {param.scale};
if (param.in_dtype == PRECISION(kFloat) &&
param.out_dtype == PRECISION(kInt8)) {
const auto* din = param.input->data<float>(); const auto* din = param.input->data<float>();
auto* dout = param.output->mutable_data<signed char>(); auto* dout = param.output->mutable_data<signed char>();
lite::arm::math::fp32_to_int8(din, dout, scale.data(), 1, 1, lite::arm::math::fp32_to_int8(din, dout, scale.data(), 1, 1,
param.input->numel()); param.input->numel());
return; return;
} }
if (param.in_dtype == PRECISION(kInt8) &&
param.out_dtype == PRECISION(kFloat)) { void CalibComputeInt8ToFp32::Run() {
auto& param = this->Param<operators::CalibParam>();
const auto* din = param.input->data<signed char>(); const auto* din = param.input->data<signed char>();
std::vector<float> scale = {param.scale};
auto* dout = param.output->mutable_data<float>(); auto* dout = param.output->mutable_data<float>();
lite::arm::math::int8_to_fp32(din, dout, scale.data(), 1, 1, lite::arm::math::int8_to_fp32(din, dout, scale.data(), 1, 1,
param.input->numel()); param.input->numel());
return; return;
}
LOG(FATAL) << "Unsupport Dtype.";
} }
} // namespace arm } // namespace arm
...@@ -51,7 +49,16 @@ void CalibCompute::Run() { ...@@ -51,7 +49,16 @@ void CalibCompute::Run() {
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(calib, kARM, kInt8, kNCHW, REGISTER_LITE_KERNEL(calib, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::CalibCompute, def) paddle::lite::kernels::arm::CalibComputeFp32ToInt8,
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) fp32_to_int8)
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.Finalize();
REGISTER_LITE_KERNEL(calib, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::CalibComputeInt8ToFp32,
int8_to_fp32)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize(); .Finalize();
...@@ -21,13 +21,26 @@ namespace lite { ...@@ -21,13 +21,26 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
class CalibCompute : public KernelLite<TARGET(kARM), PRECISION(kInt8)> { class CalibComputeFp32ToInt8
: public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
public: public:
using param_t = operators::CalibParam; using param_t = operators::CalibParam;
void Run() override; void Run() override;
~CalibCompute() override{}; ~CalibComputeFp32ToInt8() override{};
private:
};
class CalibComputeInt8ToFp32
: public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
public:
using param_t = operators::CalibParam;
void Run() override;
~CalibComputeInt8ToFp32() override{};
private: private:
}; };
......
...@@ -146,4 +146,5 @@ TEST(calib_arm, int8_to_fp32) { ...@@ -146,4 +146,5 @@ TEST(calib_arm, int8_to_fp32) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, def); USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, int8_to_fp32);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8);
...@@ -123,13 +123,16 @@ void ConvComputeInt8<Ptype_out>::PrepareForRun() { ...@@ -123,13 +123,16 @@ void ConvComputeInt8<Ptype_out>::PrepareForRun() {
// weigth is int8 and bias is int32 so do not need trans // weigth is int8 and bias is int32 so do not need trans
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
impl_ = new lite::arm::math::DepthwiseConvInt8<Ptype_out>; // impl_ = new lite::arm::math::DepthwiseConvInt8<Ptype_out>;
VLOG(3) << "DepthwiseConv Int8"; impl_ = new lite::arm::math::GemmLikeConvInt8<Ptype_out>;
VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && } else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) &&
kps_equal && no_dilation) { kps_equal && no_dilation) {
impl_ = new lite::arm::math::DirectConvInt8<Ptype_out>; VLOG(3) << "Run DirectConv Int8";
impl_ = new lite::arm::math::GemmLikeConvInt8<Ptype_out>;
// impl_ = new lite::arm::math::DirectConvInt8<Ptype_out>;
} else { } else {
VLOG(3) << "GemmLikeConvInt8"; VLOG(3) << "Run GemmLikeConvInt8";
impl_ = new lite::arm::math::GemmLikeConvInt8<Ptype_out>; impl_ = new lite::arm::math::GemmLikeConvInt8<Ptype_out>;
} }
...@@ -189,3 +192,25 @@ REGISTER_LITE_KERNEL( ...@@ -189,3 +192,25 @@ REGISTER_LITE_KERNEL(
.BindOutput("Output", .BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(
depthwise_conv2d, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::ConvComputeInt8<PRECISION(kInt8)>, int8_out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.Finalize();
REGISTER_LITE_KERNEL(
depthwise_conv2d, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::ConvComputeInt8<PRECISION(kFloat)>, fp32_out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();
...@@ -14,9 +14,13 @@ ...@@ -14,9 +14,13 @@
#include "paddle/fluid/lite/kernels/arm/fc_compute.h" #include "paddle/fluid/lite/kernels/arm/fc_compute.h"
#include <vector> #include <vector>
#include "paddle/fluid/lite/api/paddle_place.h"
#include "paddle/fluid/lite/arm/math/funcs.h" #include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/arm/math/gemm_prepacked_int8.h"
#include "paddle/fluid/lite/arm/math/gemv_arm_int8.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h" #include "paddle/fluid/lite/core/type_system.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
...@@ -71,8 +75,8 @@ void FcCompute::Run() { ...@@ -71,8 +75,8 @@ void FcCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
if (m_ > 1) { if (m_ > 1) {
float* packed_in = static_cast<float*>(ctx.workspace_data<float>()) + float* packed_in =
ctx.l2_cache_size() / sizeof(float); ctx.workspace_data<float>() + ctx.l2_cache_size() / sizeof(float);
lite::arm::math::prepackA(packed_in, i_data, k_, 0, m_, 0, k_, false, &ctx); lite::arm::math::prepackA(packed_in, i_data, k_, 0, m_, 0, k_, false, &ctx);
lite::arm::math::sgemm_prepack(packed_in, w_data, b_data, o_data, m_, n_, lite::arm::math::sgemm_prepack(packed_in, w_data, b_data, o_data, m_, n_,
k_, false, false, false, &ctx); k_, false, false, false, &ctx);
...@@ -89,6 +93,97 @@ void FcCompute::Run() { ...@@ -89,6 +93,97 @@ void FcCompute::Run() {
} }
} }
template <PrecisionType Ptype_out>
void FcComputeInt8<Ptype_out>::PrepareForRun() {
auto& param = this->Param<operators::FcParam>();
auto x_dims = param.input->dims();
auto w_dims = param.w->dims();
auto& ctx = this->ctx_->template As<ARMContext>();
if (!tmp_int32_out_) {
tmp_int32_out_ = new Tensor;
tmp_int32_out_->Resize(param.output->dims());
}
CHECK_GE(x_dims.size(), 2UL);
CHECK_EQ(w_dims.size(), 2UL);
CHECK_EQ(param.output->dims().size(), 2UL);
this->m_ = x_dims.Slice(0, param.in_num_col_dims).production();
this->k_ = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production();
this->n_ = w_dims[1];
CHECK_EQ(k_, static_cast<int>(w_dims[0]));
if (this->m_ == 1) {
if (!this->transed_weight_) {
this->transed_weight_ = new Tensor;
}
this->transed_weight_->Resize({this->n_, this->k_});
const auto* w_data = param.w->template data<int8_t>();
auto* t_data = this->transed_weight_->template mutable_data<int8_t>();
int i = 0;
for (int nn = 0; nn < this->n_; ++nn) {
for (int kk = 0; kk < this->k_; ++kk) {
t_data[i++] = w_data[kk * this->n_ + nn];
}
}
}
if (this->m_ > 1) {
int hblock = lite::arm::math::get_hblock(ctx.arch());
int m_round = hblock * ((this->m_ + hblock - 1) / hblock);
ctx.ExtendWorkspace(DDimLite(std::vector<int64_t>({m_round * this->k_})));
}
}
template <PrecisionType Ptype_out>
void FcComputeInt8<Ptype_out>::Run() {
auto& param = this->Param<operators::FcParam>();
const auto* i_data = param.input->template data<int8_t>();
const auto* w_data = param.w->template data<int8_t>();
const auto* b_data = param.bias ? param.bias->template data<int>() : nullptr;
int* o_data = nullptr;
auto& ctx = this->ctx_->template As<ARMContext>();
o_data = this->tmp_int32_out_->template mutable_data<int>();
if (m_ > 1) {
int8_t* packed_in =
static_cast<int8_t*>(ctx.template workspace_data<int8_t>()) +
ctx.l2_cache_size() / sizeof(int8_t);
lite::arm::math::prepackA_int8(packed_in, i_data, k_, 0, m_, 0, k_, false);
lite::arm::math::gemm_prepack_int8(packed_in, w_data, b_data, o_data, m_,
n_, k_, false, false, false, nullptr,
&ctx);
if (param.bias) {
CHECK_EQ(param.bias->numel(), n_);
lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_);
}
} else {
CHECK(transed_weight_);
const auto* t_data = transed_weight_->template data<int8_t>();
lite::arm::math::gemv_int8(t_data, i_data, o_data, false, n_, k_, nullptr,
b_data != nullptr, b_data, false);
}
float i_scale = param.input_scale;
std::vector<float> weight_scale = param.weight_scale;
if (Ptype_out == PRECISION(kInt8)) {
float o_scale = param.output_scale;
param.output->template mutable_data<int8_t>();
lite::arm::math::trans_tensor_dtype<PRECISION(kInt32), PRECISION(kInt8)>(
tmp_int32_out_, param.output, i_scale, o_scale, weight_scale);
} else if (Ptype_out == PRECISION(kFloat)) {
param.output->template mutable_data<float>();
lite::arm::math::trans_tensor_dtype<PRECISION(kInt32), PRECISION(kFloat)>(
tmp_int32_out_, param.output, i_scale, 1.f, weight_scale);
} else {
LOG(ERROR) << "unsupported precision type!!";
}
}
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
...@@ -101,3 +196,21 @@ REGISTER_LITE_KERNEL(fc, kARM, kFloat, kNCHW, ...@@ -101,3 +196,21 @@ REGISTER_LITE_KERNEL(fc, kARM, kFloat, kNCHW,
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(
fc, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::FcComputeInt8<PRECISION(kInt8)>, int8out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.Finalize();
REGISTER_LITE_KERNEL(
fc, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::FcComputeInt8<PRECISION(kFloat)>, fp32out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <stdint.h>
#include "paddle/fluid/lite/arm/math/type_trans.h"
#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/operators/fc_op.h" #include "paddle/fluid/lite/operators/fc_op.h"
...@@ -40,6 +42,27 @@ class FcCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { ...@@ -40,6 +42,27 @@ class FcCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
int m_, n_, k_; int m_, n_, k_;
}; };
template <PrecisionType Ptype_out>
class FcComputeInt8 : public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
public:
using param_t = operators::FcParam;
void PrepareForRun() override;
void Run() override;
~FcComputeInt8() override {
if (transed_weight_) {
delete transed_weight_;
}
};
private:
lite::Tensor* transed_weight_{nullptr};
Tensor* tmp_int32_out_{nullptr};
int m_, n_, k_;
};
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -37,12 +37,8 @@ bool CalibOpLite::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -37,12 +37,8 @@ bool CalibOpLite::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.input = const_cast<lite::Tensor *>(&(x_var->Get<lite::Tensor>())); param_.input = const_cast<lite::Tensor *>(&(x_var->Get<lite::Tensor>()));
param_.output = output_var->GetMutable<lite::Tensor>(); param_.output = output_var->GetMutable<lite::Tensor>();
std::vector<std::string> input_arg_names = opdesc.InputArgumentNames(); std::vector<std::string> input_arg_names = opdesc.InputArgumentNames();
param_.in_dtype = if (opdesc.HasAttr("scale")) {
static_cast<lite::PrecisionType>(opdesc.GetAttr<int>("in_dtype")); param_.scale = opdesc.GetAttr<float>("scale");
param_.out_dtype =
static_cast<lite::PrecisionType>(opdesc.GetAttr<int>("out_dtype"));
if (opdesc.HasAttr("in_scale")) {
param_.in_scale = opdesc.GetAttr<float>("in_scale");
} }
CHECK(param_.input) << "Input(X) of CalibOp should not be null."; CHECK(param_.input) << "Input(X) of CalibOp should not be null.";
CHECK(param_.output) << "Output(Out) of CalibOp should not be null."; CHECK(param_.output) << "Output(Out) of CalibOp should not be null.";
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/operators/calib_op.h" #include "paddle/fluid/lite/operators/calib_op.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
...@@ -42,9 +41,7 @@ TEST(calib_op_lite, TestARM) { ...@@ -42,9 +41,7 @@ TEST(calib_op_lite, TestARM) {
desc.SetType("calib"); desc.SetType("calib");
desc.SetInput("Input", {"Input"}); desc.SetInput("Input", {"Input"});
desc.SetOutput("Out", {"output"}); desc.SetOutput("Out", {"output"});
desc.SetAttr("in_dtype", static_cast<int>(PRECISION(kInt8))); desc.SetAttr("scale", 10.0f);
desc.SetAttr("out_dtype", static_cast<int>(PRECISION(kFloat)));
desc.SetAttr("in_scale", 10.0f);
CalibOpLite calib("calib"); CalibOpLite calib("calib");
...@@ -60,5 +57,6 @@ TEST(calib_op_lite, TestARM) { ...@@ -60,5 +57,6 @@ TEST(calib_op_lite, TestARM) {
} // namespace paddle } // namespace paddle
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, def); USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, int8_to_fp32);
#endif #endif
...@@ -76,6 +76,17 @@ class ConvOpLite : public OpLite { ...@@ -76,6 +76,17 @@ class ConvOpLite : public OpLite {
} }
} }
param_.fuse_relu = op_desc.GetAttr<bool>("fuse_relu"); param_.fuse_relu = op_desc.GetAttr<bool>("fuse_relu");
// For Int8
if (op_desc.HasAttr("enable_int8")) {
param_.enable_int8 = op_desc.GetAttr<bool>("enable_int8");
if (op_desc.HasAttr("input_scale"))
param_.input_scale = op_desc.GetAttr<float>("input_scale");
if (op_desc.HasAttr("weight_scale"))
param_.weight_scale =
op_desc.GetAttr<std::vector<float>>("weight_scale");
if (op_desc.HasAttr("output_scale"))
param_.output_scale = op_desc.GetAttr<float>("output_scale");
}
return true; return true;
} }
......
...@@ -59,6 +59,17 @@ class FcOpLite : public OpLite { ...@@ -59,6 +59,17 @@ class FcOpLite : public OpLite {
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>(); param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.in_num_col_dims = op_desc.GetAttr<int>("in_num_col_dims"); param_.in_num_col_dims = op_desc.GetAttr<int>("in_num_col_dims");
// For Int8
if (op_desc.HasAttr("enable_int8")) {
param_.enable_int8 = op_desc.GetAttr<bool>("enable_int8");
if (op_desc.HasAttr("input_scale"))
param_.input_scale = op_desc.GetAttr<float>("input_scale");
if (op_desc.HasAttr("weight_scale"))
param_.weight_scale =
op_desc.GetAttr<std::vector<float>>("weight_scale");
if (op_desc.HasAttr("output_scale"))
param_.output_scale = op_desc.GetAttr<float>("output_scale");
}
return true; return true;
} }
......
...@@ -19,11 +19,6 @@ ...@@ -19,11 +19,6 @@
#include "paddle/fluid/lite/core/framework.pb.h" #include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/utils/all.h" #include "paddle/fluid/lite/utils/all.h"
#define WITH_INT8_CONFIG \
bool enable_int8; \
float input_scale; \
std::vector<float> weight_scale{}; \
float output_scale;
/* /*
* This file contains all the argument parameter data structure for operators. * This file contains all the argument parameter data structure for operators.
*/ */
...@@ -33,6 +28,11 @@ namespace lite { ...@@ -33,6 +28,11 @@ namespace lite {
namespace operators { namespace operators {
using param_t = Any; using param_t = Any;
#define WITH_INT8_CONFIG \
bool enable_int8{false}; \
float input_scale{1.0}; \
std::vector<float> weight_scale{}; \
float output_scale{1.0};
/// ----------------------- Functional operators ------------------------------ /// ----------------------- Functional operators ------------------------------
struct FeedParam { struct FeedParam {
...@@ -56,9 +56,7 @@ struct IoCopyParam { ...@@ -56,9 +56,7 @@ struct IoCopyParam {
struct CalibParam { struct CalibParam {
const lite::Tensor* input{}; const lite::Tensor* input{};
lite::Tensor* output{}; lite::Tensor* output{};
float in_scale; float scale;
PrecisionType in_dtype;
PrecisionType out_dtype;
}; };
/// -------------------------- NN operators ------------------------------------ /// -------------------------- NN operators ------------------------------------
...@@ -71,6 +69,8 @@ struct FcParam { ...@@ -71,6 +69,8 @@ struct FcParam {
lite::DDim in_mat_dims; lite::DDim in_mat_dims;
int in_num_col_dims{1}; int in_num_col_dims{1};
bool weight_transposed{false}; bool weight_transposed{false};
// for int8
WITH_INT8_CONFIG
}; };
// For Mul Op // For Mul Op
...@@ -81,6 +81,8 @@ struct MulParam { ...@@ -81,6 +81,8 @@ struct MulParam {
int x_num_col_dims{1}; int x_num_col_dims{1};
int y_num_col_dims{1}; int y_num_col_dims{1};
// for int8
WITH_INT8_CONFIG
}; };
struct MulGradParam { struct MulGradParam {
...@@ -152,6 +154,7 @@ struct ConvParam { ...@@ -152,6 +154,7 @@ struct ConvParam {
float scale_weights{1.0f}; // only used with mkl-dnn int8 float scale_weights{1.0f}; // only used with mkl-dnn int8
bool force_fp32_output{false}; // only used in mkl-dnn int8 bool force_fp32_output{false}; // only used in mkl-dnn int8
std::string data_format{"Anylayout"}; std::string data_format{"Anylayout"};
// for int8
WITH_INT8_CONFIG WITH_INT8_CONFIG
}; };
......
...@@ -4,6 +4,7 @@ set -ex ...@@ -4,6 +4,7 @@ set -ex
TESTS_FILE="./lite_tests.txt" TESTS_FILE="./lite_tests.txt"
LIBS_FILE="./lite_libs.txt" LIBS_FILE="./lite_libs.txt"
readonly ADB_WORK_DIR="/data/local/tmp"
readonly common_flags="-DWITH_LITE=ON -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=OFF -DWITH_PYTHON=OFF -DWITH_TESTING=ON -DLITE_WITH_ARM=OFF" readonly common_flags="-DWITH_LITE=ON -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=OFF -DWITH_PYTHON=OFF -DWITH_TESTING=ON -DLITE_WITH_ARM=OFF"
NUM_CORES_FOR_COMPILE=8 NUM_CORES_FOR_COMPILE=8
...@@ -183,7 +184,36 @@ function test_arm_model { ...@@ -183,7 +184,36 @@ function test_arm_model {
adb -s emulator-${port} shell chmod +x "${adb_work_dir}/${test_name}" adb -s emulator-${port} shell chmod +x "${adb_work_dir}/${test_name}"
local adb_model_path="${adb_work_dir}/`basename ${model_dir}`" local adb_model_path="${adb_work_dir}/`basename ${model_dir}`"
adb -s emulator-${port} shell "${adb_work_dir}/${test_name} --model_dir=$adb_model_path" adb -s emulator-${port} shell "${adb_work_dir}/${test_name} --model_dir=$adb_model_path"
}
function _test_model_optimize_tool {
local port=$1
local remote_model_path=$ADB_WORK_DIR/lite_naive_model
local remote_test=$ADB_WORK_DIR/model_optimize_tool
local adb="adb -s emulator-${port}"
make model_optimize_tool -j$NUM_CORES_FOR_COMPILE
local test_path=$(find . -name model_optimize_tool)
local model_path=$(find . -name lite_naive_model)
$adb push ${test_path} ${ADB_WORK_DIR}
$adb shell mkdir -p $remote_model_path
$adb push $model_path/* $remote_model_path
$adb shell $remote_test --model_dir $remote_model_path --optimize_out ${remote_model_path}.opt \
--valid_targets "arm"
}
function _test_paddle_code_generator {
local port=$1
local test_name=paddle_code_generator
local remote_test=$ADB_WORK_DIR/$test_name
local remote_model=$ADB_WORK_DIR/lite_naive_model.opt
local adb="adb -s emulator-${port}"
make paddle_code_generator -j$NUM_CORES_FOR_COMPILE
local test_path=$(find . -name $test_name)
$adb push $test_path $remote_test
$adb shell $remote_test --optimized_model $remote_model --generated_code_file $ADB_WORK_DIR/gen_code.cc
} }
function cmake_arm { function cmake_arm {
...@@ -273,6 +303,9 @@ function test_arm { ...@@ -273,6 +303,9 @@ function test_arm {
# test finally # test finally
test_arm_api $port test_arm_api $port
_test_model_optimize_tool $port
_test_paddle_code_generator $port
} }
function prepare_emulator { function prepare_emulator {
......
...@@ -52,8 +52,8 @@ static std::string to_string_with_precision(const T& v, const int n = 6) { ...@@ -52,8 +52,8 @@ static std::string to_string_with_precision(const T& v, const int n = 6) {
return ss.str(); return ss.str();
} }
static std::string Join(const std::vector<std::string>& vec, template <typename T>
const std::string& delim) { std::string Join(const std::vector<T>& vec, const std::string& delim) {
if (vec.empty()) return ""; if (vec.empty()) return "";
std::stringstream ss; std::stringstream ss;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册