提交 13816dd4 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] Fix to crash of Transformer when mkldnn is to be used (#16233)

* - Fix to crash of Transformer when mkldnn is to be used

Desc: TensorCopy was not setting MKLDNN primitive descriptor when layout was to be kMKLDNN

test=develop

* - Enable transformer for mkl-dnn

test=develo

* - Compilation fix

test=develop

* - Removed manual selection of MKL-DNN ops to be used in Transformer test

test=develop
上级 7e20e769
......@@ -44,6 +44,11 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
<< dst_place;
return;
}
#ifdef PADDLE_WITH_MKLDNN
if (src.layout() == DataLayout::kMKLDNN) {
dst->set_mkldnn_prim_desc(src.get_mkldnn_prim_desc());
}
#endif
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
}
......
......@@ -110,7 +110,7 @@ set(TRANSFORMER_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/transformer")
download_model_and_data(${TRANSFORMER_INSTALL_DIR} "temp%2Ftransformer_model.tar.gz" "temp%2Ftransformer_data.txt.tar.gz")
inference_analysis_test(test_analyzer_transformer SRCS analyzer_transformer_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8)
ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8 SERIAL)
# ocr
set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr")
......
......@@ -183,10 +183,13 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
}
// Easy for profiling independently.
TEST(Analyzer_Transformer, profile) {
void profile(bool use_mkldnn = false) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<PaddleTensor> outputs;
if (use_mkldnn) {
cfg.EnableMKLDNN();
}
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
......@@ -194,6 +197,11 @@ TEST(Analyzer_Transformer, profile) {
input_slots_all, &outputs, FLAGS_num_threads);
}
TEST(Analyzer_Transformer, profile) { profile(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_Transformer, profile_mkldnn) { profile(true); }
#endif
// Check the fuse status
TEST(Analyzer_Transformer, fuse_statis) {
AnalysisConfig cfg;
......@@ -206,9 +214,12 @@ TEST(Analyzer_Transformer, fuse_statis) {
}
// Compare result of NativeConfig and AnalysisConfig
TEST(Analyzer_Transformer, compare) {
void compare(bool use_mkldnn = false) {
AnalysisConfig cfg;
SetConfig(&cfg);
if (use_mkldnn) {
cfg.EnableMKLDNN();
}
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
......@@ -216,5 +227,10 @@ TEST(Analyzer_Transformer, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
TEST(Analyzer_Transformer, compare) { compare(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_Transformer, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif
} // namespace inference
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册