diff --git a/paddle/fluid/lite/api/cxx_api_bin.cc b/paddle/fluid/lite/api/cxx_api_bin.cc index dec0b65eb2791b45bdf3fa54715af97a844342fc..6e78d2012b2e8857286e9a42e38dbbaacb4f3935 100644 --- a/paddle/fluid/lite/api/cxx_api_bin.cc +++ b/paddle/fluid/lite/api/cxx_api_bin.cc @@ -86,6 +86,8 @@ USE_LITE_OP(depthwise_conv2d); USE_LITE_OP(pool2d); USE_LITE_OP(elementwise_add); USE_LITE_OP(softmax); +USE_LITE_OP(fake_quantize_moving_average_abs_max); +USE_LITE_OP(fake_dequantize_max_abs); USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index fe7defcf73e6bea6819c62ae36c87b59eb4f09b2..412c23324cf2a2ca5b04cf21fecd8a380af0d393 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -9,6 +9,7 @@ cc_library(mir_passes SRCS fc_fuse_pass.cc conv_elementwise_add_relu_fuse_pass.cc conv_bn_fuse_pass.cc + quant_dequant_fuse_pass.cc static_kernel_pick_pass.cc variable_place_inference_pass.cc type_target_transform_pass.cc diff --git a/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt index fbc7ffe730bca1e2d1c5c9fa48e81bc3b98de45c..2bf9296eb0ea37d999bdcb7fd55fd1b93439f668 100644 --- a/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt @@ -8,10 +8,15 @@ cc_library(fuse_conv_bn SRCS conv_bn_fuser.cc DEPS pattern_matcher_high_api) +cc_library(fuse_quant_dequant + SRCS quant_dequant_op_fuser.cc + DEPS pattern_matcher_high_api) + set(mir_fusers fuse_fc fuse_conv_elementwise_add_relu fuse_conv_bn + fuse_quant_dequant CACHE INTERNAL "fusers") if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc index 421c920e6214756a823622b4f24dfb651d63951b..889586a3bc6bc980a19082046f189b25422b1ed2 100644 --- a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc @@ -79,7 +79,7 @@ void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph, cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) { auto* desc = matched.at("conv2d")->stmt()->op_info(); - cpp::OpDesc op_desc; + cpp::OpDesc op_desc = *desc; op_desc.SetType(conv_type_); op_desc.SetInput("Input", {matched.at("input")->arg()->name}); op_desc.SetInput("Filter", {matched.at("filter")->arg()->name}); @@ -92,7 +92,6 @@ cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) { "ResidualData") != input_arg_names.end()) { op_desc.SetInput("ResidualData", desc->Input("ResidualData")); } - // Only consider strides, padding, groups, dilations, fuse_relu for now op_desc.SetAttr("strides", desc->GetAttr>("strides")); op_desc.SetAttr("paddings", desc->GetAttr>("paddings")); diff --git a/paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc new file mode 100644 index 0000000000000000000000000000000000000000..a767d277e76890d9d5fa5f837779d9ce14bb41a1 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -0,0 +1,177 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void QuantDequantOpFuser::BuildPattern() { + const int kNumFields = 5; + const int kQuantizedWeightOffset = 0; + const int kQuantizedOpOffset = 1; + const int kQuantizedOpOutOffset = 2; + const int kDequantOpOffset = 3; + const int kDequantOpOutOffset = 4; + + std::string weight_name = ""; + if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { + weight_name = "Filter"; + } else { + weight_name = "Y"; + } + auto* quant_op_input = VarNode("quant_op_input") + ->assert_is_op_input(quant_type_, "X") + ->AsInput(); + auto* quant_op_in_scale = VarNode("quant_op_in_scale") + ->assert_is_op_input(quant_type_, "InScale") + ->AsIntermediate(); + auto* quant_op = OpNode("quant_op", quant_type_) + ->assert_is_op(quant_type_) + ->AsIntermediate(); + + auto* quant_op_out_scale = + VarNode("quant_op_out_scale") + ->assert_is_op_output(quant_type_, "OutScale") + ->assert_is_op_input("fake_dequantize_max_abs", "Scale") + ->AsIntermediate(); + + auto* quant_op_out = VarNode("quant_op_out") + ->assert_is_op_output(quant_type_, "Out") + ->assert_is_op_input(op_type_) + ->AsIntermediate(); + std::vector nodes; + for (int i = 0; i < times_; i++) { + nodes.push_back(VarNode("quantized_op_weight" + std::to_string(i)) + ->assert_is_op_input(op_type_, weight_name) + ->AsInput()); + + nodes.push_back(OpNode("quantized_op" + std::to_string(i), op_type_) + ->assert_is_op(op_type_) + ->AsIntermediate()); + + nodes.push_back(VarNode("quantized_op_out" + std::to_string(i)) + ->assert_is_op_output(op_type_) + ->assert_is_op_input("fake_dequantize_max_abs", "X") + ->AsIntermediate()); + + nodes.push_back( + OpNode("dequant_op" + std::to_string(i), "fake_dequantize_max_abs") + ->assert_is_op("fake_dequantize_max_abs") + ->AsIntermediate()); + nodes.push_back(VarNode("dequant_op_out" + std::to_string(i)) + ->assert_is_op_output("fake_dequantize_max_abs", "Out") + ->AsOutput()); + } + + quant_op->LinksFrom({quant_op_input, quant_op_in_scale}); + quant_op_out->LinksFrom({quant_op}); + quant_op_out_scale->LinksFrom({quant_op}); + for (int i = 0; i < times_; i++) { + nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom( + {quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]}); + nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom( + {nodes[i * kNumFields + kQuantizedOpOffset]}); + nodes[i * kNumFields + kDequantOpOffset]->LinksFrom( + {nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale}); + nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom( + {nodes[i * kNumFields + kDequantOpOffset]}); + } +} + +void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + const int kNumFields = 5; + const int kQuantizedWeightOffset = 0; + const int kQuantizedOpOffset = 1; + const int kQuantizedOpOutOffset = 2; + const int kDequantOpOffset = 3; + const int kDequantOpOutOffset = 4; + + auto* quant_op_input = matched.at("quant_op_input"); + auto* quant_op_in_scale = matched.at("quant_op_in_scale"); + auto* quant_op = matched.at("quant_op"); + auto* quant_op_out_scale = matched.at("quant_op_out_scale"); + auto* quant_op_out = matched.at("quant_op_out"); + + std::vector nodes; + for (int i = 0; i < times_; i++) { + nodes.push_back(matched.at("quantized_op_weight" + std::to_string(i))); + nodes.push_back(matched.at("quantized_op" + std::to_string(i))); + nodes.push_back(matched.at("quantized_op_out" + std::to_string(i))); + nodes.push_back(matched.at("dequant_op" + std::to_string(i))); + nodes.push_back(matched.at("dequant_op_out" + std::to_string(i))); + } + int bit_length = quant_op->stmt()->op_info()->GetAttr("bit_length"); + auto* scope = quant_op->stmt()->op->scope(); + auto& valid_places = quant_op->stmt()->op->valid_places(); + int range = ((1 << (bit_length - 1)) - 1); + auto input_scale_t = scope->FindVar(quant_op_in_scale->arg()->name) + ->GetMutable(); + float input_scale = input_scale_t->data()[0]; + + for (int i = 0; i < times_; i++) { + float max_range = nodes[i * kNumFields + kDequantOpOffset] + ->stmt() + ->op_info() + ->GetAttr("max_range"); + float weight_scale = (range * range) / max_range; + + cpp::OpDesc op_desc = + *nodes[i * kNumFields + kQuantizedOpOffset]->stmt()->op_info(); + if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { + op_desc.SetInput("Input", {matched.at("quant_op_input")->arg()->name}); + op_desc.SetOutput( + "Output", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name}); + } else if (op_type_ == "mul") { + op_desc.SetInput("X", {matched.at("quant_op_input")->arg()->name}); + op_desc.SetOutput( + "Out", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name}); + } + op_desc.SetAttr("enable_int8", true); + op_desc.SetAttr("input_scale", input_scale); + auto quantized_weight_var_name = + nodes[i * kNumFields + kQuantizedWeightOffset]->arg()->name; + auto quantized_weight_t = + scope->FindVar(quantized_weight_var_name)->GetMutable(); + float* quantized_weight_data = quantized_weight_t->mutable_data(); + size_t weight_num = quantized_weight_t->data_size(); + for (size_t i = 0; i < weight_num; i++) { + quantized_weight_data[i] *= (weight_scale / range); + } + auto quantized_op = LiteOpRegistry::Global().Create(op_type_); + + quantized_op->Attach(op_desc, scope); + auto* new_op_node = + graph->GraphCreateInstructNode(quantized_op, valid_places); + IR_NODE_LINK_TO(quant_op_input, new_op_node); + IR_NODE_LINK_TO(nodes[i * kNumFields + kQuantizedWeightOffset], + new_op_node); + IR_NODE_LINK_TO(new_op_node, nodes[i * kNumFields + kDequantOpOutOffset]); + } +} + +cpp::OpDesc QuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc; + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h b/paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h new file mode 100644 index 0000000000000000000000000000000000000000..be084eaf804a65781e13a44879c9bcd88a1363db --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class QuantDequantOpFuser : public FuseBase { + public: + explicit QuantDequantOpFuser(const std::string& op_type, + const std::string& quant_type, int times) + : op_type_(op_type), quant_type_(quant_type), times_(times) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + + private: + std::string op_type_{"conv2d"}; + std::string quant_type_; + int times_; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/passes.h b/paddle/fluid/lite/core/mir/passes.h index a6abb16e3eaabe6a0f12b75248f3db1f7cfeeb81..c3226819698ecf5644981796579c0fad99439c08 100644 --- a/paddle/fluid/lite/core/mir/passes.h +++ b/paddle/fluid/lite/core/mir/passes.h @@ -22,6 +22,7 @@ namespace mir {} // namespace mir } // namespace paddle #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +#endif USE_MIR_PASS(demo); USE_MIR_PASS(static_kernel_pick_pass); USE_MIR_PASS(variable_place_inference_pass); @@ -29,9 +30,9 @@ USE_MIR_PASS(type_target_transform_pass); USE_MIR_PASS(generate_program_pass); USE_MIR_PASS(io_copy_kernel_pick_pass); USE_MIR_PASS(argument_type_display_pass); -#endif USE_MIR_PASS(runtime_context_assign_pass); USE_MIR_PASS(lite_conv_bn_fuse_pass); USE_MIR_PASS(graph_visualze); USE_MIR_PASS(lite_fc_fuse_pass); USE_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass); +USE_MIR_PASS(lite_quant_dequant_fuse_pass); diff --git a/paddle/fluid/lite/core/mir/pattern_matcher.cc b/paddle/fluid/lite/core/mir/pattern_matcher.cc index bff313432f50b936f15c63b44c3e130460384317..3cda96c307c29391235c8e14e68d67497aadab2d 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher.cc +++ b/paddle/fluid/lite/core/mir/pattern_matcher.cc @@ -115,7 +115,6 @@ void PatternMatcher::operator()(SSAGraph *graph, bool PatternMatcher::MarkPMNodesInGraph(SSAGraph *graph) { VLOG(3) << "mark pmnodes in graph"; if (graph->nodes().empty()) return false; - for (auto &node : graph->mutable_nodes()) { for (const auto &pmnode : pattern_.nodes()) { if (pmnode->Tell(&node)) { @@ -398,7 +397,7 @@ PMNode *PMNode::assert_is_op_output(const std::string &op_type) { asserts_.emplace_back([=](const Node *x) { for (auto *op : x->inlinks) { if (op && op->IsStmt()) { - auto *op_info = x->stmt()->op_info(); + auto *op_info = op->stmt()->op_info(); if (op_info->Type() == op_type) return true; } } diff --git a/paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.cc b/paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..deb06ae6d7d43b84f7a8e1f66331ef87307bc9d7 --- /dev/null +++ b/paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.h" +#include +#include +#include "paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { + std::unordered_set quant_types = { + "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; + std::unordered_set quantized_op_types = {"conv2d", "mul", + "depthwise_conv2d"}; + for (auto& quant_type : quant_types) { + for (auto& op_type : quantized_op_types) { + for (int i = 6; i >= 1; i--) { + fusion::QuantDequantOpFuser fuser(op_type, quant_type, i); + fuser(graph.get()); + } + } + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_quant_dequant_fuse_pass, + paddle::lite::mir::QuantDequantFusePass); diff --git a/paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.h b/paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..5cd38de51de0184bdb7e56abf811ba51d78bae06 --- /dev/null +++ b/paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.h @@ -0,0 +1,33 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class QuantDequantFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/optimizer.h b/paddle/fluid/lite/core/optimizer.h index a3e0641b1c7a44809e2a8fdc1b34a49772f71085..3424024f14bd1909421782cbc80abab495260c7f 100644 --- a/paddle/fluid/lite/core/optimizer.h +++ b/paddle/fluid/lite/core/optimizer.h @@ -48,19 +48,19 @@ class Optimizer { if (passes.empty()) { RunPasses(std::vector{{ + "lite_quant_dequant_fuse_pass", // "lite_conv_bn_fuse_pass", // "lite_conv_elementwise_add_act_fuse_pass", // "lite_fc_fuse_pass", // + "static_kernel_pick_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + "type_target_transform_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + "io_copy_kernel_pick_pass", // + "variable_place_inference_pass", // #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK - "static_kernel_pick_pass", // - "variable_place_inference_pass", // - "argument_type_display_pass", // - "type_target_transform_pass", // - "argument_type_display_pass", // - "variable_place_inference_pass", // - "argument_type_display_pass", // - "io_copy_kernel_pick_pass", // - "variable_place_inference_pass", // #endif "runtime_context_assign_pass", // }}); diff --git a/paddle/fluid/lite/core/target_wrapper.h b/paddle/fluid/lite/core/target_wrapper.h index 1029bf5300e6782762f5cc235bea53ff66e953a0..c4a870ab83f0c61fc4a5116f8c3dd379e8ead9db 100644 --- a/paddle/fluid/lite/core/target_wrapper.h +++ b/paddle/fluid/lite/core/target_wrapper.h @@ -55,8 +55,8 @@ enum class DataLayoutType : int { #define DATALAYOUT(item__) paddle::lite::DataLayoutType::item__ static const std::string& TargetToStr(TargetType target) { - static const std::string target2string[] = {"unk", "host", "x86", "cuda", - "any"}; + static const std::string target2string[] = {"unk", "host", "x86", + "cuda", "arm", "any"}; auto x = static_cast(target); CHECK_LT(x, static_cast(TARGET(NUM))); return target2string[x]; diff --git a/paddle/fluid/lite/core/type_system.h b/paddle/fluid/lite/core/type_system.h index 46d17e3c33e9058288f9a73649cb88ea8c3ed868..8bd1f3739498367fd47dfcceceee7b345c9499b8 100644 --- a/paddle/fluid/lite/core/type_system.h +++ b/paddle/fluid/lite/core/type_system.h @@ -165,8 +165,8 @@ class Type : public DataType { // -------------------------------- compatible check --------------------------- static bool TargetCompatibleTo(const Type& a, const Type& b) { - auto is_host = [](TargetType x) { - return x == TARGET(kHost) || x == TARGET(kX86); + auto is_host = [](TargetType x) -> bool { + return x == TARGET(kHost) || x == TARGET(kX86) || x == TARGET(kARM); }; if (a.IsVoid() || b.IsVoid()) return true; if (a.IsTensor() || b.IsTensor()) { diff --git a/paddle/fluid/lite/kernels/arm/conv_compute.cc b/paddle/fluid/lite/kernels/arm/conv_compute.cc index a8a2ac790a3c045642277ef75367bbdd878f0d6d..5e9ddb6271684120c8cab68e6e10bade3a3ab015 100644 --- a/paddle/fluid/lite/kernels/arm/conv_compute.cc +++ b/paddle/fluid/lite/kernels/arm/conv_compute.cc @@ -100,7 +100,7 @@ void ConvCompute::Run() { REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ConvCompute, def) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) - // .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); @@ -108,7 +108,7 @@ REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ConvCompute, def) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) - // .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 536fcb75ef47c33c3bb0ef1996526fca50bf5497..9269e46e6624770aceab439ef5eb85505643e950 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -21,6 +21,8 @@ cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framewo cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS}) cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS}) # cc_library(split_op_lite SRCS split_op.cc DEPS ${op_DEPS}) +cc_library(fake_quant SRCS fake_quantize_moving_avg_max_abs.cc DEPS ${op_DEPS}) +cc_library(fake_dequant SRCS fake_dequantize_max_abs.cc DEPS ${op_DEPS}) set(ops_lite conv_op_lite @@ -42,6 +44,8 @@ set(ops_lite dropout_op_lite concat_op_lite #split_op_lite + fake_quant + fake_dequant PARENT_SCOPE) lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc diff --git a/paddle/fluid/lite/operators/fake_dequantize_max_abs.cc b/paddle/fluid/lite/operators/fake_dequantize_max_abs.cc new file mode 100644 index 0000000000000000000000000000000000000000..8c3c8c7fd79ee40a5d87e5d395899a6b124988cd --- /dev/null +++ b/paddle/fluid/lite/operators/fake_dequantize_max_abs.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/fake_dequantize_max_abs.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators {} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(fake_dequantize_max_abs, + paddle::lite::operators::FakeDequantizeMaxAbsOpLite); diff --git a/paddle/fluid/lite/operators/fake_dequantize_max_abs.h b/paddle/fluid/lite/operators/fake_dequantize_max_abs.h new file mode 100644 index 0000000000000000000000000000000000000000..4df7215ff061e4ba14732ff8507fbcf6eb3cb0fe --- /dev/null +++ b/paddle/fluid/lite/operators/fake_dequantize_max_abs.h @@ -0,0 +1,64 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/operators/op_params.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FakeDequantizeMaxAbsOpLite : public OpLite { + public: + FakeDequantizeMaxAbsOpLite() {} + + explicit FakeDequantizeMaxAbsOpLite(const std::string &type) : OpLite(type) {} + + bool CheckShape() const override {} + + bool InferShape() const override {} + + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + auto x = op_desc.Input("X").front(); + auto in_scale = op_desc.Input("Scale").front(); + + auto out = op_desc.Output("Out").front(); + + param_.x = scope->FindVar(x)->GetMutable(); + param_.in_scale = scope->FindVar(in_scale)->GetMutable(); + + param_.out = scope->FindVar(out)->GetMutable(); + param_.max_range = op_desc.GetAttr("max_range"); + return true; + } + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "fake_dequantize_max_abs"; } + + private: + mutable FakeDequantizeMaxAbsParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.cc b/paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.cc new file mode 100644 index 0000000000000000000000000000000000000000..59f48d4380f4a7954af73bb512b92c03ed513735 --- /dev/null +++ b/paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators {} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(fake_quantize_moving_average_abs_max, + paddle::lite::operators::FakeQuantizeMovingAvgMaxAbsOpLite); diff --git a/paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.h b/paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.h new file mode 100644 index 0000000000000000000000000000000000000000..1db4f3bf62064ef38c654d557b6f986d0d806fd6 --- /dev/null +++ b/paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.h @@ -0,0 +1,69 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/operators/op_params.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FakeQuantizeMovingAvgMaxAbsOpLite : public OpLite { + public: + FakeQuantizeMovingAvgMaxAbsOpLite() {} + + explicit FakeQuantizeMovingAvgMaxAbsOpLite(const std::string &type) + : OpLite(type) {} + + bool CheckShape() const override {} + + bool InferShape() const override {} + + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + auto x = op_desc.Input("X").front(); + auto in_scale = op_desc.Input("InScale").front(); + + auto out = op_desc.Output("Out").front(); + auto out_scale = op_desc.Output("OutScale").front(); + + param_.x = scope->FindVar(x)->GetMutable(); + param_.in_scale = scope->FindVar(in_scale)->GetMutable(); + + param_.out = scope->FindVar(out)->GetMutable(); + param_.out_scale = scope->FindVar(out_scale)->GetMutable(); + param_.bit_length = op_desc.GetAttr("bit_length"); + return true; + } + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { + return "fake_quantize_moving_avg_max_abs"; + } + + private: + mutable FakeQuantizeMovingAvgMaxAbsParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index 91a6067959854f608e31a6151a4e63e26df7eb64..bf10c717c49d0b63aa68e54c9d26bd5798517706 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -256,6 +256,28 @@ struct FillConstantParam { lite::Tensor* Out{}; }; +// +struct FakeQuantizeMovingAvgMaxAbsParam { + const lite::Tensor* x{}; + const lite::Tensor* in_scale{}; + const lite::Tensor* in_accum{}; + const lite::Tensor* in_state{}; + lite::Tensor* out{}; + lite::Tensor* out_scale{}; + lite::Tensor* out_state{}; + lite::Tensor* out_accum{}; + int bit_length; + bool is_test{true}; + float moving_rate{0.9}; +}; + +struct FakeDequantizeMaxAbsParam { + const lite::Tensor* x{}; + const lite::Tensor* in_scale{}; + lite::Tensor* out{}; + float max_range; +}; + /// ----------------------- sgd operators ---------------------- struct SGDParam { int dtype{framework::proto::VarType::FP32}; diff --git a/paddle/fluid/lite/operators/softmax_op.cc b/paddle/fluid/lite/operators/softmax_op.cc index 41d7b335e80bc0a878885c3f2d09324e36130bb3..7c554db0b562857c7750997ee0dab45195c9c077 100644 --- a/paddle/fluid/lite/operators/softmax_op.cc +++ b/paddle/fluid/lite/operators/softmax_op.cc @@ -39,7 +39,12 @@ bool SoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { &scope->FindVar(opdesc.Input("X").front())->Get()); param_.output = scope->FindVar(opdesc.Output("Out").front())->GetMutable(); - param_.axis = opdesc.GetAttr("axis"); + + if (opdesc.HasAttr("axis")) { + param_.axis = opdesc.GetAttr("axis"); + } else { + param_.axis = -1; + } CHECK(param_.x); CHECK(param_.output); return true;