diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 37a8ec12680abaa01cb2b540631d202013016235..7717bcfc3e96249bd99b80525728718ee18300b5 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2262,11 +2262,26 @@ PDNode *patterns::QuantizePlacement::operator()( PDNode *patterns::Bfloat16Placement::operator()( const std::unordered_set &bfloat16_enabled_op_types) { std::unordered_set supported_op_types = - std::unordered_set( - {"concat", "conv2d", "conv2d_transpose", "elementwise_add", - "elementwise_mul", "fc", "fusion_gru", "fusion_lstm", "gelu", - "layer_norm", "matmul", "matmul_v2", "pool2d", "relu", "reshape2", - "softmax", "split", "sum", "transpose2"}); + std::unordered_set({"concat", + "conv2d", + "conv2d_transpose", + "elementwise_add", + "elementwise_mul", + "fc", + "fusion_gru", + "fusion_lstm", + "gelu", + "layer_norm", + "matmul", + "matmul_v2", + "pool2d", + "prelu", + "relu", + "reshape2", + "softmax", + "split", + "sum", + "transpose2"}); if (!bfloat16_enabled_op_types.empty()) { supported_op_types = bfloat16_enabled_op_types; } diff --git a/paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e2a4482666a1ace818777e9e7e3abaa1e6ff2f22 --- /dev/null +++ b/paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc @@ -0,0 +1,187 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace paddle { +namespace operators { + +using dnnl::memory; +using framework::Tensor; +using platform::GetMKLDNNFormat; +using platform::MKLDNNDeviceContext; +using platform::MKLDNNGetDataType; +using platform::to_void_cast; + +namespace { +template +class PReluMKLDNNHandler + : public platform::MKLDNNHandlerT { + public: + PReluMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, + const mkldnn::engine engine, platform::Place cpu_place, + const Tensor* x, const Tensor* weights, + const std::string& uniq_name, const std::string& mode, + bool is_test = false) + : platform::MKLDNNHandlerT( + dev_ctx, engine, cpu_place, + platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), + uniq_name)) { + if (!this->isCached()) { + auto x_md = memory::desc(framework::vectorize(x->dims()), + MKLDNNGetDataType(), x->format()); + + auto weights_dims = framework::vectorize(weights->dims()); + + // weights must have same size as X only for "element" case + if (weights->dims().size() != x->dims().size()) { + auto new_weights_dims = std::vector(x->dims().size(), 1); + if (mode == "channel") { + new_weights_dims[1] = + *std::max_element(weights_dims.begin(), weights_dims.end()); + } + weights_dims = std::move(new_weights_dims); + } + auto weights_md = memory::desc(weights_dims, MKLDNNGetDataType(), + memory::format_tag::any); + + this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training, + x_md, weights_md); + if (!is_test) + this->AcquireBackwardPrimitiveDescriptor(x_md, weights_md, x_md, + weights_md); + } + } + + std::shared_ptr AcquireWeightsMemoryPossiblyWithReorder( + const Tensor* input, const bool is_test) { + const T* input_data = input->data(); + + // if weights are 1D, every format tag is correct, so we accept + // format_tag::any's output and no reorder is needed + if (input->dims().size() == 1) { + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(), + to_void_cast(input_data), + "@alpha_mem_p"); + } + + auto user_weights_md = + memory::desc(framework::vectorize(input->dims()), + MKLDNNGetDataType(), input->format()); + return this->AcquireMemoryWithReorder( + user_weights_md, this->fwd_pd_->weights_desc(), + to_void_cast(input_data), "@alpha_mem_p", is_test); + } + + std::shared_ptr AcquireDiffWeightsMemory(Tensor* output) { + T* output_data = output->mutable_data( + this->place_, this->bwd_pd_->diff_weights_desc().get_size()); + return this->AcquireMemoryFromPrimitive(this->bwd_pd_->diff_weights_desc(), + output_data, "@diff_weights_mem_p"); + } +}; +} // anonymous namespace + +template +class PReluMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx); + } + + void RunKernel(const framework::ExecutionContext& ctx) const { + const auto& dev_ctx = ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); + + const auto* x = ctx.Input("X"); + const auto* alpha = ctx.Input("Alpha"); + auto* out = ctx.Output("Out"); + const bool is_test = ctx.Attr("is_test"); + const auto mode = ctx.Attr("mode"); + + PReluMKLDNNHandler handler(dev_ctx, onednn_engine, ctx.GetPlace(), x, + alpha, ctx.InputName("X"), mode, is_test); + + auto src_memory_p = handler.AcquireSrcMemory(x); + auto weights_memory_p = + handler.AcquireWeightsMemoryPossiblyWithReorder(alpha, is_test); + auto dst_memory_p = handler.AcquireDstMemory(out); + auto prelu_p = handler.AcquireForwardPrimitive(); + + auto& astream = MKLDNNDeviceContext::tls().get_stream(); + prelu_p->execute(astream, {{DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}); + astream.wait(); + + out->set_layout(framework::DataLayout::kMKLDNN); + out->set_format(GetMKLDNNFormat(*dst_memory_p)); + } +}; + +template +class PReluGradMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx); + } + + void RunKernel(const framework::ExecutionContext& ctx) const { + const auto& dev_ctx = ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto* x = ctx.Input("X"); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dalpha = ctx.Output(framework::GradVarName("Alpha")); + auto* alpha = ctx.Input("Alpha"); + const bool is_test = ctx.Attr("is_test"); + const auto mode = ctx.Attr("mode"); + + PReluMKLDNNHandler handler(dev_ctx, onednn_engine, ctx.GetPlace(), x, + alpha, framework::GradVarName("X"), mode); + + auto src_memory_p = handler.AcquireSrcMemory(x); + auto weights_memory_p = + handler.AcquireWeightsMemoryPossiblyWithReorder(alpha, is_test); + auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx); + auto diff_weights_memory_p = handler.AcquireDiffWeightsMemory(dalpha); + auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout); + auto prelu_p = handler.AcquireBackwardPrimitive(); + + auto& astream = MKLDNNDeviceContext::tls().get_stream(); + prelu_p->execute(astream, + {{DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, + {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}, + {DNNL_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}}); + astream.wait(); + + dx->set_layout(framework::DataLayout::kMKLDNN); + dx->set_format(GetMKLDNNFormat(*diff_src_memory_p)); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(prelu, MKLDNN, paddle::platform::CPUPlace, + ops::PReluMKLDNNKernel, + ops::PReluMKLDNNKernel); + +REGISTER_OP_KERNEL(prelu_grad, MKLDNN, paddle::platform::CPUPlace, + ops::PReluGradMKLDNNKernel, + ops::PReluGradMKLDNNKernel); diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index 8a18843a97263689efed737741c71dc19f593897..b5509e760e8380eb0d85545670d67d346ce3796b 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -95,9 +95,17 @@ class PReluOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -126,6 +134,18 @@ There are modes: )DOC"); AddAttr("mode", "The mode for inputs to share weights.") .SetDefault("all"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr( + "mkldnn_data_type", + "(string, default \"float32\"). Data type of mkldnn kernel") + .SetDefault("float32") + .InEnum({"float32", "bfloat16"}); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); } }; @@ -153,9 +173,17 @@ class PReluGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_prelu_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_prelu_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..5489bf109dd54aea3440e66811f75960ed117fc7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_prelu_mkldnn_op.py @@ -0,0 +1,185 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16 + + +def ref_prelu(x, weight, mode): + result = x.copy() + + if mode == "all": + result = np.where(x > 0, x, x * weight[0]) + elif mode == "channel": + if len(weight.shape) > 1: + for i in range(x.shape[1]): + result[:, i] = np.where(x[:, i] > 0, x[:, i], + x[:, i] * weight[0, i]) + else: + for i in range(x.shape[1]): + result[:, i] = np.where(x[:, i] > 0, x[:, i], + x[:, i] * weight[i]) + elif mode == "element": + result = np.where(x[:] > 0, x[:], x[:] * weight) + + return result + + +class TestPReluModeChannelOneDNNOp(OpTest): + def init_attrs(self): + self.mode = "element" + self.alpha = np.random.random((1, 4, 5, 5)).astype("float32") + + def set_dtype_attr(self): + pass + + def set_inputs(self): + self.inputs = {'X': self.x, 'Alpha': self.alpha} + + def setUp(self): + self.op_type = "prelu" + self.x = np.random.random((2, 4, 5, 5)).astype("float32") + 1 + self.init_attrs() + self.set_inputs() + self.attrs = {'mode': self.mode, 'use_mkldnn': True} + self.set_dtype_attr() + + self.outputs = {'Out': ref_prelu(self.x, self.alpha, self.mode)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X', 'Alpha'], 'Out') + + +class TestPReluModeAllOneDNNOp(TestPReluModeChannelOneDNNOp): + def init_attrs(self): + self.mode = "all" + self.alpha = np.random.random((1, 1, 1, 1)).astype("float32") + + # Skip 'Alpha' input check because in mode = 'all' it has to be a single + # 1D value so checking if it has at least 100 values will cause an error + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestPReluModeElementOneDNNOp(TestPReluModeChannelOneDNNOp): + def init_attrs(self): + self.mode = "element" + self.alpha = np.random.random((1, 4, 5, 5)).astype("float32") + + +class TestPReluModeChannel3DOneDNNOp(TestPReluModeChannelOneDNNOp): + def init_attrs(self): + self.mode = "channel" + self.x = np.random.random((1, 100, 1)).astype("float32") + self.alpha = np.random.random((1, 100, 1)).astype("float32") + + +class TestPReluModeChannelAlpha1DOneDNNOp(TestPReluModeChannelOneDNNOp): + def init_attrs(self): + self.mode = "channel" + self.x = np.random.random((1, 100, 1)).astype("float32") + self.alpha = np.random.random((100)).astype("float32") + + +class TestPReluModeAllAlpha1DOneDNNOp(TestPReluModeAllOneDNNOp): + def init_attrs(self): + self.mode = "channel" + self.x = np.random.random((1, 1, 100)).astype("float32") + self.alpha = np.random.random((1)).astype("float32") + + +# BF16 TESTS +def create_bf16_test_class(parent): + class TestPReluBF16OneDNNOp(parent): + def set_inputs(self, ): + self.inputs = { + 'X': convert_float_to_uint16(self.x), + 'Alpha': convert_float_to_uint16(self.alpha) + } + + def set_dtype_attr(self): + self.attrs['mkldnn_data_type'] = "bfloat16" + + def calculate_grads(self): + dout = self.outputs['Out'] + self.dx = self.x.copy() + self.dalpha = self.alpha.copy() + + if self.mode == "all": + self.dx = np.where(self.x > 0, dout, dout * self.alpha[0]) + elif self.mode == "channel": + if len(self.alpha.shape) > 1: + for i in range(self.x.shape[1]): + self.dx[:, i] = np.where(self.x[:, i] > 0, dout[:, i], + dout[:, i] * self.alpha[0, i]) + else: + for i in range(self.x.shape[1]): + self.dx[:, i] = np.where(self.x[:, i] > 0, dout[:, i], + dout[:, i] * self.alpha[i]) + self.dx + elif self.mode == "element": + self.dx = np.where(self.x[:] > 0, dout[:], dout[:] * self.alpha) + + self.dalpha = np.where(self.x < 0, dout * self.x, 0) + self.dout = dout + + def test_check_output(self): + if core.is_compiled_with_cuda(): + self.skipTest( + "OneDNN doesn't support bf16 with CUDA, skipping UT" + + self.__class__.__name__) + elif not core.supports_bfloat16(): + self.skipTest("Core doesn't support bf16, skipping UT" + + self.__class__.__name__) + else: + self.check_output_with_place(core.CPUPlace()) + + def test_check_grad(self): + if core.is_compiled_with_cuda() or not core.supports_bfloat16(): + self.skipTest( + "Core is compiled with cuda or doesn't support bf16, kipping UT" + + self.__class__.__name__) + else: + self.calculate_grads() + self.check_grad_with_place( + core.CPUPlace(), ["X", "Alpha"], + "Out", + user_defined_grads=[self.dx, self.dalpha], + user_defined_grad_outputs=[ + convert_float_to_uint16(self.dout) + ]) + + cls_name = "{0}_{1}".format(parent.__name__, "BF16") + TestPReluBF16OneDNNOp.__name__ = cls_name + globals()[cls_name] = TestPReluBF16OneDNNOp + + +#TODO jakpiase +#enable bf16 tests back when oneDNN bf16 class will be ready +#create_bf16_test_class(TestPReluModeChannelOneDNNOp) +#create_bf16_test_class(TestPReluModeElementOneDNNOp) +#create_bf16_test_class(TestPReluModeChannel3DOneDNNOp) +#create_bf16_test_class(TestPReluModeChannelAlpha1DOneDNNOp) + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 4f78eceee4f157225d8ab34004ea49f8b5f83ad5..f6de13b6fd4ce5980b930c1df06cb6676603ee45 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -360,7 +360,9 @@ class OpTest(unittest.TestCase): def is_bfloat16_op(self): return self.dtype == np.uint16 or ( hasattr(self, 'mkldnn_data_type') and - getattr(self, 'mkldnn_data_type') is "bfloat16") + getattr(self, 'mkldnn_data_type') is "bfloat16") or ( + hasattr(self, 'attrs') and 'mkldnn_data_type' in self.attrs and + self.attrs['mkldnn_data_type'] == 'bfloat16') def infer_dtype_from_inputs_outputs(self, inputs, outputs): def is_np_data(input): @@ -1436,6 +1438,9 @@ class OpTest(unittest.TestCase): op_outputs = self.outputs if hasattr(self, "outputs") else dict() op_attrs = self.attrs if hasattr(self, "attrs") else dict() + if self.is_bfloat16_op(): + check_dygraph = False + self._check_grad_helper() if self.dtype == np.float64 and \ self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST: diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 09029b6ad821ee60fa9db6c633c0cc3257bf1dbf..616d5ae280ad1acdaa3e2812981d27bbac8f2ab0 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -390,6 +390,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_positive_negative_pair_op', 'test_precision_recall_op', 'test_prelu_op', + 'test_prelu_mkldnn_op', 'test_print_op', 'test_prior_box_op', 'test_profiler',