From 73249df2141aaba1889ebea231f506d2484920d1 Mon Sep 17 00:00:00 2001 From: nhzlx Date: Wed, 19 Jun 2019 07:27:36 +0000 Subject: [PATCH] ARM int8 support 1. add fake_quant fake_dequant op 2. add quant_dequant fuse pass 3. fix bug for passes for arm 4. softmax axis problem --- paddle/fluid/lite/api/cxx_api_bin.cc | 2 + paddle/fluid/lite/core/mir/CMakeLists.txt | 1 + .../fluid/lite/core/mir/fusion/CMakeLists.txt | 5 + .../fusion/conv_elementwise_add_relu_fuser.cc | 3 +- .../core/mir/fusion/quant_dequant_op_fuser.cc | 177 ++++++++++++++++++ .../core/mir/fusion/quant_dequant_op_fuser.h | 46 +++++ paddle/fluid/lite/core/mir/passes.h | 3 +- paddle/fluid/lite/core/mir/pattern_matcher.cc | 3 +- .../lite/core/mir/quant_dequant_fuse_pass.cc | 45 +++++ .../lite/core/mir/quant_dequant_fuse_pass.h | 33 ++++ paddle/fluid/lite/core/optimizer.h | 18 +- paddle/fluid/lite/core/target_wrapper.h | 4 +- paddle/fluid/lite/core/type_system.h | 4 +- paddle/fluid/lite/kernels/arm/conv_compute.cc | 4 +- paddle/fluid/lite/operators/CMakeLists.txt | 4 + .../lite/operators/fake_dequantize_max_abs.cc | 25 +++ .../lite/operators/fake_dequantize_max_abs.h | 64 +++++++ .../fake_quantize_moving_avg_max_abs.cc | 25 +++ .../fake_quantize_moving_avg_max_abs.h | 69 +++++++ paddle/fluid/lite/operators/op_params.h | 22 +++ paddle/fluid/lite/operators/softmax_op.cc | 7 +- 21 files changed, 543 insertions(+), 21 deletions(-) create mode 100644 paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc create mode 100644 paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h create mode 100644 paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.cc create mode 100644 paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.h create mode 100644 paddle/fluid/lite/operators/fake_dequantize_max_abs.cc create mode 100644 paddle/fluid/lite/operators/fake_dequantize_max_abs.h create mode 100644 paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.cc create mode 100644 paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.h diff --git a/paddle/fluid/lite/api/cxx_api_bin.cc b/paddle/fluid/lite/api/cxx_api_bin.cc index dec0b65eb27..6e78d2012b2 100644 --- a/paddle/fluid/lite/api/cxx_api_bin.cc +++ b/paddle/fluid/lite/api/cxx_api_bin.cc @@ -86,6 +86,8 @@ USE_LITE_OP(depthwise_conv2d); USE_LITE_OP(pool2d); USE_LITE_OP(elementwise_add); USE_LITE_OP(softmax); +USE_LITE_OP(fake_quantize_moving_average_abs_max); +USE_LITE_OP(fake_dequantize_max_abs); USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index fe7defcf73e..412c23324cf 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -9,6 +9,7 @@ cc_library(mir_passes SRCS fc_fuse_pass.cc conv_elementwise_add_relu_fuse_pass.cc conv_bn_fuse_pass.cc + quant_dequant_fuse_pass.cc static_kernel_pick_pass.cc variable_place_inference_pass.cc type_target_transform_pass.cc diff --git a/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt index fbc7ffe730b..2bf9296eb0e 100644 --- a/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt @@ -8,10 +8,15 @@ cc_library(fuse_conv_bn SRCS conv_bn_fuser.cc DEPS pattern_matcher_high_api) +cc_library(fuse_quant_dequant + SRCS quant_dequant_op_fuser.cc + DEPS pattern_matcher_high_api) + set(mir_fusers fuse_fc fuse_conv_elementwise_add_relu fuse_conv_bn + fuse_quant_dequant CACHE INTERNAL "fusers") if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc index 421c920e621..889586a3bc6 100644 --- a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc @@ -79,7 +79,7 @@ void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph, cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) { auto* desc = matched.at("conv2d")->stmt()->op_info(); - cpp::OpDesc op_desc; + cpp::OpDesc op_desc = *desc; op_desc.SetType(conv_type_); op_desc.SetInput("Input", {matched.at("input")->arg()->name}); op_desc.SetInput("Filter", {matched.at("filter")->arg()->name}); @@ -92,7 +92,6 @@ cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) { "ResidualData") != input_arg_names.end()) { op_desc.SetInput("ResidualData", desc->Input("ResidualData")); } - // Only consider strides, padding, groups, dilations, fuse_relu for now op_desc.SetAttr("strides", desc->GetAttr>("strides")); op_desc.SetAttr("paddings", desc->GetAttr>("paddings")); diff --git a/paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc new file mode 100644 index 00000000000..a767d277e76 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -0,0 +1,177 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void QuantDequantOpFuser::BuildPattern() { + const int kNumFields = 5; + const int kQuantizedWeightOffset = 0; + const int kQuantizedOpOffset = 1; + const int kQuantizedOpOutOffset = 2; + const int kDequantOpOffset = 3; + const int kDequantOpOutOffset = 4; + + std::string weight_name = ""; + if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { + weight_name = "Filter"; + } else { + weight_name = "Y"; + } + auto* quant_op_input = VarNode("quant_op_input") + ->assert_is_op_input(quant_type_, "X") + ->AsInput(); + auto* quant_op_in_scale = VarNode("quant_op_in_scale") + ->assert_is_op_input(quant_type_, "InScale") + ->AsIntermediate(); + auto* quant_op = OpNode("quant_op", quant_type_) + ->assert_is_op(quant_type_) + ->AsIntermediate(); + + auto* quant_op_out_scale = + VarNode("quant_op_out_scale") + ->assert_is_op_output(quant_type_, "OutScale") + ->assert_is_op_input("fake_dequantize_max_abs", "Scale") + ->AsIntermediate(); + + auto* quant_op_out = VarNode("quant_op_out") + ->assert_is_op_output(quant_type_, "Out") + ->assert_is_op_input(op_type_) + ->AsIntermediate(); + std::vector nodes; + for (int i = 0; i < times_; i++) { + nodes.push_back(VarNode("quantized_op_weight" + std::to_string(i)) + ->assert_is_op_input(op_type_, weight_name) + ->AsInput()); + + nodes.push_back(OpNode("quantized_op" + std::to_string(i), op_type_) + ->assert_is_op(op_type_) + ->AsIntermediate()); + + nodes.push_back(VarNode("quantized_op_out" + std::to_string(i)) + ->assert_is_op_output(op_type_) + ->assert_is_op_input("fake_dequantize_max_abs", "X") + ->AsIntermediate()); + + nodes.push_back( + OpNode("dequant_op" + std::to_string(i), "fake_dequantize_max_abs") + ->assert_is_op("fake_dequantize_max_abs") + ->AsIntermediate()); + nodes.push_back(VarNode("dequant_op_out" + std::to_string(i)) + ->assert_is_op_output("fake_dequantize_max_abs", "Out") + ->AsOutput()); + } + + quant_op->LinksFrom({quant_op_input, quant_op_in_scale}); + quant_op_out->LinksFrom({quant_op}); + quant_op_out_scale->LinksFrom({quant_op}); + for (int i = 0; i < times_; i++) { + nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom( + {quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]}); + nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom( + {nodes[i * kNumFields + kQuantizedOpOffset]}); + nodes[i * kNumFields + kDequantOpOffset]->LinksFrom( + {nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale}); + nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom( + {nodes[i * kNumFields + kDequantOpOffset]}); + } +} + +void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + const int kNumFields = 5; + const int kQuantizedWeightOffset = 0; + const int kQuantizedOpOffset = 1; + const int kQuantizedOpOutOffset = 2; + const int kDequantOpOffset = 3; + const int kDequantOpOutOffset = 4; + + auto* quant_op_input = matched.at("quant_op_input"); + auto* quant_op_in_scale = matched.at("quant_op_in_scale"); + auto* quant_op = matched.at("quant_op"); + auto* quant_op_out_scale = matched.at("quant_op_out_scale"); + auto* quant_op_out = matched.at("quant_op_out"); + + std::vector nodes; + for (int i = 0; i < times_; i++) { + nodes.push_back(matched.at("quantized_op_weight" + std::to_string(i))); + nodes.push_back(matched.at("quantized_op" + std::to_string(i))); + nodes.push_back(matched.at("quantized_op_out" + std::to_string(i))); + nodes.push_back(matched.at("dequant_op" + std::to_string(i))); + nodes.push_back(matched.at("dequant_op_out" + std::to_string(i))); + } + int bit_length = quant_op->stmt()->op_info()->GetAttr("bit_length"); + auto* scope = quant_op->stmt()->op->scope(); + auto& valid_places = quant_op->stmt()->op->valid_places(); + int range = ((1 << (bit_length - 1)) - 1); + auto input_scale_t = scope->FindVar(quant_op_in_scale->arg()->name) + ->GetMutable(); + float input_scale = input_scale_t->data()[0]; + + for (int i = 0; i < times_; i++) { + float max_range = nodes[i * kNumFields + kDequantOpOffset] + ->stmt() + ->op_info() + ->GetAttr("max_range"); + float weight_scale = (range * range) / max_range; + + cpp::OpDesc op_desc = + *nodes[i * kNumFields + kQuantizedOpOffset]->stmt()->op_info(); + if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { + op_desc.SetInput("Input", {matched.at("quant_op_input")->arg()->name}); + op_desc.SetOutput( + "Output", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name}); + } else if (op_type_ == "mul") { + op_desc.SetInput("X", {matched.at("quant_op_input")->arg()->name}); + op_desc.SetOutput( + "Out", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name}); + } + op_desc.SetAttr("enable_int8", true); + op_desc.SetAttr("input_scale", input_scale); + auto quantized_weight_var_name = + nodes[i * kNumFields + kQuantizedWeightOffset]->arg()->name; + auto quantized_weight_t = + scope->FindVar(quantized_weight_var_name)->GetMutable(); + float* quantized_weight_data = quantized_weight_t->mutable_data(); + size_t weight_num = quantized_weight_t->data_size(); + for (size_t i = 0; i < weight_num; i++) { + quantized_weight_data[i] *= (weight_scale / range); + } + auto quantized_op = LiteOpRegistry::Global().Create(op_type_); + + quantized_op->Attach(op_desc, scope); + auto* new_op_node = + graph->GraphCreateInstructNode(quantized_op, valid_places); + IR_NODE_LINK_TO(quant_op_input, new_op_node); + IR_NODE_LINK_TO(nodes[i * kNumFields + kQuantizedWeightOffset], + new_op_node); + IR_NODE_LINK_TO(new_op_node, nodes[i * kNumFields + kDequantOpOutOffset]); + } +} + +cpp::OpDesc QuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc; + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h b/paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h new file mode 100644 index 00000000000..be084eaf804 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class QuantDequantOpFuser : public FuseBase { + public: + explicit QuantDequantOpFuser(const std::string& op_type, + const std::string& quant_type, int times) + : op_type_(op_type), quant_type_(quant_type), times_(times) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + + private: + std::string op_type_{"conv2d"}; + std::string quant_type_; + int times_; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/passes.h b/paddle/fluid/lite/core/mir/passes.h index a6abb16e3ea..c3226819698 100644 --- a/paddle/fluid/lite/core/mir/passes.h +++ b/paddle/fluid/lite/core/mir/passes.h @@ -22,6 +22,7 @@ namespace mir {} // namespace mir } // namespace paddle #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +#endif USE_MIR_PASS(demo); USE_MIR_PASS(static_kernel_pick_pass); USE_MIR_PASS(variable_place_inference_pass); @@ -29,9 +30,9 @@ USE_MIR_PASS(type_target_transform_pass); USE_MIR_PASS(generate_program_pass); USE_MIR_PASS(io_copy_kernel_pick_pass); USE_MIR_PASS(argument_type_display_pass); -#endif USE_MIR_PASS(runtime_context_assign_pass); USE_MIR_PASS(lite_conv_bn_fuse_pass); USE_MIR_PASS(graph_visualze); USE_MIR_PASS(lite_fc_fuse_pass); USE_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass); +USE_MIR_PASS(lite_quant_dequant_fuse_pass); diff --git a/paddle/fluid/lite/core/mir/pattern_matcher.cc b/paddle/fluid/lite/core/mir/pattern_matcher.cc index bff313432f5..3cda96c307c 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher.cc +++ b/paddle/fluid/lite/core/mir/pattern_matcher.cc @@ -115,7 +115,6 @@ void PatternMatcher::operator()(SSAGraph *graph, bool PatternMatcher::MarkPMNodesInGraph(SSAGraph *graph) { VLOG(3) << "mark pmnodes in graph"; if (graph->nodes().empty()) return false; - for (auto &node : graph->mutable_nodes()) { for (const auto &pmnode : pattern_.nodes()) { if (pmnode->Tell(&node)) { @@ -398,7 +397,7 @@ PMNode *PMNode::assert_is_op_output(const std::string &op_type) { asserts_.emplace_back([=](const Node *x) { for (auto *op : x->inlinks) { if (op && op->IsStmt()) { - auto *op_info = x->stmt()->op_info(); + auto *op_info = op->stmt()->op_info(); if (op_info->Type() == op_type) return true; } } diff --git a/paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.cc b/paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.cc new file mode 100644 index 00000000000..deb06ae6d7d --- /dev/null +++ b/paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.h" +#include +#include +#include "paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { + std::unordered_set quant_types = { + "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; + std::unordered_set quantized_op_types = {"conv2d", "mul", + "depthwise_conv2d"}; + for (auto& quant_type : quant_types) { + for (auto& op_type : quantized_op_types) { + for (int i = 6; i >= 1; i--) { + fusion::QuantDequantOpFuser fuser(op_type, quant_type, i); + fuser(graph.get()); + } + } + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_quant_dequant_fuse_pass, + paddle::lite::mir::QuantDequantFusePass); diff --git a/paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.h b/paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.h new file mode 100644 index 00000000000..5cd38de51de --- /dev/null +++ b/paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.h @@ -0,0 +1,33 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class QuantDequantFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/optimizer.h b/paddle/fluid/lite/core/optimizer.h index a3e0641b1c7..3424024f14b 100644 --- a/paddle/fluid/lite/core/optimizer.h +++ b/paddle/fluid/lite/core/optimizer.h @@ -48,19 +48,19 @@ class Optimizer { if (passes.empty()) { RunPasses(std::vector{{ + "lite_quant_dequant_fuse_pass", // "lite_conv_bn_fuse_pass", // "lite_conv_elementwise_add_act_fuse_pass", // "lite_fc_fuse_pass", // + "static_kernel_pick_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + "type_target_transform_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + "io_copy_kernel_pick_pass", // + "variable_place_inference_pass", // #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK - "static_kernel_pick_pass", // - "variable_place_inference_pass", // - "argument_type_display_pass", // - "type_target_transform_pass", // - "argument_type_display_pass", // - "variable_place_inference_pass", // - "argument_type_display_pass", // - "io_copy_kernel_pick_pass", // - "variable_place_inference_pass", // #endif "runtime_context_assign_pass", // }}); diff --git a/paddle/fluid/lite/core/target_wrapper.h b/paddle/fluid/lite/core/target_wrapper.h index 1029bf5300e..c4a870ab83f 100644 --- a/paddle/fluid/lite/core/target_wrapper.h +++ b/paddle/fluid/lite/core/target_wrapper.h @@ -55,8 +55,8 @@ enum class DataLayoutType : int { #define DATALAYOUT(item__) paddle::lite::DataLayoutType::item__ static const std::string& TargetToStr(TargetType target) { - static const std::string target2string[] = {"unk", "host", "x86", "cuda", - "any"}; + static const std::string target2string[] = {"unk", "host", "x86", + "cuda", "arm", "any"}; auto x = static_cast(target); CHECK_LT(x, static_cast(TARGET(NUM))); return target2string[x]; diff --git a/paddle/fluid/lite/core/type_system.h b/paddle/fluid/lite/core/type_system.h index 46d17e3c33e..8bd1f373949 100644 --- a/paddle/fluid/lite/core/type_system.h +++ b/paddle/fluid/lite/core/type_system.h @@ -165,8 +165,8 @@ class Type : public DataType { // -------------------------------- compatible check --------------------------- static bool TargetCompatibleTo(const Type& a, const Type& b) { - auto is_host = [](TargetType x) { - return x == TARGET(kHost) || x == TARGET(kX86); + auto is_host = [](TargetType x) -> bool { + return x == TARGET(kHost) || x == TARGET(kX86) || x == TARGET(kARM); }; if (a.IsVoid() || b.IsVoid()) return true; if (a.IsTensor() || b.IsTensor()) { diff --git a/paddle/fluid/lite/kernels/arm/conv_compute.cc b/paddle/fluid/lite/kernels/arm/conv_compute.cc index a8a2ac790a3..5e9ddb62716 100644 --- a/paddle/fluid/lite/kernels/arm/conv_compute.cc +++ b/paddle/fluid/lite/kernels/arm/conv_compute.cc @@ -100,7 +100,7 @@ void ConvCompute::Run() { REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ConvCompute, def) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) - // .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); @@ -108,7 +108,7 @@ REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ConvCompute, def) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) - // .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 536fcb75ef4..9269e46e662 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -21,6 +21,8 @@ cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framewo cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS}) cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS}) # cc_library(split_op_lite SRCS split_op.cc DEPS ${op_DEPS}) +cc_library(fake_quant SRCS fake_quantize_moving_avg_max_abs.cc DEPS ${op_DEPS}) +cc_library(fake_dequant SRCS fake_dequantize_max_abs.cc DEPS ${op_DEPS}) set(ops_lite conv_op_lite @@ -42,6 +44,8 @@ set(ops_lite dropout_op_lite concat_op_lite #split_op_lite + fake_quant + fake_dequant PARENT_SCOPE) lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc diff --git a/paddle/fluid/lite/operators/fake_dequantize_max_abs.cc b/paddle/fluid/lite/operators/fake_dequantize_max_abs.cc new file mode 100644 index 00000000000..8c3c8c7fd79 --- /dev/null +++ b/paddle/fluid/lite/operators/fake_dequantize_max_abs.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/fake_dequantize_max_abs.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators {} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(fake_dequantize_max_abs, + paddle::lite::operators::FakeDequantizeMaxAbsOpLite); diff --git a/paddle/fluid/lite/operators/fake_dequantize_max_abs.h b/paddle/fluid/lite/operators/fake_dequantize_max_abs.h new file mode 100644 index 00000000000..4df7215ff06 --- /dev/null +++ b/paddle/fluid/lite/operators/fake_dequantize_max_abs.h @@ -0,0 +1,64 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/operators/op_params.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FakeDequantizeMaxAbsOpLite : public OpLite { + public: + FakeDequantizeMaxAbsOpLite() {} + + explicit FakeDequantizeMaxAbsOpLite(const std::string &type) : OpLite(type) {} + + bool CheckShape() const override {} + + bool InferShape() const override {} + + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + auto x = op_desc.Input("X").front(); + auto in_scale = op_desc.Input("Scale").front(); + + auto out = op_desc.Output("Out").front(); + + param_.x = scope->FindVar(x)->GetMutable(); + param_.in_scale = scope->FindVar(in_scale)->GetMutable(); + + param_.out = scope->FindVar(out)->GetMutable(); + param_.max_range = op_desc.GetAttr("max_range"); + return true; + } + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "fake_dequantize_max_abs"; } + + private: + mutable FakeDequantizeMaxAbsParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.cc b/paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.cc new file mode 100644 index 00000000000..59f48d4380f --- /dev/null +++ b/paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators {} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(fake_quantize_moving_average_abs_max, + paddle::lite::operators::FakeQuantizeMovingAvgMaxAbsOpLite); diff --git a/paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.h b/paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.h new file mode 100644 index 00000000000..1db4f3bf620 --- /dev/null +++ b/paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.h @@ -0,0 +1,69 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/operators/op_params.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FakeQuantizeMovingAvgMaxAbsOpLite : public OpLite { + public: + FakeQuantizeMovingAvgMaxAbsOpLite() {} + + explicit FakeQuantizeMovingAvgMaxAbsOpLite(const std::string &type) + : OpLite(type) {} + + bool CheckShape() const override {} + + bool InferShape() const override {} + + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + auto x = op_desc.Input("X").front(); + auto in_scale = op_desc.Input("InScale").front(); + + auto out = op_desc.Output("Out").front(); + auto out_scale = op_desc.Output("OutScale").front(); + + param_.x = scope->FindVar(x)->GetMutable(); + param_.in_scale = scope->FindVar(in_scale)->GetMutable(); + + param_.out = scope->FindVar(out)->GetMutable(); + param_.out_scale = scope->FindVar(out_scale)->GetMutable(); + param_.bit_length = op_desc.GetAttr("bit_length"); + return true; + } + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { + return "fake_quantize_moving_avg_max_abs"; + } + + private: + mutable FakeQuantizeMovingAvgMaxAbsParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index 91a60679598..bf10c717c49 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -256,6 +256,28 @@ struct FillConstantParam { lite::Tensor* Out{}; }; +// +struct FakeQuantizeMovingAvgMaxAbsParam { + const lite::Tensor* x{}; + const lite::Tensor* in_scale{}; + const lite::Tensor* in_accum{}; + const lite::Tensor* in_state{}; + lite::Tensor* out{}; + lite::Tensor* out_scale{}; + lite::Tensor* out_state{}; + lite::Tensor* out_accum{}; + int bit_length; + bool is_test{true}; + float moving_rate{0.9}; +}; + +struct FakeDequantizeMaxAbsParam { + const lite::Tensor* x{}; + const lite::Tensor* in_scale{}; + lite::Tensor* out{}; + float max_range; +}; + /// ----------------------- sgd operators ---------------------- struct SGDParam { int dtype{framework::proto::VarType::FP32}; diff --git a/paddle/fluid/lite/operators/softmax_op.cc b/paddle/fluid/lite/operators/softmax_op.cc index 41d7b335e80..7c554db0b56 100644 --- a/paddle/fluid/lite/operators/softmax_op.cc +++ b/paddle/fluid/lite/operators/softmax_op.cc @@ -39,7 +39,12 @@ bool SoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { &scope->FindVar(opdesc.Input("X").front())->Get()); param_.output = scope->FindVar(opdesc.Output("Out").front())->GetMutable(); - param_.axis = opdesc.GetAttr("axis"); + + if (opdesc.HasAttr("axis")) { + param_.axis = opdesc.GetAttr("axis"); + } else { + param_.axis = -1; + } CHECK(param_.x); CHECK(param_.output); return true; -- GitLab