提交 c4107748 编写于 作者: K Krzysztof Binias

Add support for dim equals 2 in activation functions

上级 c00a5dec
......@@ -40,13 +40,15 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
const T *dst_data = dst->template mutable_data<T>(ctx.GetPlace());
// get memory dim
PADDLE_ENFORCE(src->dims().size() == 4,
"Input dim must be with 4, i.e. NCHW");
PADDLE_ENFORCE(src->dims().size() == 2 || src->dims().size() == 4,
"Input dim must be with 2 or 4");
std::vector<int> src_tz = framework::vectorize2int(src->dims());
// create memory description
// TODO(kbinias-intel): support more formats
auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
auto data_md = src_tz.size() == 2
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nc)
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw);
// create memory primitives
......@@ -91,7 +93,10 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
std::vector<int> src_tz = framework::vectorize2int(x->dims());
// create memory description
auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
auto data_md = src_tz.size() == 2
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nc)
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw);
// create memory primitives
......
......@@ -535,9 +535,37 @@ class TestSwish(OpTest):
#--------------------test MKLDNN--------------------
class TestMKLDNNRelu(TestRelu):
class TestMKLDNNReluDim2(TestRelu):
def setUp(self):
super(TestMKLDNNRelu, self).setUp()
super(TestMKLDNNReluDim2, self).setUp()
self.attrs = {"use_mkldnn": True}
class TestMKLDNNTanhDim2(TestTanh):
def setUp(self):
super(TestMKLDNNTanhDim2, self).setUp()
self.attrs = {"use_mkldnn": True}
class TestMKLDNNSqrtDim2(TestSqrt):
def setUp(self):
super(TestMKLDNNSqrtDim2, self).setUp()
self.attrs = {"use_mkldnn": True}
class TestMKLDNNAbsDim2(TestAbs):
def setUp(self):
super(TestMKLDNNAbsDim2, self).setUp()
self.attrs = {"use_mkldnn": True}
class TestMKLDNNReluDim4(TestRelu):
def setUp(self):
super(TestMKLDNNReluDim4, self).setUp()
x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32")
# The same reason with TestAbs
......@@ -549,9 +577,9 @@ class TestMKLDNNRelu(TestRelu):
self.attrs = {"use_mkldnn": True}
class TestMKLDNNTanh(TestTanh):
class TestMKLDNNTanhDim4(TestTanh):
def setUp(self):
super(TestMKLDNNTanh, self).setUp()
super(TestMKLDNNTanhDim4, self).setUp()
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32")
......@@ -560,9 +588,9 @@ class TestMKLDNNTanh(TestTanh):
self.attrs = {"use_mkldnn": True}
class TestMKLDNNSqrt(TestSqrt):
class TestMKLDNNSqrtDim4(TestSqrt):
def setUp(self):
super(TestMKLDNNSqrt, self).setUp()
super(TestMKLDNNSqrtDim4, self).setUp()
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32")
......@@ -571,9 +599,9 @@ class TestMKLDNNSqrt(TestSqrt):
self.attrs = {"use_mkldnn": True}
class TestMKLDNNAbs(TestAbs):
class TestMKLDNNAbsDim4(TestAbs):
def setUp(self):
super(TestMKLDNNAbs, self).setUp()
super(TestMKLDNNAbsDim4, self).setUp()
x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32")
# The same reason with TestAbs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册