提交 ed31936b 编写于 作者: M Michal Gallus

MKLDNN elementwise_mul: Support NCHW, update UT

上级 4e54ab76
......@@ -97,6 +97,20 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
.EqualGreaterThan(-1);
AddAttr<bool>("use_mkldnn", "(bool, default false). Used by MKLDNN.")
.SetDefault(false);
AddAttr<std::string>(
"x_data_format",
"(string, default NCHW) Only used in mkldnn"
"An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". "
"Defaults to \"\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("");
AddAttr<std::string>(
"y_data_format",
"(string, default \"\") Only used in mkldnn"
"An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". "
"Defaults to \"\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("");
AddComment(string::Sprintf(R"DOC(
Elementwise %s Operator
......
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <mkldnn/include/mkldnn.hpp>
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h"
......@@ -24,6 +25,7 @@ namespace paddle {
namespace operators {
using framework::DataLayout;
using mkldnn::memory;
struct vector_mul : public Xbyak::CodeGenerator {
vector_mul() {
......@@ -66,6 +68,33 @@ void check(const float* x, const float* y, float* z, int w) {
}
}
static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) {
std::transform(format.begin(), format.end(), format.begin(), ::tolower);
if(!format.compare("nchw")) {
return memory::format::nchw;
} else if(!format.compare("nchw16c")) {
return memory::format::nChw16c;
} else if(!format.compare("nchw8c")) {
return memory::format::nChw8c;
} else if(!format.compare("nhwc")) {
return memory::format::nhwc;
} else {
return memory::format::any;
}
}
static void UpdateDataFormat(const framework::ExecutionContext& ctx,
framework::Tensor* tensor, const char* attribute) {
if(ctx.op().HasAttr(attribute)) {
auto format_as_string = ctx.Attr<std::string>(attribute);
auto format = StringToMKLDNNFormat(format_as_string);
if (format != memory::format::any) {
tensor->set_format(format);
}
}
}
template <typename T>
class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
public:
......@@ -83,6 +112,10 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
auto x_dims = x->dims();
auto y_dims_untrimmed = y->dims();
UpdateDataFormat(ctx, (Tensor*)x, "x_data_format");
UpdateDataFormat(ctx, (Tensor*)y, "y_data_format");
if (x->format() == memory::format::nChw16c && y->format() == memory::format::nc) {
if (x_dims != y_dims_untrimmed) {
int pre, n, post;
get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post);
......@@ -107,18 +140,20 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
vector_mul mul;
using mul_func_t =
void (*)(const float*, const float*, float*, int, int);
void (*)(const float *, const float *, float *, int, int);
mul_func_t mul_func = (mul_func_t)mul.getCode();
mul_func_t mul_func = (mul_func_t) mul.getCode();
for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) {
auto ptr_x =
x_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
x_data + ni * C * h * w * simd_width +
ci * h * w * simd_width;
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
auto ptr_z =
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
z_data + ni * C * h * w * simd_width +
ci * h * w * simd_width;
mul_func(ptr_x, ptr_y, ptr_z, h, w);
}
......@@ -130,6 +165,35 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
} else {
PADDLE_THROW("Not implemented when dims are equal");
}
} else {
// Fallback to naive version:
auto mul_func = [](T a, T b) -> T { return a * b; };
TransformFunctor<decltype(mul_func), T,
paddle::platform::CPUDeviceContext, T>
functor(
x, y, z,
ctx.template device_context<paddle::platform::CPUDeviceContext>(),
mul_func);
axis = (axis == -1 ? x_dims.size() - y_dims_untrimmed.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
auto y_dims = trim_trailing_singular_dims(y_dims_untrimmed);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
if (post == 1) {
functor.RunRowWise(n, pre);
} else {
functor.RunMidWise(n, pre, post);
}
z->set_layout(DataLayout::kMKLDNN);
z->set_format(x->format());
}
}
};
} // namespace operators
......
......@@ -20,8 +20,7 @@ import paddle.fluid.core as core
from paddle.fluid.op import Operator
from test_elementwise_mul_op import *
class ElementwiseMulMKLDNNOp(ElementwiseMulOp):
class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp):
def init_input_output(self):
x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
......@@ -30,6 +29,11 @@ class ElementwiseMulMKLDNNOp(ElementwiseMulOp):
self.out = x * self.y.reshape(1, 16, 1, 1)
self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
def setUp(self):
super(TestElementwiseMulMKLDNNOp_BroadcastNCHW16c, self).setUp()
self.attrs["x_data_format"] = "nchw16c"
self.attrs["y_data_format"] = "nc"
def init_kernel_type(self):
self.use_mkldnn = True
......@@ -45,6 +49,27 @@ class ElementwiseMulMKLDNNOp(ElementwiseMulOp):
def test_check_grad_ingore_y(self):
pass
class TestElementwiseMulMKLDNNOp_UnsupportedFormat(ElementwiseMulOp):
def init_input_output(self):
self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.y = np.random.rand(1, 16).astype(self.dtype)
self.out = self.x * self.y.reshape(1, 16, 1, 1)
def init_kernel_type(self):
self.use_mkldnn = True
def init_axis(self):
self.axis = 0
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册