未验证 提交 2c5c7b2a 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #15922 from kbinias/kbinias/reuse-primitives-activations-and-softmax-mkldnn-ut

MKL-DNN: Add Activations and Softmax UTs to check if primitives already exist in backward
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle.fluid.core as core
import paddle.fluid as fluid
def check_if_mkldnn_primitives_exist_in_bwd(test_case, op_type, x, out,
out_grad, x_grad):
def __assert_close(tensor, np_array, msg, atol=1e-4):
test_case.assertTrue(
np.allclose(
np.array(tensor), np_array, atol=atol), msg)
place = core.CPUPlace()
var_dict = {'x': x, 'out': out, 'out@GRAD': out_grad, 'x@GRAD': x_grad}
var_names = list(var_dict.keys())
ground_truth = {name: var_dict[name] for name in var_names}
program = fluid.Program()
with fluid.program_guard(program):
block = program.global_block()
for name in ground_truth:
block.create_var(
name=name, dtype=np.float32, shape=ground_truth[name].shape)
op = block.append_op(
type=op_type,
inputs={'X': block.var('x'), },
outputs={'Out': block.var('out')},
attrs={'use_mkldnn': True})
# Generate backward op_desc
grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(op.desc,
set(), [])
grad_op_desc = grad_op_desc_list[0]
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(grad_op_desc)
for var_name in grad_op_desc.output_arg_names():
block.desc.var(var_name.encode('ascii'))
grad_op_desc.infer_var_type(block.desc)
grad_op_desc.infer_shape(block.desc)
for arg in grad_op_desc.output_arg_names():
grad_var = block.desc.find_var(arg.encode('ascii'))
grad_var.set_dtype(core.VarDesc.VarType.FP32)
exe = fluid.Executor(place)
# Do at least 2 iterations
for i in range(2):
out = exe.run(
program,
feed={name: var_dict[name]
for name in ['x', 'out@GRAD']},
fetch_list=['x@GRAD', 'out'])
__assert_close(x_grad, out[0], 'x@GRAD')
......@@ -19,7 +19,7 @@ import numpy as np
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.test_activation_op import TestRelu, TestTanh, TestSqrt, TestAbs
import paddle.fluid as fluid
from mkldnn_op_test import check_if_mkldnn_primitives_exist_in_bwd
class TestMKLDNNReluDim2(TestRelu):
......@@ -98,62 +98,24 @@ class TestMKLDNNAbsDim4(TestAbs):
# Check if primitives already exist in backward
class TestMKLDNNReluPrimitivesAlreadyExist(unittest.TestCase):
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def test_check_forward_backward(self):
place = core.CPUPlace()
class TestMKLDNNAbsPrimitivesAlreadyExist(unittest.TestCase):
def setUp(self):
super(TestMKLDNNAbsPrimitivesAlreadyExist, self).setUp()
np.random.seed(123)
x = np.random.uniform(-1, 1, [2, 2]).astype(np.float32)
out = np.abs(x)
out_grad = np.random.random_sample(x.shape).astype(np.float32)
x_grad = out_grad * np.sign(x) # Abs grad calculation
var_dict = {'x': x, 'out': out, 'out@GRAD': out_grad, 'x@GRAD': x_grad}
var_names = list(var_dict.keys())
ground_truth = {name: var_dict[name] for name in var_names}
program = fluid.Program()
with fluid.program_guard(program):
block = program.global_block()
for name in ground_truth:
block.create_var(
name=name, dtype='float32', shape=ground_truth[name].shape)
relu_op = block.append_op(
type="abs",
inputs={"X": block.var('x'), },
outputs={"Out": block.var('out')},
attrs={"use_mkldnn": True})
# Generate backward op_desc
grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(
relu_op.desc, set(), [])
grad_op_desc = grad_op_desc_list[0]
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(grad_op_desc)
for var_name in grad_op_desc.output_arg_names():
block.desc.var(var_name.encode("ascii"))
grad_op_desc.infer_var_type(block.desc)
grad_op_desc.infer_shape(block.desc)
for arg in grad_op_desc.output_arg_names():
grad_var = block.desc.find_var(arg.encode("ascii"))
grad_var.set_dtype(core.VarDesc.VarType.FP32)
exe = fluid.Executor(place)
# Do at least 2 iterations
for i in range(2):
out = exe.run(
program,
feed={name: var_dict[name]
for name in ['x', 'out@GRAD']},
fetch_list=['x@GRAD'])
self.__assert_close(x_grad, out[0], "x@GRAD")
self.op_type = 'abs'
self.x = np.random.uniform(-1, 1, [2, 2]).astype(np.float32)
self.out = np.abs(self.x)
self.out_grad = np.random.random_sample(self.x.shape).astype(np.float32)
self.x_grad = self.__abs_bwd(self.x, self.out_grad)
# Abs grad calculation
def __abs_bwd(self, x, out_grad):
return out_grad * np.sign(x)
def test_check(self):
check_if_mkldnn_primitives_exist_in_bwd(
self, self.op_type, self.x, self.out, self.out_grad, self.x_grad)
if __name__ == '__main__':
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.test_softmax_op import TestSoftmaxOp, stable_softmax
from mkldnn_op_test import check_if_mkldnn_primitives_exist_in_bwd
class TestSoftmaxMKLDNNOp(TestSoftmaxOp):
def init_kernel_type(self):
self.use_mkldnn = True
class TestSoftmaxMKLDNNOp2(TestSoftmaxMKLDNNOp):
def get_x_shape(self):
return [2, 3, 4, 5]
# Check if primitives already exist in backward
class TestSoftmaxMKLDNNPrimitivesAlreadyExist(unittest.TestCase):
def setUp(self):
super(TestSoftmaxMKLDNNPrimitivesAlreadyExist, self).setUp()
np.random.seed(123)
self.op_type = 'softmax'
self.x = np.random.uniform(-1, 1, 2).astype(np.float32)
self.out = stable_softmax(self.x)
self.out_grad = np.random.random_sample(self.x.shape).astype(np.float32)
self.x_grad = self.__softmax_bwd(self.out, self.out_grad)
# Softmax grad calculation
def __softmax_bwd(self, out, out_grad):
return out * (out_grad - np.dot(out, out_grad))
def test_check(self):
check_if_mkldnn_primitives_exist_in_bwd(
self, self.op_type, self.x, self.out, self.out_grad, self.x_grad)
if __name__ == '__main__':
unittest.main()
......@@ -144,15 +144,5 @@ class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp):
return [2, 3, 4, 5]
class TestSoftmaxMKLDNNOp(TestSoftmaxOp):
def init_kernel_type(self):
self.use_mkldnn = True
class TestSoftmaxMKLDNNOp2(TestSoftmaxMKLDNNOp):
def get_x_shape(self):
return [2, 3, 4, 5]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册