diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_sub_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_sub_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..be8dad62c3c055e82c8d0f2588e630f8944dee99 --- /dev/null +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_sub_mkldnn_op.cc @@ -0,0 +1,132 @@ + +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" +namespace paddle { +namespace framework { +class ExecutionContext; +} // namespace framework +namespace platform { +class CPUDeviceContext; +struct CPUPlace; +} // namespace platform +} // namespace paddle + +namespace paddle { +namespace operators { +template +class EltwiseSubMKLDNNGradKernel : public ElemwiseGradKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); + using Tensor = framework::Tensor; + + auto& dev_ctx = + ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + + auto tz = framework::vectorize(dout->dims()); + memory::data_type dout_type = framework::ToMKLDNNDataType(dout->type()); + platform::ReorderMKLDNNHandler handler(tz, dout->type(), dout_type, + onednn_engine); + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + auto reorder_src_memory_p = handler.AcquireSrcMemory( + dout->format(), platform::to_void_cast(dout->data())); + + if (dx) { + auto reorder_dst_memory_p = + handler.AcquireDstMemory(dx, dout->format(), ctx.GetPlace()); + auto reorder_p = + handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p); + platform::RecordEvent record_reorder("int_reorder", + platform::EventRole::kUniqueOp); + + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); + astream.wait(); + + dx->set_layout(DataLayout::kMKLDNN); + dx->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p)); + } + + if (dy) { + // Direct copy + if (dout->dims() == dy->dims()) { + auto reorder_dst_memory_p = + handler.AcquireDstMemory(dy, dout->format(), ctx.GetPlace()); + + dnnl::primitive_attr reorder_attr; + std::vector scales = {-1}; + reorder_attr.set_output_scales(0, scales); + auto reorder_p = std::make_shared( + *(reorder_src_memory_p), *(reorder_dst_memory_p), reorder_attr); + platform::RecordEvent record_reorder("int_reorder", + platform::EventRole::kUniqueOp); + reorder_p->execute(astream, *reorder_src_memory_p, + *reorder_dst_memory_p); + astream.wait(); + + dy->set_layout(DataLayout::kMKLDNN); + dy->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p)); + } else { + // Broadcasting + + dnnl::post_ops po; + po.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, -1.0f, 0); + dnnl::primitive_attr attr; + attr.set_post_ops(po); + + platform::ReductionMKLDNNHandler handler_sum( + dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine, + ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy), attr); + + auto dy_memory_p = handler_sum.AcquireDstMemory(dy); + auto reduction_p = handler_sum.AcquireForwardPrimitive(); + + reduction_p->execute(astream, { + {DNNL_ARG_SRC, *reorder_src_memory_p}, + {DNNL_ARG_DST, *dy_memory_p}, + }); + astream.wait(); + + dy->set_layout(DataLayout::kMKLDNN); + dy->set_format( + platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape( + paddle::framework::vectorize(dy->dims())))); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL( + elementwise_sub, MKLDNN, paddle::platform::CPUPlace, + ops::EltwiseMKLDNNKernel, + ops::EltwiseMKLDNNKernel, + ops::EltwiseMKLDNNKernel, + ops::EltwiseMKLDNNKernel) + +REGISTER_OP_KERNEL(elementwise_sub_grad, MKLDNN, ::paddle::platform::CPUPlace, + ops::EltwiseSubMKLDNNGradKernel, + ops::EltwiseSubMKLDNNGradKernel) diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index d6ab9e50a066e01663800fb424d2edd2a5dc4b9f..1aa8c0cdb57f977bc8d8f016394091612b0290af 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include #include + #include "boost/optional.hpp" #include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/operator.h" @@ -927,7 +928,6 @@ class BroadcastDataMKLDNNHandler std::shared_ptr AcquireDstMemory(framework::Tensor* output) { T_out* ptr = output->mutable_data( this->place_, this->fwd_pd_->dst_desc().get_size()); - ; memset(ptr, 0, this->fwd_pd_->dst_desc().get_size()); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr); } @@ -940,7 +940,8 @@ class ReductionMKLDNNHandler ReductionMKLDNNHandler(const dnnl::algorithm algo, const float p, const float eps, const mkldnn::engine engine, platform::Place cpu_place, const Tensor* x, - const Tensor* y, std::vector y_tz) + const Tensor* y, std::vector y_tz, + const dnnl::primitive_attr& attr = NULL) : platform::MKLDNNHandlerNoCachingT(engine, cpu_place) { PADDLE_ENFORCE_EQ( @@ -957,7 +958,10 @@ class ReductionMKLDNNHandler const auto y_md = memory::desc(y_tz, platform::MKLDNNGetDataType(), x->format()); - this->AcquireForwardPrimitiveDescriptor(algo, x_md, y_md, p, eps); + if (attr) + this->AcquireForwardPrimitiveDescriptor(attr, algo, x_md, y_md, p, eps); + else + this->AcquireForwardPrimitiveDescriptor(algo, x_md, y_md, p, eps); } }; @@ -979,8 +983,9 @@ class ActivationMKLDNNHandler if (ctx.Type() == "scale") { bool bias_after_scale = ctx.Attr("bias_after_scale"); auto* scale_tensor = ctx.Input("ScaleTensor"); - alpha = (scale_tensor == nullptr) ? ctx.Attr("scale") - : (float)*(scale_tensor->data()); + alpha = (scale_tensor == nullptr) + ? ctx.Attr("scale") + : static_cast(*(scale_tensor->data())); beta = ctx.Attr("bias"); // if bias_after_scale == true // out = scale*X + bias @@ -1504,6 +1509,7 @@ static void SetDstMemoryQuantized( T* output_data = output->mutable_data(ctx.GetPlace()); const size_t dst_dims = dst_tz.size(); MKLDNNMemoryFormat dst_fmt; + PADDLE_ENFORCE_LE(dst_dims, 5, platform::errors::InvalidArgument( "Dst memory for quantization can not have " "dims > 5. But received dst_dims is %d.", diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_sub_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_sub_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..62c8c9571b7935fd5d571377327ef07d9cacbf9c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_sub_mkldnn_op.py @@ -0,0 +1,236 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +from paddle import enable_static +from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16 +from paddle.fluid.framework import _current_expected_place +import paddle.fluid.core as core + + +@OpTestTool.skip_if(not (isinstance(_current_expected_place(), core.CPUPlace)), + "GPU is not supported") +class TestMKLDNNElementwiseSubOp(OpTest): + def setUp(self): + self.op_type = "elementwise_sub" + self.init_dtype() + self.init_input_output() + self.init_kernel_type() + self.init_axis() + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn} + self.outputs = {'Out': self.out} + + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.out = np.subtract(self.x, self.y) + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out') + + def test_check_grad_ignore_x(self): + self.check_grad(['Y'], 'Out', no_grad_set=set("X")) + + def test_check_grad_ignore_y(self): + self.check_grad(['X'], 'Out', no_grad_set=set('Y')) + + def init_axis(self): + self.axis = -1 + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output() + + +class TestMKLDNNElementwiseSubOp2(TestMKLDNNElementwiseSubOp): + def init_input_output(self): + self.x = np.random.random((100, )).astype(self.dtype) + self.y = np.random.random((100, )).astype(self.dtype) + self.out = np.subtract(self.x, self.y) + + +class TestMKLDNNElementwiseSubOp3(TestMKLDNNElementwiseSubOp): + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype(self.dtype) + self.out = np.subtract(self.x, self.y) + + +class TestMKLDNNElementwiseSubOp4(TestMKLDNNElementwiseSubOp): + def init_input_output(self): + self.x = np.random.uniform(1, 2, [2, 3, 4, 32]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [4, 32]).astype(self.dtype) + self.out = np.subtract(self.x, self.y) + + +class TestMKLDNNElementwiseSubOp5(TestMKLDNNElementwiseSubOp): + def init_input_output(self): + self.x = np.random.uniform(1, 2, [2, 3, 4, 100]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [100]).astype(self.dtype) + self.out = np.subtract(self.x, self.y) + + +class TestMKLDNNElementwiseSubOp_broadcast(TestMKLDNNElementwiseSubOp): + def init_input_output(self): + self.x = np.random.rand(2, 10, 12, 3).astype(self.dtype) + self.y = np.random.rand(10, 12).astype(self.dtype) + self.out = self.x - self.y.reshape(1, 10, 12, 1) + + def init_axis(self): + self.axis = 1 + + +class TestElementwiseSubOp_xsize_lessthan_ysize_sub(TestMKLDNNElementwiseSubOp): + def init_input_output(self): + self.x = np.random.rand(10, 12).astype(self.dtype) + self.y = np.random.rand(2, 2, 10, 12).astype(self.dtype) + self.out = self.x - self.y + + def init_axis(self): + self.axis = 2 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ignore_y(self): + pass + + def test_check_grad_ignore_x(self): + pass + + +@OpTestTool.skip_if_not_cpu_bf16() +class TestBf16(TestMKLDNNElementwiseSubOp): + def setUp(self): + self.op_type = "elementwise_sub" + self.init_dtype() + self.init_input_output() + self.init_kernel_type() + self.init_axis() + + self.x_bf16 = convert_float_to_uint16(self.x) + self.y_bf16 = convert_float_to_uint16(self.y) + self.inputs = {'X': self.x_bf16, 'Y': self.y_bf16} + self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn} + self.outputs = {'Out': convert_float_to_uint16(self.out)} + + def init_dtype(self): + self.dtype = np.float32 + self.mkldnn_data_type = "bfloat16" + + def init_input_output(self): + self.x = np.random.random(100, ).astype(self.dtype) + self.y = np.random.random(100, ).astype(self.dtype) + self.out = np.subtract(self.x, self.y) + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace()) + + def test_check_grad_normal(self): + self.check_grad_with_place( + core.CPUPlace(), ["X", "Y"], + "Out", + user_defined_grads=[self.x, -self.x], + user_defined_grad_outputs=[self.x_bf16]) + + def test_check_grad_ignore_x(self): + self.check_grad_with_place( + core.CPUPlace(), ["Y"], + "Out", + user_defined_grads=[-self.y], + user_defined_grad_outputs=[self.y_bf16]) + + def test_check_grad_ignore_y(self): + self.check_grad_with_place( + core.CPUPlace(), ["X"], + "Out", + user_defined_grads=[self.x], + user_defined_grad_outputs=[self.x_bf16]) + + +class TestBf16Broadcasting(TestBf16): + def init_input_output(self): + self.x = np.random.uniform(1, 2, [2, 3, 4, 100]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [100]).astype(self.dtype) + self.out = np.subtract(self.x, self.y) + + def compute_reduced_gradients(self, out_grads): + part_sum = np.add.reduceat(out_grads, [0], axis=0) + part_sum = np.add.reduceat(part_sum, [0], axis=1) + part_sum = np.add.reduceat(part_sum, [0], axis=2) + return -part_sum.flatten() + + def test_check_grad_normal(self): + self.check_grad_with_place( + core.CPUPlace(), ["X", "Y"], + "Out", + user_defined_grads=[ + self.x, self.compute_reduced_gradients(self.x) + ], + user_defined_grad_outputs=[self.x_bf16]) + + def test_check_grad_ignore_x(self): + self.check_grad_with_place( + core.CPUPlace(), ["Y"], + "Out", + user_defined_grads=[self.compute_reduced_gradients(self.x)], + user_defined_grad_outputs=[self.x_bf16]) + + +class TestInt8(TestMKLDNNElementwiseSubOp): + def init_kernel_type(self): + self.use_mkldnn = True + self._cpu_only = True + + def init_dtype(self): + self.dtype = np.int8 + + def init_input_output(self): + self.x = np.random.randint(0, 3, (12, 9)).astype("int8") + self.y = np.random.randint(0, 3, (12, 9)).astype("int8") + self.out = np.subtract(self.x, self.y) + + def init_scales(self): + self.attrs['Scale_x'] = 1.0 + self.attrs['Scale_y'] = 1.0 + self.attrs['Scale_out'] = 1.0 + + def test_check_output(self): + self.init_scales() + self.check_output() + + def test_check_grad_normal(self): + pass + + def test_check_grad_ignore_x(self): + pass + + def test_check_grad_ignore_y(self): + pass + + +if __name__ == '__main__': + enable_static() + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 43281d4375ed0f56ef0eb01682ea95143704ca18..7d0a2a8953fc82faccea14bf1e99f9afef5b8389 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -610,6 +610,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_dequantize_mkldnn_op', 'test_elementwise_add_mkldnn_op', 'test_elementwise_add_bf16_mkldnn_op', + 'test_elementwise_sub_mkldnn_op', 'test_elementwise_mul_mkldnn_op', 'test_elementwise_mul_bf16_mkldnn_op', 'test_fc_mkldnn_op',