From 738069e491d5649b39706aed2526622a1594332c Mon Sep 17 00:00:00 2001 From: Michal Gallus Date: Mon, 3 Dec 2018 14:45:27 +0100 Subject: [PATCH] Refactor MKL-DNN Concat test=develop --- paddle/fluid/operators/concat_mkldnn_op.cc | 209 +++++++-------------- 1 file changed, 72 insertions(+), 137 deletions(-) diff --git a/paddle/fluid/operators/concat_mkldnn_op.cc b/paddle/fluid/operators/concat_mkldnn_op.cc index c6652b78851..37b2788d63b 100644 --- a/paddle/fluid/operators/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/concat_mkldnn_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/operators/concat_op.h" #include "paddle/fluid/platform/mkldnn_helper.h" @@ -26,25 +27,6 @@ using mkldnn::concat; using mkldnn::stream; using platform::to_void_cast; -// Generate keys for storing/retriving primitives for this operator -// TODO(jczaja): Make hashing function more optimial -static std::string gethash(const memory::dims& input_dims, - const std::string& pooling_type, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, - const std::string& suffix) { - auto dims2str = [](const memory::dims& operand_dims) { - std::string dstr = ""; - for (size_t i = 0; i < operand_dims.size(); ++i) { - dstr += std::to_string(operand_dims[i]) + "-"; - } - return dstr; - }; - return dims2str(input_dims) + dims2str(ksize) + dims2str(strides) + - dims2str(paddings) + pooling_type + suffix; -} - static void EnforceLayouts(const std::vector inputs) { for (auto* input : inputs) { const bool is_layout_correct = input->layout() == DataLayout::kMKLDNN; @@ -56,7 +38,7 @@ static void EnforceLayouts(const std::vector inputs) { } static memory::primitive_desc CreateMemPrimDesc( - const framework::Tensor& input, const mkldnn::engine& engine) { + const Tensor& input, const mkldnn::engine& engine) { constexpr auto data_type = mkldnn::memory::f32; const auto dims = paddle::framework::vectorize2int(input.dims()); const auto format = input.format(); @@ -65,6 +47,11 @@ static memory::primitive_desc CreateMemPrimDesc( return mem_prim_desc; } +static mkldnn::memory::format GetDstMemFormat( + const concat::primitive_desc& concat_pd) { + return (memory::format)concat_pd.dst_primitive_desc().desc().data.format; +} + static platform::CPUPlace GetCpuPlace( const paddle::framework::ExecutionContext& ctx) { auto place = ctx.GetPlace(); @@ -73,139 +60,87 @@ static platform::CPUPlace GetCpuPlace( return boost::get(place); } -template -class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - auto place = GetCpuPlace(ctx); +static const mkldnn::engine& GetMKLDNNEngine( + const paddle::framework::ExecutionContext& ctx) { auto& dev_ctx = ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); + return dev_ctx.GetEngine(); +} - auto multi_input = ctx.MultiInput("X"); - framework::Tensor* output = ctx.Output("Out"); - int64_t concat_axis = static_cast(ctx.Attr("axis")); +template +class ConcatPrimitiveFactory { + public: + concat::primitive_desc CreateConcatPrimDescriptor( + const std::vector multi_input, Tensor* output, + int concat_axis, const mkldnn::engine& mkldnn_engine) { + CreateSourcesDescriptors(multi_input, mkldnn_engine); + auto dst_desc = CreateDstMemDescriptor(output); + return concat::primitive_desc(dst_desc, concat_axis, srcs_pd); + } - EnforceLayouts(multi_input); + concat CreateConcatPrimitive(const concat::primitive_desc& concat_pd, + Tensor* output, platform::CPUPlace place) { + CreateSourcePrimitiveAts(); + auto dst_mem = CreateDstMemory(concat_pd, output, place); + return concat(concat_pd, inputs, dst_mem); + } + + private: + memory::desc CreateDstMemDescriptor(Tensor* output) { + auto dst_dims = paddle::framework::vectorize2int(output->dims()); + return memory::desc(dst_dims, platform::MKLDNNGetDataType(), + memory::format::any); + } + + mkldnn::memory CreateDstMemory(const concat::primitive_desc& concat_pd, + Tensor* output, platform::CPUPlace place) { + return memory(concat_pd.dst_primitive_desc(), + output->mutable_data(place)); + } - std::vector srcs_pd; - std::vector srcs; + void CreateSourcesDescriptors(const std::vector multi_input, + const mkldnn::engine& mkldnn_engine) { for (size_t i = 0; i < multi_input.size(); i++) { auto mem_prim_desc = CreateMemPrimDesc(*multi_input[i], mkldnn_engine); srcs_pd.push_back(mem_prim_desc); - srcs.push_back(memory(mem_prim_desc, to_void_cast(multi_input[i]->data()))); + srcs.push_back(memory(mem_prim_desc, + to_void_cast(multi_input[i]->data()))); } - auto dst_dims = paddle::framework::vectorize2int(output->dims()); - auto dst_desc = memory::desc(dst_dims, mkldnn::memory::f32, memory::format::any); - auto concat_pd = concat::primitive_desc(dst_desc, static_cast(concat_axis), srcs_pd); - auto dst_mem = memory(concat_pd.dst_primitive_desc(), output->mutable_data(place)); + } - std::vector inputs; //= {srcs}; + void CreateSourcePrimitiveAts() { inputs.reserve(srcs.size()); for (size_t i = 0; i < srcs.size(); i++) { inputs.push_back(srcs[i]); } - auto concat_prim = concat(concat_pd, inputs, dst_mem); - - std::vector pipeline; - pipeline.push_back(concat_prim); - stream(stream::kind::eager).submit(pipeline).wait(); // TODO(mgallus): When this is not workin' split into decl and def - - /* - const T* input_data = input->data(); - T* output_data = output->mutable_data(ctx.GetPlace()); - - std::vector src_tz = paddle::framework::vectorize2int(input->dims()); - std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); - - auto input_format = input->format(); - memory::format output_format{memory::format::format_undef}; - - const std::string key = gethash(src_tz, pooling_type, ksize, strides, - paddings, ctx.op().Output("Out")); - const std::string key_pool_p = key + "@pool_p"; - const std::string key_pool_pd = key + "@pool_pd"; - const std::string key_pool_src_mem_p = key + "@pool_src_mem_p"; - const std::string key_pool_dst_mem_p = key + "@pool_dst_mem_p"; - const std::string key_pool_workspace_memory = - key + "@pool_workspace_memory"; - - auto pool_p = - std::static_pointer_cast(dev_ctx.GetBlob(key_pool_p)); - if (pool_p == nullptr) { - const std::vector& padding_left_top(paddings); - std::vector padding_right_bottom(paddings); - bool ceil_mode = ctx.Attr("ceil_mode"); - if (ceil_mode) { - CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides, - padding_right_bottom); - } - auto src_md = platform::MKLDNNMemDesc( - src_tz, platform::MKLDNNGetDataType(), input_format); - - auto dst_md = platform::MKLDNNMemDesc(dst_tz, mkldnn::memory::f32, - mkldnn::memory::format::any); - - std::shared_ptr pool_pd = - CreatePrimitiveDesc(src_md, dst_md, strides, padding_left_top, - padding_right_bottom, ksize, pooling_type, - mkldnn_engine, ceil_mode, is_test); - - // save pool_pd into global device context to be referred in backward path - if (!is_test) dev_ctx.SetBlob(key_pool_pd, pool_pd); - - auto src_memory = std::make_shared(pool_pd->src_primitive_desc(), - to_void_cast(input_data)); - auto dst_memory = - std::make_shared(pool_pd->dst_primitive_desc(), output_data); - - dev_ctx.SetBlob(key_pool_src_mem_p, src_memory); - dev_ctx.SetBlob(key_pool_dst_mem_p, dst_memory); - - if (is_test) { - pool_p = std::make_shared(*pool_pd, *src_memory, - *dst_memory); - } else { - std::shared_ptr workspace_memory = - CreateWorkspaceMemory(pool_pd, pooling_type, mkldnn_engine); - - // save pool_workspace_memory to be referred in backward path - dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory); - - pool_p = std::make_shared( - *pool_pd, *src_memory, *dst_memory, *workspace_memory); - } - - dev_ctx.SetBlob(key_pool_p, pool_p); - - output_format = - (memory::format)dst_memory->get_primitive_desc().desc().data.format; - } else { - // Primitives already exist - auto pool_src_memory_p = - std::static_pointer_cast(dev_ctx.GetBlob(key_pool_src_mem_p)); - PADDLE_ENFORCE(pool_src_memory_p != nullptr, - "Fail to find pooling src mem_p in device context"); - auto pool_dst_memory_p = - std::static_pointer_cast(dev_ctx.GetBlob(key_pool_dst_mem_p)); - PADDLE_ENFORCE(pool_dst_memory_p != nullptr, - "Fail to find pooling dst mem_p in device context"); - pool_src_memory_p->set_data_handle(to_void_cast(input_data)); - pool_dst_memory_p->set_data_handle(output_data); - - output_format = (memory::format)pool_dst_memory_p->get_primitive_desc() - .desc() - .data.format; - } + } + + private: + std::vector srcs_pd; + std::vector srcs; + std::vector inputs; +}; + +template +class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + auto place = GetCpuPlace(ctx); + const auto& mkldnn_engine = GetMKLDNNEngine(ctx); + + auto multi_input = ctx.MultiInput("X"); + EnforceLayouts(multi_input); + Tensor* output = ctx.Output("Out"); + int64_t concat_axis = static_cast(ctx.Attr("axis")); + + ConcatPrimitiveFactory prim_creator; + auto concat_pd = prim_creator.CreateConcatPrimDescriptor(multi_input, + output, static_cast(concat_axis), mkldnn_engine); + auto concat = prim_creator.CreateConcatPrimitive(concat_pd, output, place); + stream(stream::kind::eager).submit({concat}).wait(); - // push primitive to stream and wait until it's executed - std::vector pipeline{*(pool_p.get())}; - stream(stream::kind::eager).submit(pipeline).wait(); - */ - output->mutable_data(place); output->set_layout(DataLayout::kMKLDNN); - output->set_format((memory::format)dst_mem.get_primitive_desc().desc() - .data.format); + output->set_format(GetDstMemFormat(concat_pd)); } }; } // namespace operators -- GitLab