From 86ea8dceb1b6b0cb1f5b8c6cdd9134e9e6c3c5f7 Mon Sep 17 00:00:00 2001 From: jakpiase <62569058+jakpiase@users.noreply.github.com> Date: Tue, 25 May 2021 04:01:47 +0200 Subject: [PATCH] Added scale op FP32/BF16 FWD/BWD kernels (#32975) --- .../fluid/framework/data_layout_transform.cc | 4 +- .../fluid/framework/data_layout_transform.h | 3 +- .../inference/api/details/zero_copy_tensor.cc | 17 +++ .../fluid/operators/mkldnn/scale_mkldnn_op.cc | 75 +++++++++++ paddle/fluid/operators/scale_op.cc | 20 +++ paddle/fluid/operators/unity_build_rule.cmake | 1 + .../mkldnn/test_scale_bf16_mkldnn_op.py | 122 ++++++++++++++++++ .../unittests/mkldnn/test_scale_mkldnn_op.py | 104 +++++++++++++++ tools/static_mode_white_list.py | 2 + 9 files changed, 345 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_scale_bf16_mkldnn_op.py create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_scale_mkldnn_op.py diff --git a/paddle/fluid/framework/data_layout_transform.cc b/paddle/fluid/framework/data_layout_transform.cc index 8ff94b0277..8708d90485 100644 --- a/paddle/fluid/framework/data_layout_transform.cc +++ b/paddle/fluid/framework/data_layout_transform.cc @@ -143,7 +143,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, const Tensor& in, Tensor* out, - platform::Place place) { + platform::Place place, bool always_copy) { PADDLE_ENFORCE_NE(in.format(), MKLDNNMemoryFormat::undef, platform::errors::InvalidArgument( "Input tensor format is invalid. Input tensor should " @@ -177,7 +177,7 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, // output tensor has the same dims as input. Reorder don't change dims out->Resize(in.dims()); - if (in_format != out_format) { + if ((in_format != out_format) || always_copy) { void* in_data = GetDataFromTensor(in, in_type); std::string key = platform::CreateKey(*dev_ctx, in_tz, in_format, out_format, in_type); diff --git a/paddle/fluid/framework/data_layout_transform.h b/paddle/fluid/framework/data_layout_transform.h index 238f2d2e67..3404ba2db6 100644 --- a/paddle/fluid/framework/data_layout_transform.h +++ b/paddle/fluid/framework/data_layout_transform.h @@ -78,7 +78,8 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) { void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, const Tensor& in, Tensor* out, - platform::Place place); + platform::Place place, + bool always_copy = false); void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, const OpKernelType& expected_kernel_type, diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index f7dbfd39cd..43306b79fa 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" @@ -161,8 +162,24 @@ void Tensor::CopyToCpu(T *data) { auto *t_data = tensor->data(); auto t_place = tensor->place(); + paddle::framework::Tensor out; + auto mem_allocation = std::make_shared( + static_cast(data), ele_num * sizeof(T), + paddle::platform::CPUPlace()); + out.ResetHolder(mem_allocation); + if (paddle::platform::is_cpu_place(t_place)) { +#ifdef PADDLE_WITH_MKLDNN + if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN) + paddle::framework::innerTransDataLayoutFromMKLDNN( + tensor->layout(), paddle::platform::MKLDNNDeviceContext::tls() + .get_cur_paddle_data_layout(), + *tensor, &out, paddle::platform::CPUPlace(), true); + else + std::memcpy(static_cast(data), t_data, ele_num * sizeof(T)); +#else std::memcpy(static_cast(data), t_data, ele_num * sizeof(T)); +#endif } else if (place_ == PlaceType::kGPU) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) paddle::platform::DeviceContextPool &pool = diff --git a/paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc new file mode 100644 index 0000000000..e91bbd15cf --- /dev/null +++ b/paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc @@ -0,0 +1,75 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; + +template +class ScaleMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx); + } + + void RunKernel(const framework::ExecutionContext& ctx) const { + const auto& dev_ctx = + ctx.template device_context(); + + bool bias_after_scale = ctx.Attr("bias_after_scale"); + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + auto* scale_tensor = ctx.Input("ScaleTensor"); + + float scale = (scale_tensor == nullptr) ? ctx.Attr("scale") + : (float)*(scale_tensor->data()); + float bias = ctx.Attr("bias"); + + // if bias_after_scale == true + // out = scale*X + bias + // else + // out = scale*(X + bias) = scale*X + scale*bias + + if (!bias_after_scale) bias *= scale; + + auto x_tz = framework::vectorize(x->dims()); + bool is_inplaced = x->IsSharedBufferWith(*out); + + platform::ActivationMKLDNNHandler handler( + x_tz, mkldnn::algorithm::eltwise_linear, scale, bias, x->format(), + dev_ctx, ctx.GetPlace(), ctx.InputName("X"), is_inplaced); + + auto src_memory_p = handler.AcquireSrcMemory(x); + auto dst_memory_p = handler.AcquireDstMemory(out); + auto activation_p = handler.AcquireForwardPrimitive(); + + auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); + activation_p->execute(astream, {{MKLDNN_ARG_FROM, *src_memory_p}, + {MKLDNN_ARG_TO, *dst_memory_p}}); + astream.wait(); + + out->set_layout(framework::DataLayout::kMKLDNN); + out->set_format(platform::GetMKLDNNFormat(*dst_memory_p)); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(scale, MKLDNN, paddle::platform::CPUPlace, + ops::ScaleMKLDNNKernel, + ops::ScaleMKLDNNKernel); diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index a9b1f299da..a71f49585b 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -54,6 +54,21 @@ class ScaleOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareLoD("X", /*->*/ "Out"); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { @@ -87,6 +102,9 @@ $$Out = scale*(X + bias)$$ "Apply bias addition after or before scaling. It is useful for " "numeric stability in some circumstances.") .SetDefault(true); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); } }; @@ -112,6 +130,8 @@ class ScaleGradMaker : public framework::SingleGradOpMaker { grad_op->SetAttr("scale", this->GetAttr("scale")); grad_op->SetAttr("bias", 0.0f); grad_op->SetAttr("bias_after_scale", true); + if (grad_op->HasAttr("use_mkldnn")) + grad_op->SetAttr("use_mkldnn", this->GetAttr("use_mkldnn")); } }; diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index cd8b31d72e..e9bc351de4 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -234,6 +234,7 @@ register_unity_group(cc save_combine_op.cc save_op.cc scale_op.cc + mkldnn/scale_mkldnn_op.cc scatter_nd_add_op.cc scatter_op.cc seed_op.cc diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_scale_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_scale_bf16_mkldnn_op.py new file mode 100644 index 0000000000..8e9f989f06 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_scale_bf16_mkldnn_op.py @@ -0,0 +1,122 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16 +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core + + +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +@unittest.skipIf(core.is_compiled_with_cuda(), + "core is compiled with CUDA which has no BF implementation") +class TestScaleOpBF16(OpTest): + def setUp(self): + self.op_type = "scale" + self.x_fp32 = np.random.random((10, 10)).astype(np.float32) + self.x_bf16 = convert_float_to_uint16(self.x_fp32) + self.scale = -2.3 + self.inputs = {'X': self.x_bf16} + self.attrs = {'scale': self.scale, 'use_mkldnn': True, 'bias': 0.4} + self.use_mkldnn = True + self.outputs = { + 'Out': (self.x_fp32 * self.attrs['scale']) + self.attrs['bias'] + } + + def calculate_grads(self): + bias = 0 + if 'bias' in self.attrs: + bias = self.attrs['bias'] + + scale = self.scale + if 'ScaleTensor' in self.attrs: + scale = self.attrs['ScaleTensor'] + + self.out = (self.x_fp32 * scale) + bias + self.dx = (self.out * scale) + + def test_check_output(self): + self.check_output(check_dygraph=False) + + def test_check_grad(self): + self.calculate_grads() + self.check_grad_with_place( + core.CPUPlace(), ["X"], + "Out", + check_dygraph=False, + user_defined_grads=[self.dx], + user_defined_grad_outputs=[convert_float_to_uint16(self.out)]) + + +class TestScaleOpBF16BiasNotAfterScale(TestScaleOpBF16): + def setUp(self): + self.op_type = "scale" + self.x_fp32 = np.random.random((10, 10)).astype(np.float32) + self.x_bf16 = convert_float_to_uint16(self.x_fp32) + self.scale = 1.5 + self.inputs = {'X': self.x_bf16} + self.attrs = { + 'scale': self.scale, + 'use_mkldnn': True, + 'bias': 0.0, + 'bias_after_scale': False + } + self.use_mkldnn = True + self.outputs = { + 'Out': (self.x_fp32 + self.attrs['bias']) * self.attrs['scale'] + } + + +class TestScaleOpBF16ScaleTensor(TestScaleOpBF16): + def setUp(self): + self.op_type = "scale" + self.scale = -2.3 + self.x_fp32 = np.random.random((10, 10)).astype(np.float32) + self.x_bf16 = convert_float_to_uint16(self.x_fp32) + self.scale_tensor = np.array([self.scale]).astype(np.float32) + self.inputs = { + 'X': self.x_bf16, + 'ScaleTensor': convert_float_to_uint16(self.scale_tensor) + } + self.attrs = {'use_mkldnn': True} + self.outputs = {'Out': self.x_fp32 * self.scale} + + +class TestScaleOpBF16ScaleTensorNotBiasAfterScale(TestScaleOpBF16): + def setUp(self): + self.op_type = "scale" + self.scale = 1.2 + self.x_fp32 = np.random.random((9, 13)).astype(np.float32) + self.x_bf16 = convert_float_to_uint16(self.x_fp32) + self.scale_tensor = np.array([self.scale]).astype(np.float32) + self.inputs = { + 'X': self.x_bf16, + 'ScaleTensor': convert_float_to_uint16(self.scale_tensor) + } + self.attrs = { + 'bias': -1.1, + 'bias_after_scale': False, + 'use_mkldnn': True + } + self.outputs = {'Out': (self.x_fp32 + self.attrs['bias']) * self.scale} + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_scale_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_scale_mkldnn_op.py new file mode 100644 index 0000000000..528b55dcd8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_scale_mkldnn_op.py @@ -0,0 +1,104 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest +import paddle +import paddle.fluid as fluid + + +class TestScaleOp(OpTest): + def setUp(self): + self.op_type = "scale" + self.inputs = {'X': np.random.random((10, 10)).astype(np.float32)} + self.attrs = {'scale': -2.3, 'use_mkldnn': True, 'bias': 0.2} + self.use_mkldnn = True + self.outputs = { + 'Out': (self.inputs['X'] * self.attrs['scale']) + self.attrs['bias'] + } + + def test_check_output(self): + self.check_output(check_dygraph=False) + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestScaleOpBiasNotAfterScale(OpTest): + def setUp(self): + self.op_type = "scale" + self.inputs = {'X': np.random.random((10, 10)).astype(np.float32)} + self.attrs = { + 'scale': 1.5, + 'use_mkldnn': True, + 'bias': 2.3, + 'bias_after_scale': False + } + self.use_mkldnn = True + self.outputs = { + 'Out': (self.inputs['X'] + self.attrs['bias']) * self.attrs['scale'] + } + + def test_check_output(self): + self.check_output(check_dygraph=False) + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestScaleOpScaleTensor(OpTest): + def setUp(self): + self.op_type = "scale" + self.scale = -2.3 + self.inputs = { + 'X': np.random.random((10, 10)).astype(np.float32), + 'ScaleTensor': np.array([self.scale]).astype(np.float32) + } + self.attrs = {} + self.outputs = {'Out': self.inputs['X'] * self.scale} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestScaleOpScaleTensorNotBiasAfterScale(OpTest): + def setUp(self): + self.op_type = "scale" + self.scale = -1.2 + self.inputs = { + 'X': np.random.random((10, 10)).astype(np.float32), + 'ScaleTensor': np.array([self.scale]).astype(np.float32) + } + self.attrs = {'bias': -6.8, 'bias_after_scale': False} + self.outputs = { + 'Out': + (self.inputs['X'] + self.attrs['bias']) * self.inputs['ScaleTensor'] + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 15bcae8260..c5ea8891a2 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -447,6 +447,8 @@ STATIC_MODE_TESTING_LIST = [ 'test_sample_logits_op', 'test_save_model_without_var', 'test_scale_op', + 'test_scale_mkldnn_op', + 'test_scale_bf16_mkldnn_op', 'test_scaled_dot_product_attention', 'test_scatter_nd_op', 'test_seed_op', -- GitLab