提交 73249df2 编写于 作者: N nhzlx

ARM int8 support

1. add fake_quant fake_dequant op
2. add quant_dequant fuse pass
3. fix bug for passes for arm
4. softmax axis problem
上级 2c9ef4b7
......@@ -86,6 +86,8 @@ USE_LITE_OP(depthwise_conv2d);
USE_LITE_OP(pool2d);
USE_LITE_OP(elementwise_add);
USE_LITE_OP(softmax);
USE_LITE_OP(fake_quantize_moving_average_abs_max);
USE_LITE_OP(fake_dequantize_max_abs);
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
......
......@@ -9,6 +9,7 @@ cc_library(mir_passes
SRCS fc_fuse_pass.cc
conv_elementwise_add_relu_fuse_pass.cc
conv_bn_fuse_pass.cc
quant_dequant_fuse_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
type_target_transform_pass.cc
......
......@@ -8,10 +8,15 @@ cc_library(fuse_conv_bn
SRCS conv_bn_fuser.cc
DEPS pattern_matcher_high_api)
cc_library(fuse_quant_dequant
SRCS quant_dequant_op_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers
fuse_fc
fuse_conv_elementwise_add_relu
fuse_conv_bn
fuse_quant_dequant
CACHE INTERNAL "fusers")
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......
......@@ -79,7 +79,7 @@ void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph,
cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) {
auto* desc = matched.at("conv2d")->stmt()->op_info();
cpp::OpDesc op_desc;
cpp::OpDesc op_desc = *desc;
op_desc.SetType(conv_type_);
op_desc.SetInput("Input", {matched.at("input")->arg()->name});
op_desc.SetInput("Filter", {matched.at("filter")->arg()->name});
......@@ -92,7 +92,6 @@ cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) {
"ResidualData") != input_arg_names.end()) {
op_desc.SetInput("ResidualData", desc->Input("ResidualData"));
}
// Only consider strides, padding, groups, dilations, fuse_relu for now
op_desc.SetAttr("strides", desc->GetAttr<std::vector<int>>("strides"));
op_desc.SetAttr("paddings", desc->GetAttr<std::vector<int>>("paddings"));
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void QuantDequantOpFuser::BuildPattern() {
const int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kQuantizedOpOutOffset = 2;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
std::string weight_name = "";
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") {
weight_name = "Filter";
} else {
weight_name = "Y";
}
auto* quant_op_input = VarNode("quant_op_input")
->assert_is_op_input(quant_type_, "X")
->AsInput();
auto* quant_op_in_scale = VarNode("quant_op_in_scale")
->assert_is_op_input(quant_type_, "InScale")
->AsIntermediate();
auto* quant_op = OpNode("quant_op", quant_type_)
->assert_is_op(quant_type_)
->AsIntermediate();
auto* quant_op_out_scale =
VarNode("quant_op_out_scale")
->assert_is_op_output(quant_type_, "OutScale")
->assert_is_op_input("fake_dequantize_max_abs", "Scale")
->AsIntermediate();
auto* quant_op_out = VarNode("quant_op_out")
->assert_is_op_output(quant_type_, "Out")
->assert_is_op_input(op_type_)
->AsIntermediate();
std::vector<PMNode*> nodes;
for (int i = 0; i < times_; i++) {
nodes.push_back(VarNode("quantized_op_weight" + std::to_string(i))
->assert_is_op_input(op_type_, weight_name)
->AsInput());
nodes.push_back(OpNode("quantized_op" + std::to_string(i), op_type_)
->assert_is_op(op_type_)
->AsIntermediate());
nodes.push_back(VarNode("quantized_op_out" + std::to_string(i))
->assert_is_op_output(op_type_)
->assert_is_op_input("fake_dequantize_max_abs", "X")
->AsIntermediate());
nodes.push_back(
OpNode("dequant_op" + std::to_string(i), "fake_dequantize_max_abs")
->assert_is_op("fake_dequantize_max_abs")
->AsIntermediate());
nodes.push_back(VarNode("dequant_op_out" + std::to_string(i))
->assert_is_op_output("fake_dequantize_max_abs", "Out")
->AsOutput());
}
quant_op->LinksFrom({quant_op_input, quant_op_in_scale});
quant_op_out->LinksFrom({quant_op});
quant_op_out_scale->LinksFrom({quant_op});
for (int i = 0; i < times_; i++) {
nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom(
{quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]});
nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOffset]});
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale});
nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kDequantOpOffset]});
}
}
void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
const int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kQuantizedOpOutOffset = 2;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
auto* quant_op_input = matched.at("quant_op_input");
auto* quant_op_in_scale = matched.at("quant_op_in_scale");
auto* quant_op = matched.at("quant_op");
auto* quant_op_out_scale = matched.at("quant_op_out_scale");
auto* quant_op_out = matched.at("quant_op_out");
std::vector<Node*> nodes;
for (int i = 0; i < times_; i++) {
nodes.push_back(matched.at("quantized_op_weight" + std::to_string(i)));
nodes.push_back(matched.at("quantized_op" + std::to_string(i)));
nodes.push_back(matched.at("quantized_op_out" + std::to_string(i)));
nodes.push_back(matched.at("dequant_op" + std::to_string(i)));
nodes.push_back(matched.at("dequant_op_out" + std::to_string(i)));
}
int bit_length = quant_op->stmt()->op_info()->GetAttr<int>("bit_length");
auto* scope = quant_op->stmt()->op->scope();
auto& valid_places = quant_op->stmt()->op->valid_places();
int range = ((1 << (bit_length - 1)) - 1);
auto input_scale_t = scope->FindVar(quant_op_in_scale->arg()->name)
->GetMutable<lite::Tensor>();
float input_scale = input_scale_t->data<float>()[0];
for (int i = 0; i < times_; i++) {
float max_range = nodes[i * kNumFields + kDequantOpOffset]
->stmt()
->op_info()
->GetAttr<float>("max_range");
float weight_scale = (range * range) / max_range;
cpp::OpDesc op_desc =
*nodes[i * kNumFields + kQuantizedOpOffset]->stmt()->op_info();
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") {
op_desc.SetInput("Input", {matched.at("quant_op_input")->arg()->name});
op_desc.SetOutput(
"Output", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name});
} else if (op_type_ == "mul") {
op_desc.SetInput("X", {matched.at("quant_op_input")->arg()->name});
op_desc.SetOutput(
"Out", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name});
}
op_desc.SetAttr("enable_int8", true);
op_desc.SetAttr("input_scale", input_scale);
auto quantized_weight_var_name =
nodes[i * kNumFields + kQuantizedWeightOffset]->arg()->name;
auto quantized_weight_t =
scope->FindVar(quantized_weight_var_name)->GetMutable<lite::Tensor>();
float* quantized_weight_data = quantized_weight_t->mutable_data<float>();
size_t weight_num = quantized_weight_t->data_size();
for (size_t i = 0; i < weight_num; i++) {
quantized_weight_data[i] *= (weight_scale / range);
}
auto quantized_op = LiteOpRegistry::Global().Create(op_type_);
quantized_op->Attach(op_desc, scope);
auto* new_op_node =
graph->GraphCreateInstructNode(quantized_op, valid_places);
IR_NODE_LINK_TO(quant_op_input, new_op_node);
IR_NODE_LINK_TO(nodes[i * kNumFields + kQuantizedWeightOffset],
new_op_node);
IR_NODE_LINK_TO(new_op_node, nodes[i * kNumFields + kDequantOpOutOffset]);
}
}
cpp::OpDesc QuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class QuantDequantOpFuser : public FuseBase {
public:
explicit QuantDequantOpFuser(const std::string& op_type,
const std::string& quant_type, int times)
: op_type_(op_type), quant_type_(quant_type), times_(times) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private:
std::string op_type_{"conv2d"};
std::string quant_type_;
int times_;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -22,6 +22,7 @@ namespace mir {} // namespace mir
} // namespace paddle
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#endif
USE_MIR_PASS(demo);
USE_MIR_PASS(static_kernel_pick_pass);
USE_MIR_PASS(variable_place_inference_pass);
......@@ -29,9 +30,9 @@ USE_MIR_PASS(type_target_transform_pass);
USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass);
#endif
USE_MIR_PASS(runtime_context_assign_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(graph_visualze);
USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass);
USE_MIR_PASS(lite_quant_dequant_fuse_pass);
......@@ -115,7 +115,6 @@ void PatternMatcher::operator()(SSAGraph *graph,
bool PatternMatcher::MarkPMNodesInGraph(SSAGraph *graph) {
VLOG(3) << "mark pmnodes in graph";
if (graph->nodes().empty()) return false;
for (auto &node : graph->mutable_nodes()) {
for (const auto &pmnode : pattern_.nodes()) {
if (pmnode->Tell(&node)) {
......@@ -398,7 +397,7 @@ PMNode *PMNode::assert_is_op_output(const std::string &op_type) {
asserts_.emplace_back([=](const Node *x) {
for (auto *op : x->inlinks) {
if (op && op->IsStmt()) {
auto *op_info = x->stmt()->op_info();
auto *op_info = op->stmt()->op_info();
if (op_info->Type() == op_type) return true;
}
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::unordered_set<std::string> quantized_op_types = {"conv2d", "mul",
"depthwise_conv2d"};
for (auto& quant_type : quant_types) {
for (auto& op_type : quantized_op_types) {
for (int i = 6; i >= 1; i--) {
fusion::QuantDequantOpFuser fuser(op_type, quant_type, i);
fuser(graph.get());
}
}
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_quant_dequant_fuse_pass,
paddle::lite::mir::QuantDequantFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_set>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class QuantDequantFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -48,19 +48,19 @@ class Optimizer {
if (passes.empty()) {
RunPasses(std::vector<std::string>{{
"lite_quant_dequant_fuse_pass", //
"lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_add_act_fuse_pass", //
"lite_fc_fuse_pass", //
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"static_kernel_pick_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"type_target_transform_pass", //
"argument_type_display_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"io_copy_kernel_pick_pass", //
"variable_place_inference_pass", //
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#endif
"runtime_context_assign_pass", //
}});
......
......@@ -55,8 +55,8 @@ enum class DataLayoutType : int {
#define DATALAYOUT(item__) paddle::lite::DataLayoutType::item__
static const std::string& TargetToStr(TargetType target) {
static const std::string target2string[] = {"unk", "host", "x86", "cuda",
"any"};
static const std::string target2string[] = {"unk", "host", "x86",
"cuda", "arm", "any"};
auto x = static_cast<int>(target);
CHECK_LT(x, static_cast<int>(TARGET(NUM)));
return target2string[x];
......
......@@ -165,8 +165,8 @@ class Type : public DataType {
// -------------------------------- compatible check ---------------------------
static bool TargetCompatibleTo(const Type& a, const Type& b) {
auto is_host = [](TargetType x) {
return x == TARGET(kHost) || x == TARGET(kX86);
auto is_host = [](TargetType x) -> bool {
return x == TARGET(kHost) || x == TARGET(kX86) || x == TARGET(kARM);
};
if (a.IsVoid() || b.IsVoid()) return true;
if (a.IsTensor() || b.IsTensor()) {
......
......@@ -100,7 +100,7 @@ void ConvCompute::Run() {
REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
// .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -108,7 +108,7 @@ REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW,
REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
// .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -21,6 +21,8 @@ cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framewo
cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS})
cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS})
# cc_library(split_op_lite SRCS split_op.cc DEPS ${op_DEPS})
cc_library(fake_quant SRCS fake_quantize_moving_avg_max_abs.cc DEPS ${op_DEPS})
cc_library(fake_dequant SRCS fake_dequantize_max_abs.cc DEPS ${op_DEPS})
set(ops_lite
conv_op_lite
......@@ -42,6 +44,8 @@ set(ops_lite
dropout_op_lite
concat_op_lite
#split_op_lite
fake_quant
fake_dequant
PARENT_SCOPE)
lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/fake_dequantize_max_abs.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(fake_dequantize_max_abs,
paddle::lite::operators::FakeDequantizeMaxAbsOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class FakeDequantizeMaxAbsOpLite : public OpLite {
public:
FakeDequantizeMaxAbsOpLite() {}
explicit FakeDequantizeMaxAbsOpLite(const std::string &type) : OpLite(type) {}
bool CheckShape() const override {}
bool InferShape() const override {}
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
auto in_scale = op_desc.Input("Scale").front();
auto out = op_desc.Output("Out").front();
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.in_scale = scope->FindVar(in_scale)->GetMutable<lite::Tensor>();
param_.out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.max_range = op_desc.GetAttr<float>("max_range");
return true;
}
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "fake_dequantize_max_abs"; }
private:
mutable FakeDequantizeMaxAbsParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(fake_quantize_moving_average_abs_max,
paddle::lite::operators::FakeQuantizeMovingAvgMaxAbsOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class FakeQuantizeMovingAvgMaxAbsOpLite : public OpLite {
public:
FakeQuantizeMovingAvgMaxAbsOpLite() {}
explicit FakeQuantizeMovingAvgMaxAbsOpLite(const std::string &type)
: OpLite(type) {}
bool CheckShape() const override {}
bool InferShape() const override {}
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
auto in_scale = op_desc.Input("InScale").front();
auto out = op_desc.Output("Out").front();
auto out_scale = op_desc.Output("OutScale").front();
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.in_scale = scope->FindVar(in_scale)->GetMutable<lite::Tensor>();
param_.out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.out_scale = scope->FindVar(out_scale)->GetMutable<lite::Tensor>();
param_.bit_length = op_desc.GetAttr<int>("bit_length");
return true;
}
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "fake_quantize_moving_avg_max_abs";
}
private:
mutable FakeQuantizeMovingAvgMaxAbsParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -256,6 +256,28 @@ struct FillConstantParam {
lite::Tensor* Out{};
};
//
struct FakeQuantizeMovingAvgMaxAbsParam {
const lite::Tensor* x{};
const lite::Tensor* in_scale{};
const lite::Tensor* in_accum{};
const lite::Tensor* in_state{};
lite::Tensor* out{};
lite::Tensor* out_scale{};
lite::Tensor* out_state{};
lite::Tensor* out_accum{};
int bit_length;
bool is_test{true};
float moving_rate{0.9};
};
struct FakeDequantizeMaxAbsParam {
const lite::Tensor* x{};
const lite::Tensor* in_scale{};
lite::Tensor* out{};
float max_range;
};
/// ----------------------- sgd operators ----------------------
struct SGDParam {
int dtype{framework::proto::VarType::FP32};
......
......@@ -39,7 +39,12 @@ bool SoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
if (opdesc.HasAttr("axis")) {
param_.axis = opdesc.GetAttr<int>("axis");
} else {
param_.axis = -1;
}
CHECK(param_.x);
CHECK(param_.output);
return true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册