diff --git a/paddle/fluid/operators/clip_op.cc b/paddle/fluid/operators/clip_op.cc index 7176a0466bb831cdbbaf66dfbb2d2625bdbf66cf..362f955ffc60c2539fb70e624349db95c0273c3f 100644 --- a/paddle/fluid/operators/clip_op.cc +++ b/paddle/fluid/operators/clip_op.cc @@ -31,6 +31,21 @@ class ClipOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", x_dims); ctx->ShareLoD("X", /*->*/ "Out"); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; template @@ -54,6 +69,14 @@ class ClipOpMaker : public framework::OpProtoAndCheckerMaker { "input(x)"); AddAttr("min", "float number, the minimum value to clip by."); AddAttr("max", "float number, the maximum value to clip by."); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr( + "mkldnn_data_type", + "(string, default \"float32\"). Data type of mkldnn kernel") + .SetDefault("float32") + .InEnum({"float32", "bfloat16"}); AddComment(R"DOC( Clip Operator. @@ -81,6 +104,21 @@ class ClipOpGrad : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("X"), x_dims); } } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; template diff --git a/paddle/fluid/operators/mkldnn/clip_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/clip_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..94c3700da8ca654ab27190d0904e8411b3ceac43 --- /dev/null +++ b/paddle/fluid/operators/mkldnn/clip_mkldnn_op.cc @@ -0,0 +1,99 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace { + +using paddle::framework::Tensor; + +template +class ClipMKLDNNKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx); + } + + void RunKernel(const paddle::framework::ExecutionContext& ctx) const { + const auto& dev_ctx = + ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + paddle::platform::ActivationMKLDNNHandler handler( + mkldnn::algorithm::eltwise_clip_v2, ctx, mkldnn_engine, ctx.GetPlace(), + x); + + auto src_memory_p = handler.AcquireSrcMemory(x); + auto dst_memory_p = handler.AcquireDstMemory(out); + auto activation_p = handler.AcquireForwardPrimitive(); + + auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); + activation_p->execute(astream, {{MKLDNN_ARG_FROM, *src_memory_p}, + {MKLDNN_ARG_TO, *dst_memory_p}}); + astream.wait(); + + out->set_layout(paddle::framework::DataLayout::kMKLDNN); + out->set_format(paddle::platform::GetMKLDNNFormat(*dst_memory_p)); + } +}; + +template +class ClipGradMKLDNNKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx); + } + + void RunKernel(const paddle::framework::ExecutionContext& ctx) const { + const auto& dev_ctx = + ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + auto* x = ctx.Input("X"); + auto* dx = ctx.Output(paddle::framework::GradVarName("X")); + auto* dout = ctx.Input(paddle::framework::GradVarName("Out")); + + paddle::platform::ActivationMKLDNNHandler handler( + mkldnn::algorithm::eltwise_clip_v2, ctx, mkldnn_engine, ctx.GetPlace(), + x, dout); + + auto src_memory_p = handler.AcquireBackwardSrcMemory(x); + auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout); + auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx); + auto activation_backward_p = handler.AcquireBackwardPrimitive(); + + auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); + activation_backward_p->execute(astream, + {{MKLDNN_ARG_SRC, *src_memory_p}, + {MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p}, + {MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}}); + astream.wait(); + + dx->set_layout(paddle::framework::DataLayout::kMKLDNN); + dx->set_format(paddle::platform::GetMKLDNNFormat(*diff_dst_memory_p)); + } +}; + +} // anonymous namespace + +REGISTER_OP_KERNEL(clip, MKLDNN, paddle::platform::CPUPlace, + ClipMKLDNNKernel, + ClipMKLDNNKernel); + +REGISTER_OP_KERNEL(clip_grad, MKLDNN, paddle::platform::CPUPlace, + ClipGradMKLDNNKernel, + ClipGradMKLDNNKernel); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 370d9b3925226249130559ccca90c26af4af44d4..49160f9463240d848baddebe46a7ea02e31946e7 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -977,8 +977,8 @@ class ActivationMKLDNNHandler cpu_place) { float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 0; float beta = ctx.HasAttr("beta") ? ctx.Attr("beta") : 0; - // eltwise_linear means we are in scale op - if (algorithm == mkldnn::algorithm::eltwise_linear) { + + if (ctx.Type() == "scale") { bool bias_after_scale = ctx.Attr("bias_after_scale"); auto* scale_tensor = ctx.Input("ScaleTensor"); alpha = (scale_tensor == nullptr) ? ctx.Attr("scale") @@ -988,7 +988,14 @@ class ActivationMKLDNNHandler // out = scale*X + bias // else // out = scale*(X + bias) = scale*X + scale*bias - if (!bias_after_scale) beta *= alpha; + if (!bias_after_scale) { + beta *= alpha; + } + } else if (ctx.Type() == "clip") { + alpha = ctx.HasInput("Min") ? ctx.Input("Min")->data()[0] + : ctx.Attr("min"); + beta = ctx.HasInput("Max") ? ctx.Input("Max")->data()[0] + : ctx.Attr("max"); } else { // paddle uses beta but mkldnn uses alpha for swish if (algorithm == mkldnn::algorithm::eltwise_swish) { @@ -1030,6 +1037,13 @@ class ActivationMKLDNNHandler alpha = ctx.Attr("threshold"); } + if (ctx.Type() == "clip_grad") { + alpha = ctx.HasInput("Min") ? ctx.Input("Min")->data()[0] + : ctx.Attr("min"); + beta = ctx.HasInput("Max") ? ctx.Input("Max")->data()[0] + : ctx.Attr("max"); + } + auto diff_dst_tz = framework::vectorize(out_grad->dims()); auto src_fmt = diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_clip_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_clip_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..97a913753184532cd8009de970ca615c3bf51ee6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_clip_mkldnn_op.py @@ -0,0 +1,120 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16 +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core + + +@OpTestTool.skip_if_not_cpu_bf16() +class TestClipOneDNNOp(OpTest): + def setUp(self): + self.op_type = "clip" + self.set_inputs() + self.set_attrs() + self.set_additional_inputs() + self.adjust_op_settings() + + self.min = self.attrs[ + 'min'] if not 'Min' in self.inputs else self.inputs['Min'] + self.max = self.attrs[ + 'max'] if not 'Max' in self.inputs else self.inputs['Max'] + + self.outputs = {'Out': np.clip(self.x_fp32, self.min, self.max)} + + def set_inputs(self): + self.inputs = {'X': np.random.random((10, 10)).astype(np.float32) * 25} + self.x_fp32 = self.inputs['X'] + + def set_additional_inputs(self): + pass + + def adjust_op_settings(self): + pass + + def set_attrs(self): + self.attrs = {'min': 7.2, 'max': 9.6, 'use_mkldnn': True} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestClipMinAsInputOneDNNOp(TestClipOneDNNOp): + def set_additional_inputs(self): + self.inputs['Min'] = np.array([6.8]).astype('float32') + + +class TestClipMaxAsInputOneDNNOp(TestClipOneDNNOp): + def set_additional_inputs(self): + self.inputs['Max'] = np.array([9.1]).astype('float32') + + +class TestClipMaxAndMinAsInputsOneDNNOp(TestClipOneDNNOp): + def set_additional_inputs(self): + self.inputs['Max'] = np.array([8.5]).astype('float32') + self.inputs['Min'] = np.array([7.1]).astype('float32') + + +# BF16 TESTS +def create_bf16_test_class(parent): + @OpTestTool.skip_if_not_cpu_bf16() + class TestClipBF16OneDNNOp(parent): + def set_inputs(self): + self.x_fp32 = np.random.random((10, 10)).astype(np.float32) * 25 + self.inputs = {'X': convert_float_to_uint16(self.x_fp32)} + + def adjust_op_settings(self): + self.dtype = np.uint16 + self.attrs['mkldnn_data_type'] = "bfloat16" + + def calculate_grads(self): + self.dout = self.outputs['Out'] + self.dx = np.zeros(self.x_fp32.shape).astype("float32") + + for i in range(self.dx.shape[0]): + for j in range(self.dx.shape[1]): + if self.x_fp32[j][i] > self.min and self.x_fp32[j][ + i] < self.max: + self.dx[j][i] = self.dout[j][i] + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace()) + + def test_check_grad(self): + self.calculate_grads() + self.check_grad_with_place( + core.CPUPlace(), ["X"], + "Out", + user_defined_grads=[self.dx], + user_defined_grad_outputs=[convert_float_to_uint16(self.dout)]) + + cls_name = "{0}_{1}".format(parent.__name__, "BF16") + TestClipBF16OneDNNOp.__name__ = cls_name + globals()[cls_name] = TestClipBF16OneDNNOp + + +create_bf16_test_class(TestClipOneDNNOp) +create_bf16_test_class(TestClipMinAsInputOneDNNOp) +create_bf16_test_class(TestClipMaxAsInputOneDNNOp) +create_bf16_test_class(TestClipMaxAndMinAsInputsOneDNNOp) + +if __name__ == "__main__": + paddle.enable_static() + unittest.main()