未验证 提交 36d9a364 编写于 作者: C Chen Weihang 提交者: GitHub

add infershape utils (#39140)

上级 60df9254
......@@ -192,11 +192,11 @@ cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_va
IF(WITH_XPU)
cc_library(operator SRCS operator.cc DEPS xpu_op_list op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
pten pten_utils kernel_factory)
pten pten_utils kernel_factory infershape_utils)
ELSE()
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
pten pten_utils kernel_factory)
pten pten_utils kernel_factory infershape_utils)
ENDIF()
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
......@@ -408,6 +408,7 @@ cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tens
cc_library(generator SRCS generator.cc DEPS enforce place)
cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows_utils place pten var_type_traits pten_api_utils op_info)
cc_library(infershape_utils SRCS infershape_utils.cc DEPS lod_tensor selected_rows_utils attribute place pten var_type_traits pten pten_api_utils op_info shape_inference)
# Get the current working branch
execute_process(
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/pten/core/compat_utils.h"
#include "paddle/pten/core/convert_utils.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/meta_tensor.h"
namespace paddle {
namespace framework {
class InferShapeArgumentMappingContext : public pten::ArgumentMappingContext {
public:
explicit InferShapeArgumentMappingContext(const InferShapeContext& ctx)
: ctx_(ctx) {}
bool HasInput(const std::string& name) const override {
return ctx_.HasInput(name);
}
bool HasOutput(const std::string& name) const override {
return ctx_.HasOutput(name);
}
paddle::any Attr(const std::string& name) const override {
auto& attr = ctx_.Attrs().GetAttr(name);
return GetAttrValue(attr);
}
size_t InputSize(const std::string& name) const override {
return ctx_.Inputs(name).size();
}
size_t OutputSize(const std::string& name) const override {
return ctx_.Outputs(name).size();
}
bool IsDenseTensorInput(const std::string& name) const override {
auto var_types = ctx_.GetInputsVarType(name);
return var_types[0] == proto::VarType::LOD_TENSOR;
}
bool IsSelectedRowsInput(const std::string& name) const override {
auto var_types = ctx_.GetInputsVarType(name);
return var_types[0] == proto::VarType::SELECTED_ROWS;
}
private:
const InferShapeContext& ctx_;
};
// TODO(chenweihang): Support SelectedRows later
// TODO(chenweihang): Support TensorArray later
class CompatMetaTensor : public pten::MetaTensor {
public:
CompatMetaTensor(InferShapeVarPtr var, bool is_runtime)
: var_(std::move(var)), is_runtime_(is_runtime) {}
CompatMetaTensor() = default;
CompatMetaTensor(const CompatMetaTensor&) = default;
CompatMetaTensor(CompatMetaTensor&&) = default;
CompatMetaTensor& operator=(const CompatMetaTensor&) = delete;
CompatMetaTensor& operator=(CompatMetaTensor&&) = delete;
int64_t numel() const override {
if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<Tensor>().numel();
} else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return var->ElementSize();
}
}
DDim dims() const override {
if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<LoDTensor>().dims();
} else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return make_ddim(var->GetShape());
}
}
pten::DataType dtype() const override {
if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<LoDTensor>().dtype();
} else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return pten::TransToPtenDataType(var->GetDataType());
}
}
DataLayout layout() const override {
if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<LoDTensor>().layout();
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported get layout for VarDesc now."));
}
}
void set_dims(const DDim& dims) override {
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
pten::CompatibleDenseTensorUtils::GetMutableMeta(
static_cast<pten::DenseTensor*>(tensor))
->dims = dims;
} else {
auto* var = BOOST_GET(VarDesc*, var_);
var->SetShape(vectorize(dims));
}
}
void set_dtype(pten::DataType dtype) override {
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
pten::CompatibleDenseTensorUtils::GetMutableMeta(
static_cast<pten::DenseTensor*>(tensor))
->dtype = dtype;
} else {
auto* var = BOOST_GET(VarDesc*, var_);
var->SetDataType(pten::TransToProtoVarType(dtype));
}
}
void set_layout(DataLayout layout) override {
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
pten::CompatibleDenseTensorUtils::GetMutableMeta(
static_cast<pten::DenseTensor*>(tensor))
->layout = layout;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported set layout for VarDesc now."));
}
}
void share_lod(const MetaTensor& meta_tensor) override {
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
pten::CompatibleDenseTensorUtils::GetMutableMeta(
static_cast<pten::DenseTensor*>(tensor))
->lod =
static_cast<const CompatMetaTensor&>(meta_tensor).GetRuntimeLoD();
} else {
auto* var = BOOST_GET(VarDesc*, var_);
var->SetLoDLevel(static_cast<const CompatMetaTensor&>(meta_tensor)
.GetCompileTimeLoD());
}
}
private:
const LoD& GetRuntimeLoD() const {
auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<LoDTensor>().lod();
}
int32_t GetCompileTimeLoD() const {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return var->GetLoDLevel();
}
InferShapeVarPtr var_;
bool is_runtime_;
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/shape_inference.h"
namespace pten {
class InferMetaContext;
} // namespace pten
namespace paddle {
namespace framework {
// TODO(chenweihang): impl this function in next PR
pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type);
#define DELCARE_INFER_SHAPE_FUNCTOR(op_type, functor_name, fn) \
struct functor_name : public paddle::framework::InferShapeBase { \
void operator()( \
paddle::framework::InferShapeContext* ctx) const override { \
auto infer_meta_context = \
paddle::framework::BuildInferMetaContext(ctx, #op_type); \
fn(&infer_meta_context); \
} \
}
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册