From 36d9a364d2822c8034cdd41c1f06cd9feefd2fdb Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sat, 22 Jan 2022 21:08:38 +0800 Subject: [PATCH] add infershape utils (#39140) --- paddle/fluid/framework/CMakeLists.txt | 5 +- paddle/fluid/framework/infershape_utils.cc | 190 +++++++++++++++++++++ paddle/fluid/framework/infershape_utils.h | 44 +++++ 3 files changed, 237 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/framework/infershape_utils.cc create mode 100644 paddle/fluid/framework/infershape_utils.h diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 5c3b24463ef..0220e5fd594 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -192,11 +192,11 @@ cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_va IF(WITH_XPU) cc_library(operator SRCS operator.cc DEPS xpu_op_list op_info device_context tensor scope glog trainer_desc_proto data_feed_proto shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils - pten pten_utils kernel_factory) + pten pten_utils kernel_factory infershape_utils) ELSE() cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils - pten pten_utils kernel_factory) + pten pten_utils kernel_factory infershape_utils) ENDIF() cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) @@ -408,6 +408,7 @@ cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tens cc_library(generator SRCS generator.cc DEPS enforce place) cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows_utils place pten var_type_traits pten_api_utils op_info) +cc_library(infershape_utils SRCS infershape_utils.cc DEPS lod_tensor selected_rows_utils attribute place pten var_type_traits pten pten_api_utils op_info shape_inference) # Get the current working branch execute_process( diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc new file mode 100644 index 00000000000..9a91a5208eb --- /dev/null +++ b/paddle/fluid/framework/infershape_utils.cc @@ -0,0 +1,190 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/infershape_utils.h" + +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/pten/core/compat/arg_map_context.h" +#include "paddle/pten/core/compat_utils.h" +#include "paddle/pten/core/convert_utils.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/meta_tensor.h" + +namespace paddle { +namespace framework { + +class InferShapeArgumentMappingContext : public pten::ArgumentMappingContext { + public: + explicit InferShapeArgumentMappingContext(const InferShapeContext& ctx) + : ctx_(ctx) {} + + bool HasInput(const std::string& name) const override { + return ctx_.HasInput(name); + } + + bool HasOutput(const std::string& name) const override { + return ctx_.HasOutput(name); + } + + paddle::any Attr(const std::string& name) const override { + auto& attr = ctx_.Attrs().GetAttr(name); + return GetAttrValue(attr); + } + + size_t InputSize(const std::string& name) const override { + return ctx_.Inputs(name).size(); + } + + size_t OutputSize(const std::string& name) const override { + return ctx_.Outputs(name).size(); + } + + bool IsDenseTensorInput(const std::string& name) const override { + auto var_types = ctx_.GetInputsVarType(name); + return var_types[0] == proto::VarType::LOD_TENSOR; + } + + bool IsSelectedRowsInput(const std::string& name) const override { + auto var_types = ctx_.GetInputsVarType(name); + return var_types[0] == proto::VarType::SELECTED_ROWS; + } + + private: + const InferShapeContext& ctx_; +}; + +// TODO(chenweihang): Support SelectedRows later +// TODO(chenweihang): Support TensorArray later +class CompatMetaTensor : public pten::MetaTensor { + public: + CompatMetaTensor(InferShapeVarPtr var, bool is_runtime) + : var_(std::move(var)), is_runtime_(is_runtime) {} + + CompatMetaTensor() = default; + CompatMetaTensor(const CompatMetaTensor&) = default; + CompatMetaTensor(CompatMetaTensor&&) = default; + CompatMetaTensor& operator=(const CompatMetaTensor&) = delete; + CompatMetaTensor& operator=(CompatMetaTensor&&) = delete; + + int64_t numel() const override { + if (is_runtime_) { + auto* var = BOOST_GET_CONST(Variable*, var_); + return var->Get().numel(); + } else { + auto* var = BOOST_GET_CONST(VarDesc*, var_); + return var->ElementSize(); + } + } + + DDim dims() const override { + if (is_runtime_) { + auto* var = BOOST_GET_CONST(Variable*, var_); + return var->Get().dims(); + } else { + auto* var = BOOST_GET_CONST(VarDesc*, var_); + return make_ddim(var->GetShape()); + } + } + + pten::DataType dtype() const override { + if (is_runtime_) { + auto* var = BOOST_GET_CONST(Variable*, var_); + return var->Get().dtype(); + } else { + auto* var = BOOST_GET_CONST(VarDesc*, var_); + return pten::TransToPtenDataType(var->GetDataType()); + } + } + + DataLayout layout() const override { + if (is_runtime_) { + auto* var = BOOST_GET_CONST(Variable*, var_); + return var->Get().layout(); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported get layout for VarDesc now.")); + } + } + + void set_dims(const DDim& dims) override { + if (is_runtime_) { + auto* var = BOOST_GET(Variable*, var_); + LoDTensor* tensor = var->GetMutable(); + pten::CompatibleDenseTensorUtils::GetMutableMeta( + static_cast(tensor)) + ->dims = dims; + } else { + auto* var = BOOST_GET(VarDesc*, var_); + var->SetShape(vectorize(dims)); + } + } + + void set_dtype(pten::DataType dtype) override { + if (is_runtime_) { + auto* var = BOOST_GET(Variable*, var_); + LoDTensor* tensor = var->GetMutable(); + pten::CompatibleDenseTensorUtils::GetMutableMeta( + static_cast(tensor)) + ->dtype = dtype; + } else { + auto* var = BOOST_GET(VarDesc*, var_); + var->SetDataType(pten::TransToProtoVarType(dtype)); + } + } + + void set_layout(DataLayout layout) override { + if (is_runtime_) { + auto* var = BOOST_GET(Variable*, var_); + LoDTensor* tensor = var->GetMutable(); + pten::CompatibleDenseTensorUtils::GetMutableMeta( + static_cast(tensor)) + ->layout = layout; + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported set layout for VarDesc now.")); + } + } + + void share_lod(const MetaTensor& meta_tensor) override { + if (is_runtime_) { + auto* var = BOOST_GET(Variable*, var_); + LoDTensor* tensor = var->GetMutable(); + pten::CompatibleDenseTensorUtils::GetMutableMeta( + static_cast(tensor)) + ->lod = + static_cast(meta_tensor).GetRuntimeLoD(); + } else { + auto* var = BOOST_GET(VarDesc*, var_); + var->SetLoDLevel(static_cast(meta_tensor) + .GetCompileTimeLoD()); + } + } + + private: + const LoD& GetRuntimeLoD() const { + auto* var = BOOST_GET_CONST(Variable*, var_); + return var->Get().lod(); + } + int32_t GetCompileTimeLoD() const { + auto* var = BOOST_GET_CONST(VarDesc*, var_); + return var->GetLoDLevel(); + } + + InferShapeVarPtr var_; + bool is_runtime_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/infershape_utils.h b/paddle/fluid/framework/infershape_utils.h new file mode 100644 index 00000000000..f943989523e --- /dev/null +++ b/paddle/fluid/framework/infershape_utils.h @@ -0,0 +1,44 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/framework/op_info.h" +#include "paddle/fluid/framework/shape_inference.h" + +namespace pten { +class InferMetaContext; +} // namespace pten + +namespace paddle { +namespace framework { + +// TODO(chenweihang): impl this function in next PR +pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, + const std::string& op_type); + +#define DELCARE_INFER_SHAPE_FUNCTOR(op_type, functor_name, fn) \ + struct functor_name : public paddle::framework::InferShapeBase { \ + void operator()( \ + paddle::framework::InferShapeContext* ctx) const override { \ + auto infer_meta_context = \ + paddle::framework::BuildInferMetaContext(ctx, #op_type); \ + fn(&infer_meta_context); \ + } \ + } + +} // namespace framework +} // namespace paddle -- GitLab