提交 d30a85b9 编写于 作者: X xingzhaolong

INT8 ARM MobilenetV1 union test.

上级 2d941468
......@@ -25,7 +25,7 @@
#include <vector>
// #include "paddle/fluid/lite/utils/logging.h"
// #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include <glog/logging.h>
#include <glog/logging.h> // NOLINT
// #endif
namespace paddle {
......
......@@ -114,9 +114,17 @@ if (WITH_TESTING)
add_dependencies(test_paddle_api_lite extern_lite_download_lite_naive_model_tar_gz)
endif()
#lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc
#X86_DEPS operator
#DEPS light_api_lite model_parser_lite target_wrapper_host mir_passes
#ARM_DEPS ${arm_kernels})
lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin_int8.cc
DEPS
cxx_api_lite
model_parser_lite
target_wrapper_host
mir_passes
${ops_lite} ${host_kernels}
ARM_DEPS ${arm_kernels})
lite_cc_binary(model_optimize_tool SRCS model_optimize_tool.cc DEPS paddle_api_full)
# lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc
# X86_DEPS operator
# DEPS light_api_lite model_parser_lite target_wrapper_host mir_passes
# ARM_DEPS ${arm_kernels})
......@@ -29,16 +29,18 @@ double time_diff(Time t1, Time t2) {
return counter.count() / 1000.0;
}
void Run(const char* model_dir, int repeat, int thread_num) {
void Run(const char* model_dir, int repeat) {
#ifdef LITE_WITH_ARM
DeviceInfo::Init();
DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, thread_num);
#endif
lite::Predictor predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}});
std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kInt8)},
});
predictor.Build(model_dir, Place{TARGET(kARM), PRECISION(kFloat)},
predictor.Build(model_dir, Place{TARGET(kARM), PRECISION(kInt8)},
valid_places);
auto* input_tensor = predictor.GetInput(0);
......@@ -48,8 +50,6 @@ void Run(const char* model_dir, int repeat, int thread_num) {
data[i] = 1;
}
for (int i = 0; i < 10; i++) predictor.Run();
auto time1 = time();
for (int i = 0; i < repeat; i++) predictor.Run();
auto time2 = time();
......@@ -68,8 +68,8 @@ void Run(const char* model_dir, int repeat, int thread_num) {
} // namespace paddle
int main(int argc, char** argv) {
CHECK_EQ(argc, 4) << "usage: ./cmd <model_dir> <repeat> <thread_num>";
paddle::lite::Run(argv[1], std::stoi(argv[2]), std::stoi(argv[3]));
CHECK_EQ(argc, 3) << "usage: ./cmd <model_dir> <repeat>";
paddle::lite::Run(argv[1], std::stoi(argv[2]));
return 0;
}
......@@ -93,13 +93,18 @@ USE_LITE_OP(fake_dequantize_max_abs);
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
USE_LITE_OP(calib);
#ifdef LITE_WITH_ARM
USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(fc, kARM, kInt8, kNCHW, int8out);
USE_LITE_KERNEL(fc, kARM, kInt8, kNCHW, fp32out);
USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, int8_out);
USE_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, fp32_out);
USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(relu, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, def);
......@@ -107,6 +112,9 @@ USE_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, int8_to_fp32);
// USE_LITE_KERNEL(feed, kARM, kAny, kAny, def);
// USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def);
#endif // LITE_WITH_ARM
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h"
#include <chrono> // NOLINT
#include "paddle/fluid/lite/api/paddle_use_kernels.h"
#include "paddle/fluid/lite/api/paddle_use_ops.h"
#include "paddle/fluid/lite/api/paddle_use_passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
using Time = decltype(std::chrono::high_resolution_clock::now());
Time time() { return std::chrono::high_resolution_clock::now(); }
double time_diff(Time t1, Time t2) {
typedef std::chrono::microseconds ms;
auto diff = t2 - t1;
ms counter = std::chrono::duration_cast<ms>(diff);
return counter.count() / 1000.0;
}
void Run(const char* model_dir, int repeat) {
#ifdef LITE_WITH_ARM
DeviceInfo::Init();
#endif
lite::Predictor predictor;
std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kInt8)},
});
predictor.Build(model_dir, Place{TARGET(kARM), PRECISION(kInt8)},
valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < input_tensor->dims().production(); i++) {
data[i] = 1;
}
auto time1 = time();
for (int i = 0; i < repeat; i++) predictor.Run();
auto time2 = time();
std::cout << " predict cost: " << time_diff(time1, time2) / repeat << "ms"
<< std::endl;
auto* out = predictor.GetOutput(0);
LOG(INFO) << out << " memory size " << out->data_size();
LOG(INFO) << "out " << out->data<float>()[0];
LOG(INFO) << "out " << out->data<float>()[1];
LOG(INFO) << "dims " << out->dims();
LOG(INFO) << "out data size: " << out->data_size();
}
} // namespace lite
} // namespace paddle
int main(int argc, char** argv) {
CHECK_EQ(argc, 3) << "usage: ./cmd <model_dir> <repeat>";
paddle::lite::Run(argv[1], std::stoi(argv[2]));
return 0;
}
......@@ -38,6 +38,13 @@ USE_LITE_KERNEL(relu, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(transpose, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(transpose2, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, int8_to_fp32);
USE_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, int8_out);
USE_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, fp32_out);
USE_LITE_KERNEL(fc, kARM, kInt8, kNCHW, int8out);
USE_LITE_KERNEL(fc, kARM, kInt8, kNCHW, fp32out);
#endif
#ifdef LITE_WITH_X86
......
......@@ -38,3 +38,7 @@ USE_LITE_OP(batch_norm)
USE_LITE_OP(fusion_elementwise_sub_activation)
USE_LITE_OP(transpose)
USE_LITE_OP(transpose2)
USE_LITE_OP(fake_quantize_moving_average_abs_max);
USE_LITE_OP(fake_dequantize_max_abs);
USE_LITE_OP(calib);
......@@ -31,3 +31,5 @@ USE_MIR_PASS(identity_scale_eliminate_pass);
USE_MIR_PASS(lite_conv_elementwise_add_activation_fuse_pass);
USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass);
USE_MIR_PASS(lite_quant_dequant_fuse_pass);
USE_MIR_PASS(precision_cast_transform_pass);
USE_MIR_PASS(trans_weight_pass);
......@@ -18,10 +18,12 @@ cc_library(mir_passes
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
type_target_transform_pass.cc
precision_cast_transform_pass.cc
io_copy_kernel_pick_pass.cc
graph_visualize_pass.cc
generate_program_pass.cc
argument_type_display_pass.cc
trans_weigths_pass.cc
demo_pass.cc
runtime_context_assign_pass.cc
DEPS mir_pass types_lite context_lite ${mir_fusers})
......
......@@ -60,7 +60,7 @@ void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
}
cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
cpp::OpDesc op_desc = *matched.at("mul")->stmt()->op_info();
op_desc.SetType("fc");
op_desc.SetInput("Input", {matched.at("x")->arg()->name});
op_desc.SetInput("W", {matched.at("W")->arg()->name});
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/precision_cast_transform_pass.h"
#include <list>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void PrecisionCastPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// Start from inputs of the graph, those should have place set.
std::list<Node*> nodes;
for (auto& node : graph->mutable_nodes()) {
nodes.push_back(&node);
}
for (auto& node : nodes) {
if (!node->IsStmt()) continue;
auto inlinks = node->inlinks;
for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in);
}
}
VLOG(3) << "\n" << Visualize(graph.get());
}
void PrecisionCastPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
Node* in) {
// If this input is out of date.
if (inst_node->inlinks.end() ==
std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in))
return;
CHECK(inst_node->IsStmt());
auto& inst = inst_node->AsStmt();
CHECK(in->IsRoleSet());
CHECK(in->IsArg());
auto in_arg_name = in->AsArg().name;
std::string tmp;
CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp));
auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp);
CHECK(in->AsArg().type);
LOG(INFO) << inst.picked_kernel().name();
// if (!in->AsArg().is_weight && !PrecisionCompatibleTo(*in->AsArg().type,
// *decl_arg_type)) {
if (!PrecisionCompatibleTo(*in->AsArg().type, *decl_arg_type)) {
LOG(INFO) << "found Target unmatched tensor: " << in->AsArg().name
<< " for kernel " << inst.op()->DebugString() << " "
<< *in->AsArg().type << " -> " << *decl_arg_type;
// Add an Cast instruction to make the input compatible with other dist.
AddCastInst(*in->AsArg().type, *decl_arg_type, in, graph, inst_node,
graph->valid_places());
}
}
void PrecisionCastPass::AddCastInst(const Type& from, const Type& to, Node* in,
SSAGraph* graph, Node* inst_node,
const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set";
// var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new Cast Statement Node.
CHECK(in->IsArg());
auto node_id = [&] { return graph->nodes().size(); };
auto cast_op_output_name =
in->AsArg().name + "/trans/" + std::to_string(node_id());
auto* cast_op_output_arg = graph->NewArgumentNode(cast_op_output_name);
auto* cast_inst = graph->NewInstructNode();
// create Op and kernels.
auto cast_op = LiteOpRegistry::Global().Create("calib");
CHECK(cast_op) << "create op [" << cast_op << "] failed";
// Create the new var manually.
inst_node->AsStmt().op()->scope()->Var(cast_op_output_name);
// Create Calib Instruction.
cpp::OpDesc op_desc;
op_desc.SetType("calib");
op_desc.SetInput("Input", {in->AsArg().name});
op_desc.SetOutput("Out", {cast_op_output_name});
CHECK(inst_node->AsStmt().op_info()->HasAttr("input_scale"));
op_desc.SetAttr("scale",
inst_node->AsStmt().op_info()->GetAttr<float>("input_scale"));
cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope());
auto kernels = cast_op->CreateKernels(valid_places);
std::vector<std::unique_ptr<KernelBase>> selected_kernels;
bool is_found = false;
for (auto& kernel : kernels) {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (in_arg_ty->precision() == from.precision() &&
out_arg_ty->precision() == to.precision()) {
is_found = true;
selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel
cast_inst->AsStmt("calib", std::move(selected_kernels), cast_op);
break;
}
}
CHECK(is_found) << "Can't find a Cast kernel for Cast op: " << from << ":"
<< in->AsArg().name << "->" << to << ":"
<< inst_node->AsStmt().op_info()->Type();
// Remove the old link
RemoveDirectedLink(in, inst_node);
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink(in, cast_inst);
DirectedLink(cast_inst, cast_op_output_arg);
DirectedLink(cast_op_output_arg, inst_node);
// reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), in->AsArg().name,
cast_op_output_name);
// recreate the op
auto original_selected_kernel =
std::move(inst_node->AsStmt().kernels().front());
auto updated_op_info = *inst_node->AsStmt().mutable_op_info();
inst_node->AsStmt().ResetOp(updated_op_info, graph->valid_places());
inst_node->AsStmt().kernels().clear();
inst_node->AsStmt().kernels().emplace_back(
std::move(original_selected_kernel));
for (auto& kernel : inst_node->AsStmt().kernels()) {
LOG(INFO) << "kernel info: " << kernel->name();
inst_node->AsStmt().op()->AttachKernel(kernel.get());
}
graph->CheckValid();
}
void PrecisionCastPass::SetValidPlaces(const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty());
valid_places_ = valid_places;
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(precision_cast_transform_pass,
paddle::lite::mir::PrecisionCastPass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace mir {
static void UpdateInputTo(cpp::OpDesc* desc, const std::string& from,
const std::string& to) {
for (auto& item : *desc->mutable_inputs()) {
for (auto& input : item.second) {
if (input == from) {
input = to;
}
}
}
}
/*
* The pass complement the necessary instruction to make data
* transferring or transformation between different places.
*/
class PrecisionCastPass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in);
void AddCastInst(const Type& from, const Type& to, Node* in, SSAGraph* graph,
Node* inst_node, const std::vector<Place>& valid_places);
void SetValidPlaces(const std::vector<Place>& valid_places);
const std::vector<Place>& valid_places() const { return valid_places_; }
private:
std::vector<Place> valid_places_;
};
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -33,9 +33,12 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
<< "kernel_pick_factors should be specified first";
CHECK(graph) << "graph not valid";
// sort kernels by the factors.
for (auto& node : graph->mutable_nodes()) {
if (!node.IsStmt()) continue;
auto& instruct = node.AsStmt();
// Get candidate kernels
std::vector<std::pair<size_t, std::unique_ptr<KernelBase>>> scored;
CHECK(!instruct.kernels().empty()) << "No kernels found for "
<< instruct.op_type();
......@@ -43,15 +46,56 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
size_t score = KernelGrade(*kernel);
scored.emplace_back(score, std::move(kernel));
}
std::sort(scored.begin(), scored.end(), KernelScoreCmp);
// Move kernel back
// Just keep a single best kernel.
// TODO(Superjomn) reconsider this.
instruct.kernels().clear();
instruct.kernels().emplace_back(std::move(scored.front().second));
VLOG(2) << "pick " << instruct.kernels().front()->name();
if (!instruct.op_info()->HasAttr("enable_int8")) {
// Move kernel back
// Just keep a single best kernel.
// TODO(Superjomn) reconsider this.
instruct.kernels().emplace_back(std::move(scored.front().second));
VLOG(2) << "pick " << instruct.kernels().front()->name();
} else {
bool out_type_int8 = true;
// Only if all ops linked to this op output has enable_int8 attr,
// then the op output type is int8, or fp32.
for (auto* out_n : node.outlinks) {
CHECK(out_n->IsArg());
for (auto* tmp_op : out_n->outlinks) {
CHECK(tmp_op->IsStmt());
if (!tmp_op->AsStmt().op_info()->HasAttr("enable_int8")) {
out_type_int8 = false;
break;
}
}
if (!out_type_int8) break;
}
// According to the out type, we pick the kernel.
auto output_arguments = instruct.op_info()->OutputArgumentNames();
for (auto& candidate : scored) {
bool all_output_type_match = true;
auto expect_output_type =
out_type_int8 ? PRECISION(kInt8) : PRECISION(kFloat);
for (auto& arg_name : output_arguments) {
const Type* out_arg_ty =
candidate.second->GetOutputDeclType(arg_name);
if (out_arg_ty->precision() != expect_output_type) {
all_output_type_match = false;
}
}
if (all_output_type_match) {
instruct.kernels().emplace_back(std::move(candidate.second));
VLOG(2) << "pick " << instruct.kernels().front()->name();
break;
}
}
CHECK(!instruct.kernels().empty()) << "No kernels found for "
<< instruct.op_type();
}
}
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/trans_weigths_pass.h"
#include <list>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void TransWeightPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// Start from inputs of the graph, those should have place set.
std::list<Node*> nodes;
for (auto& node : graph->mutable_nodes()) {
nodes.push_back(&node);
}
for (auto& node : nodes) {
if (!node->IsStmt()) continue;
auto& instruct = node->AsStmt();
if (!instruct.op_info()->HasAttr("enable_int8")) {
continue;
}
std::vector<std::string> output_arg_names =
instruct.op_info()->output_argnames();
CHECK(output_arg_names.size() == 1)
<< "Currently, the op that supports int8 supports only one output";
// After static kernel select pass, there is only one kernel here.
const Type* out_arg_ty =
instruct.kernels()[0]->GetOutputDeclType(output_arg_names[0]);
auto out_precision = out_arg_ty->precision();
bool out_type_int8 = out_precision == PRECISION(kInt8) ? true : false;
float in_scale, out_scale;
in_scale = instruct.op_info()->GetAttr<float>("input_scale");
// Get next input op's input_scale
if (out_type_int8) {
LOG(INFO) << "output_type_int8";
auto out_node = node->outlinks.front();
CHECK(out_node->IsArg());
auto one_adj_op_node = out_node->outlinks.front();
CHECK(one_adj_op_node->IsStmt());
auto& one_adj_instruct = one_adj_op_node->AsStmt();
CHECK(one_adj_instruct.op_info()->HasAttr("enable_int8"));
CHECK(one_adj_instruct.op_info()->HasAttr("input_scale"));
out_scale = one_adj_instruct.op_info()->GetAttr<float>("input_scale");
instruct.mutable_op_info()->SetAttr("output_scale", out_scale);
} else {
LOG(INFO) << "output_type_fp32";
}
std::string op_type = instruct.op_info()->Type();
std::vector<float> weight_scale;
auto* scope = instruct.op()->scope();
if (op_type == "depthwise_conv2d" || op_type == "conv2d") {
std::string weight_var_name = instruct.op_info()->Input("Filter").front();
auto conv_weight_t =
scope->FindVar(weight_var_name)->GetMutable<lite::Tensor>();
// till now, all the weight should be float32 type
float* conv_weight_d = conv_weight_t->mutable_data<float>();
int64_t axis_size = conv_weight_t->dims()[0];
int64_t inner_size = conv_weight_t->data_size() / axis_size;
weight_scale =
GetWeightScale(conv_weight_d, axis_size, inner_size, 127.0);
Tensor temp_tensor;
temp_tensor.Resize(conv_weight_t->dims());
int8_t* temp_data = temp_tensor.mutable_data<int8_t>();
FP32ToInt8(conv_weight_d, temp_data, weight_scale.data(), axis_size, 1,
inner_size);
conv_weight_t->CopyDataFrom(temp_tensor);
} else if (op_type == "fc" || op_type == "mul") {
std::string weight_arg_name = "W";
if (op_type == "mul") weight_arg_name = "Y";
std::string weight_var_name =
instruct.op_info()->Input(weight_arg_name).front();
auto fc_weight_t =
scope->FindVar(weight_var_name)->GetMutable<lite::Tensor>();
// till now, all the weight should be float32 type
float* fc_weight_d = fc_weight_t->mutable_data<float>();
CHECK_EQ(fc_weight_t->dims().size(), 2UL);
int64_t h = fc_weight_t->dims()[0];
int64_t w = fc_weight_t->data_size() / h;
Tensor trans_w_t, int8_temp_t;
trans_w_t.CopyDataFrom(*fc_weight_t);
float* trans_w_data = trans_w_t.mutable_data<float>();
int8_temp_t.Resize(fc_weight_t->dims());
int8_t* int8_temp_data = int8_temp_t.mutable_data<int8_t>();
// trans weight for calc the weight scale.
for (int i = 0; i < h; i++) {
for (int j = 0; j < w; j++) {
trans_w_data[i * w + j] = fc_weight_d[j * h + i];
}
}
weight_scale = GetWeightScale(trans_w_data, w, h, 127.0);
int8_t* fc_weight_int8_d = fc_weight_t->mutable_data<int8_t>();
FP32ToInt8(trans_w_data, int8_temp_data, weight_scale.data(), w, 1, h);
// Retrans back
for (int i = 0; i < w; i++) {
for (int j = 0; j < h; j++) {
fc_weight_int8_d[i * h + j] = int8_temp_data[j * w + i];
}
}
}
// Convert fp32 bias to int8 bias
std::vector<std::string> input_arg_names =
instruct.op_info()->InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") !=
input_arg_names.end() &&
instruct.op_info()->Input("Bias").size() > 0) {
std::string bias_var_name = instruct.op_info()->Input("Bias").front();
auto bias_weight_t =
scope->FindVar(bias_var_name)->GetMutable<lite::Tensor>();
float* bias_weight_d = bias_weight_t->mutable_data<float>();
Tensor temp_bias;
temp_bias.Resize(bias_weight_t->dims());
int* temp_bias_data = temp_bias.mutable_data<int>();
TransFP32BiasToInt32(bias_weight_d, temp_bias_data, temp_bias.data_size(),
in_scale, weight_scale);
bias_weight_t->CopyDataFrom(temp_bias);
}
instruct.mutable_op_info()->SetAttr("weight_scale", weight_scale);
auto original_selected_kernel = std::move(instruct.kernels().front());
auto updated_op_info = *instruct.mutable_op_info();
instruct.ResetOp(updated_op_info, graph->valid_places());
instruct.kernels().clear();
instruct.kernels().emplace_back(std::move(original_selected_kernel));
for (auto& kernel : instruct.kernels()) {
LOG(INFO) << "kernel info: " << kernel->name();
instruct.op()->AttachKernel(kernel.get());
}
}
}
void TransWeightPass::SetValidPlaces(const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty());
valid_places_ = valid_places;
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(trans_weight_pass, paddle::lite::mir::TransWeightPass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/lite/arm/math/saturate.h"
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace mir {
/*
* IoComplementPass complement the necessary instruction to make data
* transferring or transformation between different places.
*/
class TransWeightPass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
std::vector<float> GetWeightScale(float* in_data, int64_t axis_size,
int64_t inner_size, float scale_factor) {
std::vector<float> scale_out(axis_size);
auto calc_abs_max = [&](float* in, size_t data_size) -> float {
float max_data = 0.0;
for (size_t i = 0; i < data_size; i++) {
if (max_data < std::abs(in[i])) max_data = std::abs(in[i]);
}
return max_data;
};
for (int c = 0; c < axis_size; c++) {
float* part_in = in_data + c * inner_size;
scale_out[c] = calc_abs_max(part_in, inner_size) / scale_factor;
}
return scale_out;
}
void FP32ToInt8(const float* din, int8_t* dout, const float* scale,
int axis_size, int64_t outer_size, int64_t inner_size) {
int loop_size = axis_size * outer_size;
for (int i = 0; i < loop_size; ++i) {
float inv_scale = 1.f / scale[i % axis_size];
for (int j = 0; j < inner_size; ++j) {
dout[j] = static_cast<int8_t>(std::roundf(din[j] * inv_scale));
}
dout += inner_size;
din += inner_size;
}
}
void TransFP32BiasToInt32(const float* din, int* dout, size_t data_size,
float in_scale, std::vector<float> weight_scale) {
CHECK(data_size == weight_scale.size())
<< "Bias data size should be equal toe the weight scale data size.";
for (size_t i = 0; i < data_size; i++) {
dout[i] =
static_cast<int>(std::roundf(din[i] / in_scale / weight_scale[i]));
}
}
void SetValidPlaces(const std::vector<Place>& valid_places);
const std::vector<Place>& valid_places() const { return valid_places_; }
private:
std::vector<Place> valid_places_;
};
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -49,34 +49,37 @@ class Optimizer {
InitTargetTypeTransformPass();
if (passes.empty()) {
RunPasses(std::vector<std::string>{{
"lite_quant_dequant_fuse_pass", //
"lite_conv_bn_fuse_pass", //
RunPasses(std::vector<std::string>{
{"lite_quant_dequant_fuse_pass", //
"lite_conv_bn_fuse_pass", //
// This pass is disabled to force some opencl kernels selected for final
// running, otherwise, they will be fused to ARM fusion kernels, and the OpenCL
// devices will be discarded.
// TODO(Superjomn) Refine the fusion related design to select fusion kernels for
// devices automatically.
#ifndef LITE_WITH_OPENCL
"lite_conv_elementwise_add_activation_fuse_pass", //
"lite_conv_elementwise_add_activation_fuse_pass", //
#endif
"lite_fc_fuse_pass", //
"identity_scale_eliminate_pass", //
"lite_fc_fuse_pass", //
"identity_scale_eliminate_pass", //
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#ifndef LITE_WITH_OPENCL
"lite_elementwise_add_activation_fuse_pass", //
"lite_elementwise_add_activation_fuse_pass", //
#endif
#endif
"static_kernel_pick_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"type_target_transform_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"io_copy_kernel_pick_pass", //
"variable_place_inference_pass", //
"runtime_context_assign_pass", //
}});
"static_kernel_pick_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"type_target_transform_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"io_copy_kernel_pick_pass", //
"variable_place_inference_pass", //
"precision_cast_transform_pass", //
"argument_type_display_pass", //
"trans_weight_pass", //
"runtime_context_assign_pass", //
"graph_visualze"}});
} else {
RunPasses(passes);
}
......@@ -134,7 +137,7 @@ class Optimizer {
for (auto& x : passes) {
LOG(INFO) << "== Running pass " << x;
auto* pass = mir::PassManager::Global().LookUp(x);
CHECK(pass);
CHECK(pass) << "Can not find pass: " << x;
pass->Apply(graph_);
}
}
......
......@@ -31,7 +31,7 @@ lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm)
lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm)
lite_cc_test(test_concat_compute_arm SRCS concat_compute_test.cc DEPS concat_compute_arm)
lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_compute_arm)
lite_cc_test(test_calib_compute_arm SRCS calib_compute_test.cc DEPS calib_compute_arm)
# lite_cc_test(test_calib_compute_arm SRCS calib_compute_test.cc DEPS calib_compute_arm)
lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS transpose_compute_arm)
set(arm_kernels
......@@ -48,6 +48,7 @@ set(arm_kernels
concat_compute_arm
dropout_compute_arm
transpose_compute_arm
calib_compute_arm
)
set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels")
......@@ -23,26 +23,24 @@ namespace lite {
namespace kernels {
namespace arm {
void CalibCompute::Run() {
void CalibComputeFp32ToInt8::Run() {
auto& param = this->Param<operators::CalibParam>();
std::vector<float> scale = {param.in_scale};
if (param.in_dtype == PRECISION(kFloat) &&
param.out_dtype == PRECISION(kInt8)) {
const auto* din = param.input->data<float>();
auto* dout = param.output->mutable_data<signed char>();
lite::arm::math::fp32_to_int8(din, dout, scale.data(), 1, 1,
param.input->numel());
return;
}
if (param.in_dtype == PRECISION(kInt8) &&
param.out_dtype == PRECISION(kFloat)) {
const auto* din = param.input->data<signed char>();
auto* dout = param.output->mutable_data<float>();
lite::arm::math::int8_to_fp32(din, dout, scale.data(), 1, 1,
param.input->numel());
return;
}
LOG(FATAL) << "Unsupport Dtype.";
std::vector<float> scale = {param.scale};
const auto* din = param.input->data<float>();
auto* dout = param.output->mutable_data<signed char>();
lite::arm::math::fp32_to_int8(din, dout, scale.data(), 1, 1,
param.input->numel());
return;
}
void CalibComputeInt8ToFp32::Run() {
auto& param = this->Param<operators::CalibParam>();
const auto* din = param.input->data<signed char>();
std::vector<float> scale = {param.scale};
auto* dout = param.output->mutable_data<float>();
lite::arm::math::int8_to_fp32(din, dout, scale.data(), 1, 1,
param.input->numel());
return;
}
} // namespace arm
......@@ -51,7 +49,16 @@ void CalibCompute::Run() {
} // namespace paddle
REGISTER_LITE_KERNEL(calib, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::CalibCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
paddle::lite::kernels::arm::CalibComputeFp32ToInt8,
fp32_to_int8)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.Finalize();
REGISTER_LITE_KERNEL(calib, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::CalibComputeInt8ToFp32,
int8_to_fp32)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();
......@@ -21,13 +21,26 @@ namespace lite {
namespace kernels {
namespace arm {
class CalibCompute : public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
class CalibComputeFp32ToInt8
: public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
public:
using param_t = operators::CalibParam;
void Run() override;
~CalibCompute() override{};
~CalibComputeFp32ToInt8() override{};
private:
};
class CalibComputeInt8ToFp32
: public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
public:
using param_t = operators::CalibParam;
void Run() override;
~CalibComputeInt8ToFp32() override{};
private:
};
......
......@@ -146,4 +146,5 @@ TEST(calib_arm, int8_to_fp32) {
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, def);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, int8_to_fp32);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8);
......@@ -123,13 +123,16 @@ void ConvComputeInt8<Ptype_out>::PrepareForRun() {
// weigth is int8 and bias is int32 so do not need trans
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
impl_ = new lite::arm::math::DepthwiseConvInt8<Ptype_out>;
VLOG(3) << "DepthwiseConv Int8";
// impl_ = new lite::arm::math::DepthwiseConvInt8<Ptype_out>;
impl_ = new lite::arm::math::GemmLikeConvInt8<Ptype_out>;
VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) &&
kps_equal && no_dilation) {
impl_ = new lite::arm::math::DirectConvInt8<Ptype_out>;
VLOG(3) << "Run DirectConv Int8";
impl_ = new lite::arm::math::GemmLikeConvInt8<Ptype_out>;
// impl_ = new lite::arm::math::DirectConvInt8<Ptype_out>;
} else {
VLOG(3) << "GemmLikeConvInt8";
VLOG(3) << "Run GemmLikeConvInt8";
impl_ = new lite::arm::math::GemmLikeConvInt8<Ptype_out>;
}
......@@ -189,3 +192,25 @@ REGISTER_LITE_KERNEL(
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();
REGISTER_LITE_KERNEL(
depthwise_conv2d, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::ConvComputeInt8<PRECISION(kInt8)>, int8_out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.Finalize();
REGISTER_LITE_KERNEL(
depthwise_conv2d, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::ConvComputeInt8<PRECISION(kFloat)>, fp32_out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();
......@@ -14,9 +14,13 @@
#include "paddle/fluid/lite/kernels/arm/fc_compute.h"
#include <vector>
#include "paddle/fluid/lite/api/paddle_place.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/arm/math/gemm_prepacked_int8.h"
#include "paddle/fluid/lite/arm/math/gemv_arm_int8.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -71,8 +75,8 @@ void FcCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>();
if (m_ > 1) {
float* packed_in = static_cast<float*>(ctx.workspace_data<float>()) +
ctx.l2_cache_size() / sizeof(float);
float* packed_in =
ctx.workspace_data<float>() + ctx.l2_cache_size() / sizeof(float);
lite::arm::math::prepackA(packed_in, i_data, k_, 0, m_, 0, k_, false, &ctx);
lite::arm::math::sgemm_prepack(packed_in, w_data, b_data, o_data, m_, n_,
k_, false, false, false, &ctx);
......@@ -89,6 +93,97 @@ void FcCompute::Run() {
}
}
template <PrecisionType Ptype_out>
void FcComputeInt8<Ptype_out>::PrepareForRun() {
auto& param = this->Param<operators::FcParam>();
auto x_dims = param.input->dims();
auto w_dims = param.w->dims();
auto& ctx = this->ctx_->template As<ARMContext>();
if (!tmp_int32_out_) {
tmp_int32_out_ = new Tensor;
tmp_int32_out_->Resize(param.output->dims());
}
CHECK_GE(x_dims.size(), 2UL);
CHECK_EQ(w_dims.size(), 2UL);
CHECK_EQ(param.output->dims().size(), 2UL);
this->m_ = x_dims.Slice(0, param.in_num_col_dims).production();
this->k_ = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production();
this->n_ = w_dims[1];
CHECK_EQ(k_, static_cast<int>(w_dims[0]));
if (this->m_ == 1) {
if (!this->transed_weight_) {
this->transed_weight_ = new Tensor;
}
this->transed_weight_->Resize({this->n_, this->k_});
const auto* w_data = param.w->template data<int8_t>();
auto* t_data = this->transed_weight_->template mutable_data<int8_t>();
int i = 0;
for (int nn = 0; nn < this->n_; ++nn) {
for (int kk = 0; kk < this->k_; ++kk) {
t_data[i++] = w_data[kk * this->n_ + nn];
}
}
}
if (this->m_ > 1) {
int hblock = lite::arm::math::get_hblock(ctx.arch());
int m_round = hblock * ((this->m_ + hblock - 1) / hblock);
ctx.ExtendWorkspace(DDimLite(std::vector<int64_t>({m_round * this->k_})));
}
}
template <PrecisionType Ptype_out>
void FcComputeInt8<Ptype_out>::Run() {
auto& param = this->Param<operators::FcParam>();
const auto* i_data = param.input->template data<int8_t>();
const auto* w_data = param.w->template data<int8_t>();
const auto* b_data = param.bias ? param.bias->template data<int>() : nullptr;
int* o_data = nullptr;
auto& ctx = this->ctx_->template As<ARMContext>();
o_data = this->tmp_int32_out_->template mutable_data<int>();
if (m_ > 1) {
int8_t* packed_in =
static_cast<int8_t*>(ctx.template workspace_data<int8_t>()) +
ctx.l2_cache_size() / sizeof(int8_t);
lite::arm::math::prepackA_int8(packed_in, i_data, k_, 0, m_, 0, k_, false);
lite::arm::math::gemm_prepack_int8(packed_in, w_data, b_data, o_data, m_,
n_, k_, false, false, false, nullptr,
&ctx);
if (param.bias) {
CHECK_EQ(param.bias->numel(), n_);
lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_);
}
} else {
CHECK(transed_weight_);
const auto* t_data = transed_weight_->template data<int8_t>();
lite::arm::math::gemv_int8(t_data, i_data, o_data, false, n_, k_, nullptr,
b_data != nullptr, b_data, false);
}
float i_scale = param.input_scale;
std::vector<float> weight_scale = param.weight_scale;
if (Ptype_out == PRECISION(kInt8)) {
float o_scale = param.output_scale;
param.output->template mutable_data<int8_t>();
lite::arm::math::trans_tensor_dtype<PRECISION(kInt32), PRECISION(kInt8)>(
tmp_int32_out_, param.output, i_scale, o_scale, weight_scale);
} else if (Ptype_out == PRECISION(kFloat)) {
param.output->template mutable_data<float>();
lite::arm::math::trans_tensor_dtype<PRECISION(kInt32), PRECISION(kFloat)>(
tmp_int32_out_, param.output, i_scale, 1.f, weight_scale);
} else {
LOG(ERROR) << "unsupported precision type!!";
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
......@@ -101,3 +196,21 @@ REGISTER_LITE_KERNEL(fc, kARM, kFloat, kNCHW,
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(
fc, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::FcComputeInt8<PRECISION(kInt8)>, int8out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.Finalize();
REGISTER_LITE_KERNEL(
fc, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::FcComputeInt8<PRECISION(kFloat)>, fp32out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();
......@@ -13,6 +13,8 @@
// limitations under the License.
#pragma once
#include <stdint.h>
#include "paddle/fluid/lite/arm/math/type_trans.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/operators/fc_op.h"
......@@ -40,6 +42,27 @@ class FcCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
int m_, n_, k_;
};
template <PrecisionType Ptype_out>
class FcComputeInt8 : public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
public:
using param_t = operators::FcParam;
void PrepareForRun() override;
void Run() override;
~FcComputeInt8() override {
if (transed_weight_) {
delete transed_weight_;
}
};
private:
lite::Tensor* transed_weight_{nullptr};
Tensor* tmp_int32_out_{nullptr};
int m_, n_, k_;
};
} // namespace arm
} // namespace kernels
} // namespace lite
......
......@@ -37,12 +37,8 @@ bool CalibOpLite::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.input = const_cast<lite::Tensor *>(&(x_var->Get<lite::Tensor>()));
param_.output = output_var->GetMutable<lite::Tensor>();
std::vector<std::string> input_arg_names = opdesc.InputArgumentNames();
param_.in_dtype =
static_cast<lite::PrecisionType>(opdesc.GetAttr<int>("in_dtype"));
param_.out_dtype =
static_cast<lite::PrecisionType>(opdesc.GetAttr<int>("out_dtype"));
if (opdesc.HasAttr("in_scale")) {
param_.in_scale = opdesc.GetAttr<float>("in_scale");
if (opdesc.HasAttr("scale")) {
param_.scale = opdesc.GetAttr<float>("scale");
}
CHECK(param_.input) << "Input(X) of CalibOp should not be null.";
CHECK(param_.output) << "Output(Out) of CalibOp should not be null.";
......
......@@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/calib_op.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_registry.h"
......@@ -42,9 +41,7 @@ TEST(calib_op_lite, TestARM) {
desc.SetType("calib");
desc.SetInput("Input", {"Input"});
desc.SetOutput("Out", {"output"});
desc.SetAttr("in_dtype", static_cast<int>(PRECISION(kInt8)));
desc.SetAttr("out_dtype", static_cast<int>(PRECISION(kFloat)));
desc.SetAttr("in_scale", 10.0f);
desc.SetAttr("scale", 10.0f);
CalibOpLite calib("calib");
......@@ -60,5 +57,6 @@ TEST(calib_op_lite, TestARM) {
} // namespace paddle
#ifdef LITE_WITH_ARM
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, def);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, int8_to_fp32);
#endif
......@@ -76,6 +76,17 @@ class ConvOpLite : public OpLite {
}
}
param_.fuse_relu = op_desc.GetAttr<bool>("fuse_relu");
// For Int8
if (op_desc.HasAttr("enable_int8")) {
param_.enable_int8 = op_desc.GetAttr<bool>("enable_int8");
if (op_desc.HasAttr("input_scale"))
param_.input_scale = op_desc.GetAttr<float>("input_scale");
if (op_desc.HasAttr("weight_scale"))
param_.weight_scale =
op_desc.GetAttr<std::vector<float>>("weight_scale");
if (op_desc.HasAttr("output_scale"))
param_.output_scale = op_desc.GetAttr<float>("output_scale");
}
return true;
}
......
......@@ -59,6 +59,17 @@ class FcOpLite : public OpLite {
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.in_num_col_dims = op_desc.GetAttr<int>("in_num_col_dims");
// For Int8
if (op_desc.HasAttr("enable_int8")) {
param_.enable_int8 = op_desc.GetAttr<bool>("enable_int8");
if (op_desc.HasAttr("input_scale"))
param_.input_scale = op_desc.GetAttr<float>("input_scale");
if (op_desc.HasAttr("weight_scale"))
param_.weight_scale =
op_desc.GetAttr<std::vector<float>>("weight_scale");
if (op_desc.HasAttr("output_scale"))
param_.output_scale = op_desc.GetAttr<float>("output_scale");
}
return true;
}
......
......@@ -19,11 +19,6 @@
#include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/utils/all.h"
#define WITH_INT8_CONFIG \
bool enable_int8; \
float input_scale; \
std::vector<float> weight_scale{}; \
float output_scale;
/*
* This file contains all the argument parameter data structure for operators.
*/
......@@ -33,6 +28,11 @@ namespace lite {
namespace operators {
using param_t = Any;
#define WITH_INT8_CONFIG \
bool enable_int8{false}; \
float input_scale{1.0}; \
std::vector<float> weight_scale{}; \
float output_scale{1.0};
/// ----------------------- Functional operators ------------------------------
struct FeedParam {
......@@ -56,9 +56,7 @@ struct IoCopyParam {
struct CalibParam {
const lite::Tensor* input{};
lite::Tensor* output{};
float in_scale;
PrecisionType in_dtype;
PrecisionType out_dtype;
float scale;
};
/// -------------------------- NN operators ------------------------------------
......@@ -71,6 +69,8 @@ struct FcParam {
lite::DDim in_mat_dims;
int in_num_col_dims{1};
bool weight_transposed{false};
// for int8
WITH_INT8_CONFIG
};
// For Mul Op
......@@ -81,6 +81,8 @@ struct MulParam {
int x_num_col_dims{1};
int y_num_col_dims{1};
// for int8
WITH_INT8_CONFIG
};
struct MulGradParam {
......@@ -152,6 +154,7 @@ struct ConvParam {
float scale_weights{1.0f}; // only used with mkl-dnn int8
bool force_fp32_output{false}; // only used in mkl-dnn int8
std::string data_format{"Anylayout"};
// for int8
WITH_INT8_CONFIG
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册