diff --git a/paddle/fluid/operators/expand_v2_op.cc b/paddle/fluid/operators/expand_v2_op.cc index 618c1560c5eac709e32f928f4142cf159ac5c39d..3c2b939e799577e999ebb0feb1634ac0f3e35c7b 100644 --- a/paddle/fluid/operators/expand_v2_op.cc +++ b/paddle/fluid/operators/expand_v2_op.cc @@ -89,9 +89,17 @@ class ExpandV2Op : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( @@ -130,6 +138,14 @@ class ExpandV2OpMaker : public framework::OpProtoAndCheckerMaker { "the corresponding value given by Attr(expand_times)."); AddAttr>("shape", "The expanded shape for each dimension.") .SetDefault({}); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr( + "mkldnn_data_type", + "(string, default \"float32\"). Data type of mkldnn kernel") + .SetDefault("float32") + .InEnum({"float32", "bfloat16"}); AddComment(R"DOC( Expand the input to the given shape. The rank of X should be in [1, 6] and size of 'shape' must be in [1, 6] also. @@ -200,9 +216,17 @@ class ExpandV2GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/mkldnn/expand_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/expand_v2_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ffd64a841ecb39ed9de4f761a936b196cf4d4eaa --- /dev/null +++ b/paddle/fluid/operators/mkldnn/expand_v2_mkldnn_op.cc @@ -0,0 +1,161 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace { + +using paddle::framework::Tensor; +using paddle::framework::vectorize; +using paddle::framework::GradVarName; +using paddle::framework::ExecutionContext; +using paddle::platform::MKLDNNDeviceContext; + +template +class ExpandMKLDNNKernel : public paddle::framework::OpKernel { + public: + void Compute(const ExecutionContext& ctx) const override { + this->RunKernel(ctx); + } + + void RunKernel(const ExecutionContext& ctx) const { + const auto& dev_ctx = ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); + + const auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + auto x_vec_dims = vectorize(x->dims()); + auto out_vec_dims = vectorize(out->dims()); + + dnnl::memory::format_tag x_format_tag = x->format(); + if (x_vec_dims.size() != out_vec_dims.size()) { + x_format_tag = + GetExtendedFormatTag(x_vec_dims, out_vec_dims.size(), x_format_tag); + } + + out->set_format(x_format_tag); + + paddle::platform::BroadcastDataMKLDNNHandler handler( + dnnl::algorithm::binary_add, dev_ctx, onednn_engine, ctx.GetPlace(), + out, x, 0.0f, 1.0f, ctx.InputName("X"), x_vec_dims); + + auto src_memory_p = handler.AcquireSrcMemory(x); + auto dst_memory_p = handler.AcquireDstMemory(out); + auto binary_p = handler.AcquireForwardPrimitive(); + + const std::unordered_map args = { + {DNNL_ARG_SRC_0, *dst_memory_p}, + {DNNL_ARG_SRC_1, *src_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + auto& astream = MKLDNNDeviceContext::tls().get_stream(); + binary_p->execute(astream, args); + astream.wait(); + + out->set_layout(paddle::framework::DataLayout::kMKLDNN); + out->set_format(paddle::platform::GetMKLDNNFormat(*dst_memory_p)); + } + + private: + dnnl::memory::format_tag GetExtendedFormatTag( + std::vector& dims, int new_size, + mkldnn::memory::format_tag format_tag) const { + mkldnn::memory::desc md(dims, paddle::platform::MKLDNNGetDataType(), + format_tag); + std::vector new_dims(new_size, 1); + std::copy(dims.begin(), dims.end(), + new_dims.begin() + new_size - dims.size()); + + dims = std::move(new_dims); + return paddle::platform::GetMKLDNNFormat(md.reshape(dims)); + } +}; + +template +class ExpandGradMKLDNNKernel : public paddle::framework::OpKernel { + public: + void Compute(const ExecutionContext& ctx) const override { + this->RunKernel(ctx); + } + + void RunKernel(const ExecutionContext& ctx) const { + const auto& dev_ctx = ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto* dout = ctx.Input(GradVarName("Out")); + auto* dx = ctx.Output(GradVarName("X")); + + auto dx_vec_dims = vectorize(dx->dims()); + auto dout_vec_dims = vectorize(dout->dims()); + + if (dx_vec_dims.size() != dout_vec_dims.size()) { + dx_vec_dims.insert(dx_vec_dims.begin(), + dout_vec_dims.size() - dx_vec_dims.size(), 1); + } + + auto& astream = MKLDNNDeviceContext::tls().get_stream(); + if (dout_vec_dims == dx_vec_dims) { + mkldnn::memory::data_type dout_type = + paddle::framework::ToMKLDNNDataType(dout->type()); + std::string key = paddle::platform::CreateKey( + dev_ctx, dout_vec_dims, dout->format(), dout->format(), dout_type); + paddle::platform::ReorderMKLDNNHandler reorder_handler( + dout_vec_dims, dout->type(), dout_type, dev_ctx, onednn_engine, key); + + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + dout->format(), paddle::platform::to_void_cast(dout->data())); + + auto reorder_dst_memory_p = + reorder_handler.AcquireDstMemory(dx, dout->format(), ctx.GetPlace()); + + auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p, + reorder_dst_memory_p); + + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); + astream.wait(); + + dx->set_layout(paddle::framework::DataLayout::kMKLDNN); + dx->set_format( + paddle::platform::GetMKLDNNFormat(reorder_dst_memory_p->get_desc())); + } else { + paddle::platform::ReductionMKLDNNHandler handler( + dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine, + ctx.GetPlace(), dout, dx, ctx.InputName("X"), dx_vec_dims); + + auto src_memory_p = handler.AcquireSrcMemory(dout); + auto dst_memory_p = handler.AcquireDstMemory(dx); + + std::unordered_map reduction_args = { + {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; + + auto reduction_p = handler.AcquireForwardPrimitive(); + + reduction_p->execute(astream, reduction_args); + astream.wait(); + dx->set_layout(paddle::framework::DataLayout::kMKLDNN); + dx->set_format(paddle::platform::GetMKLDNNFormat( + dst_memory_p->get_desc().reshape(vectorize(dx->dims())))); + } + } +}; +} // anonymous namespace + +REGISTER_OP_KERNEL(expand_v2, MKLDNN, paddle::platform::CPUPlace, + ExpandMKLDNNKernel, + ExpandMKLDNNKernel); + +REGISTER_OP_KERNEL(expand_v2_grad, MKLDNN, paddle::platform::CPUPlace, + ExpandGradMKLDNNKernel, + ExpandGradMKLDNNKernel); diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h index 40cd3ba974f04c0196101f432cf8d51f2b00ce34..6a9aae046f386614464760652cbda004b3f24086 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -165,23 +165,21 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel { x_format_tag = getPlainFormatTag(output_dx); } - output_dx->mutable_data(ctx.GetPlace()); output_dx->set_format(x_format_tag); - output_dx->set_layout(input_dy->layout()); platform::BroadcastDataMKLDNNHandler handler( binary_type, dev_ctx, onednn_engine, ctx.GetPlace(), output_dx, input_dy, scale_x, scale_y, ctx.InputName(framework::GradVarName("Out")), input_dims); - const auto src_dx_memory = handler.AcquireSrcMemory(output_dx); - const auto src_dy_memory = handler.AcquireSecondSrcMemory(input_dy); + const auto src_memory_p = handler.AcquireSrcMemory(input_dy); + const auto dst_memory_p = handler.AcquireDstMemory(output_dx); const auto binary_prim = handler.AcquireForwardPrimitive(); const std::unordered_map args = { - {DNNL_ARG_SRC_0, *src_dx_memory}, - {DNNL_ARG_SRC_1, *src_dy_memory}, - {DNNL_ARG_DST, *src_dx_memory}}; + {DNNL_ARG_SRC_0, *dst_memory_p}, + {DNNL_ARG_SRC_1, *src_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); binary_prim->execute(astream, args); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 58622fb2529b830ed222284296153dd4b55c1cf8..f63d45d7ff6ae611dc1633e94dac00c4f6db2339 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -695,8 +695,8 @@ class BroadcastDataMKLDNNHandler BroadcastDataMKLDNNHandler(const dnnl::algorithm algo, const MKLDNNDeviceContext& dev_ctx, const mkldnn::engine engine, - platform::Place cpu_place, const Tensor* x, - const Tensor* y, float scale_x, float scale_y, + platform::Place cpu_place, const Tensor* out, + const Tensor* x, float scale_x, float scale_y, const std::string& uniq_name, const std::vector& input_dims) : platform::MKLDNNHandlerT( @@ -711,19 +711,12 @@ class BroadcastDataMKLDNNHandler x->format(), MKLDNNMemoryFormat::undef, platform::errors::InvalidArgument("Wrong format set for X tensor.")); - PADDLE_ENFORCE_EQ( - y->layout(), DataLayout::kMKLDNN, - platform::errors::InvalidArgument("Wrong layout set for Y tensor.")); - PADDLE_ENFORCE_NE( - y->format(), MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument("Wrong format set for Y tensor.")); - - const auto src0_tz = framework::vectorize(x->dims()); + const auto src0_tz = framework::vectorize(out->dims()); const auto src0_md = dnnl::memory::desc( - src0_tz, platform::MKLDNNGetDataType(), x->format()); + src0_tz, platform::MKLDNNGetDataType(), out->format()); const auto src1_md = dnnl::memory::desc( - input_dims, platform::MKLDNNGetDataType(), x->format()); + input_dims, platform::MKLDNNGetDataType(), out->format()); dnnl::primitive_attr attributes; attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x}); @@ -734,18 +727,14 @@ class BroadcastDataMKLDNNHandler } } - std::shared_ptr AcquireSrcMemory(framework::Tensor* input) { - T* input_data = input->data(); - memset(input_data, 0, this->fwd_pd_->src_desc().get_size()); - return this->AcquireMemoryFromPrimitive( - this->fwd_pd_->src_desc(), to_void_cast(input_data), "@src0_mem_p"); - } - - std::shared_ptr AcquireSecondSrcMemory( - const framework::Tensor* input) { - const T* input_data = input->data(); - return this->AcquireMemoryFromPrimitive( - this->fwd_pd_->src1_desc(), to_void_cast(input_data), "@src1_mem_p"); + template + std::shared_ptr AcquireDstMemory(framework::Tensor* output) { + T_out* ptr = output->mutable_data( + this->place_, this->fwd_pd_->dst_desc().get_size()); + ; + memset(ptr, 0, this->fwd_pd_->dst_desc().get_size()); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr, + "@dst_mem_p"); } }; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_expand_v2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_expand_v2_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..51d7fe971674dec38f01bc7137a5c2ea0ba497e6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_expand_v2_mkldnn_op.py @@ -0,0 +1,107 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard, core +from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16 + + +@OpTestTool.skip_if(core.is_compiled_with_cuda(), + "CUDA required dygraph so oneDNN UT must be skipped") +class TestExpandV2OneDNNOp(OpTest): + def setUp(self): + self.op_type = "expand_v2" + self.init_data() + self.x = np.random.random(self.ori_shape).astype("float32") + self.set_inputs() + self.attrs = {'shape': self.shape, 'use_mkldnn': True} + output = np.tile(self.x, self.expand_times) + self.outputs = {'Out': output} + + def set_inputs(self): + self.inputs = {'X': self.x} + + def init_data(self): + self.ori_shape = [1, 140] + self.shape = [12, 140] + self.expand_times = [12, 1] + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace()) + + def test_check_grad(self): + self.check_grad_with_place(core.CPUPlace(), ["X"], "Out") + + +class TestExpandV2ExpandDimOneDNNOp(TestExpandV2OneDNNOp): + def init_data(self): + self.ori_shape = [120] + self.shape = [2, 120] + self.expand_times = [2, 1] + + +class TestExpandV2CopyScenarioOneDNNOp(TestExpandV2OneDNNOp): + def init_data(self): + self.ori_shape = (2, 10, 5) + self.shape = (2, 10, 5) + self.expand_times = (1, 1, 1) + + +class TestExpandV2CopyScenarioShapeNotGivenOneDNNOp(TestExpandV2OneDNNOp): + def init_data(self): + self.ori_shape = (2, 4, 5, 7) + self.shape = (-1, -1, -1, -1) + self.expand_times = (1, 1, 1, 1) + + +# BF16 TESTS +def create_expand_v2_bf16_test_class(parent): + @OpTestTool.skip_if_not_cpu_bf16() + class TestExpandV2BF16OneDNNOp(parent): + def set_inputs(self): + self.inputs = {"X": convert_float_to_uint16(self.x)} + + def calculate_grads(self): + self.dout = self.outputs['Out'] + self.dx = self.dout.copy() + + for i in range(len(self.shape)): + if self.expand_times[i] != 1: + self.dx = np.sum(self.dx, axis=i, keepdims=True) + + def test_check_grad(self): + self.calculate_grads() + self.check_grad_with_place( + core.CPUPlace(), ["X"], + "Out", + user_defined_grads=[convert_float_to_uint16(self.dx)], + user_defined_grad_outputs=[self.dout]) + + cls_name = "{0}_{1}".format(parent.__name__, "Expand_v2_BF16") + TestExpandV2BF16OneDNNOp.__name__ = cls_name + globals()[cls_name] = TestExpandV2BF16OneDNNOp + + +create_expand_v2_bf16_test_class(TestExpandV2OneDNNOp) +create_expand_v2_bf16_test_class(TestExpandV2ExpandDimOneDNNOp) +create_expand_v2_bf16_test_class(TestExpandV2CopyScenarioOneDNNOp) +create_expand_v2_bf16_test_class(TestExpandV2CopyScenarioShapeNotGivenOneDNNOp) + +if __name__ == '__main__': + paddle.enable_static() + unittest.main()