diff --git a/paddle/fluid/operators/concat_mkldnn_op.cc b/paddle/fluid/operators/concat_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c6652b788510d85c3bb5a1c0a9045135901dfd21 --- /dev/null +++ b/paddle/fluid/operators/concat_mkldnn_op.cc @@ -0,0 +1,217 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/concat_op.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace operators { + +using framework::DataLayout; +using framework::Tensor; +using mkldnn::memory; +using mkldnn::primitive; +using mkldnn::concat; +using mkldnn::stream; +using platform::to_void_cast; + +// Generate keys for storing/retriving primitives for this operator +// TODO(jczaja): Make hashing function more optimial +static std::string gethash(const memory::dims& input_dims, + const std::string& pooling_type, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, + const std::string& suffix) { + auto dims2str = [](const memory::dims& operand_dims) { + std::string dstr = ""; + for (size_t i = 0; i < operand_dims.size(); ++i) { + dstr += std::to_string(operand_dims[i]) + "-"; + } + return dstr; + }; + return dims2str(input_dims) + dims2str(ksize) + dims2str(strides) + + dims2str(paddings) + pooling_type + suffix; +} + +static void EnforceLayouts(const std::vector inputs) { + for (auto* input : inputs) { + const bool is_layout_correct = input->layout() == DataLayout::kMKLDNN; + const bool is_format_defined = input->format() != + memory::format::format_undef; + PADDLE_ENFORCE(is_layout_correct && is_format_defined, + "Wrong layout/format set for Input tensor"); + } +} + +static memory::primitive_desc CreateMemPrimDesc( + const framework::Tensor& input, const mkldnn::engine& engine) { + constexpr auto data_type = mkldnn::memory::f32; + const auto dims = paddle::framework::vectorize2int(input.dims()); + const auto format = input.format(); + auto description = memory::desc(dims, data_type, format); + auto mem_prim_desc = memory::primitive_desc(description, engine); + return mem_prim_desc; +} + +static platform::CPUPlace GetCpuPlace( + const paddle::framework::ExecutionContext& ctx) { + auto place = ctx.GetPlace(); + PADDLE_ENFORCE(paddle::platform::is_cpu_place(place), + "It must use CPUPlace."); + return boost::get(place); +} + +template +class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + auto place = GetCpuPlace(ctx); + auto& dev_ctx = + ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + auto multi_input = ctx.MultiInput("X"); + framework::Tensor* output = ctx.Output("Out"); + int64_t concat_axis = static_cast(ctx.Attr("axis")); + + EnforceLayouts(multi_input); + + std::vector srcs_pd; + std::vector srcs; + for (size_t i = 0; i < multi_input.size(); i++) { + auto mem_prim_desc = CreateMemPrimDesc(*multi_input[i], mkldnn_engine); + srcs_pd.push_back(mem_prim_desc); + srcs.push_back(memory(mem_prim_desc, to_void_cast(multi_input[i]->data()))); + } + auto dst_dims = paddle::framework::vectorize2int(output->dims()); + auto dst_desc = memory::desc(dst_dims, mkldnn::memory::f32, memory::format::any); + auto concat_pd = concat::primitive_desc(dst_desc, static_cast(concat_axis), srcs_pd); + auto dst_mem = memory(concat_pd.dst_primitive_desc(), output->mutable_data(place)); + + std::vector inputs; //= {srcs}; + inputs.reserve(srcs.size()); + for (size_t i = 0; i < srcs.size(); i++) { + inputs.push_back(srcs[i]); + } + auto concat_prim = concat(concat_pd, inputs, dst_mem); + + std::vector pipeline; + pipeline.push_back(concat_prim); + stream(stream::kind::eager).submit(pipeline).wait(); // TODO(mgallus): When this is not workin' split into decl and def + + /* + const T* input_data = input->data(); + T* output_data = output->mutable_data(ctx.GetPlace()); + + std::vector src_tz = paddle::framework::vectorize2int(input->dims()); + std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); + + auto input_format = input->format(); + memory::format output_format{memory::format::format_undef}; + + const std::string key = gethash(src_tz, pooling_type, ksize, strides, + paddings, ctx.op().Output("Out")); + const std::string key_pool_p = key + "@pool_p"; + const std::string key_pool_pd = key + "@pool_pd"; + const std::string key_pool_src_mem_p = key + "@pool_src_mem_p"; + const std::string key_pool_dst_mem_p = key + "@pool_dst_mem_p"; + const std::string key_pool_workspace_memory = + key + "@pool_workspace_memory"; + + auto pool_p = + std::static_pointer_cast(dev_ctx.GetBlob(key_pool_p)); + if (pool_p == nullptr) { + const std::vector& padding_left_top(paddings); + std::vector padding_right_bottom(paddings); + bool ceil_mode = ctx.Attr("ceil_mode"); + if (ceil_mode) { + CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides, + padding_right_bottom); + } + auto src_md = platform::MKLDNNMemDesc( + src_tz, platform::MKLDNNGetDataType(), input_format); + + auto dst_md = platform::MKLDNNMemDesc(dst_tz, mkldnn::memory::f32, + mkldnn::memory::format::any); + + std::shared_ptr pool_pd = + CreatePrimitiveDesc(src_md, dst_md, strides, padding_left_top, + padding_right_bottom, ksize, pooling_type, + mkldnn_engine, ceil_mode, is_test); + + // save pool_pd into global device context to be referred in backward path + if (!is_test) dev_ctx.SetBlob(key_pool_pd, pool_pd); + + auto src_memory = std::make_shared(pool_pd->src_primitive_desc(), + to_void_cast(input_data)); + auto dst_memory = + std::make_shared(pool_pd->dst_primitive_desc(), output_data); + + dev_ctx.SetBlob(key_pool_src_mem_p, src_memory); + dev_ctx.SetBlob(key_pool_dst_mem_p, dst_memory); + + if (is_test) { + pool_p = std::make_shared(*pool_pd, *src_memory, + *dst_memory); + } else { + std::shared_ptr workspace_memory = + CreateWorkspaceMemory(pool_pd, pooling_type, mkldnn_engine); + + // save pool_workspace_memory to be referred in backward path + dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory); + + pool_p = std::make_shared( + *pool_pd, *src_memory, *dst_memory, *workspace_memory); + } + + dev_ctx.SetBlob(key_pool_p, pool_p); + + output_format = + (memory::format)dst_memory->get_primitive_desc().desc().data.format; + } else { + // Primitives already exist + auto pool_src_memory_p = + std::static_pointer_cast(dev_ctx.GetBlob(key_pool_src_mem_p)); + PADDLE_ENFORCE(pool_src_memory_p != nullptr, + "Fail to find pooling src mem_p in device context"); + auto pool_dst_memory_p = + std::static_pointer_cast(dev_ctx.GetBlob(key_pool_dst_mem_p)); + PADDLE_ENFORCE(pool_dst_memory_p != nullptr, + "Fail to find pooling dst mem_p in device context"); + pool_src_memory_p->set_data_handle(to_void_cast(input_data)); + pool_dst_memory_p->set_data_handle(output_data); + + output_format = (memory::format)pool_dst_memory_p->get_primitive_desc() + .desc() + .data.format; + } + + // push primitive to stream and wait until it's executed + std::vector pipeline{*(pool_p.get())}; + stream(stream::kind::eager).submit(pipeline).wait(); + */ + output->mutable_data(place); + output->set_layout(DataLayout::kMKLDNN); + output->set_format((memory::format)dst_mem.get_primitive_desc().desc() + .data.format); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(concat, MKLDNN, ::paddle::platform::CPUPlace, + ops::ConcatMKLDNNOpKernel) diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 57817da71adfd80faad29a48b05ba2f326de6c07..7e58f9cde13193ae069f794e7c12a69b3fa65778 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include namespace paddle { namespace operators { @@ -59,6 +60,21 @@ class ConcatOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", out_dims); ctx->ShareLoD("X", /*->*/ "Out"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]); + + #ifdef PADDLE_WITH_MKLDNN + if (platform::CanMKLDNNBeUsed(ctx)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } + #endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { @@ -66,6 +82,9 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "Input tensors of concat operator.").AsDuplicable(); AddOutput("Out", "Output tensor of concat operator."); + AddAttr("use_mkldnn", + "(bool, default false) Indicates if MKL-DNN kernel will be used") + .SetDefault(false); AddAttr("axis", "The axis along which the input tensors will be concatenated.") .SetDefault(0); @@ -82,6 +101,7 @@ Examples: [5,6]] )DOC"); + } }; diff --git a/python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..c590687a24d8757a501181de7df82ca9c728d3d7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py @@ -0,0 +1,56 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +from test_concat_op import TestConcatOp, TestConcatOp2, TestConcatOp3 + + +class TestMKLDNNConcatOp(TestConcatOp): + def setUp(self): + super(TestMKLDNNConcatOp, self).setUp() + self.attrs["use_mkldnn"] = True + + def test_check_grad(self): + pass + + def init_kernel_type(self): + self.use_mkldnn = True + +class TestMKLDNNConcatOp2(TestConcatOp2): + def setUp(self): + super(TestMKLDNNConcatOp2, self).setUp() + self.attrs["use_mkldnn"] = True + + def test_check_grad(self): + pass + + def init_kernel_type(self): + self.use_mkldnn = True + +class TestMKLDNNConcatOp3(TestConcatOp3): + def setUp(self): + super(TestMKLDNNConcatOp3, self).setUp() + self.attrs["use_mkldnn"] = True + + def test_check_grad(self): + pass + + def init_kernel_type(self): + self.use_mkldnn = True + + +if __name__ == '__main__': + unittest.main()