未验证 提交 0abf7560 编写于 作者: J jakpiase 提交者: GitHub

Added workaround for elementwise oneDNN kernel (#47080)

* return proper state

* fix for dims

* fix
上级 06ef3f04
......@@ -78,6 +78,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
scale_x,
scale_y,
scale_o,
true,
get_post_ops(ctx));
// oneDNN's binary is optimized for broadcasting y into x, so in other case
......@@ -126,7 +127,14 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
binary_prim->execute(astream, args);
astream.wait();
if (handler.use_broadcasting_hack == false) {
z->set_mem_desc(dst_memory->get_desc());
} else {
auto dims = dst_memory->get_desc().dims();
dims.insert(dims.begin(), x->dims()[0]);
dims[1] /= dims[0];
z->set_mem_desc(dst_memory->get_desc().reshape(dims));
}
}
};
......@@ -210,7 +218,8 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
dx,
1.0f,
1.0f,
1.0f);
1.0f,
false);
const auto src_dout_memory = binary_handler.AcquireSrcMemory(dout);
const auto src_y_memory = binary_handler.AcquireSecondSrcMemory(y);
......@@ -276,7 +285,8 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
nullptr,
1.0f,
1.0f,
1.0f);
1.0f,
false);
src_1_memory = binary_handler.AcquireSecondSrcMemory(x);
......@@ -291,7 +301,8 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
nullptr,
1.0f,
1.0f,
1.0f);
1.0f,
false);
post_op_memory = post_op_binary_handler.AcquireSrcMemory(y);
......@@ -310,6 +321,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
-1.0f,
1.0f,
1.0f,
false,
po);
src_1_memory = binary_handler.AcquireSecondSrcMemory(out);
......
......@@ -825,6 +825,7 @@ class ReorderOneDNNHandler {
template <typename T>
class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::binary> {
public:
bool use_broadcasting_hack;
BinaryOneDNNHandler(const dnnl::algorithm algo,
const int axis,
const dnnl::engine engine,
......@@ -835,14 +836,16 @@ class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::binary> {
float scale_x,
float scale_y,
float scale_out,
bool allow_hack,
const dnnl::post_ops& post_ops = dnnl::post_ops{})
: OneDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
use_broadcasting_hack = false;
const auto src_x_tz = vectorize(x->dims());
const auto src_y_tz = vectorize(y->dims());
// if output tensor(z) is nullptr then we are computing into oneDNN
// managed buffer
auto rankdiff = x->dims().size() - y->dims().size();
const auto dst_tz = (out == nullptr) ? (rankdiff > 0 ? src_x_tz : src_y_tz)
auto dst_tz = (out == nullptr) ? (rankdiff > 0 ? src_x_tz : src_y_tz)
: vectorize(out->dims());
auto src0_md = x->mem_desc();
......@@ -870,12 +873,48 @@ class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::binary> {
}
src0_md = src0_md.reshape(dims0_ex);
}
const auto dst_md =
memory::desc(dst_tz, OneDNNGetDataType<T>(), OneDNNMemoryFormat::any);
auto attributes =
CreateAttributes(algo, scale_x, scale_y, scale_out, post_ops);
// Workaround for U2++ model which deletes first tensor dimensions to enable
// optimized oneDNNs broadcasting. Output tensor is reshaped back afterwards
// at the end of the kernel, after the computation
if (allow_hack && dst_tz.size() == 4 &&
src0_md.dims()[2] != src1_md.dims()[2]) {
auto are_strides_plain = [](int64_t* strides, int ndims) {
for (int i = 0; i < ndims - 1; ++i) {
if (strides[i] < strides[i + 1]) {
return false;
}
}
return true;
};
auto src0_strides = src0_md.data.format_desc.blocking.strides;
auto src1_strides = src1_md.data.format_desc.blocking.strides;
auto src0_dims = src0_md.dims();
auto src1_dims = src1_md.dims();
bool can_squeeze = src0_dims[0] == src1_dims[0] &&
src0_dims[1] == src1_dims[1] &&
src0_dims[3] == src1_dims[3];
if (can_squeeze && are_strides_plain(src0_strides, 4) &&
are_strides_plain(src1_strides, 4)) {
src0_dims[1] *= dst_tz[0];
src1_dims[1] *= dst_tz[0];
dst_tz[1] *= dst_tz[0];
dst_tz.erase(dst_tz.begin());
src0_md = src0_md.reshape({src0_dims.begin() + 1, src0_dims.end()});
src1_md = src1_md.reshape({src1_dims.begin() + 1, src1_dims.end()});
use_broadcasting_hack = true;
}
}
auto dst_md =
memory::desc(dst_tz, OneDNNGetDataType<T>(), OneDNNMemoryFormat::any);
if (x->numel() < y->numel()) {
if (algo == dnnl::algorithm::binary_sub) {
attributes = CreateAttributes(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册