提交 49b09327 编写于 作者: M Michal Gallus

MKLDNN elementwise_mul: Reorder on non-nchw input, fallback on non-16 divisable fm

test=develop
上级 f820573b
......@@ -95,6 +95,26 @@ static void UpdateDataFormat(const framework::ExecutionContext& ctx,
}
}
template <typename T>
static void ReorderInput(framework::Tensor* tensor,
const platform::Place& place,
const mkldnn::engine& engine,
bool isFourDim) {
using platform::to_void_cast;
auto dims = paddle::framework::vectorize2int(tensor->dims());
framework::Tensor out_tensor;
out_tensor.Resize(tensor->dims());
out_tensor.set_format(isFourDim ? memory::format::nchw : memory::format::nc);
out_tensor.set_layout(tensor->layout());
mkldnn::memory input_memory = {{{dims, platform::MKLDNNGetDataType<T>(),
tensor->format()}, engine}, to_void_cast<T>(tensor->data<T>())};
mkldnn::memory output_memory = {{{dims, platform::MKLDNNGetDataType<T>(),
out_tensor.format()}, engine},
to_void_cast<T>(out_tensor.mutable_data<T>(place))};
platform::Reorder(input_memory, output_memory);
tensor->ShareDataWith(out_tensor);
}
template <typename T>
class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
public:
......@@ -111,12 +131,15 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
auto x_dims = x->dims();
auto y_dims_untrimmed = y->dims();
auto x_int_dims = paddle::framework::vectorize2int(x_dims);
UpdateDataFormat(ctx, (Tensor*)x, "x_data_format");
UpdateDataFormat(ctx, (Tensor*)y, "y_data_format");
if (x->format() == memory::format::nChw16c && y->format() == memory::format::nc) {
if (x_dims != y_dims_untrimmed) {
const bool are_dims_divisable = !(x_int_dims[1] % 16);
const bool is_x_format_correct = x->format() == memory::format::nChw16c;
const bool is_y_format_correct = y->format() == memory::format::nc;
if (is_x_format_correct && is_y_format_correct && are_dims_divisable) {
int pre, n, post;
get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post);
......@@ -163,11 +186,23 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
z->set_layout(DataLayout::kMKLDNN);
z->set_format(x->format());
} else {
PADDLE_THROW("Not implemented when dims are equal");
}
} else {
// Fallback to naive version:
const bool are_inputs_in_same_format = x->format() == y->format();
const bool is_x_nchw= x->format() == memory::format::nchw;
const bool is_x_nc = x->format() == memory::format::nc;
const bool is_y_nchw= y->format() == memory::format::nchw;
const bool is_y_nc = y->format() == memory::format::nc;
if(!are_inputs_in_same_format) {
using platform::MKLDNNDeviceContext;
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
if(!(is_x_nchw || is_x_nc))
ReorderInput<T>((Tensor*)x, ctx.GetPlace(), mkldnn_engine, x->dims().size() == 4);
if(!(is_y_nchw || is_y_nc))
ReorderInput<T>((Tensor*)y, ctx.GetPlace(), mkldnn_engine, y->dims().size() == 4);
}
auto mul_func = [](T a, T b) -> T { return a * b; };
TransformFunctor<decltype(mul_func), T,
......
......@@ -49,7 +49,7 @@ class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp):
def test_check_grad_ingore_y(self):
pass
@unittest.skip("Not implemented yet.")
@unittest.skip("Not implemented yet.") # TODO(mgallus): enable when implemented.
class TestElementwiseMulMKLDNNOp_BroadcastNCHW8c(ElementwiseMulOp):
def init_input_output(self):
x = np.random.rand(1, 8, 2, 2).astype(self.dtype)
......@@ -159,8 +159,7 @@ class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp):
def test_check_grad_ingore_y(self):
pass
@unittest.skip("Not implemented yet.")
class TestElementwiseMulMKLDNNOp_FallbackWithReorder(ElementwiseMulOp):
class TestElementwiseMulMKLDNNOp_FallbackWithReorder1(ElementwiseMulOp):
def init_input_output(self):
self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
......@@ -169,7 +168,7 @@ class TestElementwiseMulMKLDNNOp_FallbackWithReorder(ElementwiseMulOp):
self.out = self.x * y
def setUp(self):
super(TestElementwiseMulMKLDNNOp_FallbackNCHW16C, self).setUp()
super(TestElementwiseMulMKLDNNOp_FallbackWithReorder1, self).setUp()
self.attrs["x_data_format"] = "nchw"
self.attrs["y_data_format"] = "nchw16c"
......@@ -188,5 +187,60 @@ class TestElementwiseMulMKLDNNOp_FallbackWithReorder(ElementwiseMulOp):
def test_check_grad_ingore_y(self):
pass
class TestElementwiseMulMKLDNNOp_FallbackWithReorder2(ElementwiseMulOp):
def init_input_output(self):
self.y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
self.out = x * self.y
def setUp(self):
super(TestElementwiseMulMKLDNNOp_FallbackWithReorder2, self).setUp()
self.attrs["x_data_format"] = "nchw16c"
self.attrs["y_data_format"] = "nchw"
def init_kernel_type(self):
self.use_mkldnn = True
def init_axis(self):
self.axis = 0
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
class TestElementwiseMulMKLDNNOp_FallbackNoReorders2(ElementwiseMulOp):
def init_input_output(self):
self.x = np.random.rand(1, 16).astype(self.dtype)
self.y = np.random.rand(1, 16).astype(self.dtype)
self.out = self.x * self.y
def setUp(self):
super(TestElementwiseMulMKLDNNOp_FallbackNoReorders2, self).setUp()
self.attrs["x_data_format"] = "nc"
self.attrs["y_data_format"] = "nc"
def init_kernel_type(self):
self.use_mkldnn = True
def init_axis(self):
self.axis = 0
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册