未验证 提交 f866bfed 编写于 作者: S Sylwester Fraczek 提交者: GitHub

support mnist and resnet dygraph_to_static test (#25774)

* support mnist and resnet dygraph_to_static test

* make FLAGS_use_mkldnn a public flag

* fix test_get_set_flags

* Change name of a function

* Rerun CIs commit

* Fix oneDNN dygraph training
Co-authored-by: NAdam <38704900+grygielski@users.noreply.github.com>
Co-authored-by: Ngrygielski <adam.grygielski@gmail.com>
上级 ebc5f997
...@@ -147,8 +147,10 @@ class LayerHelper(LayerHelperBase): ...@@ -147,8 +147,10 @@ class LayerHelper(LayerHelperBase):
if 'use_cudnn' in self.kwargs and self.kwargs.get('use_cudnn'): if 'use_cudnn' in self.kwargs and self.kwargs.get('use_cudnn'):
act['use_cudnn'] = self.kwargs.get('use_cudnn') act['use_cudnn'] = self.kwargs.get('use_cudnn')
if 'use_mkldnn' in self.kwargs: use_mkldnn = self.kwargs.get(
act['use_mkldnn'] = self.kwargs.get('use_mkldnn') 'use_mkldnn', core.globals().get("FLAGS_use_mkldnn", False))
if use_mkldnn:
act['use_mkldnn'] = use_mkldnn
act_type = act.pop('type') act_type = act.pop('type')
tmp = self.create_variable_for_type_inference(dtype=input_var.dtype) tmp = self.create_variable_for_type_inference(dtype=input_var.dtype)
......
...@@ -154,6 +154,18 @@ class TestMNISTWithToStatic(TestMNIST): ...@@ -154,6 +154,18 @@ class TestMNISTWithToStatic(TestMNIST):
msg='dygraph is {}\n static_res is \n{}'.format(dygraph_loss, msg='dygraph is {}\n static_res is \n{}'.format(dygraph_loss,
static_loss)) static_loss))
def test_mnist_declarative_cpu_vs_mkldnn(self):
dygraph_loss_cpu = self.train_dygraph()
fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
dygraph_loss_mkldnn = self.train_dygraph()
finally:
fluid.set_flags({'FLAGS_use_mkldnn': False})
self.assertTrue(
np.allclose(dygraph_loss_cpu, dygraph_loss_mkldnn),
msg='cpu dygraph is {}\n mkldnn dygraph is \n{}'.format(
dygraph_loss_cpu, dygraph_loss_mkldnn))
def train(self, to_static=False): def train(self, to_static=False):
prog_trans = ProgramTranslator() prog_trans = ProgramTranslator()
prog_trans.enable(to_static) prog_trans.enable(to_static)
......
...@@ -346,6 +346,13 @@ class TestResnet(unittest.TestCase): ...@@ -346,6 +346,13 @@ class TestResnet(unittest.TestCase):
dygraph_loss)) dygraph_loss))
self.verify_predict() self.verify_predict()
def test_in_static_mode_mkldnn(self):
fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
train(to_static=True)
finally:
fluid.set_flags({'FLAGS_use_mkldnn': False})
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册