提交 61403f87 编写于 作者: W wozna

Add transpose2 INT8 for mkl-dnn

test=develop
上级 08fa98f7
...@@ -174,7 +174,10 @@ function(op_library TARGET) ...@@ -174,7 +174,10 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n")
elseif(${MKLDNN_FILE} STREQUAL "transpose_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, FP32);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, S8);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, U8);\n")
else() else()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
endif() endif()
......
...@@ -1155,6 +1155,27 @@ PDNode *patterns::Conv::operator()() { ...@@ -1155,6 +1155,27 @@ PDNode *patterns::Conv::operator()() {
return output_var; return output_var;
} }
PDNode *patterns::Transpose::operator()() {
auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
auto transpose_op =
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
auto transpose_in = pattern->NewNode(transpose_in_repr())
->AsInput()
->assert_is_op_input("transpose2");
auto transpose_out = pattern->NewNode(transpose_out_repr())
->AsOutput()
->assert_is_op_output("transpose2", "Out");
auto next_op = pattern->NewNode(next_op_repr())->assert_is_op();
prev_op->LinksTo({transpose_in});
transpose_op->LinksFrom({transpose_in}).LinksTo({transpose_out});
next_op->LinksFrom({transpose_out});
return transpose_out;
}
PDNode *patterns::ConvResidual::operator()(bool with_residual_data) { PDNode *patterns::ConvResidual::operator()(bool with_residual_data) {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
......
...@@ -750,6 +750,21 @@ struct ElementwiseAdd : public PatternBase { ...@@ -750,6 +750,21 @@ struct ElementwiseAdd : public PatternBase {
PATTERN_DECL_NODE(elementwise_add_out); PATTERN_DECL_NODE(elementwise_add_out);
}; };
// Transpose op
// Forward pass for transpose.
// transpose_out is a result of the operator.
struct Transpose : public PatternBase {
Transpose(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "transpose2") {}
PDNode* operator()();
PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(transpose_in);
PATTERN_DECL_NODE(transpose_op);
PATTERN_DECL_NODE(transpose_out);
PATTERN_DECL_NODE(next_op);
};
// Concat op // Concat op
// Forward pass for concat. // Forward pass for concat.
// concat_out is a result of the operator. // concat_out is a result of the operator.
......
...@@ -343,6 +343,65 @@ void CPUQuantizePass::QuantizePriorBox(Graph* graph) const { ...@@ -343,6 +343,65 @@ void CPUQuantizePass::QuantizePriorBox(Graph* graph) const {
quantize_prior_box_count); quantize_prior_box_count);
} }
void CPUQuantizePass::QuantizeTranspose(Graph* graph) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Transpose transpose_pattern{pattern, name_scope_};
transpose_pattern();
int quantize_transpose_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Quantize transpose op";
GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, transpose_pattern);
auto* transpose_op_desc = transpose_op->Op();
if (!transpose_op_desc->HasAttr("use_quantizer")) {
return;
}
// skip if should not be quantized
if (!transpose_op_desc->HasAttr("use_quantizer") ||
!boost::get<bool>(transpose_op_desc->GetAttr("use_quantizer"))) {
return;
}
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, transpose_pattern);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, transpose_pattern);
// skip if prev op is not quantized
// in future we should checked if next_op is quantized
// transpose INT8 schould be used only between INT8 operators
if (!(prev_op->Op()->Type() == "dequantize" ||
(prev_op->Op()->HasAttr("use_quantizer") &&
boost::get<bool>(prev_op->Op()->GetAttr("use_quantizer"))))) {
return;
}
GET_IR_NODE_FROM_SUBGRAPH(transpose_in, transpose_in, transpose_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose_out, transpose_out, transpose_pattern);
// get scales calculated after warmup, they scale variables to MAX=1.0
auto scales = Get<VarQuantScale>("quant_var_scales");
auto input_scale = scales[transpose_in->Name()].second.data<double>()[0];
bool is_input_unsigned = scales[transpose_in->Name()].first;
QuantizeInput(g, transpose_op, transpose_in, "X", input_scale,
is_input_unsigned);
auto output_scale = scales[transpose_out->Name()].second.data<double>()[0];
bool is_output_unsigned = scales[transpose_out->Name()].first;
DequantizeOutput(g, transpose_op, transpose_out, "Out", output_scale,
is_output_unsigned);
++quantize_transpose_count;
};
gpd(graph, handler);
AddStatis(quantize_transpose_count);
PrettyLogDetail("--- quantized %d transpose ops",
quantize_transpose_count);
}
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Quantizing the graph."; VLOG(3) << "Quantizing the graph.";
PADDLE_ENFORCE(graph); PADDLE_ENFORCE(graph);
...@@ -355,6 +414,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -355,6 +414,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizePool(graph); QuantizePool(graph);
QuantizeConcat(graph); QuantizeConcat(graph);
QuantizePriorBox(graph); QuantizePriorBox(graph);
QuantizeTranspose(graph);
} }
} // namespace ir } // namespace ir
......
...@@ -52,6 +52,8 @@ class CPUQuantizePass : public FusePassBase { ...@@ -52,6 +52,8 @@ class CPUQuantizePass : public FusePassBase {
void QuantizePriorBox(Graph* graph) const; void QuantizePriorBox(Graph* graph) const;
void QuantizeTranspose(Graph* graph) const;
void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name, void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name,
double scale_to_one, bool is_unsigned, double scale_to_one, bool is_unsigned,
std::string scale_attr_name = "") const; std::string scale_attr_name = "") const;
......
...@@ -48,7 +48,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -48,7 +48,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetAttr("Scale_in", 1.0f); op->SetAttr("Scale_in", 1.0f);
op->SetAttr("Scale_out", 1.0f); op->SetAttr("Scale_out", 1.0f);
op->SetAttr("Scale_weights", std::vector<float>{1.0f}); op->SetAttr("Scale_weights", std::vector<float>{1.0f});
} else if (type == "pool2d") { } else if (type == "pool2d" || type == "transpose2") {
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
op->SetAttr("use_quantizer", use_quantizer); op->SetAttr("use_quantizer", use_quantizer);
...@@ -113,19 +113,14 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place, ...@@ -113,19 +113,14 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
tensor->mutable_data(place, proto::VarType::FP32, 1); tensor->mutable_data(place, proto::VarType::FP32, 1);
} }
void MainTest(const ProgramDesc& prog, int conv_count, int pool_count, void PreparePass(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog,
int quant_count, int dequant_count, int added_nodes_count, const std::initializer_list<std::string> variable_names,
float scale) { int* original_nodes_num, int* current_nodes_num) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
// Init scope, as it is used in pass
auto place = paddle::platform::CPUPlace(); auto place = paddle::platform::CPUPlace();
NaiveExecutor exe{place}; NaiveExecutor exe{place};
Scope scope; Scope scope;
exe.CreateVariables(prog, 0, true, &scope); exe.CreateVariables(prog, 0, true, &scope);
auto* scales = new VarQuantScale(); auto* scales = new VarQuantScale();
for (auto& v : variable_names) { for (auto& v : variable_names) {
InitTensorHolder(&scope, place, v.c_str()); InitTensorHolder(&scope, place, v.c_str());
LoDTensor tensor; LoDTensor tensor;
...@@ -136,16 +131,23 @@ void MainTest(const ProgramDesc& prog, int conv_count, int pool_count, ...@@ -136,16 +131,23 @@ void MainTest(const ProgramDesc& prog, int conv_count, int pool_count,
(*scales)[v] = std::make_pair(false, std::move(tensor)); (*scales)[v] = std::make_pair(false, std::move(tensor));
} }
graph->SetNotOwned(kParamScopeAttr, &scope); (*graph)->SetNotOwned(kParamScopeAttr, &scope);
std::unique_ptr<Pass> pass =
auto pass = PassRegistry::Instance().Get("cpu_quantize_pass"); PassRegistry::Instance().Get("cpu_quantize_pass");
pass->Set("quant_var_scales", scales); pass->Set("quant_var_scales", scales);
int original_nodes_num = graph->Nodes().size(); *original_nodes_num = (*graph)->Nodes().size();
(*graph).reset(pass->Apply((*graph).release()));
graph.reset(pass->Apply(graph.release())); *current_nodes_num = (*graph)->Nodes().size();
}
int current_nodes_num = graph->Nodes().size(); void MainTest(const ProgramDesc& prog, int conv_count, int pool_count,
int quant_count, int dequant_count, int added_nodes_count,
float scale) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num, current_nodes_num;
PreparePass(&graph, prog, variable_names, &original_nodes_num,
&current_nodes_num);
int quantize_nodes_count = 0; int quantize_nodes_count = 0;
int dequantize_nodes_count = 0; int dequantize_nodes_count = 0;
...@@ -232,35 +234,9 @@ ProgramDesc BuildProgramDescConcat() { ...@@ -232,35 +234,9 @@ ProgramDesc BuildProgramDescConcat() {
void MainTestConcat(const ProgramDesc& prog, int pool_count, int concat_count, void MainTestConcat(const ProgramDesc& prog, int pool_count, int concat_count,
int quant_count, int dequant_count, int added_nodes_count) { int quant_count, int dequant_count, int added_nodes_count) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num, current_nodes_num;
// Init scope, as it is used in pass PreparePass(&graph, prog, variable_names_concat, &original_nodes_num,
auto place = paddle::platform::CPUPlace(); &current_nodes_num);
NaiveExecutor exe{place};
Scope scope;
exe.CreateVariables(prog, 0, true, &scope);
auto* scales = new VarQuantScale();
for (auto& v : variable_names_concat) {
InitTensorHolder(&scope, place, v.c_str());
LoDTensor tensor;
tensor.Resize({1});
auto* ptr = tensor.mutable_data<double>(place);
ptr[0] = 2.0;
(*scales)[v] = std::make_pair(false, std::move(tensor));
}
graph->SetNotOwned(kParamScopeAttr, &scope);
auto pass = PassRegistry::Instance().Get("cpu_quantize_pass");
pass->Set("quant_var_scales", scales);
int original_nodes_num = graph->Nodes().size();
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
int quantize_nodes_count = 0; int quantize_nodes_count = 0;
int dequantize_nodes_count = 0; int dequantize_nodes_count = 0;
...@@ -300,9 +276,93 @@ TEST(CpuQuantizePass, concat) { ...@@ -300,9 +276,93 @@ TEST(CpuQuantizePass, concat) {
MainTestConcat(BuildProgramDescConcat(), pool_count, concat_count, MainTestConcat(BuildProgramDescConcat(), pool_count, concat_count,
quant_count, dequant_count, added_nodes_count); quant_count, dequant_count, added_nodes_count);
} }
} // namespace } // namespace
namespace {
static const std::initializer_list<std::string> variable_names_transpose = {
"a", "w1", "b", "c", "w2", "d", "e", "f"};
// a->Conv1->b
// b->Transpose1->c
// c->Conv2->d
// d->Transpose2->e
// e->Dropout->f
ProgramDesc BuildProgramDescTranspose() {
ProgramDesc prog;
for (auto& v : variable_names_transpose) {
auto* var = prog.MutableBlock(0)->Var(v);
if (v.find("w") == 0) {
var->SetPersistable(true);
}
}
SetOp(&prog, "conv2d", "Conv1", {"a", "w1"}, {"b"}, true, true);
SetOp(&prog, "transpose2", "Transpose1", {"b"}, {"c"}, true, true);
SetOp(&prog, "conv2d", "Conv1", {"c", "w2"}, {"d"}, true, true);
SetOp(&prog, "transpose2", "Transpose2", {"d"}, {"e"}, true, true);
SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, false);
return prog;
}
void MainTestTranspose(const ProgramDesc& prog, int conv_count,
int transpose_count, int quant_count, int dequant_count,
int added_nodes_count, float scale) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num, current_nodes_num;
PreparePass(&graph, prog, variable_names_transpose, &original_nodes_num,
&current_nodes_num);
int quantize_nodes_count = 0;
int dequantize_nodes_count = 0;
int transpose_nodes_count = 0;
int conv_nodes_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->Type() == "transpose2") {
transpose_nodes_count++;
} else if (op->Type() == "conv2d") {
conv_nodes_count++;
auto op_name = boost::get<std::string>(op->GetAttr("name"));
EXPECT_EQ(boost::get<float>(op->GetAttr("Scale_in")), scale)
<< "Scale_in for node '" + op_name + "'.";
EXPECT_EQ(boost::get<float>(op->GetAttr("Scale_out")), scale)
<< "Scale_out for node '" + op_name + "'.";
EXPECT_EQ(
boost::get<std::vector<float>>(op->GetAttr("Scale_weights"))[0],
scale)
<< "Scale_weights for node '" + op_name + "'.";
} else if (op->Type() == "quantize") {
quantize_nodes_count++;
} else if (op->Type() == "dequantize") {
dequantize_nodes_count++;
}
}
}
EXPECT_EQ(transpose_nodes_count, transpose_count);
EXPECT_EQ(conv_nodes_count, conv_count);
EXPECT_EQ(quantize_nodes_count, quant_count);
EXPECT_EQ(dequantize_nodes_count, dequant_count);
EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num);
}
TEST(CpuQuantizePass, transpose) {
// a1->Quant->a2->Conv1->b1->Dequant->b2
// b2->Quant->b3->Transpose->c1->Dequant->c2
// c2->Quant->c3->Conv2->d1->Dequant->d2
// d2->Quant->d3->Transpose->e1->Dequant->e2
// e2->Dropout->f
int conv_count = 2;
int transpose_count = 2;
int quant_count = 4;
int dequant_count = 4;
// 4 Quant + 4 IN + 4 DeQuant + 4 OUT
int added_nodes_count = 16;
MainTestTranspose(BuildProgramDescTranspose(), conv_count, transpose_count,
quant_count, dequant_count, added_nodes_count, 2.0f * 127);
}
} // namespace
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -34,6 +34,9 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() { ...@@ -34,6 +34,9 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() {
rules_["prior_box"]["Image"] = ScaleAlgo::NONE; rules_["prior_box"]["Image"] = ScaleAlgo::NONE;
rules_["prior_box"]["Boxes"] = ScaleAlgo::NONE; rules_["prior_box"]["Boxes"] = ScaleAlgo::NONE;
rules_["prior_box"]["Variances"] = ScaleAlgo::NONE; rules_["prior_box"]["Variances"] = ScaleAlgo::NONE;
rules_["transpose"]["X"] = ScaleAlgo::KL;
rules_["transpose"]["Out"] = ScaleAlgo::KL;
} }
ScaleAlgo MkldnnQuantizerConfig::scale_algo( ScaleAlgo MkldnnQuantizerConfig::scale_algo(
......
...@@ -268,7 +268,7 @@ TEST(Analyzer_int8_mobilenet_ssd, quantization) { ...@@ -268,7 +268,7 @@ TEST(Analyzer_int8_mobilenet_ssd, quantization) {
q_cfg.EnableMkldnnQuantizer(); q_cfg.EnableMkldnnQuantizer();
q_cfg.mkldnn_quantizer_config(); q_cfg.mkldnn_quantizer_config();
std::unordered_set<std::string> quantize_operators( std::unordered_set<std::string> quantize_operators(
{"conv2d", "depthwise_conv2d", "prior_box"}); {"conv2d", "depthwise_conv2d", "prior_box", "transpose2"});
q_cfg.mkldnn_quantizer_config()->SetEnabledOpTypes(quantize_operators); q_cfg.mkldnn_quantizer_config()->SetEnabledOpTypes(quantize_operators);
q_cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data); q_cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data);
q_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(FLAGS_warmup_batch_size); q_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(FLAGS_warmup_batch_size);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle { namespace paddle {
...@@ -29,6 +30,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -29,6 +30,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace."); "It must use CPUPlace.");
mkldnn::memory::data_type in_type = platform::MKLDNNGetDataType<T>();
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
...@@ -49,8 +51,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -49,8 +51,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
nchw_tz, axis, nchw_tz, axis,
ctx.op().Output("Out") + std::to_string(input->format())); ctx.op().Output("Out") + std::to_string(input->format()));
platform::TransposeMKLDNNHandler handler(nchw_tz, axis, dev_ctx, platform::TransposeMKLDNNHandler handler(
mkldnn_engine, key); nchw_tz, axis, input->type(), in_type, dev_ctx, mkldnn_engine, key);
auto transpose_src_memory_p = handler.AcquireSrcMemory( auto transpose_src_memory_p = handler.AcquireSrcMemory(
input->format(), platform::to_void_cast<T>(input_data)); input->format(), platform::to_void_cast<T>(input_data));
...@@ -78,7 +80,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -78,7 +80,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
ctx.Input<framework::Tensor>(framework::GradVarName("Out")); ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
if (!x_grad) return; if (!x_grad) return;
mkldnn::memory::data_type in_type = platform::MKLDNNGetDataType<T>();
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
...@@ -103,7 +105,8 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -103,7 +105,8 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const std::string key = platform::TransposeMKLDNNHandler::GetHash( const std::string key = platform::TransposeMKLDNNHandler::GetHash(
nchw_tz, axis, ctx.op().Output(framework::GradVarName("X"))); nchw_tz, axis, ctx.op().Output(framework::GradVarName("X")));
platform::TransposeMKLDNNHandler handler(nchw_tz, reversed_axis, dev_ctx, platform::TransposeMKLDNNHandler handler(nchw_tz, reversed_axis,
x_grad->type(), in_type, dev_ctx,
mkldnn_engine, key); mkldnn_engine, key);
auto transpose_src_memory_p = handler.AcquireSrcMemory( auto transpose_src_memory_p = handler.AcquireSrcMemory(
...@@ -124,12 +127,36 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -124,12 +127,36 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(transpose2, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN,
::paddle::platform::CPUPlace, FP32,
ops::kTransposeMKLDNNFP32,
ops::TransposeMKLDNNOpKernel<float>); ops::TransposeMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(transpose, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN,
::paddle::platform::CPUPlace, U8,
ops::kTransposeMKLDNNINT8,
ops::TransposeMKLDNNOpKernel<uint8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN,
::paddle::platform::CPUPlace, S8,
ops::kTransposeMKLDNNINT8,
ops::TransposeMKLDNNOpKernel<int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose, MKLDNN,
::paddle::platform::CPUPlace, FP32,
ops::kTransposeMKLDNNFP32,
ops::TransposeMKLDNNOpKernel<float>); ops::TransposeMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose, MKLDNN,
::paddle::platform::CPUPlace, U8,
ops::kTransposeMKLDNNINT8,
ops::TransposeMKLDNNOpKernel<uint8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose, MKLDNN,
::paddle::platform::CPUPlace, S8,
ops::kTransposeMKLDNNINT8,
ops::TransposeMKLDNNOpKernel<int8_t>);
REGISTER_OP_KERNEL(transpose_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(transpose_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::TransposeMKLDNNGradOpKernel<float>); ops::TransposeMKLDNNGradOpKernel<float>);
REGISTER_OP_KERNEL(transpose2_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(transpose2_grad, MKLDNN, ::paddle::platform::CPUPlace,
......
...@@ -65,15 +65,23 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -65,15 +65,23 @@ class TransposeOp : public framework::OperatorWithKernel {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = ctx.Input<Tensor>("X")->type();
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
using framework::proto::VarType;
customized_type_value = (input_data_type == VarType::INT8 ||
input_data_type == VarType::UINT8)
? kTransposeMKLDNNINT8
: kTransposeMKLDNNFP32;
} }
#endif #endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
ctx.GetPlace(), layout_, library_); library_, customized_type_value);
} }
}; };
...@@ -99,6 +107,13 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -99,6 +107,13 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
"Defaults to \"NHWC\". Specify the data format of the output data, " "Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ") "the input will be transformed automatically. ")
.SetDefault("AnyLayout"); .SetDefault("AnyLayout");
/* int8 parameters */
AddAttr<bool>("use_quantizer",
"(bool, default false) "
"Set to true for operators that should be quantized and use "
"int8 kernel. "
"Only used on CPU.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Transpose Operator. Transpose Operator.
...@@ -196,16 +211,24 @@ class Transpose2Op : public TransposeOp { ...@@ -196,16 +211,24 @@ class Transpose2Op : public TransposeOp {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = ctx.Input<Tensor>("X")->type();
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
using framework::proto::VarType;
customized_type_value = (input_data_type == VarType::INT8 ||
input_data_type == VarType::UINT8)
? kTransposeMKLDNNINT8
: kTransposeMKLDNNFP32;
} }
#endif #endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
ctx.GetPlace(), layout_, library_); library_, customized_type_value);
} }
}; };
......
...@@ -21,6 +21,8 @@ limitations under the License. */ ...@@ -21,6 +21,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
enum { kTransposeMKLDNNFP32 = 1, kTransposeMKLDNNINT8 = 2 };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
inline void TransCompute(const int dim, const DeviceContext& dev_ctx, inline void TransCompute(const int dim, const DeviceContext& dev_ctx,
const framework::Tensor& in, framework::Tensor* out, const framework::Tensor& in, framework::Tensor* out,
......
...@@ -828,12 +828,16 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -828,12 +828,16 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
public: public:
TransposeMKLDNNHandler(std::vector<int>& dims, // NOLINT TransposeMKLDNNHandler(std::vector<int>& dims, // NOLINT
std::vector<int>& axis, // NOLINT std::vector<int>& axis, // NOLINT
framework::proto::VarType::Type vtype,
mkldnn::memory::data_type dtype,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key) mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key), : platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims), dims_(dims),
axis_(axis), axis_(axis),
logical_axis_(dims.size(), 0) {} logical_axis_(dims.size(), 0),
vtype_(vtype),
dtype_(dtype) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory( std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::format& fmt, void* ptr) { const mkldnn::memory::format& fmt, void* ptr) {
...@@ -847,9 +851,8 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -847,9 +851,8 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
logical_axis_[i] = i; logical_axis_[i] = i;
} }
auto src_md = fmt != mkldnn::memory::format::nchw auto src_md = fmt != mkldnn::memory::format::nchw
? platform::MKLDNNMemDesc( ? platform::MKLDNNMemDesc(dims_, dtype_, fmt)
dims_, platform::MKLDNNGetDataType<float>(), fmt) : Axis2MemoryDesc(dims_, logical_axis_, dtype_);
: Axis2MemoryDesc(dims_, logical_axis_);
mem_p = std::make_shared<mkldnn::memory>( mem_p = std::make_shared<mkldnn::memory>(
mkldnn::memory::primitive_desc{src_md, engine_}, ptr); mkldnn::memory::primitive_desc{src_md, engine_}, ptr);
dev_ctx_.SetBlob(local_key, mem_p); dev_ctx_.SetBlob(local_key, mem_p);
...@@ -866,14 +869,14 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -866,14 +869,14 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) { if (mem_p == nullptr) {
auto dst_mdp = mkldnn::memory::primitive_desc{ auto dst_mdp = mkldnn::memory::primitive_desc{
Axis2MemoryDesc(dims_, axis_), engine_}; Axis2MemoryDesc(dims_, axis_, dtype_), engine_};
auto dst_data = output->mutable_data<float>(place, dst_mdp.get_size()); auto dst_data = output->mutable_data(place, vtype_);
mem_p = std::make_shared<mkldnn::memory>(dst_mdp, dst_data); mem_p = std::make_shared<mkldnn::memory>(dst_mdp, dst_data);
dev_ctx_.SetBlob(local_key, mem_p); dev_ctx_.SetBlob(local_key, mem_p);
} else { } else {
auto dst_data = output->mutable_data<float>(place); auto dst_data = output->mutable_data(place, vtype_);
mem_p->set_data_handle(dst_data); mem_p->set_data_handle(dst_data);
} }
return mem_p; return mem_p;
...@@ -901,8 +904,8 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -901,8 +904,8 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
protected: protected:
mkldnn_memory_desc_t Axis2MemoryDesc(std::vector<int>& nchw_tz, // NOLINT mkldnn_memory_desc_t Axis2MemoryDesc(std::vector<int>& nchw_tz, // NOLINT
std::vector<int>& axis // NOLINT std::vector<int>& axis, // NOLINT
) { mkldnn::memory::data_type dtype) {
mkldnn_memory_desc_t mem_fmt; mkldnn_memory_desc_t mem_fmt;
mem_fmt.primitive_kind = mkldnn_memory; mem_fmt.primitive_kind = mkldnn_memory;
...@@ -911,6 +914,11 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -911,6 +914,11 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
mem_fmt.dims[i] = nchw_tz[i]; // logical dimensions (nchw format, mem_fmt.dims[i] = nchw_tz[i]; // logical dimensions (nchw format,
// regardless physical layout) // regardless physical layout)
} }
if (dtype == mkldnn::memory::data_type::s8)
mem_fmt.data_type = mkldnn_s8;
else if (dtype == mkldnn::memory::data_type::u8)
mem_fmt.data_type = mkldnn_u8;
else
mem_fmt.data_type = mkldnn_f32; mem_fmt.data_type = mkldnn_f32;
mem_fmt.format = mkldnn_blocked; mem_fmt.format = mkldnn_blocked;
...@@ -933,6 +941,8 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -933,6 +941,8 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
std::vector<int> dims_; std::vector<int> dims_;
std::vector<int> axis_; std::vector<int> axis_;
std::vector<int> logical_axis_; std::vector<int> logical_axis_;
framework::proto::VarType::Type vtype_;
mkldnn::memory::data_type dtype_;
}; };
class ReorderMKLDNNHandler : public MKLDNNHandler { class ReorderMKLDNNHandler : public MKLDNNHandler {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册