提交 c4107748 编写于 作者: K Krzysztof Binias

Add support for dim equals 2 in activation functions

上级 c00a5dec
...@@ -40,13 +40,15 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm, ...@@ -40,13 +40,15 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
const T *dst_data = dst->template mutable_data<T>(ctx.GetPlace()); const T *dst_data = dst->template mutable_data<T>(ctx.GetPlace());
// get memory dim // get memory dim
PADDLE_ENFORCE(src->dims().size() == 4, PADDLE_ENFORCE(src->dims().size() == 2 || src->dims().size() == 4,
"Input dim must be with 4, i.e. NCHW"); "Input dim must be with 2 or 4");
std::vector<int> src_tz = framework::vectorize2int(src->dims()); std::vector<int> src_tz = framework::vectorize2int(src->dims());
// create memory description // create memory description
// TODO(kbinias-intel): support more formats auto data_md = src_tz.size() == 2
auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, ? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nc)
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw); mkldnn::memory::format::nchw);
// create memory primitives // create memory primitives
...@@ -91,7 +93,10 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm, ...@@ -91,7 +93,10 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
std::vector<int> src_tz = framework::vectorize2int(x->dims()); std::vector<int> src_tz = framework::vectorize2int(x->dims());
// create memory description // create memory description
auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, auto data_md = src_tz.size() == 2
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nc)
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw); mkldnn::memory::format::nchw);
// create memory primitives // create memory primitives
......
...@@ -535,9 +535,37 @@ class TestSwish(OpTest): ...@@ -535,9 +535,37 @@ class TestSwish(OpTest):
#--------------------test MKLDNN-------------------- #--------------------test MKLDNN--------------------
class TestMKLDNNRelu(TestRelu): class TestMKLDNNReluDim2(TestRelu):
def setUp(self): def setUp(self):
super(TestMKLDNNRelu, self).setUp() super(TestMKLDNNReluDim2, self).setUp()
self.attrs = {"use_mkldnn": True}
class TestMKLDNNTanhDim2(TestTanh):
def setUp(self):
super(TestMKLDNNTanhDim2, self).setUp()
self.attrs = {"use_mkldnn": True}
class TestMKLDNNSqrtDim2(TestSqrt):
def setUp(self):
super(TestMKLDNNSqrtDim2, self).setUp()
self.attrs = {"use_mkldnn": True}
class TestMKLDNNAbsDim2(TestAbs):
def setUp(self):
super(TestMKLDNNAbsDim2, self).setUp()
self.attrs = {"use_mkldnn": True}
class TestMKLDNNReluDim4(TestRelu):
def setUp(self):
super(TestMKLDNNReluDim4, self).setUp()
x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32") x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32")
# The same reason with TestAbs # The same reason with TestAbs
...@@ -549,9 +577,9 @@ class TestMKLDNNRelu(TestRelu): ...@@ -549,9 +577,9 @@ class TestMKLDNNRelu(TestRelu):
self.attrs = {"use_mkldnn": True} self.attrs = {"use_mkldnn": True}
class TestMKLDNNTanh(TestTanh): class TestMKLDNNTanhDim4(TestTanh):
def setUp(self): def setUp(self):
super(TestMKLDNNTanh, self).setUp() super(TestMKLDNNTanhDim4, self).setUp()
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32") 'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32")
...@@ -560,9 +588,9 @@ class TestMKLDNNTanh(TestTanh): ...@@ -560,9 +588,9 @@ class TestMKLDNNTanh(TestTanh):
self.attrs = {"use_mkldnn": True} self.attrs = {"use_mkldnn": True}
class TestMKLDNNSqrt(TestSqrt): class TestMKLDNNSqrtDim4(TestSqrt):
def setUp(self): def setUp(self):
super(TestMKLDNNSqrt, self).setUp() super(TestMKLDNNSqrtDim4, self).setUp()
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32") 'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32")
...@@ -571,9 +599,9 @@ class TestMKLDNNSqrt(TestSqrt): ...@@ -571,9 +599,9 @@ class TestMKLDNNSqrt(TestSqrt):
self.attrs = {"use_mkldnn": True} self.attrs = {"use_mkldnn": True}
class TestMKLDNNAbs(TestAbs): class TestMKLDNNAbsDim4(TestAbs):
def setUp(self): def setUp(self):
super(TestMKLDNNAbs, self).setUp() super(TestMKLDNNAbsDim4, self).setUp()
x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32") x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32")
# The same reason with TestAbs # The same reason with TestAbs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册