From 85334b04b92870f44c0f0266f56b5cfcf0b7df02 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sun, 23 Jan 2022 10:45:58 +0800 Subject: [PATCH] [PTen] Add infermeta utils for register infermeta funtion (#39135) * add infermeta utils for register infermeta * polish license format --- paddle/pten/core/CMakeLists.txt | 1 + paddle/pten/core/infermeta_utils.cc | 73 ++++++++++++ paddle/pten/core/infermeta_utils.h | 170 ++++++++++++++++++++++++++++ 3 files changed, 244 insertions(+) create mode 100644 paddle/pten/core/infermeta_utils.cc create mode 100644 paddle/pten/core/infermeta_utils.h diff --git a/paddle/pten/core/CMakeLists.txt b/paddle/pten/core/CMakeLists.txt index cd3a1755a9d..181012732fa 100644 --- a/paddle/pten/core/CMakeLists.txt +++ b/paddle/pten/core/CMakeLists.txt @@ -19,6 +19,7 @@ cc_library(dense_tensor SRCS dense_tensor.cc DEPS convert_utils tensor_meta tens cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base ) cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor) +cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor) cc_test(unroll_array_ops_test SRCS unroll_array_ops_test.cc) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce) diff --git a/paddle/pten/core/infermeta_utils.cc b/paddle/pten/core/infermeta_utils.cc new file mode 100644 index 00000000000..9f0037d18ed --- /dev/null +++ b/paddle/pten/core/infermeta_utils.cc @@ -0,0 +1,73 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/infermeta_utils.h" + +namespace pten { + +void InferMetaContext::SetMetaConfig(MetaConfig config) { + config_ = std::move(config); +} + +void InferMetaContext::EmplaceBackInput( + std::shared_ptr input) { + int index = inputs_.size(); + inputs_.emplace_back(std::move(input)); + input_range_.emplace_back(std::pair(index, index + 1)); +} +void InferMetaContext::EmplaceBackOutput( + std::shared_ptr output) { + int index = outputs_.size(); + outputs_.emplace_back(std::move(output)); + output_range_.emplace_back(std::pair(index, index + 1)); +} +void InferMetaContext::EmplaceBackAttr(paddle::any attr) { + attrs_.emplace_back(std::move(attr)); +} + +void InferMetaContext::EmplaceBackInputs( + paddle::SmallVector> inputs) { + int index = inputs_.size(); + input_range_.emplace_back(std::pair(index, index + inputs.size())); + inputs_.insert(inputs_.end(), + std::make_move_iterator(inputs.begin()), + std::make_move_iterator(inputs.end())); +} +void InferMetaContext::EmplaceBackOutputs( + paddle::SmallVector> outputs) { + int index = outputs_.size(); + output_range_.emplace_back( + std::pair(index, index + outputs.size())); + outputs_.insert(outputs_.end(), + std::make_move_iterator(outputs.begin()), + std::make_move_iterator(outputs.end())); +} + +const std::pair& InferMetaContext::InputRangeAt(size_t idx) const { + return input_range_.at(idx); +} +const std::pair& InferMetaContext::OutputRangeAt(size_t idx) const { + return output_range_.at(idx); +} + +const MetaConfig& InferMetaContext::GetMetaConfig() const { return config_; } + +const MetaTensor& InferMetaContext::InputAt(size_t idx) const { + return *inputs_.at(idx); +} +MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) { + return outputs_.at(idx).get(); +} + +} // namespace pten diff --git a/paddle/pten/core/infermeta_utils.h b/paddle/pten/core/infermeta_utils.h new file mode 100644 index 00000000000..c6812dee92b --- /dev/null +++ b/paddle/pten/core/infermeta_utils.h @@ -0,0 +1,170 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "paddle/pten/core/meta_tensor.h" +#include "paddle/utils/small_vector.h" + +namespace pten { + +// TODO(chenweihang): add other flags if needed +struct MetaConfig { + bool is_runtime{true}; + + MetaConfig() = default; + + // supporting implicit construction is easier to use + MetaConfig(bool is_runtime) : is_runtime(is_runtime) {} // NOLINT +}; + +class InferMetaContext { + public: + InferMetaContext() = default; + explicit InferMetaContext(MetaConfig config) : config_(config) {} + + void SetMetaConfig(MetaConfig config); + void EmplaceBackInput(std::shared_ptr input); + void EmplaceBackOutput(std::shared_ptr output); + void EmplaceBackAttr(paddle::any attr); + + void EmplaceBackInputs( + paddle::SmallVector> inputs); + void EmplaceBackOutputs( + paddle::SmallVector> outputs); + + const std::pair& InputRangeAt(size_t idx) const; + const std::pair& OutputRangeAt(size_t idx) const; + + const MetaConfig& GetMetaConfig() const; + const MetaTensor& InputAt(size_t idx) const; + MetaTensor* MutableOutputAt(size_t idx); + + template + AttrType AttrAt(size_t idx) { + try { + return paddle::any_cast(attrs_.at(idx)); + } catch (paddle::bad_any_cast&) { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Attribute cast error in InferMeta Context.")); + } + } + + private: + MetaConfig config_; + + // NOTE(chenweihang): Because the MetaTensor is a base class, and MetaTensor + // objects are all created in each round, so we have to use smart pointer + // here, maybe we can implemented a new InferMetaContext and a series utils + // specifically for fluid to avoid using shared_ptr + paddle::SmallVector> inputs_; + paddle::SmallVector> outputs_; + paddle::SmallVector attrs_; + + paddle::SmallVector> input_range_; + paddle::SmallVector> output_range_; +}; + +#define PT_INFER_META(...) \ + ::pten::InferMetaFnImpl::Call + +#define PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(attr_type) \ + template \ + struct InferMetaFnCallHelper { \ + template \ + static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { \ + static_assert(out_idx == 0, \ + "InferMeta's Attributes should appear before Outputs."); \ + attr_type arg = ctx->AttrAt(attr_idx); \ + InferMetaFnCallHelper< \ + Tail...>::template Call(pargs..., \ + arg); \ + } \ + } + +template +struct InferMetaTypeTag {}; + +template +struct InferMetaFnImpl; + +template +struct InferMetaFnImpl { + static void Call(InferMetaContext* ctx) { + InferMetaFnCallHelper>::template Call<0, 0, 0>(ctx); + } + + private: + template + struct InferMetaFnCallHelper; + + template + struct InferMetaFnCallHelper { + template + static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { + static_assert(attr_idx == 0, + "InferMeta's Input should appear before Attributes."); + static_assert(out_idx == 0, + "InferMeta's Input should appear before Outputs."); + const std::pair range = ctx->InputRangeAt(in_idx); + const MetaTensor& arg = ctx->InputAt(range.first); + InferMetaFnCallHelper< + Tail...>::template Call(ctx, + pargs..., + arg); + } + }; + + // TODO(chenweihang): support vector input later + + template + struct InferMetaFnCallHelper { + template + static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { + const std::pair range = ctx->OutputRangeAt(out_idx); + MetaTensor* arg = ctx->MutableOutputAt(range.first); + InferMetaFnCallHelper< + Tail...>::template Call(ctx, + pargs..., + arg); + } + }; + + // TODO(chenweihang): support vector output later + + template + struct InferMetaFnCallHelper { + template + static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { + const MetaConfig& arg = ctx->GetMetaConfig(); + InferMetaFnCallHelper::template Call( + ctx, pargs..., arg); + } + }; + + /* End case */ + template + struct InferMetaFnCallHelper> { + template + static void Call(InferMetaContext* ctx, Args&... args) { + return infer_meta_fn(args...); + } + }; +}; + +} // namespace pten -- GitLab