提交 ed31936b 编写于 作者: M Michal Gallus

MKLDNN elementwise_mul: Support NCHW, update UT

上级 4e54ab76
...@@ -97,6 +97,20 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -97,6 +97,20 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
.EqualGreaterThan(-1); .EqualGreaterThan(-1);
AddAttr<bool>("use_mkldnn", "(bool, default false). Used by MKLDNN.") AddAttr<bool>("use_mkldnn", "(bool, default false). Used by MKLDNN.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>(
"x_data_format",
"(string, default NCHW) Only used in mkldnn"
"An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". "
"Defaults to \"\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("");
AddAttr<std::string>(
"y_data_format",
"(string, default \"\") Only used in mkldnn"
"An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". "
"Defaults to \"\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("");
AddComment(string::Sprintf(R"DOC( AddComment(string::Sprintf(R"DOC(
Elementwise %s Operator Elementwise %s Operator
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <mkldnn/include/mkldnn.hpp>
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
...@@ -24,6 +25,7 @@ namespace paddle { ...@@ -24,6 +25,7 @@ namespace paddle {
namespace operators { namespace operators {
using framework::DataLayout; using framework::DataLayout;
using mkldnn::memory;
struct vector_mul : public Xbyak::CodeGenerator { struct vector_mul : public Xbyak::CodeGenerator {
vector_mul() { vector_mul() {
...@@ -66,6 +68,33 @@ void check(const float* x, const float* y, float* z, int w) { ...@@ -66,6 +68,33 @@ void check(const float* x, const float* y, float* z, int w) {
} }
} }
static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) {
std::transform(format.begin(), format.end(), format.begin(), ::tolower);
if(!format.compare("nchw")) {
return memory::format::nchw;
} else if(!format.compare("nchw16c")) {
return memory::format::nChw16c;
} else if(!format.compare("nchw8c")) {
return memory::format::nChw8c;
} else if(!format.compare("nhwc")) {
return memory::format::nhwc;
} else {
return memory::format::any;
}
}
static void UpdateDataFormat(const framework::ExecutionContext& ctx,
framework::Tensor* tensor, const char* attribute) {
if(ctx.op().HasAttr(attribute)) {
auto format_as_string = ctx.Attr<std::string>(attribute);
auto format = StringToMKLDNNFormat(format_as_string);
if (format != memory::format::any) {
tensor->set_format(format);
}
}
}
template <typename T> template <typename T>
class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
public: public:
...@@ -83,52 +112,87 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -83,52 +112,87 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims_untrimmed = y->dims(); auto y_dims_untrimmed = y->dims();
if (x_dims != y_dims_untrimmed) { UpdateDataFormat(ctx, (Tensor*)x, "x_data_format");
int pre, n, post; UpdateDataFormat(ctx, (Tensor*)y, "y_data_format");
get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post);
if (post == 1) { if (x->format() == memory::format::nChw16c && y->format() == memory::format::nc) {
PADDLE_THROW("Not implemented when post is 1"); if (x_dims != y_dims_untrimmed) {
} else { int pre, n, post;
// Just check whether it works for RE-Resnext. get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post);
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions");
if (post == 1) {
PADDLE_THROW("Not implemented when post is 1");
} else {
// Just check whether it works for RE-Resnext.
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions");
int n = x_dims[0]; int n = x_dims[0];
int c = x_dims[1]; int c = x_dims[1];
int h = x_dims[2]; int h = x_dims[2];
int w = x_dims[3]; int w = x_dims[3];
PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c, PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c,
"Y should be in nc format"); "Y should be in nc format");
constexpr int simd_width = 16; constexpr int simd_width = 16;
int C = c / simd_width; int C = c / simd_width;
vector_mul mul; vector_mul mul;
using mul_func_t = using mul_func_t =
void (*)(const float*, const float*, float*, int, int); void (*)(const float *, const float *, float *, int, int);
mul_func_t mul_func = (mul_func_t)mul.getCode(); mul_func_t mul_func = (mul_func_t) mul.getCode();
for (int ni = 0; ni < n; ni++) { for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) { for (int ci = 0; ci < C; ci++) {
auto ptr_x = auto ptr_x =
x_data + ni * C * h * w * simd_width + ci * h * w * simd_width; x_data + ni * C * h * w * simd_width +
ci * h * w * simd_width;
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
auto ptr_z = auto ptr_z =
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; z_data + ni * C * h * w * simd_width +
ci * h * w * simd_width;
mul_func(ptr_x, ptr_y, ptr_z, h, w); mul_func(ptr_x, ptr_y, ptr_z, h, w);
}
} }
} }
z->set_layout(DataLayout::kMKLDNN);
z->set_format(x->format());
} else {
PADDLE_THROW("Not implemented when dims are equal");
} }
} else {
// Fallback to naive version:
auto mul_func = [](T a, T b) -> T { return a * b; };
TransformFunctor<decltype(mul_func), T,
paddle::platform::CPUDeviceContext, T>
functor(
x, y, z,
ctx.template device_context<paddle::platform::CPUDeviceContext>(),
mul_func);
axis = (axis == -1 ? x_dims.size() - y_dims_untrimmed.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
auto y_dims = trim_trailing_singular_dims(y_dims_untrimmed);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
if (post == 1) {
functor.RunRowWise(n, pre);
} else {
functor.RunMidWise(n, pre, post);
}
z->set_layout(DataLayout::kMKLDNN); z->set_layout(DataLayout::kMKLDNN);
z->set_format(x->format()); z->set_format(x->format());
} else {
PADDLE_THROW("Not implemented when dims are equal");
} }
} }
}; };
......
...@@ -20,8 +20,7 @@ import paddle.fluid.core as core ...@@ -20,8 +20,7 @@ import paddle.fluid.core as core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
from test_elementwise_mul_op import * from test_elementwise_mul_op import *
class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp):
class ElementwiseMulMKLDNNOp(ElementwiseMulOp):
def init_input_output(self): def init_input_output(self):
x = np.random.rand(1, 16, 2, 2).astype(self.dtype) x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
...@@ -30,6 +29,11 @@ class ElementwiseMulMKLDNNOp(ElementwiseMulOp): ...@@ -30,6 +29,11 @@ class ElementwiseMulMKLDNNOp(ElementwiseMulOp):
self.out = x * self.y.reshape(1, 16, 1, 1) self.out = x * self.y.reshape(1, 16, 1, 1)
self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
def setUp(self):
super(TestElementwiseMulMKLDNNOp_BroadcastNCHW16c, self).setUp()
self.attrs["x_data_format"] = "nchw16c"
self.attrs["y_data_format"] = "nc"
def init_kernel_type(self): def init_kernel_type(self):
self.use_mkldnn = True self.use_mkldnn = True
...@@ -45,6 +49,27 @@ class ElementwiseMulMKLDNNOp(ElementwiseMulOp): ...@@ -45,6 +49,27 @@ class ElementwiseMulMKLDNNOp(ElementwiseMulOp):
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
pass pass
class TestElementwiseMulMKLDNNOp_UnsupportedFormat(ElementwiseMulOp):
def init_input_output(self):
self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.y = np.random.rand(1, 16).astype(self.dtype)
self.out = self.x * self.y.reshape(1, 16, 1, 1)
def init_kernel_type(self):
self.use_mkldnn = True
def init_axis(self):
self.axis = 0
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册