提交 6fa52f83 编写于 作者: L luotao1

Merge branch 'develop' into fc_infershape

......@@ -68,7 +68,7 @@ paddle.fluid.initializer.MSRAInitializer.__init__ (ArgSpec(args=['self', 'unifor
paddle.fluid.initializer.force_init_on_cpu (ArgSpec(args=[], varargs=None, keywords=None, defaults=None), ('document', '6d0f3e22c90d9d500d36ff57daf056ee'))
paddle.fluid.initializer.init_on_cpu (ArgSpec(args=[], varargs=None, keywords=None, defaults=None), ('document', 'a6d7011ca3d8c0d454dac3a56eae0c29'))
paddle.fluid.initializer.NumpyArrayInitializer.__init__ (ArgSpec(args=['self', 'value'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.layers.fc (ArgSpec(args=['input', 'size', 'num_flatten_dims', 'param_attr', 'bias_attr', 'act', 'is_test', 'name'], varargs=None, keywords=None, defaults=(1, None, None, None, False, None)), ('document', '1929058262994f212620599c63aea6bd'))
paddle.fluid.layers.fc (ArgSpec(args=['input', 'size', 'num_flatten_dims', 'param_attr', 'bias_attr', 'act', 'is_test', 'name'], varargs=None, keywords=None, defaults=(1, None, None, None, False, None)), ('document', '424e898365195e3ccbc2e7dc8b63605e'))
paddle.fluid.layers.embedding (ArgSpec(args=['input', 'size', 'is_sparse', 'is_distributed', 'padding_idx', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, False, None, None, 'float32')), ('document', '89c2c55a0b0656b106064048e068e77a'))
paddle.fluid.layers.dynamic_lstm (ArgSpec(args=['input', 'size', 'h_0', 'c_0', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'float32', None)), ('document', 'dfbb624f85015df29e994ca6999e8ff6'))
paddle.fluid.layers.dynamic_lstmp (ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name', 'h_0', 'c_0', 'cell_clip', 'proj_clip'], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32', None, None, None, None, None)), ('document', 'b4b608b986eb9617aa0525e1be21d32d'))
......@@ -330,7 +330,8 @@ paddle.fluid.layers.generate_mask_labels (ArgSpec(args=['im_info', 'gt_classes',
paddle.fluid.layers.iou_similarity (ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '587845f60c5d97ffdf2dfd21da52eca1'))
paddle.fluid.layers.box_coder (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name', 'axis'], varargs=None, keywords=None, defaults=('encode_center_size', True, None, 0)), ('document', '032d0f4b7d8f6235ee5d91e473344f0e'))
paddle.fluid.layers.polygon_box_transform (ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '0e5ac2507723a0b5adec473f9556799b'))
paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '991e934c3e09abf0edec7c9c978b4691'))
paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'gtscore', 'use_label_smooth', 'name'], varargs=None, keywords=None, defaults=(None, True, None)), ('document', '57fa96922e42db8f064c3fb77f2255e8'))
paddle.fluid.layers.yolo_box (ArgSpec(args=['x', 'img_size', 'anchors', 'class_num', 'conf_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '5566169a5ab993d177792c023c7fb340'))
paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '397e9e02b451d99c56e20f268fa03f2e'))
paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0'))
paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d'))
......@@ -367,7 +368,7 @@ paddle.fluid.contrib.BeamSearchDecoder.read_array (ArgSpec(args=['self', 'init',
paddle.fluid.contrib.BeamSearchDecoder.update_array (ArgSpec(args=['self', 'array', 'value'], varargs=None, keywords=None, defaults=None), ('document', '5754e9b3212b7c09497151516a0de5a7'))
paddle.fluid.contrib.memory_usage (ArgSpec(args=['program', 'batch_size'], varargs=None, keywords=None, defaults=None), ('document', '8fcb2f93bb743693baa8d4860a5ccc47'))
paddle.fluid.contrib.op_freq_statistic (ArgSpec(args=['program'], varargs=None, keywords=None, defaults=None), ('document', '4d43687113c4bf5b29d15aee2f4e4afa'))
paddle.fluid.contrib.QuantizeTranspiler.__init__ (ArgSpec(args=['self', 'weight_bits', 'activation_bits', 'activation_quantize_type', 'weight_quantize_type', 'window_size'], varargs=None, keywords=None, defaults=(8, 8, 'abs_max', 'abs_max', 10000)), ('document', '14b39f1fcd5667ff556b1aad94357d1d'))
paddle.fluid.contrib.QuantizeTranspiler.__init__ (ArgSpec(args=['self', 'weight_bits', 'activation_bits', 'activation_quantize_type', 'weight_quantize_type', 'window_size', 'moving_rate'], varargs=None, keywords=None, defaults=(8, 8, 'abs_max', 'abs_max', 10000, 0.9)), ('document', '14b39f1fcd5667ff556b1aad94357d1d'))
paddle.fluid.contrib.QuantizeTranspiler.convert_to_int8 (ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.QuantizeTranspiler.freeze_program (ArgSpec(args=['self', 'program', 'place', 'fuse_bn', 'scope'], varargs=None, keywords=None, defaults=(False, None)), ('document', '909675a1ab055c69b436a7893fcae4fd'))
paddle.fluid.contrib.QuantizeTranspiler.training_transpile (ArgSpec(args=['self', 'program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None)), ('document', '6dd9909f10b283ba2892a99058a72884'))
......
......@@ -46,6 +46,7 @@ cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base)
pass_library(cpu_quantize_pass inference)
pass_library(cpu_quantize_squash_pass inference)
pass_library(fc_fuse_pass inference)
pass_library(attention_lstm_fuse_pass inference)
......@@ -102,8 +103,11 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto)
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass)
cc_test(test_cpu_quantize_pass SRCS cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
if(NOT WIN32)
cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass)
endif()
if (WITH_MKLDNN)
cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/cpu_quantize_pass.h"
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
void UnlinkNodes(ir::Node* a, ir::Node* b) {
a->outputs.erase(std::remove(a->outputs.begin(), a->outputs.end(), b),
a->outputs.end());
b->inputs.erase(std::remove(b->inputs.begin(), b->inputs.end(), a),
b->inputs.end());
}
} // namespace
enum { U8_MAX = 255, S8_MAX = 127 };
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<double, Eigen::Dynamic, 1>>;
using string::PrettyLogDetail;
void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input,
std::string input_name, double scale_to_one,
bool is_unsigned,
std::string scale_attr_name) const {
unsigned max = is_unsigned ? U8_MAX : S8_MAX;
float scale = scale_to_one * max;
// Create quantize output variable
VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc);
// create a quantize op node
OpDesc q_desc;
q_desc.SetType("quantize");
q_desc.SetInput("Input", std::vector<std::string>({input->Name()}));
q_desc.SetOutput("Output",
std::vector<std::string>({quantize_out_node->Name()}));
q_desc.SetAttr("Scale", scale);
q_desc.SetAttr("is_negative_input", !is_unsigned);
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.
// update op's input
op->Op()->SetInput(input_name,
std::vector<std::string>({quantize_out_node->Name()}));
// link quantize op
UnlinkNodes(input, op);
IR_NODE_LINK_TO(input, quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_node);
IR_NODE_LINK_TO(quantize_out_node, op);
if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
}
void CPUQuantizePass::DequantizeOutput(Graph* g, Node* op, Node* output,
std::string output_name,
double scale_to_one, bool is_unsigned,
std::string scale_attr_name) const {
unsigned max = is_unsigned ? U8_MAX : S8_MAX;
float scale = scale_to_one * max;
// Create dequantize input variable
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);
// create a dequantize op node for output.
OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node->Name()}));
deq_desc.SetOutput("Output", std::vector<std::string>({output->Name()}));
deq_desc.SetAttr("Scale", scale);
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.
// update op's output
op->Op()->SetOutput(output_name,
std::vector<std::string>({dequantize_in_node->Name()}));
// link dequantize op
UnlinkNodes(op, output);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, output);
if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
}
void CPUQuantizePass::QuantizeConv(Graph* graph,
bool with_residual_data) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::ConvResidual conv_pattern{pattern, name_scope_};
conv_pattern(with_residual_data);
int quantize_conv_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Quantize conv2d op";
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
auto* conv_op_desc = conv_op->Op();
// skip if should not be quantized
if (!conv_op_desc->HasAttr("use_quantizer") ||
!boost::get<bool>(conv_op_desc->GetAttr("use_quantizer")))
return;
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
// get scales calculated after warmup, they scale variables to MAX=1.0
auto scales = Get<VarQuantScale>("quant_var_scales");
auto input_scale = scales[conv_input->Name()].second.data<double>()[0];
bool is_input_unsigned = scales[conv_input->Name()].first;
QuantizeInput(g, conv_op, conv_input, "Input", input_scale,
is_input_unsigned, "Scale_in");
auto filter_scale_tensor = scales[conv_filter->Name()].second;
EigenVectorArrayMap eigen_tensor{filter_scale_tensor.data<double>(),
filter_scale_tensor.numel(), 1};
eigen_tensor *= static_cast<double>(S8_MAX);
std::vector<float> filter_scale{
filter_scale_tensor.data<double>(),
filter_scale_tensor.data<double>() + filter_scale_tensor.numel()};
conv_op->Op()->SetAttr("Scale_weights", filter_scale);
if (with_residual_data) {
GET_IR_NODE_FROM_SUBGRAPH(conv_residual_data, conv_residual_data,
conv_pattern);
auto residual_scale =
scales[conv_residual_data->Name()].second.data<double>()[0];
bool is_residual_unsigned = scales[conv_residual_data->Name()].first;
QuantizeInput(g, conv_op, conv_residual_data, "ResidualData",
residual_scale, is_residual_unsigned, "Scale_in_eltwise");
}
auto output_scale = scales[conv_output->Name()].second.data<double>()[0];
bool is_output_unsigned = scales[conv_output->Name()].first;
DequantizeOutput(g, conv_op, conv_output, "Output", output_scale,
is_output_unsigned, "Scale_out");
++quantize_conv_count;
};
gpd(graph, handler);
AddStatis(quantize_conv_count);
std::stringstream msg_ss;
msg_ss << "--- quantized " << quantize_conv_count << " conv2d ops";
if (with_residual_data) msg_ss << " with residual connection";
PrettyLogDetail(msg_ss.str().c_str());
}
void CPUQuantizePass::QuantizePool(Graph* graph) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Pool pool_pattern{pattern, name_scope_};
pool_pattern();
int quantize_pool_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Quantize pool2d op";
GET_IR_NODE_FROM_SUBGRAPH(pool_op, pool_op, pool_pattern);
auto* pool_op_desc = pool_op->Op();
// skip if should not be quantized
if (!pool_op_desc->HasAttr("use_quantizer") ||
!boost::get<bool>(pool_op_desc->GetAttr("use_quantizer")))
return;
GET_IR_NODE_FROM_SUBGRAPH(pool_input, pool_input, pool_pattern);
GET_IR_NODE_FROM_SUBGRAPH(pool_output, pool_output, pool_pattern);
// get scales calculated after warmup, they scale variables to MAX=1.0
auto scales = Get<VarQuantScale>("quant_var_scales");
auto input_scale = scales[pool_input->Name()].second.data<double>()[0];
bool is_input_unsigned = scales[pool_input->Name()].first;
QuantizeInput(g, pool_op, pool_input, "X", input_scale, is_input_unsigned);
auto output_scale = scales[pool_output->Name()].second.data<double>()[0];
bool is_output_unsigned = scales[pool_output->Name()].first;
DequantizeOutput(g, pool_op, pool_output, "Out", output_scale,
is_output_unsigned);
++quantize_pool_count;
};
gpd(graph, handler);
AddStatis(quantize_pool_count);
PrettyLogDetail("--- quantized %d pool2d ops", quantize_pool_count);
}
std::unique_ptr<ir::Graph> CPUQuantizePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Quantizing the graph.";
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
PADDLE_ENFORCE(param_scope());
QuantizeConv(graph.get(), true /* with_residual_data */);
QuantizeConv(graph.get());
QuantizePool(graph.get());
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(cpu_quantize_pass, paddle::framework::ir::CPUQuantizePass)
.RequirePassAttr("quant_var_scales");
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Map variable name to tensor of scaling factors scaling it to MAX=1.0.
* bool denotes whether quantization of the variable should be done to unsigned
* type.
*/
using VarQuantScale =
std::unordered_map<std::string, std::pair<bool, LoDTensor>>;
/*
* Quantize all supported operators.
*/
class CPUQuantizePass : public FusePassBase {
public:
virtual ~CPUQuantizePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void QuantizeConv(Graph* graph, bool with_residual_data = false) const;
void QuantizePool(Graph* graph) const;
void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name,
double scale_to_one, bool is_unsigned,
std::string scale_attr_name = "") const;
void DequantizeOutput(Graph* g, Node* op, Node* output,
std::string output_name, double scale_to_one,
bool is_unsigned,
std::string scale_attr_name = "") const;
const std::string name_scope_{"quantize"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/cpu_quantize_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn,
bool use_quantizer = false) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name);
if (type == "conv2d") {
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
if (inputs.size() > 2)
op->SetInput("Bias", {inputs[2]});
else
op->SetInput("Bias", {});
if (inputs.size() > 3) {
op->SetInput("ResidualData", {inputs[3]});
op->SetAttr("fuse_residual_connection", true);
} else {
op->SetInput("ResidualData", {});
op->SetAttr("fuse_residual_connection", false);
}
op->SetOutput("Output", {outputs[0]});
op->SetAttr("use_quantizer", use_quantizer);
op->SetAttr("Scale_in", 1.0f);
op->SetAttr("Scale_out", 1.0f);
op->SetAttr("Scale_weights", std::vector<float>{1.0f});
} else if (type == "pool2d") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("use_quantizer", use_quantizer);
} else if (type == "dropout") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
} else if (type == "fc") {
op->SetInput("Input", {inputs[0]});
if (inputs.size() > 1) op->SetInput("W", {inputs[1]});
if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]});
op->SetOutput("Out", {outputs[0]});
}
}
static const std::initializer_list<std::string> variable_names{
"a", "w1", "c", "d", "w2", "e", "f", "g",
"h", "w3", "b1", "i", "j", "w4", "b2"};
// (a,w1)->Conv1->c and c->Pool1->d
//
// (d,w2)->Conv2->e and e->Pool2->f
//
// d->Dropout1->g and g->Fc1->h and (h,w3,b1,i)->Conv3->j
//
// (d,w4, b2)->Conv4->i
ProgramDesc BuildProgramDesc(bool use_mkldnn, bool use_quantizer) {
ProgramDesc prog;
for (auto& v : variable_names) {
auto* var = prog.MutableBlock(0)->Var(v);
if (v.find("w") == 0 || v.find("b") == 0) {
var->SetPersistable(true);
}
}
SetOp(&prog, "conv2d", "Conv1", {"a", "w1"}, {"c"}, use_mkldnn,
use_quantizer);
SetOp(&prog, "pool2d", "Pool1", {"c"}, {"d"}, use_mkldnn, use_quantizer);
SetOp(&prog, "conv2d", "Conv2", {"d", "w2"}, {"e"}, use_mkldnn,
use_quantizer);
SetOp(&prog, "pool2d", "Pool2", {"e"}, {"f"}, use_mkldnn, use_quantizer);
SetOp(&prog, "dropout", "Dropout1", {"d"}, {"g"}, use_mkldnn);
SetOp(&prog, "fc", "Fc1", {"g"}, {"h"}, use_mkldnn);
SetOp(&prog, "conv2d", "Conv3", {"h", "w3", "b1", "i"}, {"j"}, use_mkldnn,
use_quantizer);
SetOp(&prog, "conv2d", "Conv4", {"c", "w4", "b2"}, {"i"}, use_mkldnn,
use_quantizer);
return prog;
}
void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
const char* var_name) {
auto x = scope->Var(var_name);
auto tensor = x->GetMutable<LoDTensor>();
tensor->mutable_data(place, proto::VarType::FP32,
::paddle::memory::Allocator::kDefault, 1);
}
void MainTest(const ProgramDesc& prog, int conv_count, int pool_count,
int quant_count, int dequant_count, int added_nodes_count,
float scale) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
// Init scope, as it is used in pass
auto place = paddle::platform::CPUPlace();
NaiveExecutor exe{place};
Scope scope;
exe.CreateVariables(prog, 0, true, &scope);
auto* scales = new VarQuantScale();
for (auto& v : variable_names) {
InitTensorHolder(&scope, place, v.c_str());
LoDTensor tensor;
tensor.Resize({1});
auto* ptr = tensor.mutable_data<double>(place);
ptr[0] = 2.0;
(*scales)[v] = std::make_pair(false, std::move(tensor));
}
graph->Set(kParamScopeAttr, new framework::Scope*(&scope));
auto pass = PassRegistry::Instance().Get("cpu_quantize_pass");
pass->Set("quant_var_scales", scales);
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
int quantize_nodes_count = 0;
int dequantize_nodes_count = 0;
int conv2d_nodes_count = 0;
int pool2d_nodes_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->Type() == "conv2d") {
conv2d_nodes_count++;
auto op_name = boost::get<std::string>(op->GetAttr("name"));
EXPECT_EQ(boost::get<float>(op->GetAttr("Scale_in")), scale)
<< "Scale_in for node '" + op_name + "'.";
EXPECT_EQ(boost::get<float>(op->GetAttr("Scale_out")), scale)
<< "Scale_out for node '" + op_name + "'.";
EXPECT_EQ(
boost::get<std::vector<float>>(op->GetAttr("Scale_weights"))[0],
scale)
<< "Scale_weights for node '" + op_name + "'.";
} else if (op->Type() == "pool2d") {
pool2d_nodes_count++;
} else if (op->Type() == "quantize") {
quantize_nodes_count++;
} else if (op->Type() == "dequantize") {
dequantize_nodes_count++;
}
}
}
EXPECT_EQ(conv2d_nodes_count, conv_count);
EXPECT_EQ(pool2d_nodes_count, pool_count);
EXPECT_EQ(quantize_nodes_count, quant_count);
EXPECT_EQ(dequantize_nodes_count, dequant_count);
EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num);
}
TEST(CpuQuantizePass, quantize) {
bool use_mkldnn = true;
bool use_quantizer = true;
// (a->QUANT1->IN1,w1)->Conv1->OUT1->DEQUANT1->c and
// c->QUANT2->IN2->Pool1->OUT2->DEQUANT2->d
//
// (d->QUANT3->IN3,w2)->Conv2->OUT3->DEQUANT3->e and
// e->QUANT4->IN4->Pool2->OUT4->DEQUANT4->f
//
// d->Dropout1->g and g->Fc1->h and
// (h->QUANT5->IN5,w3,b1,i->QUANT6->IN6)->Conv3->OUT5->DEQUANT5->j
//
// (d->QUANT7->IN7,w4, b2)->Conv4->DEQUANT6->OUT6->i
// Insert nodes: 7 Quant + 7 IN + 6 OUT + 6 DEQUANT
int added_nodes = 7 + 7 + 6 + 6;
MainTest(BuildProgramDesc(use_mkldnn, use_quantizer), 4, 2, 7, 6, added_nodes,
2.0f * 127);
}
TEST(CpuQuantizePass, do_not_quantize) {
bool use_mkldnn = true;
bool use_quantizer = false;
int added_nodes = 0;
MainTest(BuildProgramDesc(use_mkldnn, use_quantizer), 4, 2, 0, 0, added_nodes,
1.0f);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(cpu_quantize_pass);
......@@ -90,7 +90,8 @@ void GraphPatternDetector::operator()(Graph *graph,
ValidateByNodeRole(&subgraphs);
if (subgraphs.empty()) return;
PrettyLogEndl(Style::detail(), "--- detect %d subgraphs", subgraphs.size());
PrettyLogEndl(Style::detail(), "--- detected %d subgraphs",
subgraphs.size());
int id = 0;
for (auto &g : subgraphs) {
VLOG(3) << "optimizing #" << id++ << " subgraph";
......@@ -1074,9 +1075,53 @@ PDNode *patterns::Conv::operator()() {
->AsOutput()
->assert_is_op_output("conv2d", "Output");
conv_op->LinksFrom({input_var, filter_var});
conv_op->LinksTo({output_var});
conv_op->LinksFrom({input_var, filter_var}).LinksTo({output_var});
return output_var;
}
PDNode *patterns::ConvResidual::operator()(bool with_residual_data) {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
if (!with_residual_data)
conv_op->assert_op_attr("fuse_residual_connection", false);
auto input_var = pattern->NewNode(conv_input_repr())
->AsInput()
->assert_is_op_input("conv2d", "Input");
auto filter_var = pattern->NewNode(conv_filter_repr())
->AsInput()
->assert_is_op_input("conv2d", "Filter");
auto output_var = pattern->NewNode(conv_output_repr())
->AsOutput()
->assert_is_op_output("conv2d", "Output");
std::vector<PDNode *> links_from{input_var, filter_var};
if (with_residual_data) {
auto res_conn_var = pattern->NewNode(conv_residual_data_repr())
->AsInput()
->assert_is_op_input("conv2d", "ResidualData");
links_from.push_back(res_conn_var);
}
conv_op->LinksFrom(links_from).LinksTo({output_var});
return output_var;
}
PDNode *patterns::Pool::operator()() {
auto pool_op = pattern->NewNode(pool_op_repr())->assert_is_op("pool2d");
auto input_var = pattern->NewNode(pool_input_repr())
->AsInput()
->assert_is_op_input("pool2d", "X");
auto output_var = pattern->NewNode(pool_output_repr())
->AsOutput()
->assert_is_op_output("pool2d", "Out");
pool_op->LinksFrom({input_var}).LinksTo({output_var});
return output_var;
}
......
......@@ -659,6 +659,35 @@ struct Conv : public PatternBase {
PATTERN_DECL_NODE(conv_output);
};
// Convolution op with residual data
struct ConvResidual : public PatternBase {
ConvResidual(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_residual") {}
PDNode* operator()(bool with_residual_data);
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_input);
PATTERN_DECL_NODE(conv_filter);
PATTERN_DECL_NODE(conv_residual_data);
PATTERN_DECL_NODE(conv_output);
};
// Pool op
// Forward pass for pooling.
// pool_input is the input.
// pool_output is a result of the operator.
struct Pool : public PatternBase {
Pool(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "pooling") {}
PDNode* operator()();
PATTERN_DECL_NODE(pool_op);
PATTERN_DECL_NODE(pool_input);
PATTERN_DECL_NODE(pool_output);
};
// ElementwiseAdd used in residual connections.
// y_var is used and convolution output.
// The operator is removed, when residual
......
......@@ -27,6 +27,7 @@
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
......@@ -38,7 +39,10 @@
namespace paddle {
namespace inference {
namespace analysis {
using framework::ir::Graph;
using VarQuantScale =
std::unordered_map<std::string, std::pair<bool, framework::LoDTensor>>;
/*
* The argument definition of both Pass and PassManagers.
......@@ -127,6 +131,8 @@ struct Argument {
// Pass a set of op types to enable its mkldnn kernel
DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types, MKLDNNEnabledOpTypes,
std::unordered_set<std::string>);
// Scales for variables to be quantized
DECL_ARGUMENT_FIELD(quant_var_scales, QuantVarScales, VarQuantScale);
// Passed from config.
DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
......@@ -55,14 +56,14 @@ void IRPassManager::CreatePasses(Argument *argument,
".dot";
pass->Set("graph_viz_path", new std::string(std::move(dot_file_path)));
pass_num++;
}
if (pass_name == "mkldnn_placement_pass") {
} else if (pass_name == "mkldnn_placement_pass") {
pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(
argument->mkldnn_enabled_op_types()));
}
if (pass_name == "tensorrt_subgraph_pass") {
} else if (pass_name == "cpu_quantize_pass") {
pass->Set("quant_var_scales",
new VarQuantScale(argument->quant_var_scales()));
} else if (pass_name == "tensorrt_subgraph_pass") {
pass->Set("workspace_size", new int(argument->tensorrt_workspace_size()));
pass->Set("max_batch_size", new int(argument->tensorrt_max_batch_size()));
pass->Set("min_subgraph_size",
......
......@@ -219,7 +219,14 @@ void AnalysisConfig::Update() {
}
if (enable_memory_optim_) {
pass_builder()->AppendAnalysisPass("memory_optimize_pass");
auto analysis_passes = pass_builder()->AnalysisPasses();
auto memory_opti_pass_name = "memory_optimize_pass";
bool already_exists =
std::find(analysis_passes.begin(), analysis_passes.end(),
memory_opti_pass_name) != analysis_passes.end();
if (!already_exists) {
pass_builder()->AppendAnalysisPass(memory_opti_pass_name);
}
}
if (ir_debug_) {
......
......@@ -58,8 +58,10 @@ if (WITH_GPU)
op_library(conv_fusion_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n")
endif()
op_library(sync_batch_norm_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n")
if (NOT WIN32)
op_library(sync_batch_norm_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n")
endif()
else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_op.h"
#include <memory>
#include <string>
#include <vector>
......@@ -194,6 +195,12 @@ void Conv2DOpMaker::Make() {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("use_quantizer",
"(bool, default false) "
"Set to true for operators that should be quantized and use "
"int8 kernel. "
"Only used on CPU.")
.SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("fuse_residual_connection",
......
......@@ -33,6 +33,7 @@ detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc)
detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu)
detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc)
detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op.cu)
detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu)
if(WITH_GPU)
......
......@@ -60,14 +60,15 @@ class BoxCoderOp : public framework::OperatorWithKernel {
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
PADDLE_ENFORCE_EQ(target_box_dims.size(), 3,
"The rank of Input TargetBox must be 3");
if (axis == 0) {
PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]);
} else if (axis == 1) {
PADDLE_ENFORCE_EQ(target_box_dims[0], prior_box_dims[0]);
} else {
PADDLE_THROW("axis must be 0 or 1.");
PADDLE_ENFORCE(axis == 0 || axis == 1, "axis must be 0 or 1");
if (ctx->IsRuntime()) {
if (axis == 0) {
PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]);
} else if (axis == 1) {
PADDLE_ENFORCE_EQ(target_box_dims[0], prior_box_dims[0]);
}
PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]);
}
PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]);
ctx->ShareDim("TargetBox", /*->*/ "OutputBox");
}
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/yolo_box_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class YoloBoxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of YoloBoxOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ImgSize"),
"Input(ImgSize) of YoloBoxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Boxes"),
"Output(Boxes) of YoloBoxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Scores"),
"Output(Scores) of YoloBoxOp should not be null.");
auto dim_x = ctx->GetInputDim("X");
auto dim_imgsize = ctx->GetInputDim("ImgSize");
auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors");
int anchor_num = anchors.size() / 2;
auto class_num = ctx->Attrs().Get<int>("class_num");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor.");
PADDLE_ENFORCE_EQ(
dim_x[1], anchor_num * (5 + class_num),
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
"+ class_num)).");
PADDLE_ENFORCE_EQ(dim_imgsize.size(), 2,
"Input(ImgSize) should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(
dim_imgsize[0], dim_x[0],
"Input(ImgSize) dim[0] and Input(X) dim[0] should be same.");
PADDLE_ENFORCE_EQ(dim_imgsize[1], 2, "Input(ImgSize) dim[1] should be 2.");
PADDLE_ENFORCE_GT(anchors.size(), 0,
"Attr(anchors) length should be greater than 0.");
PADDLE_ENFORCE_EQ(anchors.size() % 2, 0,
"Attr(anchors) length should be even integer.");
PADDLE_ENFORCE_GT(class_num, 0,
"Attr(class_num) should be an integer greater than 0.");
int box_num = dim_x[2] * dim_x[3] * anchor_num;
std::vector<int64_t> dim_boxes({dim_x[0], box_num, 4});
ctx->SetOutputDim("Boxes", framework::make_ddim(dim_boxes));
std::vector<int64_t> dim_scores({dim_x[0], box_num, class_num});
ctx->SetOutputDim("Scores", framework::make_ddim(dim_scores));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
}
};
class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input tensor of YoloBox operator is a 4-D tensor with "
"shape of [N, C, H, W]. The second dimension(C) stores "
"box locations, confidence score and classification one-hot "
"keys of each anchor box. Generally, X should be the output "
"of YOLOv3 network.");
AddInput("ImgSize",
"The image size tensor of YoloBox operator, "
"This is a 2-D tensor with shape of [N, 2]. This tensor holds "
"height and width of each input image used for resizing output "
"box in input image scale.");
AddOutput("Boxes",
"The output tensor of detection boxes of YoloBox operator, "
"This is a 3-D tensor with shape of [N, M, 4], N is the "
"batch num, M is output box number, and the 3rd dimension "
"stores [xmin, ymin, xmax, ymax] coordinates of boxes.");
AddOutput("Scores",
"The output tensor of detection boxes scores of YoloBox "
"operator, This is a 3-D tensor with shape of "
"[N, M, :attr:`class_num`], N is the batch num, M is "
"output box number.");
AddAttr<int>("class_num", "The number of classes to predict.");
AddAttr<std::vector<int>>("anchors",
"The anchor width and height, "
"it will be parsed pair by pair.")
.SetDefault(std::vector<int>{});
AddAttr<int>("downsample_ratio",
"The downsample ratio from network input to YoloBox operator "
"input, so 32, 16, 8 should be set for the first, second, "
"and thrid YoloBox operators.")
.SetDefault(32);
AddAttr<float>("conf_thresh",
"The confidence scores threshold of detection boxes. "
"Boxes with confidence scores under threshold should "
"be ignored.")
.SetDefault(0.01);
AddComment(R"DOC(
This operator generates YOLO detection boxes from output of YOLOv3 network.
The output of previous network is in shape [N, C, H, W], while H and W
should be the same, H and W specify the grid size, each grid point predict
given number boxes, this given number, which following will be represented as S,
is specified by the number of anchors. In the second dimension(the channel
dimension), C should be equal to S * (5 + class_num), class_num is the object
category number of source dataset(such as 80 in coco dataset), so the
second(channel) dimension, apart from 4 box location coordinates x, y, w, h,
also includes confidence score of the box and class one-hot key of each anchor
box.
Assume the 4 location coordinates are :math:`t_x, t_y, t_w, t_h`, the box
predictions should be as follows:
$$
b_x = \\sigma(t_x) + c_x
$$
$$
b_y = \\sigma(t_y) + c_y
$$
$$
b_w = p_w e^{t_w}
$$
$$
b_h = p_h e^{t_h}
$$
in the equation above, :math:`c_x, c_y` is the left top corner of current grid
and :math:`p_w, p_h` is specified by anchors.
The logistic regression value of the 5th channel of each anchor prediction boxes
represents the confidence score of each prediction box, and the logistic
regression value of the last :attr:`class_num` channels of each anchor prediction
boxes represents the classifcation scores. Boxes with confidence scores less than
:attr:`conf_thresh` should be ignored, and box final scores is the product of
confidence scores and classification scores.
$$
score_{pred} = score_{conf} * score_{class}
$$
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(yolo_box, ops::YoloBoxOp, ops::YoloBoxOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(yolo_box, ops::YoloBoxKernel<float>,
ops::YoloBoxKernel<double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/yolo_box_op.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
T* scores, const float conf_thresh,
const int* anchors, const int n, const int h,
const int w, const int an_num, const int class_num,
const int box_num, int input_size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
T box[4];
for (; tid < n * box_num; tid += stride) {
int grid_num = h * w;
int i = tid / box_num;
int j = (tid % box_num) / grid_num;
int k = (tid % grid_num) / w;
int l = tid % w;
int an_stride = (5 + class_num) * grid_num;
int img_height = imgsize[2 * i];
int img_width = imgsize[2 * i + 1];
int obj_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4);
T conf = sigmoid<T>(input[obj_idx]);
if (conf < conf_thresh) {
continue;
}
int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0);
GetYoloBox<T>(box, input, anchors, l, k, j, h, input_size, box_idx,
grid_num, img_height, img_width);
box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width);
int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5);
int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num;
CalcLabelScore<T>(scores, input, label_idx, score_idx, class_num, conf,
grid_num);
}
}
template <typename T>
class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* img_size = ctx.Input<Tensor>("ImgSize");
auto* boxes = ctx.Output<Tensor>("Boxes");
auto* scores = ctx.Output<Tensor>("Scores");
auto anchors = ctx.Attr<std::vector<int>>("anchors");
int class_num = ctx.Attr<int>("class_num");
float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
const int n = input->dims()[0];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int box_num = boxes->dims()[1];
const int an_num = anchors.size() / 2;
int input_size = downsample_ratio * h;
auto& dev_ctx = ctx.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = sizeof(int) * anchors.size();
auto anchors_ptr = allocator.Allocate(sizeof(int) * anchors.size());
int* anchors_data = reinterpret_cast<int*>(anchors_ptr->ptr());
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
const auto cplace = platform::CPUPlace();
memory::Copy(gplace, anchors_data, cplace, anchors.data(), bytes,
dev_ctx.stream());
const T* input_data = input->data<T>();
const int* imgsize_data = img_size->data<int>();
T* boxes_data = boxes->mutable_data<T>({n, box_num, 4}, ctx.GetPlace());
T* scores_data =
scores->mutable_data<T>({n, box_num, class_num}, ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
set_zero(dev_ctx, boxes, static_cast<T>(0));
set_zero(dev_ctx, scores, static_cast<T>(0));
int grid_dim = (n * box_num + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
anchors_data, n, h, w, an_num, class_num, box_num, input_size);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(yolo_box, ops::YoloBoxOpCUDAKernel<float>,
ops::YoloBoxOpCUDAKernel<double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
HOSTDEVICE inline T sigmoid(T x) {
return 1.0 / (1.0 + std::exp(-x));
}
template <typename T>
HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i,
int j, int an_idx, int grid_size,
int input_size, int index, int stride,
int img_height, int img_width) {
box[0] = (i + sigmoid<T>(x[index])) * img_width / grid_size;
box[1] = (j + sigmoid<T>(x[index + stride])) * img_height / grid_size;
box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width /
input_size;
box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] *
img_height / input_size;
}
HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
int an_num, int an_stride, int stride,
int entry) {
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx;
}
template <typename T>
HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, const int box_idx,
const int img_height,
const int img_width) {
boxes[box_idx] = box[0] - box[2] / 2;
boxes[box_idx + 1] = box[1] - box[3] / 2;
boxes[box_idx + 2] = box[0] + box[2] / 2;
boxes[box_idx + 3] = box[1] + box[3] / 2;
boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast<T>(0);
boxes[box_idx + 1] =
boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<T>(0);
boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1
? boxes[box_idx + 2]
: static_cast<T>(img_width - 1);
boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1
? boxes[box_idx + 3]
: static_cast<T>(img_height - 1);
}
template <typename T>
HOSTDEVICE inline void CalcLabelScore(T* scores, const T* input,
const int label_idx, const int score_idx,
const int class_num, const T conf,
const int stride) {
for (int i = 0; i < class_num; i++) {
scores[score_idx + i] = conf * sigmoid<T>(input[label_idx + i * stride]);
}
}
template <typename T>
class YoloBoxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* imgsize = ctx.Input<Tensor>("ImgSize");
auto* boxes = ctx.Output<Tensor>("Boxes");
auto* scores = ctx.Output<Tensor>("Scores");
auto anchors = ctx.Attr<std::vector<int>>("anchors");
int class_num = ctx.Attr<int>("class_num");
float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
const int n = input->dims()[0];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int box_num = boxes->dims()[1];
const int an_num = anchors.size() / 2;
int input_size = downsample_ratio * h;
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
Tensor anchors_;
auto anchors_data =
anchors_.mutable_data<int>({an_num * 2}, ctx.GetPlace());
std::copy(anchors.begin(), anchors.end(), anchors_data);
const T* input_data = input->data<T>();
const int* imgsize_data = imgsize->data<int>();
T* boxes_data = boxes->mutable_data<T>({n, box_num, 4}, ctx.GetPlace());
memset(boxes_data, 0, boxes->numel() * sizeof(T));
T* scores_data =
scores->mutable_data<T>({n, box_num, class_num}, ctx.GetPlace());
memset(scores_data, 0, scores->numel() * sizeof(T));
T box[4];
for (int i = 0; i < n; i++) {
int img_height = imgsize_data[2 * i];
int img_width = imgsize_data[2 * i + 1];
for (int j = 0; j < an_num; j++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
int obj_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 4);
T conf = sigmoid<T>(input_data[obj_idx]);
if (conf < conf_thresh) {
continue;
}
int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0);
GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, input_size,
box_idx, stride, img_height, img_width);
box_idx = (i * box_num + j * stride + k * w + l) * 4;
CalcDetectionBox<T>(boxes_data, box, box_idx, img_height,
img_width);
int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5);
int score_idx = (i * box_num + j * stride + k * w + l) * class_num;
CalcLabelScore<T>(scores_data, input_data, label_idx, score_idx,
class_num, conf, stride);
}
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -10,6 +10,7 @@
limitations under the License. */
#include "paddle/fluid/operators/detection/yolov3_loss_op.h"
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
......@@ -72,6 +73,18 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GT(class_num, 0,
"Attr(class_num) should be an integer greater then 0.");
if (ctx->HasInput("GTScore")) {
auto dim_gtscore = ctx->GetInputDim("GTScore");
PADDLE_ENFORCE_EQ(dim_gtscore.size(), 2,
"Input(GTScore) should be a 2-D tensor");
PADDLE_ENFORCE_EQ(
dim_gtscore[0], dim_gtbox[0],
"Input(GTBox) and Input(GTScore) dim[0] should be same");
PADDLE_ENFORCE_EQ(
dim_gtscore[1], dim_gtbox[1],
"Input(GTBox) and Input(GTScore) dim[1] should be same");
}
std::vector<int64_t> dim_out({dim_x[0]});
ctx->SetOutputDim("Loss", framework::make_ddim(dim_out));
......@@ -112,6 +125,12 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
"This is a 2-D tensor with shape of [N, max_box_num], "
"and each element should be an integer to indicate the "
"box class id.");
AddInput("GTScore",
"The score of GTLabel, This is a 2-D tensor in same shape "
"GTLabel, and score values should in range (0, 1). This "
"input is for GTLabel score can be not 1.0 in image mixup "
"augmentation.")
.AsDispensable();
AddOutput("Loss",
"The output yolov3 loss tensor, "
"This is a 1-D tensor with shape of [N]");
......@@ -143,6 +162,9 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("ignore_thresh",
"The ignore threshold to ignore confidence loss.")
.SetDefault(0.7);
AddAttr<bool>("use_label_smooth",
"Whether to use label smooth. Default True.")
.SetDefault(true);
AddComment(R"DOC(
This operator generates yolov3 loss based on given predict result and ground
truth boxes.
......@@ -204,6 +226,15 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
loss = (loss_{xy} + loss_{wh}) * weight_{box}
+ loss_{conf} + loss_{class}
$$
While :attr:`use_label_smooth` is set to be :attr:`True`, the classification
target will be smoothed when calculating classification loss, target of
positive samples will be smoothed to :math:`1.0 - 1.0 / class\_num` and target of
negetive samples will be smoothed to :math:`1.0 / class\_num`.
While :attr:`GTScore` is given, which means the mixup score of ground truth
boxes, all losses incured by a ground truth box will be multiplied by its
mixup score.
)DOC");
}
};
......@@ -240,6 +271,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker {
op->SetInput("X", Input("X"));
op->SetInput("GTBox", Input("GTBox"));
op->SetInput("GTLabel", Input("GTLabel"));
op->SetInput("GTScore", Input("GTScore"));
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
op->SetInput("ObjectnessMask", Output("ObjectnessMask"));
op->SetInput("GTMatchMask", Output("GTMatchMask"));
......@@ -249,6 +281,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker {
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("GTBox"), {});
op->SetOutput(framework::GradVarName("GTLabel"), {});
op->SetOutput(framework::GradVarName("GTScore"), {});
return std::unique_ptr<framework::OpDesc>(op);
}
};
......
......@@ -37,8 +37,8 @@ static T SigmoidCrossEntropy(T x, T label) {
}
template <typename T>
static T L2Loss(T x, T y) {
return 0.5 * (y - x) * (y - x);
static T L1Loss(T x, T y) {
return std::abs(y - x);
}
template <typename T>
......@@ -47,8 +47,8 @@ static T SigmoidCrossEntropyGrad(T x, T label) {
}
template <typename T>
static T L2LossGrad(T x, T y) {
return x - y;
static T L1LossGrad(T x, T y) {
return x > y ? 1.0 : -1.0;
}
static int GetMaskIndex(std::vector<int> mask, int val) {
......@@ -121,47 +121,49 @@ template <typename T>
static void CalcBoxLocationLoss(T* loss, const T* input, Box<T> gt,
std::vector<int> anchors, int an_idx,
int box_idx, int gi, int gj, int grid_size,
int input_size, int stride) {
int input_size, int stride, T score) {
T tx = gt.x * grid_size - gi;
T ty = gt.y * grid_size - gj;
T tw = std::log(gt.w * input_size / anchors[2 * an_idx]);
T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]);
T scale = (2.0 - gt.w * gt.h);
T scale = (2.0 - gt.w * gt.h) * score;
loss[0] += SigmoidCrossEntropy<T>(input[box_idx], tx) * scale;
loss[0] += SigmoidCrossEntropy<T>(input[box_idx + stride], ty) * scale;
loss[0] += L2Loss<T>(input[box_idx + 2 * stride], tw) * scale;
loss[0] += L2Loss<T>(input[box_idx + 3 * stride], th) * scale;
loss[0] += L1Loss<T>(input[box_idx + 2 * stride], tw) * scale;
loss[0] += L1Loss<T>(input[box_idx + 3 * stride], th) * scale;
}
template <typename T>
static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input,
Box<T> gt, std::vector<int> anchors,
int an_idx, int box_idx, int gi, int gj,
int grid_size, int input_size, int stride) {
int grid_size, int input_size, int stride,
T score) {
T tx = gt.x * grid_size - gi;
T ty = gt.y * grid_size - gj;
T tw = std::log(gt.w * input_size / anchors[2 * an_idx]);
T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]);
T scale = (2.0 - gt.w * gt.h);
T scale = (2.0 - gt.w * gt.h) * score;
input_grad[box_idx] =
SigmoidCrossEntropyGrad<T>(input[box_idx], tx) * scale * loss;
input_grad[box_idx + stride] =
SigmoidCrossEntropyGrad<T>(input[box_idx + stride], ty) * scale * loss;
input_grad[box_idx + 2 * stride] =
L2LossGrad<T>(input[box_idx + 2 * stride], tw) * scale * loss;
L1LossGrad<T>(input[box_idx + 2 * stride], tw) * scale * loss;
input_grad[box_idx + 3 * stride] =
L2LossGrad<T>(input[box_idx + 3 * stride], th) * scale * loss;
L1LossGrad<T>(input[box_idx + 3 * stride], th) * scale * loss;
}
template <typename T>
static inline void CalcLabelLoss(T* loss, const T* input, const int index,
const int label, const int class_num,
const int stride) {
const int stride, const T pos, const T neg,
T score) {
for (int i = 0; i < class_num; i++) {
T pred = input[index + i * stride];
loss[0] += SigmoidCrossEntropy<T>(pred, (i == label) ? 1.0 : 0.0);
loss[0] += SigmoidCrossEntropy<T>(pred, (i == label) ? pos : neg) * score;
}
}
......@@ -169,11 +171,13 @@ template <typename T>
static inline void CalcLabelLossGrad(T* input_grad, const T loss,
const T* input, const int index,
const int label, const int class_num,
const int stride) {
const int stride, const T pos, const T neg,
T score) {
for (int i = 0; i < class_num; i++) {
T pred = input[index + i * stride];
input_grad[index + i * stride] =
SigmoidCrossEntropyGrad<T>(pred, (i == label) ? 1.0 : 0.0) * loss;
SigmoidCrossEntropyGrad<T>(pred, (i == label) ? pos : neg) * score *
loss;
}
}
......@@ -188,8 +192,8 @@ static inline void CalcObjnessLoss(T* loss, const T* input, const T* objness,
for (int l = 0; l < w; l++) {
T obj = objness[k * w + l];
if (obj > 1e-5) {
// positive sample: obj = 1
loss[i] += SigmoidCrossEntropy<T>(input[k * w + l], 1.0);
// positive sample: obj = mixup score
loss[i] += SigmoidCrossEntropy<T>(input[k * w + l], 1.0) * obj;
} else if (obj > -0.5) {
// negetive sample: obj = 0
loss[i] += SigmoidCrossEntropy<T>(input[k * w + l], 0.0);
......@@ -215,7 +219,8 @@ static inline void CalcObjnessLossGrad(T* input_grad, const T* loss,
T obj = objness[k * w + l];
if (obj > 1e-5) {
input_grad[k * w + l] =
SigmoidCrossEntropyGrad<T>(input[k * w + l], 1.0) * loss[i];
SigmoidCrossEntropyGrad<T>(input[k * w + l], 1.0) * obj *
loss[i];
} else if (obj > -0.5) {
input_grad[k * w + l] =
SigmoidCrossEntropyGrad<T>(input[k * w + l], 0.0) * loss[i];
......@@ -252,6 +257,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("X");
auto* gt_box = ctx.Input<Tensor>("GTBox");
auto* gt_label = ctx.Input<Tensor>("GTLabel");
auto* gt_score = ctx.Input<Tensor>("GTScore");
auto* loss = ctx.Output<Tensor>("Loss");
auto* objness_mask = ctx.Output<Tensor>("ObjectnessMask");
auto* gt_match_mask = ctx.Output<Tensor>("GTMatchMask");
......@@ -260,6 +266,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
int class_num = ctx.Attr<int>("class_num");
float ignore_thresh = ctx.Attr<float>("ignore_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool use_label_smooth = ctx.Attr<bool>("use_label_smooth");
const int n = input->dims()[0];
const int h = input->dims()[2];
......@@ -272,6 +279,13 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
T label_pos = 1.0;
T label_neg = 0.0;
if (use_label_smooth) {
label_pos = 1.0 - 1.0 / static_cast<T>(class_num);
label_neg = 1.0 / static_cast<T>(class_num);
}
const T* input_data = input->data<T>();
const T* gt_box_data = gt_box->data<T>();
const int* gt_label_data = gt_label->data<int>();
......@@ -283,6 +297,19 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
int* gt_match_mask_data =
gt_match_mask->mutable_data<int>({n, b}, ctx.GetPlace());
const T* gt_score_data;
if (!gt_score) {
Tensor gtscore;
gtscore.mutable_data<T>({n, b}, ctx.GetPlace());
math::SetConstant<platform::CPUDeviceContext, T>()(
ctx.template device_context<platform::CPUDeviceContext>(), &gtscore,
static_cast<T>(1.0));
gt_score = &gtscore;
gt_score_data = gtscore.data<T>();
} else {
gt_score_data = gt_score->data<T>();
}
// calc valid gt box mask, avoid calc duplicately in following code
Tensor gt_valid_mask;
bool* gt_valid_mask_data =
......@@ -355,19 +382,20 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
int mask_idx = GetMaskIndex(anchor_mask, best_n);
gt_match_mask_data[i * b + t] = mask_idx;
if (mask_idx >= 0) {
T score = gt_score_data[i * b + t];
int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 0);
CalcBoxLocationLoss<T>(loss_data + i, input_data, gt, anchors, best_n,
box_idx, gi, gj, h, input_size, stride);
box_idx, gi, gj, h, input_size, stride, score);
int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi;
obj_mask_data[obj_idx] = 1.0;
obj_mask_data[obj_idx] = score;
int label = gt_label_data[i * b + t];
int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 5);
CalcLabelLoss<T>(loss_data + i, input_data, label_idx, label,
class_num, stride);
class_num, stride, label_pos, label_neg, score);
}
}
}
......@@ -384,6 +412,7 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("X");
auto* gt_box = ctx.Input<Tensor>("GTBox");
auto* gt_label = ctx.Input<Tensor>("GTLabel");
auto* gt_score = ctx.Input<Tensor>("GTScore");
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
auto* objness_mask = ctx.Input<Tensor>("ObjectnessMask");
......@@ -392,6 +421,7 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
auto anchor_mask = ctx.Attr<std::vector<int>>("anchor_mask");
int class_num = ctx.Attr<int>("class_num");
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool use_label_smooth = ctx.Attr<bool>("use_label_smooth");
const int n = input_grad->dims()[0];
const int c = input_grad->dims()[1];
......@@ -404,6 +434,13 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
T label_pos = 1.0;
T label_neg = 0.0;
if (use_label_smooth) {
label_pos = 1.0 - 1.0 / static_cast<T>(class_num);
label_neg = 1.0 / static_cast<T>(class_num);
}
const T* input_data = input->data<T>();
const T* gt_box_data = gt_box->data<T>();
const int* gt_label_data = gt_label->data<int>();
......@@ -414,25 +451,41 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
input_grad->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
memset(input_grad_data, 0, input_grad->numel() * sizeof(T));
const T* gt_score_data;
if (!gt_score) {
Tensor gtscore;
gtscore.mutable_data<T>({n, b}, ctx.GetPlace());
math::SetConstant<platform::CPUDeviceContext, T>()(
ctx.template device_context<platform::CPUDeviceContext>(), &gtscore,
static_cast<T>(1.0));
gt_score = &gtscore;
gt_score_data = gtscore.data<T>();
} else {
gt_score_data = gt_score->data<T>();
}
for (int i = 0; i < n; i++) {
for (int t = 0; t < b; t++) {
int mask_idx = gt_match_mask_data[i * b + t];
if (mask_idx >= 0) {
T score = gt_score_data[i * b + t];
Box<T> gt = GetGtBox(gt_box_data, i, b, t);
int gi = static_cast<int>(gt.x * w);
int gj = static_cast<int>(gt.y * h);
int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 0);
CalcBoxLocationLossGrad<T>(
input_grad_data, loss_grad_data[i], input_data, gt, anchors,
anchor_mask[mask_idx], box_idx, gi, gj, h, input_size, stride);
CalcBoxLocationLossGrad<T>(input_grad_data, loss_grad_data[i],
input_data, gt, anchors,
anchor_mask[mask_idx], box_idx, gi, gj, h,
input_size, stride, score);
int label = gt_label_data[i * b + t];
int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 5);
CalcLabelLossGrad<T>(input_grad_data, loss_grad_data[i], input_data,
label_idx, label, class_num, stride);
label_idx, label, class_num, stride, label_pos,
label_neg, score);
}
}
}
......
......@@ -81,6 +81,30 @@ struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
template struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, float>;
template <typename T>
struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in_accum,
const framework::Tensor& in_state, const T* cur_scale,
const float rate, framework::Tensor* out_state,
framework::Tensor* out_accum, framework::Tensor* out_scale) {
T accum = in_accum.data<T>()[0];
T state = in_state.data<T>()[0];
T scale = cur_scale[0];
state = rate * state + 1;
accum = rate * accum + scale;
scale = accum / state;
out_state->mutable_data<T>(ctx.GetPlace())[0] = state;
out_accum->mutable_data<T>(ctx.GetPlace())[0] = accum;
out_scale->mutable_data<T>(ctx.GetPlace())[0] = scale;
}
};
template struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext,
float>;
class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
public:
FakeQuantizeAbsMaxOp(const std::string& type,
......@@ -255,6 +279,78 @@ $$Out = round(X/scale * range)$$
}
};
class FakeQuantizeMovingAverageAbsMaxOp : public framework::OperatorWithKernel {
public:
FakeQuantizeMovingAverageAbsMaxOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("X"),
"Input(X) of FakeQuantizeMovingAverageAbsMaxOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FakeQuantizeMovingAverageAbsMaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutScale"),
"Output(OutScale) of FakeQuantizeMovingAverageAbsMaxOp "
"should not be null");
if (ctx->HasOutput("OutState")) {
ctx->SetOutputDim("OutState", {1});
}
if (ctx->HasOutput("OutAccum")) {
ctx->SetOutputDim("OutAccum", {1});
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {1});
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
class FakeQuantizeMovingAverageAbsMaxOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input is float data type.");
AddInput("InScale", "Last scale.");
AddInput("InAccum", "Last accum.").AsDispensable();
AddInput("InState", "Last state.").AsDispensable();
AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
AddOutput("OutScale", " Current scale");
AddOutput("OutState", "(Tensor) state buffer.").AsDispensable();
AddOutput("OutAccum", "(Tensor) accum buffer.").AsDispensable();
AddAttr<float>("moving_rate", "(float, default 0.9) moving rate.")
.SetDefault(0.9);
AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
.SetDefault(8)
.AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
"'bit_length' should be between 1 and 16.");
});
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddComment(R"DOC(
FakeQuantize operator is used in static quantization.
$$scale = (0.9*max(abs(x))+accum)/(0.9*state+1)$$
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$
)DOC");
}
};
} // namespace operators
} // namespace paddle
......@@ -273,6 +369,12 @@ REGISTER_OPERATOR(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp,
REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max,
ops::FakeQuantizeRangeAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR(fake_quantize_moving_average_abs_max,
ops::FakeQuantizeMovingAverageAbsMaxOp,
ops::FakeQuantizeMovingAverageAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxOp,
ops::FakeChannelWiseQuantizeAbsMaxOpMaker,
......
......@@ -147,6 +147,41 @@ struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> {
template struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, float>;
template <typename T>
struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& in_accum,
const framework::Tensor& in_state, const T* cur_scale,
const float rate, framework::Tensor* out_state,
framework::Tensor* out_accum, framework::Tensor* out_scale) {
const auto gpu_place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
T accum;
memory::Copy(platform::CPUPlace(), &accum, gpu_place, in_accum.data<T>(),
sizeof(T), 0);
T state;
memory::Copy(platform::CPUPlace(), &state, gpu_place, in_state.data<T>(),
sizeof(T), 0);
T scale;
memory::Copy(platform::CPUPlace(), &scale, gpu_place, cur_scale, sizeof(T),
0);
state = rate * state + 1;
accum = rate * accum + scale;
scale = accum / state;
memory::Copy(gpu_place, out_accum->mutable_data<T>(gpu_place),
platform::CPUPlace(), &accum, sizeof(T), 0);
memory::Copy(gpu_place, out_state->mutable_data<T>(gpu_place),
platform::CPUPlace(), &state, sizeof(T), 0);
memory::Copy(gpu_place, out_scale->mutable_data<T>(gpu_place),
platform::CPUPlace(), &scale, sizeof(T), 0);
}
};
template struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext,
float>;
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
......@@ -178,3 +213,6 @@ REGISTER_OP_CUDA_KERNEL(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max,
ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(
fake_quantize_moving_average_abs_max,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float>);
......@@ -42,12 +42,20 @@ struct FindRangeAbsMaxFunctor {
framework::Tensor* scales_arr, framework::Tensor* out_scale);
};
template <typename DeviceContext, typename T>
struct FindMovingAverageAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in_accum,
const framework::Tensor& in_state,
const framework::Tensor& cur_scale,
framework::Tensor* out_state, framework::Tensor* out_accum,
framework::Tensor* out_scale);
};
template <typename DeviceContext, typename T>
class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto* out_scale = context.Output<framework::Tensor>("OutScale");
T* out_s = out_scale->mutable_data<T>(context.GetPlace());
......@@ -138,5 +146,54 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class FakeQuantizeMovingAverageAbsMaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* in_scale = context.Input<framework::Tensor>("InScale");
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
bool is_test = context.Attr<bool>("is_test");
int bit_length = context.Attr<int>("bit_length");
int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto& dev_ctx = context.template device_context<DeviceContext>();
// testing
if (is_test) {
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *in_scale,
bin_cnt, out);
return;
}
// training
auto* in_accum = context.Input<framework::Tensor>("InAccum");
auto* in_state = context.Input<framework::Tensor>("InState");
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
auto cur_scale = allocator.Allocate(1 * sizeof(T));
T* cur_scale_data = static_cast<T*>(cur_scale->ptr());
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), in->numel(),
cur_scale_data);
auto* out_state = context.Output<framework::Tensor>("OutState");
auto* out_accum = context.Output<framework::Tensor>("OutAccum");
auto* out_scale = context.Output<framework::Tensor>("OutScale");
out_state->mutable_data<T>(context.GetPlace());
out_accum->mutable_data<T>(context.GetPlace());
out_scale->mutable_data<T>(context.GetPlace());
float moving_rate = context.Attr<float>("moving_rate");
FindMovingAverageAbsMaxFunctor<DeviceContext, T>()(
dev_ctx, *in_accum, *in_state, cur_scale_data, moving_rate, out_state,
out_accum, out_scale);
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *out_scale,
bin_cnt, out);
}
};
} // namespace operators
} // namespace paddle
......@@ -592,6 +592,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::SetDstMemoryHandler<uint8_t>(ctx, output, handler,
&dst_memory_p);
} else {
need_s8_to_u8 = fuse_relu;
platform::SetDstMemoryHandler<int8_t>(ctx, output, handler,
&dst_memory_p);
}
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/pool_op.h"
#include <unordered_map>
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
......@@ -212,6 +213,12 @@ void Pool2dOpMaker::Make() {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("use_quantizer",
"(bool, default false) "
"Set to true for operators that should be quantized and use "
"int8 kernel. "
"Only used on CPU.")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
......
......@@ -439,7 +439,8 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
Tensor* logit_grad =
context.Output<Tensor>(framework::GradVarName("Logits"));
logit_grad->ShareDataWith(*context.Input<Tensor>("Softmax"));
framework::TensorCopy(*context.Input<Tensor>("Softmax"), context.GetPlace(),
context.device_context(), logit_grad);
T* logit_grad_data = logit_grad->data<T>();
const int batch_size = logit_grad->dims()[0];
......
......@@ -94,6 +94,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
}
};
// TODO(paddle-dev): Should use OpKernel.
class SqueezeOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
......
......@@ -316,7 +316,9 @@ CUDADeviceContext::~CUDADeviceContext() {
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
#if !defined(_WIN32)
PADDLE_ENFORCE(dynload::ncclCommDestroy(nccl_comm_));
#endif
}
Place CUDADeviceContext::GetPlace() const { return place_; }
......
......@@ -265,11 +265,13 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cuda stream in the device context. */
cudaStream_t stream() const;
#if !defined(_WIN32)
/*! \brief Return nccl communicators. */
ncclComm_t nccl_comm() const { return nccl_comm_; }
/*! \brief Set nccl communicators. */
void set_nccl_comm(ncclComm_t comm) { nccl_comm_ = comm; }
#endif
template <typename Callback>
void RecordEvent(cudaEvent_t ev, Callback callback) {
......@@ -295,12 +297,14 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<CublasHandleHolder> cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
#if !defined(_WIN32)
// NCCL communicator (single process version) for NCCL collective operations.
// NCCL collective operations provides fast collectives over multiple GPUs
// both within and across nodes.
// But, this collectives is used for collectives over multiple GPUs within
// nodes.
ncclComm_t nccl_comm_{nullptr};
#endif
int compute_capability_;
int runtime_version_;
......
......@@ -84,7 +84,8 @@ class QuantizeTranspiler(object):
activation_bits=8,
activation_quantize_type='abs_max',
weight_quantize_type='abs_max',
window_size=10000):
window_size=10000,
moving_rate=0.9):
"""
Convert and rewrite the fluid Program according to weight and
activation quantization type.
......@@ -117,23 +118,27 @@ class QuantizeTranspiler(object):
"""
self.weight_bits = weight_bits
self.activation_bits = activation_bits
quant_type = ['abs_max', 'range_abs_max']
quant_type = ['abs_max', 'range_abs_max', 'moving_average_abs_max']
if weight_quantize_type not in quant_type:
raise ValueError(
"Unknown weight_quantize_type: '%s'. It can only be ",
"'abs_max' or 'range_abs_max'.", str(weight_quantize_type))
"'abs_max' or 'range_abs_max' or 'moving_average_abs_max'.",
str(weight_quantize_type))
if activation_quantize_type not in quant_type:
raise ValueError(
"Unknown activation_quantize_type : '%s'. It can only be ",
"'abs_max' or 'range_abs_max'.", str(activation_quantize_type))
"'abs_max' or 'range_abs_max' or 'moving_average_abs_max'.",
str(activation_quantize_type))
self.weight_quantize_type = weight_quantize_type
self.activation_quantize_type = activation_quantize_type
self.window_size = window_size
self.moving_rate = moving_rate
self.helper = LayerHelper(self.__class__.__name__)
self.fake_quant_op_types = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max'
'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
'fake_quantize_moving_average_abs_max'
]
self.fake_dequant_op_types = ['fake_dequantize_max_abs']
self.is_test = None
......@@ -168,6 +173,7 @@ class QuantizeTranspiler(object):
block_id = block.idx
# insert quant op and dequant op
for name in op.input_arg_names:
#if share input between ops
if name in dequanted_vars[block_id]:
dequant_var = dequanted_vars[block_id][name]
else:
......@@ -261,6 +267,7 @@ class QuantizeTranspiler(object):
max_range = None
scale_var = None
for name in op.input_arg_names:
#rename input name of the op to the input name of last op which has be removed
if name in op_in_rename_map[block_id]:
op._rename_input(name, op_in_rename_map[block_id][name])
......@@ -272,8 +279,7 @@ class QuantizeTranspiler(object):
max_range = param_range * act_range / scale_v
else:
assert isinstance(scale_v, Variable)
scale_var = var_scale_map[block_id][_original_var_name(
name)]
scale_var = scale_v
if len(op.output_arg_names) != 1:
raise ValueError("Only support one output, but op %s has"
......@@ -309,7 +315,7 @@ class QuantizeTranspiler(object):
op_type = op.type
# insert dequant_op after fc/conv, need to rename
# input of the followed ops
# input of the followed ops(of fc/conv) to the dquant_op
for name in op.input_arg_names:
if name in op_out_rename_map[block_id]:
op._rename_input(name,
......@@ -389,8 +395,8 @@ class QuantizeTranspiler(object):
for op in block.ops:
args += op.input_arg_names
args += op.output_arg_names
args = list(set(args))
var_names = block.vars.keys()
args = list(set(args)) #vals of all left ops
var_names = block.vars.keys() # all vals
sub_block_remove_vars = []
for var in var_names:
if var not in args:
......@@ -471,6 +477,61 @@ class QuantizeTranspiler(object):
return quant_var, scale
def _insert_quant_moving_average_abs_max_op(self, block, idx, var,
quant_bits):
"""Insert fake_quantize_moving_average_abs_max
"""
quant_var = block.create_var(
name=_quantized_var_name(var.name),
type=var.type,
shape=var.shape,
dtype=var.dtype)
state = self.helper.create_global_variable(
name=unique_name.generate('state'),
persistable=True,
dtype=var.dtype,
shape=[1])
self.helper.set_variable_initializer(
state, initializer=Constant(value=1))
accum = self.helper.create_global_variable(
name=unique_name.generate('accum'),
persistable=True,
dtype=var.dtype,
shape=[1])
self.helper.set_variable_initializer(
accum, initializer=Constant(value=1))
scale = self.helper.create_parameter(
attr=ParamAttr(
name=_quantized_scale_name(var.name),
initializer=Constant(0.001),
trainable=False),
shape=[1],
dtype=var.dtype)
scale.stop_gradient = True
ins = {'X': var, 'InScale': scale}
outs = {'Out': quant_var, 'OutScale': scale}
if not self.is_test:
ins['InState'] = state
ins['InAccum'] = accum
outs['OutState'] = state
outs['OutAccum'] = accum
attrs = {
'bit_length': quant_bits,
'moving_rate': self.moving_rate,
'is_test': self.is_test
}
quant_op = block._insert_op(
idx,
type='fake_quantize_moving_average_abs_max',
attrs=attrs,
inputs=ins,
outputs=outs)
return quant_var, scale
def _insert_quant_op(self, block, idx, var, quant_bits, quant_type):
"""
Insert fake_quantize_op
......@@ -480,6 +541,9 @@ class QuantizeTranspiler(object):
elif quant_type == 'range_abs_max':
return self._insert_quant_range_abs_max_op(block, idx, var,
quant_bits)
elif quant_type == 'moving_average_abs_max':
return self._insert_quant_moving_average_abs_max_op(block, idx, var,
quant_bits)
def _insert_dequant_op(self, block, idx, var, scale, quant_bits):
"""
......
......@@ -38,7 +38,8 @@ class QuantizationTransformPass(object):
activation_bits=8,
activation_quantize_type='abs_max',
weight_quantize_type='abs_max',
window_size=10000):
window_size=10000,
moving_rate=0.9):
"""
Convert and rewrite the IrGraph according to weight and
activation quantization type.
......@@ -83,19 +84,22 @@ class QuantizationTransformPass(object):
self._weight_bits = weight_bits
self._activation_bits = activation_bits
quant_type = ['abs_max', 'range_abs_max']
quant_type = ['abs_max', 'range_abs_max', 'moving_average_abs_max']
if activation_quantize_type not in quant_type:
raise ValueError(
"Unknown activation_quantize_type : '%s'. It can only be ",
"'abs_max' or 'range_abs_max'.", str(activation_quantize_type))
"'abs_max' or 'range_abs_max' or 'moving_average_abs_max'.",
str(activation_quantize_type))
if weight_quantize_type not in quant_type:
raise ValueError(
"Unknown weight_quantize_type: '%s'. It can only be ",
"'abs_max' or 'range_abs_max'.", str(weight_quantize_type))
"'abs_max' or 'range_abs_max' or 'moving_average_abs_max'.",
str(weight_quantize_type))
self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type
self._window_size = window_size
self._moving_rate = moving_rate
self._need_initialized = collections.OrderedDict()
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
......@@ -222,6 +226,9 @@ class QuantizationTransformPass(object):
elif quant_type == 'range_abs_max':
return self._insert_quant_range_abs_max_op(graph, var_node,
quant_bits)
elif quant_type == 'moving_average_abs_max':
return self._insert_quant_moving_average_abs_max_op(graph, var_node,
quant_bits)
def _insert_quant_abs_max_op(self, graph, var_node, quant_bits):
"""
......@@ -309,6 +316,74 @@ class QuantizationTransformPass(object):
return quant_var_node, scale_out_node
def _insert_quant_moving_average_abs_max_op(self, graph, var_node,
quant_bits):
"""Insert fake_quantize_moving_average_abs_max
"""
quant_var_node = graph.create_var_node(
name=self._quantized_var_name(var_node.name()),
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=var_node.dtype())
scale_in_node = graph.create_persistable_node(
name=self._quantized_scale_name(var_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=var_node.dtype())
self._need_initialized[scale_in_node.var()] = Constant(value=0.001)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
ins = {'X': var_node, 'InScale': scale_in_node}
outs = {'Out': quant_var_node, 'OutScale': scale_out_node}
if not self._is_test:
state_in_node = graph.create_persistable_node(
name=unique_name.generate('state'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(),
shape=[1])
self._need_initialized[state_in_node.var()] = Constant(value=1)
accum_in_node = graph.create_persistable_node(
name=unique_name.generate('accum'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(),
shape=[1])
self._need_initialized[accum_in_node.var()] = Constant(value=1)
state_out_node = graph.create_var_node_from_desc(state_in_node.var(
))
accum_out_node = graph.create_var_node_from_desc(accum_in_node.var(
))
ins['InState'] = state_in_node
ins['InAccum'] = accum_in_node
outs['OutState'] = state_out_node
outs['OutAccum'] = accum_out_node
attrs = {
'bit_length': quant_bits,
'moving_rate': self._moving_rate,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
}
quant_op_node = graph.create_op_node(
op_type='fake_quantize_moving_average_abs_max',
attrs=attrs,
inputs=ins,
outputs=outs)
graph.link_to(var_node, quant_op_node)
graph.link_to(scale_in_node, quant_op_node)
graph.link_to(quant_op_node, quant_var_node)
graph.link_to(quant_op_node, scale_out_node)
if not self._is_test:
graph.link_to(state_in_node, quant_op_node)
graph.link_to(accum_in_node, quant_op_node)
graph.link_to(quant_op_node, state_out_node)
graph.link_to(quant_op_node, accum_out_node)
return quant_var_node, scale_out_node
def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits):
"""
Insert fake_dequantize_op in the graph.
......@@ -389,7 +464,8 @@ class QuantizationFreezePass(object):
self._weight_quantize_type = weight_quantize_type
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self._fake_quant_op_names = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max'
'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
'fake_quantize_moving_average_abs_max'
]
self._fake_dequant_op_names = ['fake_dequantize_max_abs']
self._op_input_rename_map = collections.OrderedDict()
......
......@@ -164,6 +164,9 @@ class TestQuantizationTransformPass(unittest.TestCase):
def test_linear_fc_quant_range_abs_max(self):
self.linear_fc_quant('range_abs_max', for_ci=True)
def test_linear_fc_quant_moving_average_abs_max(self):
self.linear_fc_quant('moving_average_abs_max', for_ci=True)
def residual_block_quant(self, quant_type, for_ci=False):
main = fluid.Program()
startup = fluid.Program()
......@@ -201,6 +204,9 @@ class TestQuantizationTransformPass(unittest.TestCase):
def test_residual_block_range_abs_max(self):
self.residual_block_quant('range_abs_max', for_ci=True)
def test_residual_block_moving_average_abs_max(self):
self.residual_block_quant('moving_average_abs_max', for_ci=True)
class TestQuantizationFreezePass(unittest.TestCase):
def freeze_graph(self, use_cuda, seed, quant_type, for_ci=False):
......@@ -380,11 +386,18 @@ class TestQuantizationFreezePass(unittest.TestCase):
with fluid.unique_name.guard():
self.freeze_graph(
True, seed=1, quant_type='range_abs_max', for_ci=True)
self.freeze_graph(
True,
seed=1,
quant_type='moving_average_abs_max',
for_ci=True)
def test_freeze_graph_cpu_static(self):
with fluid.unique_name.guard():
self.freeze_graph(
False, seed=2, quant_type='range_abs_max', for_ci=True)
self.freeze_graph(
False, seed=2, quant_type='moving_average_abs_max', for_ci=True)
if __name__ == '__main__':
......
......@@ -430,6 +430,11 @@ class Variable(object):
Returns:
str: The debug string.
"""
if _in_imperative_mode():
# TODO(panyx0718): add more imperative debug info.
return 'name %s, dtype: %s shape: %s' % (self.name, self.dtype,
self.shape)
assert isinstance(throw_on_error, bool) and isinstance(with_details,
bool)
protostr = self.desc.serialize_to_string()
......
......@@ -49,6 +49,7 @@ __all__ = [
'box_coder',
'polygon_box_transform',
'yolov3_loss',
'yolo_box',
'box_clip',
'multiclass_nms',
'distribute_fpn_proposals',
......@@ -515,6 +516,8 @@ def yolov3_loss(x,
class_num,
ignore_thresh,
downsample_ratio,
gtscore=None,
use_label_smooth=True,
name=None):
"""
${comment}
......@@ -533,28 +536,35 @@ def yolov3_loss(x,
class_num (int): ${class_num_comment}
ignore_thresh (float): ${ignore_thresh_comment}
downsample_ratio (int): ${downsample_ratio_comment}
name (string): the name of yolov3 loss
name (string): the name of yolov3 loss. Default None.
gtscore (Variable): mixup score of ground truth boxes, shoud be in shape
of [N, B]. Default None.
use_label_smooth (bool): ${use_label_smooth_comment}
Returns:
Variable: A 1-D tensor with shape [1], the value of yolov3 loss
Variable: A 1-D tensor with shape [N], the value of yolov3 loss
Raises:
TypeError: Input x of yolov3_loss must be Variable
TypeError: Input gtbox of yolov3_loss must be Variable"
TypeError: Input gtlabel of yolov3_loss must be Variable"
TypeError: Input gtbox of yolov3_loss must be Variable
TypeError: Input gtlabel of yolov3_loss must be Variable
TypeError: Input gtscore of yolov3_loss must be None or Variable
TypeError: Attr anchors of yolov3_loss must be list or tuple
TypeError: Attr class_num of yolov3_loss must be an integer
TypeError: Attr ignore_thresh of yolov3_loss must be a float number
TypeError: Attr use_label_smooth of yolov3_loss must be a bool value
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32')
gtbox = fluid.layers.data(name='gtbox', shape=[6, 5], dtype='float32')
gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32')
gtbox = fluid.layers.data(name='gtbox', shape=[6, 4], dtype='float32')
gtlabel = fluid.layers.data(name='gtlabel', shape=[6], dtype='int32')
gtscore = fluid.layers.data(name='gtscore', shape=[6], dtype='float32')
anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326]
anchor_mask = [0, 1, 2]
loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, gtlabel=gtlabel, anchors=anchors,
loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, gtlabel=gtlabel,
gtscore=gtscore, anchors=anchors,
anchor_mask=anchor_mask, class_num=80,
ignore_thresh=0.7, downsample_ratio=32)
"""
......@@ -566,6 +576,8 @@ def yolov3_loss(x,
raise TypeError("Input gtbox of yolov3_loss must be Variable")
if not isinstance(gtlabel, Variable):
raise TypeError("Input gtlabel of yolov3_loss must be Variable")
if gtscore is not None and not isinstance(gtscore, Variable):
raise TypeError("Input gtscore of yolov3_loss must be Variable")
if not isinstance(anchors, list) and not isinstance(anchors, tuple):
raise TypeError("Attr anchors of yolov3_loss must be list or tuple")
if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple):
......@@ -575,6 +587,9 @@ def yolov3_loss(x,
if not isinstance(ignore_thresh, float):
raise TypeError(
"Attr ignore_thresh of yolov3_loss must be a float number")
if not isinstance(use_label_smooth, bool):
raise TypeError(
"Attr use_label_smooth of yolov3_loss must be a bool value")
if name is None:
loss = helper.create_variable_for_type_inference(dtype=x.dtype)
......@@ -585,21 +600,26 @@ def yolov3_loss(x,
objectness_mask = helper.create_variable_for_type_inference(dtype='int32')
gt_match_mask = helper.create_variable_for_type_inference(dtype='int32')
inputs = {
"X": x,
"GTBox": gtbox,
"GTLabel": gtlabel,
}
if gtscore:
inputs["GTScore"] = gtscore
attrs = {
"anchors": anchors,
"anchor_mask": anchor_mask,
"class_num": class_num,
"ignore_thresh": ignore_thresh,
"downsample_ratio": downsample_ratio,
"use_label_smooth": use_label_smooth,
}
helper.append_op(
type='yolov3_loss',
inputs={
"X": x,
"GTBox": gtbox,
"GTLabel": gtlabel,
},
inputs=inputs,
outputs={
'Loss': loss,
'ObjectnessMask': objectness_mask,
......@@ -609,6 +629,83 @@ def yolov3_loss(x,
return loss
@templatedoc(op_type="yolo_box")
def yolo_box(x,
img_size,
anchors,
class_num,
conf_thresh,
downsample_ratio,
name=None):
"""
${comment}
Args:
x (Variable): ${x_comment}
img_size (Variable): ${img_size_comment}
anchors (list|tuple): ${anchors_comment}
class_num (int): ${class_num_comment}
conf_thresh (float): ${conf_thresh_comment}
downsample_ratio (int): ${downsample_ratio_comment}
name (string): the name of yolo box layer. Default None.
Returns:
Variable: A 3-D tensor with shape [N, M, 4], the coordinates of boxes,
and a 3-D tensor with shape [N, M, :attr:`class_num`], the classification
scores of boxes.
Raises:
TypeError: Input x of yolov_box must be Variable
TypeError: Attr anchors of yolo box must be list or tuple
TypeError: Attr class_num of yolo box must be an integer
TypeError: Attr conf_thresh of yolo box must be a float number
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32')
anchors = [10, 13, 16, 30, 33, 23]
loss = fluid.layers.yolo_box(x=x, class_num=80, anchors=anchors,
conf_thresh=0.01, downsample_ratio=32)
"""
helper = LayerHelper('yolo_box', **locals())
if not isinstance(x, Variable):
raise TypeError("Input x of yolo_box must be Variable")
if not isinstance(img_size, Variable):
raise TypeError("Input img_size of yolo_box must be Variable")
if not isinstance(anchors, list) and not isinstance(anchors, tuple):
raise TypeError("Attr anchors of yolo_box must be list or tuple")
if not isinstance(class_num, int):
raise TypeError("Attr class_num of yolo_box must be an integer")
if not isinstance(conf_thresh, float):
raise TypeError("Attr ignore_thresh of yolo_box must be a float number")
boxes = helper.create_variable_for_type_inference(dtype=x.dtype)
scores = helper.create_variable_for_type_inference(dtype=x.dtype)
attrs = {
"anchors": anchors,
"class_num": class_num,
"conf_thresh": conf_thresh,
"downsample_ratio": downsample_ratio,
}
helper.append_op(
type='yolo_box',
inputs={
"X": x,
"ImgSize": img_size,
},
outputs={
'Boxes': boxes,
'Scores': scores,
},
attrs=attrs)
return boxes, scores
@templatedoc()
def detection_map(detect_res,
label,
......
......@@ -23,7 +23,7 @@ import os
import inspect
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant, NumpyArrayInitializer
from ..framework import Variable, OpProtoHolder
from ..framework import Variable, OpProtoHolder, _in_imperative_mode
from ..param_attr import ParamAttr
from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_
from .tensor import concat, assign
......@@ -205,16 +205,23 @@ def fc(input,
**Fully Connected Layer**
This function creates a fully connected layer in the network. It can take
multiple tensors as its inputs. It creates a variable called weights for
each input tensor, which represents a fully connected weight matrix from
each input unit to each output unit. The fully connected layer multiplies
each input tensor with its coresponding weight to produce an output Tensor.
If multiple input tensors are given, the results of multiple multiplications
will be sumed up. If bias_attr is not None, a bias variable will be created
and added to the output. Finally, if activation is not None, it will be applied
to the output as well.
one or multiple tensors as its inputs(input can be a list of Variable, see
Args in detail). It creates a variable called weights for each input tensor,
which represents a fully connected weight matrix from each input unit to
each output unit. The fully connected layer multiplies each input tensor
with its corresponding weight to produce an output Tensor with shape [M, `size`],
where M is batch size. If multiple input tensors are given, the results of
multiple output tensors with shape [M, `size`] will be summed up. If bias_attr
is not None, a bias variable will be created and added to the output.
Finally, if activation is not None, it will be applied to the output as well.
When the input is single tensor:
This process can be formulated as follows:
.. math::
Out = Act({XW + b})
When the input are multiple tensors:
.. math::
......@@ -222,13 +229,31 @@ def fc(input,
In the above equation:
* :math:`N`: Number of the input.
* :math:`X_i`: The input tensor.
* :math:`W`: The weights created by this layer.
* :math:`N`: Number of the input. N equals to len(input) if input is list of Variable.
* :math:`X_i`: The i-th input tensor.
* :math:`W_i`: The i-th weights matrix corresponding i-th input tensor.
* :math:`b`: The bias parameter created by this layer (if needed).
* :math:`Act`: The activation function.
* :math:`Out`: The output tensor.
See below for an example.
.. code-block:: text
Given:
data_1.data = [[[0.1, 0.2],
[0.3, 0.4]]]
data_1.shape = (1, 2, 2) # 1 is batch_size
data_2 = [[[0.1, 0.2, 0.3]]]
data_2.shape = (1, 1, 3)
out = fluid.layers.fc(input=[data_1, data_2], size=2)
Then:
out.data = [[0.18669507, 0.1893476]]
out.shape = (1, 2)
Args:
input (Variable|list of Variable): The input tensor(s) of this layer, and the dimension of
the input tensor(s) is at least 2.
......@@ -260,8 +285,14 @@ def fc(input,
Examples:
.. code-block:: python
# when input is single tensor
data = fluid.layers.data(name="data", shape=[32, 32], dtype="float32")
fc = fluid.layers.fc(input=data, size=1000, act="tanh")
# when input are multiple tensors
data_1 = fluid.layers.data(name="data_1", shape=[32, 32], dtype="float32")
data_2 = fluid.layers.data(name="data_2", shape=[24, 36], dtype="float32")
fc = fluid.layers.fc(input=[data_1, data_2], size=1000, act="tanh")
"""
helper = LayerHelper("fc", **locals())
......@@ -4864,7 +4895,8 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None):
if transpose_y:
y_shape[-2], y_shape[-1] = y_shape[-1], y_shape[-2]
if x_shape[-1] != y_shape[-2]:
raise ValueError("Invalid inputs for matmul.")
raise ValueError("Invalid inputs for matmul. x: %s, y: %s\n" %
(x_shape, y_shape))
if len(y_shape) > 2 and len(x_shape) > 2:
for i, dim_x in enumerate(x_shape[:-2]):
......@@ -6367,6 +6399,8 @@ def squeeze(input, axes, name=None):
x = layers.data(name='x', shape=[5, 1, 10])
y = layers.sequeeze(input=x, axes=[1])
"""
assert not _in_imperative_mode(), (
"squeeze layer is not supported in imperative mode yet.")
helper = LayerHelper("squeeze", **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype)
x_shape = helper.create_variable_for_type_inference(dtype=input.dtype)
......
......@@ -476,11 +476,29 @@ class TestYoloDetection(unittest.TestCase):
x = layers.data(name='x', shape=[30, 7, 7], dtype='float32')
gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32')
gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32')
loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13],
[0, 1], 10, 0.7, 32)
gtscore = layers.data(name='gtscore', shape=[10], dtype='float32')
loss = layers.yolov3_loss(
x,
gtbox,
gtlabel, [10, 13, 30, 13], [0, 1],
10,
0.7,
32,
gtscore=gtscore,
use_label_smooth=False)
self.assertIsNotNone(loss)
def test_yolo_box(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[30, 7, 7], dtype='float32')
img_size = layers.data(name='img_size', shape=[2], dtype='int32')
boxes, scores = layers.yolo_box(x, img_size, [10, 13, 30, 13], 10,
0.01, 32)
self.assertIsNotNone(boxes)
self.assertIsNotNone(scores)
class TestBoxClip(unittest.TestCase):
def test_box_clip(self):
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
class TestFakeQuantizeOp(OpTest):
......@@ -75,6 +76,7 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest):
'InScale': np.zeros(1).astype("float32")
}
scale = np.max(np.abs(self.inputs['X'])).astype("float32")
out_scales = np.zeros(self.attrs['window_size']).astype("float32")
out_scales[0] = scale
self.outputs = {
......@@ -88,6 +90,46 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest):
self.check_output()
class TestFakeQuantizeMovingOp(OpTest):
def setUp(self):
self.op_type = "fake_quantize_moving_average_abs_max"
self.attrs = {
'bit_length': int(5),
'moving_rate': float(0.9),
'is_test': False
}
accum = np.zeros(1).astype("float32")
accum[0] = 1
state = np.zeros(1).astype("float32")
state[0] = 1
scale = np.zeros(1).astype("float32")
scale[0] = 0.001
self.inputs = {
'X': np.random.random((8, 16, 7, 7)).astype("float32"),
'InScale': scale,
'InAccum': accum,
'InState': state,
}
out_accum = np.zeros(1).astype("float32")
out_state = np.zeros(1).astype("float32")
out_scale = np.zeros(1).astype("float32")
out_accum[0] = self.attrs['moving_rate'] * accum[0] + np.max(
np.abs(self.inputs['X'])).astype("float32")
out_state[0] = self.attrs['moving_rate'] * state[0] + 1
out_scale = out_accum / out_state
self.outputs = {
'Out': np.round(self.inputs['X'] / out_scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1)),
'OutAccum': out_accum,
'OutState': out_state,
'OutScale': out_scale,
}
def test_check_output(self):
self.check_output()
class TestFakeQuantizeRangeAbsMaxOp2(OpTest):
def setUp(self):
self.op_type = "fake_quantize_range_abs_max"
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import unittest
import numpy as np
import six
import sys
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.imperative.nn import Conv2D, Pool2D, FC
from test_imperative_base import new_program_scope
from paddle.fluid.imperative.base import to_variable
def gen_data():
pass
class GraphConv(fluid.imperative.Layer):
def __init__(self, name_scope, in_features, out_features):
super(GraphConv, self).__init__(name_scope)
self._in_features = in_features
self._out_features = out_features
self.weight = self.create_parameter(
attr=None,
dtype='float32',
shape=[self._in_features, self._out_features])
self.bias = self.create_parameter(
attr=None, dtype='float32', shape=[self._out_features])
def forward(self, features, adj):
support = fluid.layers.matmul(features, self.weight)
# TODO(panyx0718): sparse matmul?
return fluid.layers.matmul(adj, support) + self.bias
class GCN(fluid.imperative.Layer):
def __init__(self, name_scope, num_hidden):
super(GCN, self).__init__(name_scope)
self.gc = GraphConv(self.full_name(), num_hidden, 32)
self.gc2 = GraphConv(self.full_name(), 32, 10)
def forward(self, x, adj):
x = fluid.layers.relu(self.gc(x, adj))
return self.gc2(x, adj)
class TestImperativeGNN(unittest.TestCase):
def test_gnn_float32(self):
seed = 90
startup = fluid.Program()
startup.random_seed = seed
main = fluid.Program()
main.random_seed = seed
scope = fluid.core.Scope()
with new_program_scope(main=main, startup=startup, scope=scope):
features = fluid.layers.data(
name='features',
shape=[1, 100, 50],
dtype='float32',
append_batch_size=False)
# Use selected rows when it's supported.
adj = fluid.layers.data(
name='adj',
shape=[1, 100, 100],
dtype='float32',
append_batch_size=False)
labels = fluid.layers.data(
name='labels',
shape=[100, 1],
dtype='int64',
append_batch_size=False)
model = GCN('test_gcn', 50)
logits = model(features, adj)
logits = fluid.layers.reshape(logits, logits.shape[1:])
# In other example, it's nll with log_softmax. However, paddle's
# log_loss only supports binary classification now.
loss = fluid.layers.softmax_with_cross_entropy(logits, labels)
loss = fluid.layers.reduce_sum(loss)
adam = AdamOptimizer(learning_rate=1e-3)
adam.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
exe.run(startup)
static_loss = exe.run(feed={
'features': np.zeros(
[1, 100, 50], dtype=np.float32),
'adj': np.zeros(
[1, 100, 100], dtype=np.float32),
'labels': np.zeros(
[100, 1], dtype=np.int64)
},
fetch_list=[loss])[0]
static_weight = np.array(
scope.find_var(model.gc.weight.name).get_tensor())
with fluid.imperative.guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
features = np.zeros([1, 100, 50], dtype=np.float32)
# Use selected rows when it's supported.
adj = np.zeros([1, 100, 100], dtype=np.float32)
labels = np.zeros([100, 1], dtype=np.int64)
model = GCN('test_gcn', 50)
logits = model(to_variable(features), to_variable(adj))
logits = fluid.layers.reshape(logits, logits.shape[1:])
# In other example, it's nll with log_softmax. However, paddle's
# log_loss only supports binary classification now.
loss = fluid.layers.softmax_with_cross_entropy(logits,
to_variable(labels))
loss = fluid.layers.reduce_sum(loss)
adam = AdamOptimizer(learning_rate=1e-3)
adam.minimize(loss)
self.assertEqual(static_loss, loss._numpy())
self.assertTrue(
np.allclose(static_weight, model.gc.weight._numpy()))
sys.stderr.write('%s %s\n' % (static_loss, loss._numpy()))
if __name__ == '__main__':
unittest.main()
......@@ -84,6 +84,27 @@ class TestLayer(LayerTest):
self.assertTrue(np.allclose(static_ret, dy_ret._numpy()))
def test_matmul(self):
with self.static_graph():
t = layers.data(name='t', shape=[3, 3], dtype='float32')
t2 = layers.data(name='t2', shape=[3, 3], dtype='float32')
ret = layers.matmul(t, t2)
static_ret = self.get_static_graph_result(
feed={
't': np.ones(
[3, 3], dtype='float32'),
't2': np.ones(
[3, 3], dtype='float32')
},
fetch_list=[ret])[0]
with self.dynamic_graph():
t = np.ones([3, 3], dtype='float32')
t2 = np.ones([3, 3], dtype='float32')
dy_ret = layers.matmul(base.to_variable(t), base.to_variable(t2))
self.assertTrue(np.allclose(static_ret, dy_ret._numpy()))
def test_conv2d(self):
with self.static_graph():
images = layers.data(name='pixel', shape=[3, 5, 5], dtype='float32')
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import unittest
import numpy as np
from op_test import OpTest
from paddle.fluid import core
def sigmoid(x):
return 1.0 / (1.0 + np.exp(-1.0 * x))
def YoloBox(x, img_size, attrs):
n, c, h, w = x.shape
anchors = attrs['anchors']
an_num = int(len(anchors) // 2)
class_num = attrs['class_num']
conf_thresh = attrs['conf_thresh']
downsample = attrs['downsample']
input_size = downsample * h
x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
pred_box = x[:, :, :, :, :4].copy()
grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1))
grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w))
pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w
pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h
anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
anchors_s = np.array(
[(an_w / input_size, an_h / input_size) for an_w, an_h in anchors])
anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1))
anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1))
pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w
pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * anchor_h
pred_conf = sigmoid(x[:, :, :, :, 4:5])
pred_conf[pred_conf < conf_thresh] = 0.
pred_score = sigmoid(x[:, :, :, :, 5:]) * pred_conf
pred_box = pred_box * (pred_conf > 0.).astype('float32')
pred_box = pred_box.reshape((n, -1, 4))
pred_box[:, :, :2], pred_box[:, :, 2:4] = \
pred_box[:, :, :2] - pred_box[:, :, 2:4] / 2., \
pred_box[:, :, :2] + pred_box[:, :, 2:4] / 2.0
pred_box[:, :, 0] = pred_box[:, :, 0] * img_size[:, 1][:, np.newaxis]
pred_box[:, :, 1] = pred_box[:, :, 1] * img_size[:, 0][:, np.newaxis]
pred_box[:, :, 2] = pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis]
pred_box[:, :, 3] = pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis]
for i in range(len(pred_box)):
pred_box[i, :, 0] = np.clip(pred_box[i, :, 0], 0, np.inf)
pred_box[i, :, 1] = np.clip(pred_box[i, :, 1], 0, np.inf)
pred_box[i, :, 2] = np.clip(pred_box[i, :, 2], -np.inf,
img_size[i, 1] - 1)
pred_box[i, :, 3] = np.clip(pred_box[i, :, 3], -np.inf,
img_size[i, 0] - 1)
return pred_box, pred_score.reshape((n, -1, class_num))
class TestYoloBoxOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'yolo_box'
x = np.random.random(self.x_shape).astype('float32')
img_size = np.random.randint(10, 20, self.imgsize_shape).astype('int32')
self.attrs = {
"anchors": self.anchors,
"class_num": self.class_num,
"conf_thresh": self.conf_thresh,
"downsample": self.downsample,
}
self.inputs = {
'X': x,
'ImgSize': img_size,
}
boxes, scores = YoloBox(x, img_size, self.attrs)
self.outputs = {
"Boxes": boxes,
"Scores": scores,
}
def test_check_output(self):
self.check_output()
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int(len(self.anchors) // 2)
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.imgsize_shape = (self.batch_size, 2)
if __name__ == "__main__":
unittest.main()
......@@ -23,8 +23,8 @@ from op_test import OpTest
from paddle.fluid import core
def l2loss(x, y):
return 0.5 * (y - x) * (y - x)
def l1loss(x, y):
return abs(x - y)
def sce(x, label):
......@@ -66,7 +66,7 @@ def batch_xywh_box_iou(box1, box2):
return inter_area / union
def YOLOv3Loss(x, gtbox, gtlabel, attrs):
def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs):
n, c, h, w = x.shape
b = gtbox.shape[1]
anchors = attrs['anchors']
......@@ -75,21 +75,21 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs):
mask_num = len(anchor_mask)
class_num = attrs["class_num"]
ignore_thresh = attrs['ignore_thresh']
downsample = attrs['downsample']
input_size = downsample * h
downsample_ratio = attrs['downsample_ratio']
use_label_smooth = attrs['use_label_smooth']
input_size = downsample_ratio * h
x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
loss = np.zeros((n)).astype('float32')
label_pos = 1.0 - 1.0 / class_num if use_label_smooth else 1.0
label_neg = 1.0 / class_num if use_label_smooth else 0.0
pred_box = x[:, :, :, :, :4].copy()
grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1))
grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w))
pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w
pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h
x[:, :, :, :, 5:] = np.where(x[:, :, :, :, 5:] < -0.5, x[:, :, :, :, 5:],
np.ones_like(x[:, :, :, :, 5:]) * 1.0 /
class_num)
mask_anchors = []
for m in anchor_mask:
mask_anchors.append((anchors[2 * m], anchors[2 * m + 1]))
......@@ -138,21 +138,22 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs):
ty = gtbox[i, j, 1] * w - gj
tw = np.log(gtbox[i, j, 2] * input_size / mask_anchors[an_idx][0])
th = np.log(gtbox[i, j, 3] * input_size / mask_anchors[an_idx][1])
scale = (2.0 - gtbox[i, j, 2] * gtbox[i, j, 3])
scale = (2.0 - gtbox[i, j, 2] * gtbox[i, j, 3]) * gtscore[i, j]
loss[i] += sce(x[i, an_idx, gj, gi, 0], tx) * scale
loss[i] += sce(x[i, an_idx, gj, gi, 1], ty) * scale
loss[i] += l2loss(x[i, an_idx, gj, gi, 2], tw) * scale
loss[i] += l2loss(x[i, an_idx, gj, gi, 3], th) * scale
loss[i] += l1loss(x[i, an_idx, gj, gi, 2], tw) * scale
loss[i] += l1loss(x[i, an_idx, gj, gi, 3], th) * scale
objness[i, an_idx * h * w + gj * w + gi] = 1.0
objness[i, an_idx * h * w + gj * w + gi] = gtscore[i, j]
for label_idx in range(class_num):
loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx],
float(label_idx == gtlabel[i, j]))
loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], label_pos
if label_idx == gtlabel[i, j] else
label_neg) * gtscore[i, j]
for j in range(mask_num * h * w):
if objness[i, j] > 0:
loss[i] += sce(pred_obj[i, j], 1.0)
loss[i] += sce(pred_obj[i, j], 1.0) * objness[i, j]
elif objness[i, j] == 0:
loss[i] += sce(pred_obj[i, j], 0.0)
......@@ -176,7 +177,8 @@ class TestYolov3LossOp(OpTest):
"anchor_mask": self.anchor_mask,
"class_num": self.class_num,
"ignore_thresh": self.ignore_thresh,
"downsample": self.downsample,
"downsample_ratio": self.downsample_ratio,
"use_label_smooth": self.use_label_smooth,
}
self.inputs = {
......@@ -184,7 +186,14 @@ class TestYolov3LossOp(OpTest):
'GTBox': gtbox.astype('float32'),
'GTLabel': gtlabel.astype('int32'),
}
loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, self.attrs)
gtscore = np.ones(self.gtbox_shape[:2]).astype('float32')
if self.gtscore:
gtscore = np.random.random(self.gtbox_shape[:2]).astype('float32')
self.inputs['GTScore'] = gtscore
loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, gtscore,
self.attrs)
self.outputs = {
'Loss': loss,
'ObjectnessMask': objness,
......@@ -193,24 +202,57 @@ class TestYolov3LossOp(OpTest):
def test_check_output(self):
place = core.CPUPlace()
self.check_output_with_place(place, atol=1e-3)
self.check_output_with_place(place, atol=2e-3)
def test_check_grad_ignore_gtbox(self):
place = core.CPUPlace()
self.check_grad_with_place(
place, ['X'],
'Loss',
no_grad_set=set(["GTBox", "GTLabel"]),
max_relative_error=0.3)
self.check_grad_with_place(place, ['X'], 'Loss', max_relative_error=0.2)
def initTestCase(self):
self.anchors = [
10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198,
373, 326
]
self.anchor_mask = [0, 1, 2]
self.class_num = 5
self.ignore_thresh = 0.7
self.downsample_ratio = 32
self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5)
self.gtbox_shape = (3, 5, 4)
self.gtscore = True
self.use_label_smooth = True
class TestYolov3LossWithoutLabelSmooth(TestYolov3LossOp):
def initTestCase(self):
self.anchors = [
10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198,
373, 326
]
self.anchor_mask = [0, 1, 2]
self.class_num = 5
self.ignore_thresh = 0.7
self.downsample_ratio = 32
self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5)
self.gtbox_shape = (3, 5, 4)
self.gtscore = True
self.use_label_smooth = False
class TestYolov3LossNoGTScore(TestYolov3LossOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
self.anchor_mask = [1, 2]
self.anchors = [
10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198,
373, 326
]
self.anchor_mask = [0, 1, 2]
self.class_num = 5
self.ignore_thresh = 0.5
self.downsample = 32
self.ignore_thresh = 0.7
self.downsample_ratio = 32
self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5)
self.gtbox_shape = (3, 5, 4)
self.gtscore = False
self.use_label_smooth = True
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册