提交 13816dd4 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] Fix to crash of Transformer when mkldnn is to be used (#16233)

* - Fix to crash of Transformer when mkldnn is to be used

Desc: TensorCopy was not setting MKLDNN primitive descriptor when layout was to be kMKLDNN

test=develop

* - Enable transformer for mkl-dnn

test=develo

* - Compilation fix

test=develop

* - Removed manual selection of MKL-DNN ops to be used in Transformer test

test=develop
上级 7e20e769
...@@ -44,6 +44,11 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, ...@@ -44,6 +44,11 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
<< dst_place; << dst_place;
return; return;
} }
#ifdef PADDLE_WITH_MKLDNN
if (src.layout() == DataLayout::kMKLDNN) {
dst->set_mkldnn_prim_desc(src.get_mkldnn_prim_desc());
}
#endif
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr, memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size); boost::get<platform::CPUPlace>(src_place), src_ptr, size);
} }
......
...@@ -110,7 +110,7 @@ set(TRANSFORMER_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/transformer") ...@@ -110,7 +110,7 @@ set(TRANSFORMER_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/transformer")
download_model_and_data(${TRANSFORMER_INSTALL_DIR} "temp%2Ftransformer_model.tar.gz" "temp%2Ftransformer_data.txt.tar.gz") download_model_and_data(${TRANSFORMER_INSTALL_DIR} "temp%2Ftransformer_model.tar.gz" "temp%2Ftransformer_data.txt.tar.gz")
inference_analysis_test(test_analyzer_transformer SRCS analyzer_transformer_tester.cc inference_analysis_test(test_analyzer_transformer SRCS analyzer_transformer_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8) ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8 SERIAL)
# ocr # ocr
set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr") set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr")
......
...@@ -183,10 +183,13 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) { ...@@ -183,10 +183,13 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
} }
// Easy for profiling independently. // Easy for profiling independently.
TEST(Analyzer_Transformer, profile) { void profile(bool use_mkldnn = false) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
if (use_mkldnn) {
cfg.EnableMKLDNN();
}
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all); SetInput(&input_slots_all);
...@@ -194,6 +197,11 @@ TEST(Analyzer_Transformer, profile) { ...@@ -194,6 +197,11 @@ TEST(Analyzer_Transformer, profile) {
input_slots_all, &outputs, FLAGS_num_threads); input_slots_all, &outputs, FLAGS_num_threads);
} }
TEST(Analyzer_Transformer, profile) { profile(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_Transformer, profile_mkldnn) { profile(true); }
#endif
// Check the fuse status // Check the fuse status
TEST(Analyzer_Transformer, fuse_statis) { TEST(Analyzer_Transformer, fuse_statis) {
AnalysisConfig cfg; AnalysisConfig cfg;
...@@ -206,9 +214,12 @@ TEST(Analyzer_Transformer, fuse_statis) { ...@@ -206,9 +214,12 @@ TEST(Analyzer_Transformer, fuse_statis) {
} }
// Compare result of NativeConfig and AnalysisConfig // Compare result of NativeConfig and AnalysisConfig
TEST(Analyzer_Transformer, compare) { void compare(bool use_mkldnn = false) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
if (use_mkldnn) {
cfg.EnableMKLDNN();
}
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all); SetInput(&input_slots_all);
...@@ -216,5 +227,10 @@ TEST(Analyzer_Transformer, compare) { ...@@ -216,5 +227,10 @@ TEST(Analyzer_Transformer, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
} }
TEST(Analyzer_Transformer, compare) { compare(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_Transformer, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册