diff --git a/paddle/fluid/operators/log_softmax_op.cc b/paddle/fluid/operators/log_softmax_op.cc index d6e2b3ecff8c83e47a9016cc3d233d1aa03fb52b..0e69b397e04c7eda7f515350caf870be5d7b57a5 100644 --- a/paddle/fluid/operators/log_softmax_op.cc +++ b/paddle/fluid/operators/log_softmax_op.cc @@ -31,9 +31,17 @@ class LogSoftmaxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -48,6 +56,10 @@ class LogSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { "The dimension index of Input(x) to perform log_softmax," "default -1 for last dimension") .SetDefault(-1); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false) + .AsExtra(); AddComment(R"DOC( LogSoftmax Operator. diff --git a/paddle/fluid/operators/mkldnn/log_softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/log_softmax_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..450462e7d4bb995a4bd60a3a93fe6f2c6e91042f --- /dev/null +++ b/paddle/fluid/operators/mkldnn/log_softmax_mkldnn_op.cc @@ -0,0 +1,78 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/softmax_op.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +template +class LogSoftmaxMKLDNNHandler + : public platform::MKLDNNHandlerNoCachingT { + public: + LogSoftmaxMKLDNNHandler(const dnnl::engine mkldnn_engine, + platform::Place cpu_place, const Tensor* x, + const int axis) + : platform::MKLDNNHandlerNoCachingT( + mkldnn_engine, cpu_place) { + const auto logsoftmax_tz = phi::vectorize(x->dims()); + const auto md = dnnl::memory::desc( + logsoftmax_tz, platform::MKLDNNGetDataType(), x->format()); + + this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_inference, + md, axis); + } +}; + +template +class LogSoftmaxMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = + ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + const Tensor* x = ctx.Input("X"); + Tensor* out = ctx.Output("Out"); + + int axis = ctx.Attr("axis"); + axis = axis >= 0 ? axis : x->dims().size() + axis; + + LogSoftmaxMKLDNNHandler handler(mkldnn_engine, ctx.GetPlace(), x, axis); + + auto src_memory_p = handler.AcquireSrcMemory(x); + auto dst_memory_p = handler.AcquireDstMemory(out); + + auto logsoftmax_p = handler.AcquireForwardPrimitive(); + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + logsoftmax_p->execute(astream, {{DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}); + astream.wait(); + + out->set_layout(framework::DataLayout::kMKLDNN); + out->set_format(x->format()); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(log_softmax, MKLDNN, ::paddle::platform::CPUPlace, + ops::LogSoftmaxMKLDNNKernel, + ops::LogSoftmaxMKLDNNKernel); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_log_softmax_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_log_softmax_op.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc0623a112f51fb74654f575db4194f06c79e5b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_log_softmax_op.py @@ -0,0 +1,63 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import MkldnnAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +from functools import partial +import unittest +from hypothesis import given +import hypothesis.strategies as st + + +class TestMKLDNNLogSoftmaxOp(MkldnnAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self, *args, **kwargs): + def generate_input(*args, **kwargs): + return np.random.random(kwargs['in_shape']).astype(np.float32) + + logsoftmax_op = OpConfig( + type="log_softmax", + inputs={"X": ["input_data"]}, + outputs={"Out": ["output_data"]}, + attrs={"axis": kwargs['axis']}) + + program_config = ProgramConfig( + ops=[logsoftmax_op], + weights={}, + inputs={ + "input_data": TensorConfig(data_gen=partial(generate_input, + *args, **kwargs)), + }, + outputs=["output_data"]) + + yield program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_mkldnn=True) + yield config, (1e-5, 1e-5) + + @given( + axis=st.sampled_from([-2, -1, 0, 1]), + in_shape=st.lists( + st.integers( + min_value=2, max_value=5), min_size=3, max_size=5)) + def test(self, *args, **kwargs): + self.run_test(quant=False, *args, **kwargs) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_log_softmax_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_log_softmax_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..7477eaf3339b25a9c40fcf0870b55544e7cf5a2e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_log_softmax_mkldnn_op.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +from paddle.fluid import core +from paddle.fluid.tests.unittests.test_log_softmax import ref_log_softmax +from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16 + + +@OpTestTool.skip_if_not_cpu_bf16() +class TestLogSoftmaxOneDNNOp(OpTest): + def setUp(self): + self.op_type = 'log_softmax' + self.set_dtype() + self.set_shape() + self.set_axis() + + x = np.random.uniform(0.1, 1.0, self.shape).astype(np.float32) + out = np.apply_along_axis(ref_log_softmax, self.axis, x) + + if self.dtype == np.uint16: + x = convert_float_to_uint16(x) + + self.inputs = {'X': x} + self.outputs = {'Out': out} + self.attrs = {'axis': self.axis, 'use_mkldnn': True} + + def set_dtype(self): + self.dtype = np.float32 + + def set_shape(self): + self.shape = [2, 3, 4, 5] + + def set_axis(self): + self.axis = -1 + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace()) + + +class TestLogSoftmax1DOneDNNOp(TestLogSoftmaxOneDNNOp): + def set_shape(self): + self.shape = [100] + + +class TestLogSoftmax3DOneDNNOp(TestLogSoftmaxOneDNNOp): + def set_shape(self): + self.shape = [12, 10, 3] + + +class TestLogSoftmax5DOneDNNOp(TestLogSoftmaxOneDNNOp): + def set_shape(self): + self.shape = [2, 3, 4, 5, 6] + + +class TestLogSoftmaxPositiveAxisOneDNNOp(TestLogSoftmaxOneDNNOp): + def set_axis(self): + self.axis = 2 + + +# BF16 TESTS +class TestLogSoftmax1DBF16OneDNNOp(TestLogSoftmax1DOneDNNOp): + def set_dtype(self): + self.dtype = np.uint16 + + +class TestLogSoftmaxPositiveAxisBF16OneDNNOp( + TestLogSoftmaxPositiveAxisOneDNNOp): + def set_dtype(self): + self.dtype = np.uint16 + + +class TestLogSoftmax5DBF16OneDNNOp(TestLogSoftmax5DOneDNNOp): + def set_shape(self): + self.shape = [2, 3, 4, 5, 6] + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_log_softmax.py b/python/paddle/fluid/tests/unittests/test_log_softmax.py index 0dd6c9f893e2a78dff9f77617853b3d8e35a6648..d1437ca9c96f1ba5fd2b9e1e420f91414d4f923a 100644 --- a/python/paddle/fluid/tests/unittests/test_log_softmax.py +++ b/python/paddle/fluid/tests/unittests/test_log_softmax.py @@ -14,7 +14,7 @@ import unittest import numpy as np -from op_test import OpTest +from paddle.fluid.tests.unittests.op_test import OpTest import paddle import paddle.nn.functional as F