未验证 提交 f866bfed 编写于 作者: S Sylwester Fraczek 提交者: GitHub

support mnist and resnet dygraph_to_static test (#25774)

* support mnist and resnet dygraph_to_static test

* make FLAGS_use_mkldnn a public flag

* fix test_get_set_flags

* Change name of a function

* Rerun CIs commit

* Fix oneDNN dygraph training
Co-authored-by: NAdam <38704900+grygielski@users.noreply.github.com>
Co-authored-by: Ngrygielski <adam.grygielski@gmail.com>
上级 ebc5f997
......@@ -147,8 +147,10 @@ class LayerHelper(LayerHelperBase):
if 'use_cudnn' in self.kwargs and self.kwargs.get('use_cudnn'):
act['use_cudnn'] = self.kwargs.get('use_cudnn')
if 'use_mkldnn' in self.kwargs:
act['use_mkldnn'] = self.kwargs.get('use_mkldnn')
use_mkldnn = self.kwargs.get(
'use_mkldnn', core.globals().get("FLAGS_use_mkldnn", False))
if use_mkldnn:
act['use_mkldnn'] = use_mkldnn
act_type = act.pop('type')
tmp = self.create_variable_for_type_inference(dtype=input_var.dtype)
......
......@@ -154,6 +154,18 @@ class TestMNISTWithToStatic(TestMNIST):
msg='dygraph is {}\n static_res is \n{}'.format(dygraph_loss,
static_loss))
def test_mnist_declarative_cpu_vs_mkldnn(self):
dygraph_loss_cpu = self.train_dygraph()
fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
dygraph_loss_mkldnn = self.train_dygraph()
finally:
fluid.set_flags({'FLAGS_use_mkldnn': False})
self.assertTrue(
np.allclose(dygraph_loss_cpu, dygraph_loss_mkldnn),
msg='cpu dygraph is {}\n mkldnn dygraph is \n{}'.format(
dygraph_loss_cpu, dygraph_loss_mkldnn))
def train(self, to_static=False):
prog_trans = ProgramTranslator()
prog_trans.enable(to_static)
......
......@@ -346,6 +346,13 @@ class TestResnet(unittest.TestCase):
dygraph_loss))
self.verify_predict()
def test_in_static_mode_mkldnn(self):
fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
train(to_static=True)
finally:
fluid.set_flags({'FLAGS_use_mkldnn': False})
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册