diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index f01f67692e1e5dd040971cb0dd1dd793648da97a..16d919689cc2916fd491e81dac05462cc67053d0 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -97,6 +97,20 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { .EqualGreaterThan(-1); AddAttr("use_mkldnn", "(bool, default false). Used by MKLDNN.") .SetDefault(false); + AddAttr( + "x_data_format", + "(string, default NCHW) Only used in mkldnn" + "An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". " + "Defaults to \"\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault(""); + AddAttr( + "y_data_format", + "(string, default \"\") Only used in mkldnn" + "An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". " + "Defaults to \"\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault(""); AddComment(string::Sprintf(R"DOC( Elementwise %s Operator diff --git a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc index 21716e271d36365b9759c1616ec3b4c0e09a3cba..d66c58bd450737fc9bc3ce0788f9c3e3e9c126ba 100644 --- a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" @@ -24,6 +25,7 @@ namespace paddle { namespace operators { using framework::DataLayout; +using mkldnn::memory; struct vector_mul : public Xbyak::CodeGenerator { vector_mul() { @@ -66,6 +68,33 @@ void check(const float* x, const float* y, float* z, int w) { } } +static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) { + std::transform(format.begin(), format.end(), format.begin(), ::tolower); + + if(!format.compare("nchw")) { + return memory::format::nchw; + } else if(!format.compare("nchw16c")) { + return memory::format::nChw16c; + } else if(!format.compare("nchw8c")) { + return memory::format::nChw8c; + } else if(!format.compare("nhwc")) { + return memory::format::nhwc; + } else { + return memory::format::any; + } +} + +static void UpdateDataFormat(const framework::ExecutionContext& ctx, + framework::Tensor* tensor, const char* attribute) { + if(ctx.op().HasAttr(attribute)) { + auto format_as_string = ctx.Attr(attribute); + auto format = StringToMKLDNNFormat(format_as_string); + if (format != memory::format::any) { + tensor->set_format(format); + } + } +} + template class ElementwiseMulMKLDNNKernel : public framework::OpKernel { public: @@ -83,52 +112,87 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { auto x_dims = x->dims(); auto y_dims_untrimmed = y->dims(); - if (x_dims != y_dims_untrimmed) { - int pre, n, post; - get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post); + UpdateDataFormat(ctx, (Tensor*)x, "x_data_format"); + UpdateDataFormat(ctx, (Tensor*)y, "y_data_format"); - if (post == 1) { - PADDLE_THROW("Not implemented when post is 1"); - } else { - // Just check whether it works for RE-Resnext. - PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions"); + if (x->format() == memory::format::nChw16c && y->format() == memory::format::nc) { + if (x_dims != y_dims_untrimmed) { + int pre, n, post; + get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post); + + if (post == 1) { + PADDLE_THROW("Not implemented when post is 1"); + } else { + // Just check whether it works for RE-Resnext. + PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions"); - int n = x_dims[0]; - int c = x_dims[1]; - int h = x_dims[2]; - int w = x_dims[3]; + int n = x_dims[0]; + int c = x_dims[1]; + int h = x_dims[2]; + int w = x_dims[3]; - PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c, - "Y should be in nc format"); + PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c, + "Y should be in nc format"); - constexpr int simd_width = 16; - int C = c / simd_width; + constexpr int simd_width = 16; + int C = c / simd_width; - vector_mul mul; + vector_mul mul; - using mul_func_t = - void (*)(const float*, const float*, float*, int, int); + using mul_func_t = + void (*)(const float *, const float *, float *, int, int); - mul_func_t mul_func = (mul_func_t)mul.getCode(); + mul_func_t mul_func = (mul_func_t) mul.getCode(); - for (int ni = 0; ni < n; ni++) { - for (int ci = 0; ci < C; ci++) { - auto ptr_x = - x_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + for (int ni = 0; ni < n; ni++) { + for (int ci = 0; ci < C; ci++) { + auto ptr_x = + x_data + ni * C * h * w * simd_width + + ci * h * w * simd_width; - auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; - auto ptr_z = - z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; + auto ptr_z = + z_data + ni * C * h * w * simd_width + + ci * h * w * simd_width; - mul_func(ptr_x, ptr_y, ptr_z, h, w); + mul_func(ptr_x, ptr_y, ptr_z, h, w); + } } } + + z->set_layout(DataLayout::kMKLDNN); + z->set_format(x->format()); + } else { + PADDLE_THROW("Not implemented when dims are equal"); } + } else { + // Fallback to naive version: + auto mul_func = [](T a, T b) -> T { return a * b; }; + + TransformFunctor + functor( + x, y, z, + ctx.template device_context(), + mul_func); + axis = (axis == -1 ? x_dims.size() - y_dims_untrimmed.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + + auto y_dims = trim_trailing_singular_dims(y_dims_untrimmed); + axis = (y_dims.size() == 0) ? x_dims.size() : axis; + + int pre, n, post; + get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); + + if (post == 1) { + functor.RunRowWise(n, pre); + } else { + functor.RunMidWise(n, pre, post); + } z->set_layout(DataLayout::kMKLDNN); z->set_format(x->format()); - } else { - PADDLE_THROW("Not implemented when dims are equal"); } } }; diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py index a0581d16de12fc1084960e28b7b120fa6dd1b733..a89f439664d11c2fb3c5b16639848d7ba4b545d6 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py @@ -20,8 +20,7 @@ import paddle.fluid.core as core from paddle.fluid.op import Operator from test_elementwise_mul_op import * - -class ElementwiseMulMKLDNNOp(ElementwiseMulOp): +class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp): def init_input_output(self): x = np.random.rand(1, 16, 2, 2).astype(self.dtype) self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) @@ -30,6 +29,11 @@ class ElementwiseMulMKLDNNOp(ElementwiseMulOp): self.out = x * self.y.reshape(1, 16, 1, 1) self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + def setUp(self): + super(TestElementwiseMulMKLDNNOp_BroadcastNCHW16c, self).setUp() + self.attrs["x_data_format"] = "nchw16c" + self.attrs["y_data_format"] = "nc" + def init_kernel_type(self): self.use_mkldnn = True @@ -45,6 +49,27 @@ class ElementwiseMulMKLDNNOp(ElementwiseMulOp): def test_check_grad_ingore_y(self): pass +class TestElementwiseMulMKLDNNOp_UnsupportedFormat(ElementwiseMulOp): + def init_input_output(self): + self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.y = np.random.rand(1, 16).astype(self.dtype) + + self.out = self.x * self.y.reshape(1, 16, 1, 1) + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass if __name__ == '__main__': unittest.main()