未验证 提交 3ecf6bb3 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #15028 from yihuaxu/develop_641313ea_elementwise_mul_mkldnn_bug_fix

Fix the exception when tensor format is x
...@@ -94,11 +94,15 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) { ...@@ -94,11 +94,15 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
} }
// Easy for profiling independently. // Easy for profiling independently.
TEST(Analyzer_MM_DNN, profile) { void profile(bool use_mkldnn = false) {
contrib::AnalysisConfig cfg; contrib::AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
if (use_mkldnn) {
cfg.EnableMKLDNN();
}
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all); SetInput(&input_slots_all);
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg), TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
...@@ -119,6 +123,11 @@ TEST(Analyzer_MM_DNN, profile) { ...@@ -119,6 +123,11 @@ TEST(Analyzer_MM_DNN, profile) {
} }
} }
TEST(Analyzer_MM_DNN, profile) { profile(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_MM_DNN, profile_mkldnn) { profile(true /* use_mkldnn */); }
#endif
// Check the fuse status // Check the fuse status
TEST(Analyzer_MM_DNN, fuse_statis) { TEST(Analyzer_MM_DNN, fuse_statis) {
contrib::AnalysisConfig cfg; contrib::AnalysisConfig cfg;
...@@ -131,16 +140,25 @@ TEST(Analyzer_MM_DNN, fuse_statis) { ...@@ -131,16 +140,25 @@ TEST(Analyzer_MM_DNN, fuse_statis) {
} }
// Compare result of NativeConfig and AnalysisConfig // Compare result of NativeConfig and AnalysisConfig
TEST(Analyzer_MM_DNN, compare) { void compare(bool use_mkldnn = false) {
contrib::AnalysisConfig cfg; contrib::AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
if (use_mkldnn) {
cfg.EnableMKLDNN();
}
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all); SetInput(&input_slots_all);
CompareNativeAndAnalysis( CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
} }
TEST(Analyzer_MM_DNN, compare) { compare(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_MM_DNN, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif
// Compare Deterministic result // Compare Deterministic result
TEST(Analyzer_MM_DNN, compare_determine) { TEST(Analyzer_MM_DNN, compare_determine) {
AnalysisConfig cfg; AnalysisConfig cfg;
......
...@@ -134,16 +134,18 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -134,16 +134,18 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
const bool are_inputs_in_same_format = x->format() == y->format(); const bool are_inputs_in_same_format = x->format() == y->format();
const bool is_x_nchw = x->format() == memory::format::nchw; const bool is_x_nchw = x->format() == memory::format::nchw;
const bool is_x_nc = x->format() == memory::format::nc; const bool is_x_nc = x->format() == memory::format::nc;
const bool is_x_x = x->format() == memory::format::x;
const bool is_y_nchw = y->format() == memory::format::nchw; const bool is_y_nchw = y->format() == memory::format::nchw;
const bool is_y_nc = y->format() == memory::format::nc; const bool is_y_nc = y->format() == memory::format::nc;
const bool is_y_x = y->format() == memory::format::x;
if (!are_inputs_in_same_format) { if (!are_inputs_in_same_format) {
using platform::MKLDNNDeviceContext; using platform::MKLDNNDeviceContext;
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
if (!(is_x_nchw || is_x_nc)) if (!(is_x_nchw || is_x_nc || is_x_x))
ReorderInput<T>(const_cast<Tensor*>(x), ctx.GetPlace(), mkldnn_engine, ReorderInput<T>(const_cast<Tensor*>(x), ctx.GetPlace(), mkldnn_engine,
x->dims().size() == 4); x->dims().size() == 4);
if (!(is_y_nchw || is_y_nc)) if (!(is_y_nchw || is_y_nc || is_y_x))
ReorderInput<T>(const_cast<Tensor*>(y), ctx.GetPlace(), mkldnn_engine, ReorderInput<T>(const_cast<Tensor*>(y), ctx.GetPlace(), mkldnn_engine,
y->dims().size() == 4); y->dims().size() == 4);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册