提交 08f63c4d 编写于 作者: M Michal Gallus

MKLDNN elementwise_mul: Lint changes to UT & integration

test=develop
上级 73b7cd04
...@@ -98,19 +98,19 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -98,19 +98,19 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn", "(bool, default false). Used by MKLDNN.") AddAttr<bool>("use_mkldnn", "(bool, default false). Used by MKLDNN.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
"x_data_format", "x_data_format",
"(string, default NCHW) Only used in mkldnn" "(string, default NCHW) Only used in mkldnn"
"An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". " "An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". "
"Defaults to \"\". Specify the data format of the output data, " "Defaults to \"\". Specify the data format of the output data, "
"the input will be transformed automatically. ") "the input will be transformed automatically. ")
.SetDefault(""); .SetDefault("");
AddAttr<std::string>( AddAttr<std::string>(
"y_data_format", "y_data_format",
"(string, default \"\") Only used in mkldnn" "(string, default \"\") Only used in mkldnn"
"An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". " "An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". "
"Defaults to \"\". Specify the data format of the output data, " "Defaults to \"\". Specify the data format of the output data, "
"the input will be transformed automatically. ") "the input will be transformed automatically. ")
.SetDefault(""); .SetDefault("");
AddComment(string::Sprintf(R"DOC( AddComment(string::Sprintf(R"DOC(
Elementwise %s Operator Elementwise %s Operator
......
...@@ -71,13 +71,13 @@ void check(const float* x, const float* y, float* z, int w) { ...@@ -71,13 +71,13 @@ void check(const float* x, const float* y, float* z, int w) {
static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) { static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) {
std::transform(format.begin(), format.end(), format.begin(), ::tolower); std::transform(format.begin(), format.end(), format.begin(), ::tolower);
if(!format.compare("nchw")) { if (!format.compare("nchw")) {
return memory::format::nchw; return memory::format::nchw;
} else if(!format.compare("nchw16c")) { } else if (!format.compare("nchw16c")) {
return memory::format::nChw16c; return memory::format::nChw16c;
} else if(!format.compare("nchw8c")) { } else if (!format.compare("nchw8c")) {
return memory::format::nChw8c; return memory::format::nChw8c;
} else if(!format.compare("nhwc")) { } else if (!format.compare("nhwc")) {
return memory::format::nhwc; return memory::format::nhwc;
} else { } else {
return memory::format::any; return memory::format::any;
...@@ -85,8 +85,8 @@ static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) { ...@@ -85,8 +85,8 @@ static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) {
} }
static void UpdateDataFormat(const framework::ExecutionContext& ctx, static void UpdateDataFormat(const framework::ExecutionContext& ctx,
framework::Tensor* tensor, const char* attribute) { framework::Tensor* tensor, const char* attribute) {
if(ctx.op().HasAttr(attribute)) { if (ctx.op().HasAttr(attribute)) {
auto format_as_string = ctx.Attr<std::string>(attribute); auto format_as_string = ctx.Attr<std::string>(attribute);
auto format = StringToMKLDNNFormat(format_as_string); auto format = StringToMKLDNNFormat(format_as_string);
if (format != memory::format::any) { if (format != memory::format::any) {
...@@ -98,19 +98,19 @@ static void UpdateDataFormat(const framework::ExecutionContext& ctx, ...@@ -98,19 +98,19 @@ static void UpdateDataFormat(const framework::ExecutionContext& ctx,
template <typename T> template <typename T>
static void ReorderInput(framework::Tensor* tensor, static void ReorderInput(framework::Tensor* tensor,
const platform::Place& place, const platform::Place& place,
const mkldnn::engine& engine, const mkldnn::engine& engine, bool isFourDim) {
bool isFourDim) {
using platform::to_void_cast; using platform::to_void_cast;
auto dims = paddle::framework::vectorize2int(tensor->dims()); auto dims = paddle::framework::vectorize2int(tensor->dims());
framework::Tensor out_tensor; framework::Tensor out_tensor;
out_tensor.Resize(tensor->dims()); out_tensor.Resize(tensor->dims());
out_tensor.set_format(isFourDim ? memory::format::nchw : memory::format::nc); out_tensor.set_format(isFourDim ? memory::format::nchw : memory::format::nc);
out_tensor.set_layout(tensor->layout()); out_tensor.set_layout(tensor->layout());
mkldnn::memory input_memory = {{{dims, platform::MKLDNNGetDataType<T>(), mkldnn::memory input_memory = {
tensor->format()}, engine}, to_void_cast<T>(tensor->data<T>())}; {{dims, platform::MKLDNNGetDataType<T>(), tensor->format()}, engine},
mkldnn::memory output_memory = {{{dims, platform::MKLDNNGetDataType<T>(), to_void_cast<T>(tensor->data<T>())};
out_tensor.format()}, engine}, mkldnn::memory output_memory = {
to_void_cast<T>(out_tensor.mutable_data<T>(place))}; {{dims, platform::MKLDNNGetDataType<T>(), out_tensor.format()}, engine},
to_void_cast<T>(out_tensor.mutable_data<T>(place))};
platform::Reorder(input_memory, output_memory); platform::Reorder(input_memory, output_memory);
tensor->ShareDataWith(out_tensor); tensor->ShareDataWith(out_tensor);
} }
...@@ -163,21 +163,19 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -163,21 +163,19 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
vector_mul mul; vector_mul mul;
using mul_func_t = using mul_func_t =
void (*)(const float *, const float *, float *, int, int); void (*)(const float*, const float*, float*, int, int);
mul_func_t mul_func = (mul_func_t) mul.getCode(); mul_func_t mul_func = (mul_func_t)mul.getCode();
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) { for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) { for (int ci = 0; ci < C; ci++) {
auto ptr_x = auto ptr_x =
x_data + ni * C * h * w * simd_width + x_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
ci * h * w * simd_width;
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
auto ptr_z = auto ptr_z =
z_data + ni * C * h * w * simd_width + z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
ci * h * w * simd_width;
mul_func(ptr_x, ptr_y, ptr_z, h, w); mul_func(ptr_x, ptr_y, ptr_z, h, w);
} }
...@@ -189,18 +187,20 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -189,18 +187,20 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
} else { } else {
// Fallback to naive version: // Fallback to naive version:
const bool are_inputs_in_same_format = x->format() == y->format(); const bool are_inputs_in_same_format = x->format() == y->format();
const bool is_x_nchw= x->format() == memory::format::nchw; const bool is_x_nchw = x->format() == memory::format::nchw;
const bool is_x_nc = x->format() == memory::format::nc; const bool is_x_nc = x->format() == memory::format::nc;
const bool is_y_nchw= y->format() == memory::format::nchw; const bool is_y_nchw = y->format() == memory::format::nchw;
const bool is_y_nc = y->format() == memory::format::nc; const bool is_y_nc = y->format() == memory::format::nc;
if(!are_inputs_in_same_format) { if (!are_inputs_in_same_format) {
using platform::MKLDNNDeviceContext; using platform::MKLDNNDeviceContext;
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
if(!(is_x_nchw || is_x_nc)) if (!(is_x_nchw || is_x_nc))
ReorderInput<T>((Tensor*)x, ctx.GetPlace(), mkldnn_engine, x->dims().size() == 4); ReorderInput<T>((Tensor*)x, ctx.GetPlace(), mkldnn_engine,
if(!(is_y_nchw || is_y_nc)) x->dims().size() == 4);
ReorderInput<T>((Tensor*)y, ctx.GetPlace(), mkldnn_engine, y->dims().size() == 4); if (!(is_y_nchw || is_y_nc))
ReorderInput<T>((Tensor*)y, ctx.GetPlace(), mkldnn_engine,
y->dims().size() == 4);
} }
auto mul_func = [](T a, T b) -> T { return a * b; }; auto mul_func = [](T a, T b) -> T { return a * b; };
......
...@@ -20,6 +20,7 @@ import paddle.fluid.core as core ...@@ -20,6 +20,7 @@ import paddle.fluid.core as core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
from test_elementwise_mul_op import * from test_elementwise_mul_op import *
class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp): class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp):
def init_input_output(self): def init_input_output(self):
x = np.random.rand(1, 16, 2, 2).astype(self.dtype) x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
...@@ -49,7 +50,9 @@ class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp): ...@@ -49,7 +50,9 @@ class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp):
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
pass pass
@unittest.skip("Not implemented yet.") # TODO(mgallus): enable when implemented.
@unittest.skip(
"Not implemented yet.") # TODO(mgallus): enable when implemented.
class TestElementwiseMulMKLDNNOp_BroadcastNCHW8c(ElementwiseMulOp): class TestElementwiseMulMKLDNNOp_BroadcastNCHW8c(ElementwiseMulOp):
def init_input_output(self): def init_input_output(self):
x = np.random.rand(1, 8, 2, 2).astype(self.dtype) x = np.random.rand(1, 8, 2, 2).astype(self.dtype)
...@@ -79,6 +82,7 @@ class TestElementwiseMulMKLDNNOp_BroadcastNCHW8c(ElementwiseMulOp): ...@@ -79,6 +82,7 @@ class TestElementwiseMulMKLDNNOp_BroadcastNCHW8c(ElementwiseMulOp):
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
pass pass
class TestElementwiseMulMKLDNNOp_FallbackNCHW(ElementwiseMulOp): class TestElementwiseMulMKLDNNOp_FallbackNCHW(ElementwiseMulOp):
def init_input_output(self): def init_input_output(self):
self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype) self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
...@@ -101,6 +105,7 @@ class TestElementwiseMulMKLDNNOp_FallbackNCHW(ElementwiseMulOp): ...@@ -101,6 +105,7 @@ class TestElementwiseMulMKLDNNOp_FallbackNCHW(ElementwiseMulOp):
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
pass pass
class TestElementwiseMulMKLDNNOp_FallbackNCHW16C(ElementwiseMulOp): class TestElementwiseMulMKLDNNOp_FallbackNCHW16C(ElementwiseMulOp):
def init_input_output(self): def init_input_output(self):
x = np.random.rand(1, 16, 2, 2).astype(self.dtype) x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
...@@ -130,6 +135,7 @@ class TestElementwiseMulMKLDNNOp_FallbackNCHW16C(ElementwiseMulOp): ...@@ -130,6 +135,7 @@ class TestElementwiseMulMKLDNNOp_FallbackNCHW16C(ElementwiseMulOp):
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
pass pass
class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp): class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp):
def init_input_output(self): def init_input_output(self):
x = np.random.rand(1, 16, 2, 2).astype(self.dtype) x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
...@@ -159,6 +165,7 @@ class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp): ...@@ -159,6 +165,7 @@ class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp):
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
pass pass
class TestElementwiseMulMKLDNNOp_FallbackWithReorder1(ElementwiseMulOp): class TestElementwiseMulMKLDNNOp_FallbackWithReorder1(ElementwiseMulOp):
def init_input_output(self): def init_input_output(self):
self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype) self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
...@@ -187,6 +194,7 @@ class TestElementwiseMulMKLDNNOp_FallbackWithReorder1(ElementwiseMulOp): ...@@ -187,6 +194,7 @@ class TestElementwiseMulMKLDNNOp_FallbackWithReorder1(ElementwiseMulOp):
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
pass pass
class TestElementwiseMulMKLDNNOp_FallbackWithReorder2(ElementwiseMulOp): class TestElementwiseMulMKLDNNOp_FallbackWithReorder2(ElementwiseMulOp):
def init_input_output(self): def init_input_output(self):
self.y = np.random.rand(1, 16, 2, 2).astype(self.dtype) self.y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
...@@ -215,6 +223,7 @@ class TestElementwiseMulMKLDNNOp_FallbackWithReorder2(ElementwiseMulOp): ...@@ -215,6 +223,7 @@ class TestElementwiseMulMKLDNNOp_FallbackWithReorder2(ElementwiseMulOp):
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
pass pass
class TestElementwiseMulMKLDNNOp_FallbackNoReorders2(ElementwiseMulOp): class TestElementwiseMulMKLDNNOp_FallbackNoReorders2(ElementwiseMulOp):
def init_input_output(self): def init_input_output(self):
self.x = np.random.rand(1, 16).astype(self.dtype) self.x = np.random.rand(1, 16).astype(self.dtype)
...@@ -242,5 +251,6 @@ class TestElementwiseMulMKLDNNOp_FallbackNoReorders2(ElementwiseMulOp): ...@@ -242,5 +251,6 @@ class TestElementwiseMulMKLDNNOp_FallbackNoReorders2(ElementwiseMulOp):
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
pass pass
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册