diff --git a/python/paddle/fluid/layer_helper.py b/python/paddle/fluid/layer_helper.py index 19822e410c71aa993e2d90a92c57c3522023ad81..db556913384785e1f11ba05dcc524ef1f1de92ab 100644 --- a/python/paddle/fluid/layer_helper.py +++ b/python/paddle/fluid/layer_helper.py @@ -147,8 +147,10 @@ class LayerHelper(LayerHelperBase): if 'use_cudnn' in self.kwargs and self.kwargs.get('use_cudnn'): act['use_cudnn'] = self.kwargs.get('use_cudnn') - if 'use_mkldnn' in self.kwargs: - act['use_mkldnn'] = self.kwargs.get('use_mkldnn') + use_mkldnn = self.kwargs.get( + 'use_mkldnn', core.globals().get("FLAGS_use_mkldnn", False)) + if use_mkldnn: + act['use_mkldnn'] = use_mkldnn act_type = act.pop('type') tmp = self.create_variable_for_type_inference(dtype=input_var.dtype) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist.py index 1ef3bd1bf150056816283c83fa3ff6af1e589732..bd600d2f2dbd6341ff7a83d6636047d01cae7859 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist.py @@ -154,6 +154,18 @@ class TestMNISTWithToStatic(TestMNIST): msg='dygraph is {}\n static_res is \n{}'.format(dygraph_loss, static_loss)) + def test_mnist_declarative_cpu_vs_mkldnn(self): + dygraph_loss_cpu = self.train_dygraph() + fluid.set_flags({'FLAGS_use_mkldnn': True}) + try: + dygraph_loss_mkldnn = self.train_dygraph() + finally: + fluid.set_flags({'FLAGS_use_mkldnn': False}) + self.assertTrue( + np.allclose(dygraph_loss_cpu, dygraph_loss_mkldnn), + msg='cpu dygraph is {}\n mkldnn dygraph is \n{}'.format( + dygraph_loss_cpu, dygraph_loss_mkldnn)) + def train(self, to_static=False): prog_trans = ProgramTranslator() prog_trans.enable(to_static) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py index 6556b2f03bd5304e290792d07d1d969ab255bfdc..203c8ddb3488c0fef9a0a590378505e5b61233cf 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py @@ -346,6 +346,13 @@ class TestResnet(unittest.TestCase): dygraph_loss)) self.verify_predict() + def test_in_static_mode_mkldnn(self): + fluid.set_flags({'FLAGS_use_mkldnn': True}) + try: + train(to_static=True) + finally: + fluid.set_flags({'FLAGS_use_mkldnn': False}) + if __name__ == '__main__': unittest.main()