提交 61403f87 编写于 作者: W wozna

Add transpose2 INT8 for mkl-dnn

test=develop
上级 08fa98f7
......@@ -174,7 +174,10 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n")
elseif(${MKLDNN_FILE} STREQUAL "transpose_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, FP32);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, S8);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, U8);\n")
else()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
endif()
......
......@@ -1155,6 +1155,27 @@ PDNode *patterns::Conv::operator()() {
return output_var;
}
PDNode *patterns::Transpose::operator()() {
auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
auto transpose_op =
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
auto transpose_in = pattern->NewNode(transpose_in_repr())
->AsInput()
->assert_is_op_input("transpose2");
auto transpose_out = pattern->NewNode(transpose_out_repr())
->AsOutput()
->assert_is_op_output("transpose2", "Out");
auto next_op = pattern->NewNode(next_op_repr())->assert_is_op();
prev_op->LinksTo({transpose_in});
transpose_op->LinksFrom({transpose_in}).LinksTo({transpose_out});
next_op->LinksFrom({transpose_out});
return transpose_out;
}
PDNode *patterns::ConvResidual::operator()(bool with_residual_data) {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
......
......@@ -750,6 +750,21 @@ struct ElementwiseAdd : public PatternBase {
PATTERN_DECL_NODE(elementwise_add_out);
};
// Transpose op
// Forward pass for transpose.
// transpose_out is a result of the operator.
struct Transpose : public PatternBase {
Transpose(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "transpose2") {}
PDNode* operator()();
PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(transpose_in);
PATTERN_DECL_NODE(transpose_op);
PATTERN_DECL_NODE(transpose_out);
PATTERN_DECL_NODE(next_op);
};
// Concat op
// Forward pass for concat.
// concat_out is a result of the operator.
......
......@@ -343,6 +343,65 @@ void CPUQuantizePass::QuantizePriorBox(Graph* graph) const {
quantize_prior_box_count);
}
void CPUQuantizePass::QuantizeTranspose(Graph* graph) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Transpose transpose_pattern{pattern, name_scope_};
transpose_pattern();
int quantize_transpose_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Quantize transpose op";
GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, transpose_pattern);
auto* transpose_op_desc = transpose_op->Op();
if (!transpose_op_desc->HasAttr("use_quantizer")) {
return;
}
// skip if should not be quantized
if (!transpose_op_desc->HasAttr("use_quantizer") ||
!boost::get<bool>(transpose_op_desc->GetAttr("use_quantizer"))) {
return;
}
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, transpose_pattern);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, transpose_pattern);
// skip if prev op is not quantized
// in future we should checked if next_op is quantized
// transpose INT8 schould be used only between INT8 operators
if (!(prev_op->Op()->Type() == "dequantize" ||
(prev_op->Op()->HasAttr("use_quantizer") &&
boost::get<bool>(prev_op->Op()->GetAttr("use_quantizer"))))) {
return;
}
GET_IR_NODE_FROM_SUBGRAPH(transpose_in, transpose_in, transpose_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose_out, transpose_out, transpose_pattern);
// get scales calculated after warmup, they scale variables to MAX=1.0
auto scales = Get<VarQuantScale>("quant_var_scales");
auto input_scale = scales[transpose_in->Name()].second.data<double>()[0];
bool is_input_unsigned = scales[transpose_in->Name()].first;
QuantizeInput(g, transpose_op, transpose_in, "X", input_scale,
is_input_unsigned);
auto output_scale = scales[transpose_out->Name()].second.data<double>()[0];
bool is_output_unsigned = scales[transpose_out->Name()].first;
DequantizeOutput(g, transpose_op, transpose_out, "Out", output_scale,
is_output_unsigned);
++quantize_transpose_count;
};
gpd(graph, handler);
AddStatis(quantize_transpose_count);
PrettyLogDetail("--- quantized %d transpose ops",
quantize_transpose_count);
}
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Quantizing the graph.";
PADDLE_ENFORCE(graph);
......@@ -355,6 +414,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizePool(graph);
QuantizeConcat(graph);
QuantizePriorBox(graph);
QuantizeTranspose(graph);
}
} // namespace ir
......
......@@ -52,6 +52,8 @@ class CPUQuantizePass : public FusePassBase {
void QuantizePriorBox(Graph* graph) const;
void QuantizeTranspose(Graph* graph) const;
void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name,
double scale_to_one, bool is_unsigned,
std::string scale_attr_name = "") const;
......
......@@ -48,7 +48,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetAttr("Scale_in", 1.0f);
op->SetAttr("Scale_out", 1.0f);
op->SetAttr("Scale_weights", std::vector<float>{1.0f});
} else if (type == "pool2d") {
} else if (type == "pool2d" || type == "transpose2") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("use_quantizer", use_quantizer);
......@@ -113,19 +113,14 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
tensor->mutable_data(place, proto::VarType::FP32, 1);
}
void MainTest(const ProgramDesc& prog, int conv_count, int pool_count,
int quant_count, int dequant_count, int added_nodes_count,
float scale) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
// Init scope, as it is used in pass
void PreparePass(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog,
const std::initializer_list<std::string> variable_names,
int* original_nodes_num, int* current_nodes_num) {
auto place = paddle::platform::CPUPlace();
NaiveExecutor exe{place};
Scope scope;
exe.CreateVariables(prog, 0, true, &scope);
auto* scales = new VarQuantScale();
for (auto& v : variable_names) {
InitTensorHolder(&scope, place, v.c_str());
LoDTensor tensor;
......@@ -136,16 +131,23 @@ void MainTest(const ProgramDesc& prog, int conv_count, int pool_count,
(*scales)[v] = std::make_pair(false, std::move(tensor));
}
graph->SetNotOwned(kParamScopeAttr, &scope);
auto pass = PassRegistry::Instance().Get("cpu_quantize_pass");
(*graph)->SetNotOwned(kParamScopeAttr, &scope);
std::unique_ptr<Pass> pass =
PassRegistry::Instance().Get("cpu_quantize_pass");
pass->Set("quant_var_scales", scales);
int original_nodes_num = graph->Nodes().size();
graph.reset(pass->Apply(graph.release()));
*original_nodes_num = (*graph)->Nodes().size();
(*graph).reset(pass->Apply((*graph).release()));
*current_nodes_num = (*graph)->Nodes().size();
}
int current_nodes_num = graph->Nodes().size();
void MainTest(const ProgramDesc& prog, int conv_count, int pool_count,
int quant_count, int dequant_count, int added_nodes_count,
float scale) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num, current_nodes_num;
PreparePass(&graph, prog, variable_names, &original_nodes_num,
&current_nodes_num);
int quantize_nodes_count = 0;
int dequantize_nodes_count = 0;
......@@ -232,35 +234,9 @@ ProgramDesc BuildProgramDescConcat() {
void MainTestConcat(const ProgramDesc& prog, int pool_count, int concat_count,
int quant_count, int dequant_count, int added_nodes_count) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
// Init scope, as it is used in pass
auto place = paddle::platform::CPUPlace();
NaiveExecutor exe{place};
Scope scope;
exe.CreateVariables(prog, 0, true, &scope);
auto* scales = new VarQuantScale();
for (auto& v : variable_names_concat) {
InitTensorHolder(&scope, place, v.c_str());
LoDTensor tensor;
tensor.Resize({1});
auto* ptr = tensor.mutable_data<double>(place);
ptr[0] = 2.0;
(*scales)[v] = std::make_pair(false, std::move(tensor));
}
graph->SetNotOwned(kParamScopeAttr, &scope);
auto pass = PassRegistry::Instance().Get("cpu_quantize_pass");
pass->Set("quant_var_scales", scales);
int original_nodes_num = graph->Nodes().size();
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
int original_nodes_num, current_nodes_num;
PreparePass(&graph, prog, variable_names_concat, &original_nodes_num,
&current_nodes_num);
int quantize_nodes_count = 0;
int dequantize_nodes_count = 0;
......@@ -300,9 +276,93 @@ TEST(CpuQuantizePass, concat) {
MainTestConcat(BuildProgramDescConcat(), pool_count, concat_count,
quant_count, dequant_count, added_nodes_count);
}
} // namespace
namespace {
static const std::initializer_list<std::string> variable_names_transpose = {
"a", "w1", "b", "c", "w2", "d", "e", "f"};
// a->Conv1->b
// b->Transpose1->c
// c->Conv2->d
// d->Transpose2->e
// e->Dropout->f
ProgramDesc BuildProgramDescTranspose() {
ProgramDesc prog;
for (auto& v : variable_names_transpose) {
auto* var = prog.MutableBlock(0)->Var(v);
if (v.find("w") == 0) {
var->SetPersistable(true);
}
}
SetOp(&prog, "conv2d", "Conv1", {"a", "w1"}, {"b"}, true, true);
SetOp(&prog, "transpose2", "Transpose1", {"b"}, {"c"}, true, true);
SetOp(&prog, "conv2d", "Conv1", {"c", "w2"}, {"d"}, true, true);
SetOp(&prog, "transpose2", "Transpose2", {"d"}, {"e"}, true, true);
SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, false);
return prog;
}
void MainTestTranspose(const ProgramDesc& prog, int conv_count,
int transpose_count, int quant_count, int dequant_count,
int added_nodes_count, float scale) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num, current_nodes_num;
PreparePass(&graph, prog, variable_names_transpose, &original_nodes_num,
&current_nodes_num);
int quantize_nodes_count = 0;
int dequantize_nodes_count = 0;
int transpose_nodes_count = 0;
int conv_nodes_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->Type() == "transpose2") {
transpose_nodes_count++;
} else if (op->Type() == "conv2d") {
conv_nodes_count++;
auto op_name = boost::get<std::string>(op->GetAttr("name"));
EXPECT_EQ(boost::get<float>(op->GetAttr("Scale_in")), scale)
<< "Scale_in for node '" + op_name + "'.";
EXPECT_EQ(boost::get<float>(op->GetAttr("Scale_out")), scale)
<< "Scale_out for node '" + op_name + "'.";
EXPECT_EQ(
boost::get<std::vector<float>>(op->GetAttr("Scale_weights"))[0],
scale)
<< "Scale_weights for node '" + op_name + "'.";
} else if (op->Type() == "quantize") {
quantize_nodes_count++;
} else if (op->Type() == "dequantize") {
dequantize_nodes_count++;
}
}
}
EXPECT_EQ(transpose_nodes_count, transpose_count);
EXPECT_EQ(conv_nodes_count, conv_count);
EXPECT_EQ(quantize_nodes_count, quant_count);
EXPECT_EQ(dequantize_nodes_count, dequant_count);
EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num);
}
TEST(CpuQuantizePass, transpose) {
// a1->Quant->a2->Conv1->b1->Dequant->b2
// b2->Quant->b3->Transpose->c1->Dequant->c2
// c2->Quant->c3->Conv2->d1->Dequant->d2
// d2->Quant->d3->Transpose->e1->Dequant->e2
// e2->Dropout->f
int conv_count = 2;
int transpose_count = 2;
int quant_count = 4;
int dequant_count = 4;
// 4 Quant + 4 IN + 4 DeQuant + 4 OUT
int added_nodes_count = 16;
MainTestTranspose(BuildProgramDescTranspose(), conv_count, transpose_count,
quant_count, dequant_count, added_nodes_count, 2.0f * 127);
}
} // namespace
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -34,6 +34,9 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() {
rules_["prior_box"]["Image"] = ScaleAlgo::NONE;
rules_["prior_box"]["Boxes"] = ScaleAlgo::NONE;
rules_["prior_box"]["Variances"] = ScaleAlgo::NONE;
rules_["transpose"]["X"] = ScaleAlgo::KL;
rules_["transpose"]["Out"] = ScaleAlgo::KL;
}
ScaleAlgo MkldnnQuantizerConfig::scale_algo(
......
......@@ -268,7 +268,7 @@ TEST(Analyzer_int8_mobilenet_ssd, quantization) {
q_cfg.EnableMkldnnQuantizer();
q_cfg.mkldnn_quantizer_config();
std::unordered_set<std::string> quantize_operators(
{"conv2d", "depthwise_conv2d", "prior_box"});
{"conv2d", "depthwise_conv2d", "prior_box", "transpose2"});
q_cfg.mkldnn_quantizer_config()->SetEnabledOpTypes(quantize_operators);
q_cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data);
q_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(FLAGS_warmup_batch_size);
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
......@@ -29,6 +30,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
mkldnn::memory::data_type in_type = platform::MKLDNNGetDataType<T>();
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
......@@ -49,8 +51,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
nchw_tz, axis,
ctx.op().Output("Out") + std::to_string(input->format()));
platform::TransposeMKLDNNHandler handler(nchw_tz, axis, dev_ctx,
mkldnn_engine, key);
platform::TransposeMKLDNNHandler handler(
nchw_tz, axis, input->type(), in_type, dev_ctx, mkldnn_engine, key);
auto transpose_src_memory_p = handler.AcquireSrcMemory(
input->format(), platform::to_void_cast<T>(input_data));
......@@ -78,7 +80,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
if (!x_grad) return;
mkldnn::memory::data_type in_type = platform::MKLDNNGetDataType<T>();
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
......@@ -103,7 +105,8 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const std::string key = platform::TransposeMKLDNNHandler::GetHash(
nchw_tz, axis, ctx.op().Output(framework::GradVarName("X")));
platform::TransposeMKLDNNHandler handler(nchw_tz, reversed_axis, dev_ctx,
platform::TransposeMKLDNNHandler handler(nchw_tz, reversed_axis,
x_grad->type(), in_type, dev_ctx,
mkldnn_engine, key);
auto transpose_src_memory_p = handler.AcquireSrcMemory(
......@@ -124,11 +127,35 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(transpose2, MKLDNN, ::paddle::platform::CPUPlace,
ops::TransposeMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(transpose, MKLDNN, ::paddle::platform::CPUPlace,
ops::TransposeMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN,
::paddle::platform::CPUPlace, FP32,
ops::kTransposeMKLDNNFP32,
ops::TransposeMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN,
::paddle::platform::CPUPlace, U8,
ops::kTransposeMKLDNNINT8,
ops::TransposeMKLDNNOpKernel<uint8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN,
::paddle::platform::CPUPlace, S8,
ops::kTransposeMKLDNNINT8,
ops::TransposeMKLDNNOpKernel<int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose, MKLDNN,
::paddle::platform::CPUPlace, FP32,
ops::kTransposeMKLDNNFP32,
ops::TransposeMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose, MKLDNN,
::paddle::platform::CPUPlace, U8,
ops::kTransposeMKLDNNINT8,
ops::TransposeMKLDNNOpKernel<uint8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose, MKLDNN,
::paddle::platform::CPUPlace, S8,
ops::kTransposeMKLDNNINT8,
ops::TransposeMKLDNNOpKernel<int8_t>);
REGISTER_OP_KERNEL(transpose_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::TransposeMKLDNNGradOpKernel<float>);
......
......@@ -65,15 +65,23 @@ class TransposeOp : public framework::OperatorWithKernel {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = ctx.Input<Tensor>("X")->type();
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
using framework::proto::VarType;
customized_type_value = (input_data_type == VarType::INT8 ||
input_data_type == VarType::UINT8)
? kTransposeMKLDNNINT8
: kTransposeMKLDNNFP32;
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace(), layout_, library_);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
library_, customized_type_value);
}
};
......@@ -99,6 +107,13 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
/* int8 parameters */
AddAttr<bool>("use_quantizer",
"(bool, default false) "
"Set to true for operators that should be quantized and use "
"int8 kernel. "
"Only used on CPU.")
.SetDefault(false);
AddComment(R"DOC(
Transpose Operator.
......@@ -196,16 +211,24 @@ class Transpose2Op : public TransposeOp {
const framework::ExecutionContext &ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = ctx.Input<Tensor>("X")->type();
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
using framework::proto::VarType;
customized_type_value = (input_data_type == VarType::INT8 ||
input_data_type == VarType::UINT8)
? kTransposeMKLDNNINT8
: kTransposeMKLDNNFP32;
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace(), layout_, library_);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
library_, customized_type_value);
}
};
......
......@@ -21,6 +21,8 @@ limitations under the License. */
namespace paddle {
namespace operators {
enum { kTransposeMKLDNNFP32 = 1, kTransposeMKLDNNINT8 = 2 };
template <typename DeviceContext, typename T>
inline void TransCompute(const int dim, const DeviceContext& dev_ctx,
const framework::Tensor& in, framework::Tensor* out,
......
......@@ -828,12 +828,16 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
public:
TransposeMKLDNNHandler(std::vector<int>& dims, // NOLINT
std::vector<int>& axis, // NOLINT
framework::proto::VarType::Type vtype,
mkldnn::memory::data_type dtype,
const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims),
axis_(axis),
logical_axis_(dims.size(), 0) {}
logical_axis_(dims.size(), 0),
vtype_(vtype),
dtype_(dtype) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::format& fmt, void* ptr) {
......@@ -847,9 +851,8 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
logical_axis_[i] = i;
}
auto src_md = fmt != mkldnn::memory::format::nchw
? platform::MKLDNNMemDesc(
dims_, platform::MKLDNNGetDataType<float>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_);
? platform::MKLDNNMemDesc(dims_, dtype_, fmt)
: Axis2MemoryDesc(dims_, logical_axis_, dtype_);
mem_p = std::make_shared<mkldnn::memory>(
mkldnn::memory::primitive_desc{src_md, engine_}, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
......@@ -866,14 +869,14 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
auto dst_mdp = mkldnn::memory::primitive_desc{
Axis2MemoryDesc(dims_, axis_), engine_};
Axis2MemoryDesc(dims_, axis_, dtype_), engine_};
auto dst_data = output->mutable_data<float>(place, dst_mdp.get_size());
auto dst_data = output->mutable_data(place, vtype_);
mem_p = std::make_shared<mkldnn::memory>(dst_mdp, dst_data);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
auto dst_data = output->mutable_data<float>(place);
auto dst_data = output->mutable_data(place, vtype_);
mem_p->set_data_handle(dst_data);
}
return mem_p;
......@@ -901,8 +904,8 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
protected:
mkldnn_memory_desc_t Axis2MemoryDesc(std::vector<int>& nchw_tz, // NOLINT
std::vector<int>& axis // NOLINT
) {
std::vector<int>& axis, // NOLINT
mkldnn::memory::data_type dtype) {
mkldnn_memory_desc_t mem_fmt;
mem_fmt.primitive_kind = mkldnn_memory;
......@@ -911,7 +914,12 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
mem_fmt.dims[i] = nchw_tz[i]; // logical dimensions (nchw format,
// regardless physical layout)
}
mem_fmt.data_type = mkldnn_f32;
if (dtype == mkldnn::memory::data_type::s8)
mem_fmt.data_type = mkldnn_s8;
else if (dtype == mkldnn::memory::data_type::u8)
mem_fmt.data_type = mkldnn_u8;
else
mem_fmt.data_type = mkldnn_f32;
mem_fmt.format = mkldnn_blocked;
unsigned int total_stride = 1;
......@@ -933,6 +941,8 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
std::vector<int> dims_;
std::vector<int> axis_;
std::vector<int> logical_axis_;
framework::proto::VarType::Type vtype_;
mkldnn::memory::data_type dtype_;
};
class ReorderMKLDNNHandler : public MKLDNNHandler {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册