未验证 提交 ab97b760 编写于 作者: H Hui Zhang 提交者: GitHub

[mkldnn] Fix elementwise_sub sign reverse for mkldnn (#46049)

* fix sub sign reverse for mkldnn

* refactor code as comment

* remove useless

* format code
上级 d8edf487
...@@ -149,12 +149,20 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -149,12 +149,20 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
VLOG(4) << "element sub: dx " << dx << " dy " << dy << " dout " << dout;
// oneDNN's binary is optimized for broadcasting y into x, so in other case // oneDNN's binary is optimized for broadcasting y into x, so in other case
// we have to swap tensors to achieve optimal performance // we have to swap tensors to achieve optimal performance
bool swap_x_y = false;
if (x->numel() < y->numel()) { if (x->numel() < y->numel()) {
std::swap(x, y); std::swap(x, y);
std::swap(dx, dy); std::swap(dx, dy);
swap_x_y = true;
}
std::vector<float> scales{1.0};
if (swap_x_y) {
scales[0] = (BINARY_OP == dnnl::algorithm::binary_add) ? 1 : -1;
} }
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
...@@ -172,7 +180,6 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -172,7 +180,6 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
dout->mem_desc(), platform::to_void_cast(dout->data<T>())); dout->mem_desc(), platform::to_void_cast(dout->data<T>()));
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (dx) { if (dx) {
std::shared_ptr<dnnl::memory> dst_memory; std::shared_ptr<dnnl::memory> dst_memory;
...@@ -181,8 +188,11 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -181,8 +188,11 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
BINARY_OP == dnnl::algorithm::binary_sub) { BINARY_OP == dnnl::algorithm::binary_sub) {
dst_memory = reorder_handler.AcquireDstMemory( dst_memory = reorder_handler.AcquireDstMemory(
dx, dout->mem_desc(), ctx.GetPlace()); dx, dout->mem_desc(), ctx.GetPlace());
auto reorder_p =
reorder_handler.AcquireReorder(dst_memory, reorder_src_memory_p); dnnl::primitive_attr reorder_attr;
reorder_attr.set_output_scales(0, scales);
auto reorder_p = reorder_handler.AcquireReorder(
dst_memory, reorder_src_memory_p, reorder_attr);
platform::RecordEvent record_reorder( platform::RecordEvent record_reorder(
"int_reorder", "int_reorder",
platform::TracerEventType::UserDefined, platform::TracerEventType::UserDefined,
...@@ -190,6 +200,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -190,6 +200,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
platform::EventRole::kUniqueOp); platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p, *dst_memory); reorder_p->execute(astream, *reorder_src_memory_p, *dst_memory);
} else { // elementwise_mul & elementwise_div } else { // elementwise_mul & elementwise_div
platform::BinaryMKLDNNHandler<T> binary_handler(BINARY_OP, platform::BinaryMKLDNNHandler<T> binary_handler(BINARY_OP,
axis, axis,
...@@ -233,11 +244,10 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -233,11 +244,10 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
dy, dout->mem_desc(), ctx.GetPlace()); dy, dout->mem_desc(), ctx.GetPlace());
dnnl::primitive_attr reorder_attr; dnnl::primitive_attr reorder_attr;
std::vector<float> scales(1);
scales[0] = (BINARY_OP == dnnl::algorithm::binary_add) ? 1 : -1;
reorder_attr.set_output_scales(0, scales); reorder_attr.set_output_scales(0, scales);
auto reorder_p = std::make_shared<dnnl::reorder>(
*(reorder_src_memory_p), *(reorder_dst_memory_p), reorder_attr); auto reorder_p = reorder_handler.AcquireReorder(
reorder_dst_memory_p, reorder_src_memory_p, reorder_attr);
platform::RecordEvent record_reorder( platform::RecordEvent record_reorder(
"int_reorder", "int_reorder",
platform::TracerEventType::UserDefined, platform::TracerEventType::UserDefined,
...@@ -331,7 +341,8 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -331,7 +341,8 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
// Broadcasting // Broadcasting
if (BINARY_OP == dnnl::algorithm::binary_sub) { if (BINARY_OP == dnnl::algorithm::binary_sub) {
dnnl::post_ops po; dnnl::post_ops po;
po.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, -1.0f, 0); po.append_eltwise(
1.0f, dnnl::algorithm::eltwise_linear, scales[0], 0);
broadcast_reduction_attr.set_post_ops(po); broadcast_reduction_attr.set_post_ops(po);
} }
......
...@@ -877,6 +877,10 @@ class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::binary> { ...@@ -877,6 +877,10 @@ class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::binary> {
CreateAttributes(algo, scale_x, scale_y, scale_out, post_ops); CreateAttributes(algo, scale_x, scale_y, scale_out, post_ops);
if (x->numel() < y->numel()) { if (x->numel() < y->numel()) {
if (algo == dnnl::algorithm::binary_sub) {
attributes = CreateAttributes(
algo, -1.0 * scale_x, -1.0 * scale_y, scale_out, post_ops);
}
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
attributes, algo, src1_md, src0_md, dst_md); attributes, algo, src1_md, src0_md, dst_md);
} else { } else {
......
...@@ -89,6 +89,23 @@ class TestMKLDNNElementwiseSubOp4(TestMKLDNNElementwiseSubOp): ...@@ -89,6 +89,23 @@ class TestMKLDNNElementwiseSubOp4(TestMKLDNNElementwiseSubOp):
self.out = np.subtract(self.x, self.y) self.out = np.subtract(self.x, self.y)
class TestMKLDNNElementwiseSubOp40(TestMKLDNNElementwiseSubOp):
def init_input_output(self):
self.x = np.random.uniform(0.1, 2, [180, 1]).astype(self.dtype)
self.y = np.random.uniform(0.1, 1, [1, 256]).astype(self.dtype)
self.out = np.subtract(self.x, self.y)
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
def test_check_grad_ignore_x(self):
self.check_grad(['Y'], 'Out', no_grad_set=set("X"))
def test_check_grad_ignore_y(self):
self.check_grad(['X'], 'Out', no_grad_set=set('Y'))
class TestMKLDNNElementwiseSubOp5(TestMKLDNNElementwiseSubOp): class TestMKLDNNElementwiseSubOp5(TestMKLDNNElementwiseSubOp):
def init_input_output(self): def init_input_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册