未验证 提交 b75507d3 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Unify InferMeta(Shape) Function in pten and fluid op (#38976)

* infermeta context init design

* support infermeta called in fluid op

* add hasattr and attr methods

* add dygraah GetVarPtrs support

* rename arg_map_context to arg_map_utils

* add registry for arg map func

* resolve conflit

* refactor op utils design

* polish meta config

* fix details

* remove hasattr method

* resolve conflit

* revert cmake order change

* revert some change

* change init pos

* fix compile faileed

* fix typo

* fix inference failed

* fix windows ccompile failed

* polish format
Co-authored-by: NWang Huan <wanghuan29@baidu.com>
上级 2bf9b844
......@@ -243,3 +243,29 @@ function(register_kernels)
endif()
endforeach()
endfunction()
function(append_op_util_declare TARGET)
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET} target_content)
string(REGEX MATCH "(PT_REGISTER_API_NAME|PT_REGISTER_ARG_MAPPING_FN)\\([ \t\r\n]*[a-z0-9_]*" util_registrar "${target_content}")
string(REPLACE "PT_REGISTER_ARG_MAPPING_FN" "PT_DECLARE_ARG_MAPPING_FN" util_declare "${util_registrar}")
string(REPLACE "PT_REGISTER_API_NAME" "PT_REGISTER_API_NAME" util_declare "${util_declare}")
string(APPEND util_declare ");")
file(APPEND ${op_utils_header} "${util_declare}")
endfunction()
function(register_op_utils TARGET_NAME)
set(utils_srcs)
set(options "")
set(oneValueArgs "")
set(multiValueArgs EXCLUDES DEPS)
cmake_parse_arguments(register_op_utils "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
file(GLOB SIGNATURES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_sig.cc")
foreach(target ${SIGNATURES})
append_op_util_declare(${target})
list(APPEND utils_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${target})
endforeach()
cc_library(${TARGET_NAME} SRCS ${utils_srcs} DEPS ${register_op_utils_DEPS})
endfunction()
......@@ -192,11 +192,11 @@ cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_va
IF(WITH_XPU)
cc_library(operator SRCS operator.cc DEPS xpu_op_list op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
pten pten_utils kernel_factory infershape_utils)
pten pten_utils kernel_factory infershape_utils op_utils)
ELSE()
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
pten pten_utils kernel_factory infershape_utils)
pten pten_utils kernel_factory infershape_utils op_utils)
ENDIF()
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
......@@ -404,7 +404,7 @@ cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tens
cc_library(generator SRCS generator.cc DEPS enforce place)
cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows_utils place pten var_type_traits pten_api_utils op_info)
cc_library(infershape_utils SRCS infershape_utils.cc DEPS lod_tensor selected_rows_utils attribute place pten var_type_traits pten pten_api_utils op_info shape_inference)
cc_library(infershape_utils SRCS infershape_utils.cc DEPS pten_utils attribute shape_inference op_utils)
# Get the current working branch
execute_process(
......
......@@ -275,10 +275,8 @@ struct OpInfoFiller<T, kVarTypeInference> {
template <typename T>
struct OpInfoFiller<T, kShapeInference> {
void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(
info->infer_shape_, nullptr,
platform::errors::AlreadyExists(
"Duplicate InferShapeFN of %s has been registered", op_type));
// Note: if fill InferShapeFN by this Filler, the infershape here
// will overwrite the op->InferShape func registered in kOperator Filler
info->infer_shape_ = [](InferShapeContext* ctx) {
T inference;
inference(ctx);
......
......@@ -15,11 +15,14 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/pten/core/compat/op_utils.h"
#include "paddle/pten/core/compat_utils.h"
#include "paddle/pten/core/convert_utils.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/core/meta_tensor.h"
namespace paddle {
......@@ -186,5 +189,40 @@ class CompatMetaTensor : public pten::MetaTensor {
bool is_runtime_;
};
pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type) {
// 1. get kernel args
InitDefaultKernelSignatureMap();
auto arg_map_fn = pten::OpUtilsMap::Instance().GetArgumentMappingFn(op_type);
PADDLE_ENFORCE_NOT_NULL(
arg_map_fn, platform::errors::NotFound(
"The ArgumentMappingFn of %s op is not found.", op_type));
InferShapeArgumentMappingContext arg_map_context(*ctx);
auto signature = arg_map_fn(arg_map_context);
VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;
// 2. build infermeta context
pten::InferMetaContext infer_meta_context(ctx->IsRuntime());
auto& input_names = std::get<0>(signature.args);
auto& output_names = std::get<2>(signature.args);
// TODO(chenweihang): support attrs in next pr
// auto& attr_names = std::get<1>(signature.args);
// TODO(chenweihang): support multiple inputs and outputs
pten::InferMetaContext infer_mete_context;
for (auto& in_name : input_names) {
infer_meta_context.EmplaceBackInput(std::make_shared<CompatMetaTensor>(
ctx->GetInputVarPtrs(in_name)[0], ctx->IsRuntime()));
}
for (auto& out_name : output_names) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime()));
}
// TODO(chenweihang): support attrs later
return infer_meta_context;
}
} // namespace framework
} // namespace paddle
......@@ -26,7 +26,6 @@ class InferMetaContext;
namespace paddle {
namespace framework {
// TODO(chenweihang): impl this function in next PR
pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type);
......
......@@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/ops/compat/signatures.h"
namespace pten {
class DenseTensor;
......@@ -1086,6 +1087,13 @@ bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
return use_mkldnn_ctx && this->SupportsMKLDNN(data_type);
}
void OperatorWithKernel::InferShape(InferShapeContext* ctx) const {
PADDLE_THROW(platform::errors::PermissionDenied(
"The default InferShape function of OperatorWithKernel is not allowed to "
"be called, please override corresponding InferShape function in the "
"specific operator."));
}
void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
const platform::Place& place,
const RuntimeContext& ctx) const {
......@@ -1784,8 +1792,10 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
const ExecutionContext& ctx) const {
return KernelSignatureMap::Instance().Get(
pten::TransToPtenKernelName(Type()));
InitDefaultKernelSignatureMap();
ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
return pten::OpUtilsMap::Instance().GetArgumentMappingFn(Type())(
arg_mapping_ctx);
}
Scope* OperatorWithKernel::PreparePtenData(
......
......@@ -41,6 +41,7 @@ limitations under the License. */
#include "paddle/utils/flat_hash_map.h"
#include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/pten/core/compat/op_utils.h"
#include "paddle/pten/core/kernel_context.h"
#include "paddle/pten/core/kernel_factory.h"
......@@ -468,8 +469,7 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext {
}
bool IsDenseTensorInput(const std::string& name) const override {
return ctx_.InputVar(name)->IsType<framework::Tensor>() ||
ctx_.InputVar(name)->IsType<framework::LoDTensor>();
return ctx_.InputVar(name)->IsType<framework::LoDTensor>();
}
bool IsSelectedRowsInput(const std::string& name) const override {
......@@ -550,7 +550,7 @@ class OperatorWithKernel : public OperatorBase {
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const;
virtual void InferShape(InferShapeContext* ctx) const = 0;
virtual void InferShape(InferShapeContext* ctx) const;
void RuntimeInferShape(const Scope& scope, const platform::Place& place,
const RuntimeContext& ctx) const override;
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <sstream>
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/pten/core/compat/op_utils.h"
#include "paddle/pten/core/convert_utils.h"
#include "paddle/pten/core/kernel_factory.h"
......@@ -89,48 +90,6 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey(
return pten::KernelKey(backend, layout, dtype);
}
KernelSignatureMap* KernelSignatureMap::kernel_signature_map_ = nullptr;
std::once_flag KernelSignatureMap::init_flag_;
KernelSignatureMap& KernelSignatureMap::Instance() {
std::call_once(init_flag_, [] {
kernel_signature_map_ = new KernelSignatureMap();
for (const auto& pair : OpInfoMap::Instance().map()) {
const auto& op_type = pair.first;
const auto* op_proto = pair.second.proto_;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) &&
op_proto) {
KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type;
auto success = kernel_signature_map_->map_
.emplace(pten::TransToPtenKernelName(op_type),
std::move(maker.GetKernelSignature()))
.second;
PADDLE_ENFORCE_EQ(
success, true,
platform::errors::PermissionDenied(
"Kernel signature of the operator %s has been registered.",
op_type));
}
}
});
return *kernel_signature_map_;
}
bool KernelSignatureMap::Has(const std::string& op_type) const {
return map_.find(op_type) != map_.end();
}
const KernelSignature& KernelSignatureMap::Get(
const std::string& op_type) const {
auto it = map_.find(op_type);
PADDLE_ENFORCE_NE(
it, map_.end(),
platform::errors::NotFound(
"Operator `%s`'s kernel signature is not registered.", op_type));
return it->second;
}
const paddle::SmallVector<std::string>&
KernelArgsNameMakerByOpProto::GetInputArgsNames() {
for (int i = 0; i < op_proto_->inputs_size(); ++i) {
......@@ -196,6 +155,24 @@ KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
GetOutputArgsNames());
}
std::once_flag kernel_sig_map_init_flag;
void InitDefaultKernelSignatureMap() {
std::call_once(kernel_sig_map_init_flag, [] {
for (const auto& pair : paddle::framework::OpInfoMap::Instance().map()) {
const auto& op_type = pair.first;
const auto* op_proto = pair.second.proto_;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) &&
op_proto) {
paddle::framework::KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type;
pten::DefaultKernelSignatureMap::Instance().Insert(
op_type, std::move(maker.GetKernelSignature()));
}
}
});
}
void SetAllocationForOutputTenosr(pten::DenseTensor* tensor,
const platform::Place& place) {
if (!tensor->IsInitialized() || !(tensor->place() == place)) {
......
......@@ -44,26 +44,6 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey(
/* Kernel Args parse */
// TODO(chenweihang): we can generate this map by proto info in compile time
class KernelSignatureMap {
public:
static KernelSignatureMap& Instance();
bool Has(const std::string& op_type) const;
const KernelSignature& Get(const std::string& op_type) const;
private:
KernelSignatureMap() = default;
DISABLE_COPY_AND_ASSIGN(KernelSignatureMap);
private:
static KernelSignatureMap* kernel_signature_map_;
static std::once_flag init_flag_;
paddle::flat_hash_map<std::string, KernelSignature> map_;
};
class KernelArgsNameMaker {
public:
virtual ~KernelArgsNameMaker() {}
......@@ -72,6 +52,8 @@ class KernelArgsNameMaker {
virtual const paddle::SmallVector<std::string>& GetAttrsArgsNames() = 0;
};
void InitDefaultKernelSignatureMap();
void SetAllocationForOutputTenosr(pten::DenseTensor* tensor,
const platform::Place& place);
......
......@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/operators/scale_op.h"
#include <string>
#include "paddle/fluid/platform/float16.h"
#include "paddle/pten/ops/compat/scale_args_fn.h"
namespace paddle {
namespace framework {
......@@ -71,12 +70,6 @@ class ScaleOp : public framework::OperatorWithKernel {
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
framework::ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
return pten::ScaleOpArgumentMapping(arg_mapping_ctx);
}
};
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -14,7 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/sign_op.h"
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -22,14 +25,6 @@ namespace operators {
class SignOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "sign");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "sign");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
};
template <typename AttrType>
......@@ -64,9 +59,12 @@ class SignGradMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(sign, SignInferShapeFunctor,
PT_INFER_META(pten::UnchangedInferMetaNew));
REGISTER_OPERATOR(sign, ops::SignOp, ops::SignOpMaker<float>,
ops::SignGradMaker<paddle::framework::OpDesc>,
ops::SignGradMaker<paddle::imperative::OpBase>);
ops::SignGradMaker<paddle::imperative::OpBase>,
SignInferShapeFunctor);
REGISTER_OP_CPU_KERNEL(
sign, ops::SignKernel<paddle::platform::CPUDeviceContext, float>,
ops::SignKernel<paddle::platform::CPUDeviceContext, double>);
......
......@@ -2,7 +2,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapp
feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator
cost_model cuda_graph_with_memory_pool fleet_executor global_utils)
cost_model cuda_graph_with_memory_pool fleet_executor global_utils pten_utils)
if (WITH_PSCORE)
set(PYBIND_DEPS ${PYBIND_DEPS} ps_service)
......
......@@ -50,6 +50,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/prune.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/save_load_util.h"
#include "paddle/fluid/framework/scope_pool.h"
......
......@@ -21,7 +21,7 @@ add_subdirectory(ops)
add_subdirectory(tests)
# make an unity target for compile deps
set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context arg_map_context infermeta lod_utils)
set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context arg_map_context infermeta lod_utils op_compat_infos)
get_property(pten_kernels GLOBAL PROPERTY PTEN_KERNELS)
# keep this message for debug, remove it later if needless
message(STATUS "All standard pten kernels: ${pten_kernels}")
......
......@@ -21,13 +21,14 @@ cc_library(tensor_meta SRCS tensor_meta.cc DEPS pten_enforce mixed_vector)
cc_library(lod_utils SRCS lod_utils.cc DEPS pten_enforce mixed_vector)
cc_library(dense_tensor SRCS dense_tensor.cc DEPS convert_utils tensor_meta tensor_base)
cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base )
cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor)
cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor)
cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector enforce ddim)
cc_test(unroll_array_ops_test SRCS unroll_array_ops_test.cc)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector enforce ddim)
# Will remove once we implemented MKLDNN_Tensor
if(WITH_MKLDNN)
add_dependencies(dense_tensor mkldnn)
......
cc_library(arg_map_context SRCS arg_map_context.cc DEPS pten_enforce)
cc_library(op_utils SRCS op_utils.cc DEPS arg_map_context enforce convert_utils)
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
DefaultKernelSignatureMap& DefaultKernelSignatureMap::Instance() {
static DefaultKernelSignatureMap g_default_kernel_sig_map;
return g_default_kernel_sig_map;
}
OpUtilsMap& OpUtilsMap::Instance() {
static OpUtilsMap g_op_utils_map;
return g_op_utils_map;
}
} // namespace pten
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <mutex>
#include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/core/kernel_def.h"
#include "paddle/pten/core/macros.h"
#include "paddle/utils/flat_hash_map.h"
#include "paddle/fluid/platform/enforce.h"
namespace pten {
class DefaultKernelSignatureMap {
public:
static DefaultKernelSignatureMap& Instance();
bool Has(const std::string& op_type) const { return map_.count(op_type) > 0; }
const KernelSignature& Get(const std::string& op_type) const {
auto it = map_.find(op_type);
PADDLE_ENFORCE_NE(
it,
map_.end(),
paddle::platform::errors::NotFound(
"Operator `%s`'s kernel signature is not registered.", op_type));
return it->second;
}
void Insert(std::string op_type, KernelSignature signature) {
PADDLE_ENFORCE_NE(
Has(op_type),
true,
paddle::platform::errors::AlreadyExists(
"Operator (%s)'s Kernel Siginature has been registered.", op_type));
map_.insert({std::move(op_type), std::move(signature)});
}
private:
DefaultKernelSignatureMap() = default;
paddle::flat_hash_map<std::string, KernelSignature> map_;
DISABLE_COPY_AND_ASSIGN(DefaultKernelSignatureMap);
};
class OpUtilsMap {
public:
static OpUtilsMap& Instance();
bool Contains(const std::string& op_type) const {
return name_map_.count(op_type) || arg_mapping_fn_map_.count(op_type);
}
void InsertApiName(std::string op_type, std::string api_name) {
PADDLE_ENFORCE_EQ(
name_map_.count(op_type),
0UL,
paddle::platform::errors::AlreadyExists(
"Operator (%s)'s api name has been registered.", op_type));
name_map_.insert({std::move(op_type), std::move(api_name)});
}
void InsertArgumentMappingFn(std::string op_type, ArgumentMappingFn fn) {
PADDLE_ENFORCE_EQ(
arg_mapping_fn_map_.count(op_type),
0UL,
paddle::platform::errors::AlreadyExists(
"Operator (%s)'s argu,emt mapping function has been registered.",
op_type));
arg_mapping_fn_map_.insert({std::move(op_type), std::move(fn)});
}
std::string GetApiName(const std::string& op_type) const {
auto it = name_map_.find(op_type);
if (it == name_map_.end()) {
return "deprecated";
} else {
return it->second;
}
}
ArgumentMappingFn GetArgumentMappingFn(const std::string& op_type) const {
auto it = arg_mapping_fn_map_.find(op_type);
if (it == arg_mapping_fn_map_.end()) {
auto func =
[op_type](const ArgumentMappingContext& ctx) -> KernelSignature {
return DefaultKernelSignatureMap::Instance().Get(op_type);
};
return func;
} else {
return it->second;
}
}
private:
OpUtilsMap() = default;
paddle::flat_hash_map<std::string, std::string> name_map_;
paddle::flat_hash_map<std::string, ArgumentMappingFn> arg_mapping_fn_map_;
DISABLE_COPY_AND_ASSIGN(OpUtilsMap);
};
struct ApiNameRegistrar {
ApiNameRegistrar(const char* op_type, const char* api_name) {
OpUtilsMap::Instance().InsertApiName(op_type, api_name);
}
};
struct ArgumentMappingFnRegistrar {
ArgumentMappingFnRegistrar(const char* op_type,
ArgumentMappingFn arg_mapping_fn) {
OpUtilsMap::Instance().InsertArgumentMappingFn(op_type,
std::move(arg_mapping_fn));
}
};
#define PT_REGISTER_API_NAME(op_type, api_name) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_api_name_ns_check_##op_type, \
"PT_REGISTER_API_NAME must be called in global namespace."); \
static const ::pten::ApiNameRegistrar __registrar_api_name_for_##op_type( \
#op_type, #api_name); \
int TouchApiNameSymbol_##op_type() { return 0; }
#define PT_DECLARE_API_NAME(op_type) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_declare_ai_name_ns_check_##op_type, \
"PT_DECLARE_API_NAME must be called in global namespace."); \
extern int TouchApiNameSymbol_##op_type(); \
UNUSED static int __declare_api_name_symbol_for_##op_type = \
TouchApiNameSymbol_##op_type()
#define PT_REGISTER_ARG_MAPPING_FN(op_type, arg_mapping_fn) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_arg_map_fn_ns_check_##op_type, \
"PT_REGISTER_ARG_MAPPING_FN must be called in global namespace."); \
static const ::pten::ArgumentMappingFnRegistrar \
__registrar_arg_map_fn_for_##op_type(#op_type, arg_mapping_fn); \
int TouchArgumentMappingFnSymbol_##op_type() { return 0; }
#define PT_DECLARE_ARG_MAPPING_FN(op_type) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_declare_arg_map_fn_ns_check_##op_type, \
"PT_DECLARE_ARG_MAPPING_FN must be called in global namespace."); \
extern int TouchArgumentMappingFnSymbol_##op_type(); \
UNUSED static int __declare_arg_map_fn_symbol_for_##op_type = \
TouchArgumentMappingFnSymbol_##op_type()
} // namespace pten
......@@ -151,7 +151,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
struct InferMetaFnCallHelper<MetaConfig, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
const MetaConfig& arg = ctx->GetMetaConfig();
MetaConfig arg = ctx->GetMetaConfig();
InferMetaFnCallHelper<Tail...>::template Call<in_idx, attr_idx, out_idx>(
ctx, pargs..., arg);
}
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......
cc_library(infermeta SRCS nullary.cc unary.cc binary.cc multiary.cc DEPS convert_utils)
cc_library(infermeta SRCS nullary.cc unary.cc binary.cc multiary.cc DEPS convert_utils infermeta_utils)
cc_library(backward_infermeta SRCS backward.cc DEPS convert_utils)
......@@ -12,12 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// See Note [ Why still include the fluid headers? ]
#include "paddle/pten/infermeta/unary.h"
#include <set>
#include "paddle/pten/core/infermeta_utils.h"
namespace pten {
void UnchangedInferMetaNew(MetaConfig config,
const MetaTensor& x,
MetaTensor* out) {
out->set_dims(x.dims());
out->share_lod(x);
}
DenseTensorMeta UnchangedInferMeta(const DenseTensorMeta& x_meta) {
return x_meta;
}
......
......@@ -16,23 +16,27 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ]
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/core/meta_tensor.h"
#include "paddle/pten/core/tensor_meta.h"
namespace pten {
class MetaConfig;
// Common InferMeta Functions for unary operators, The format like:
//
// 1. DenseTensorMeta [OpName]InferMeta(const DenseTensorMeta& x_meta, ...)
// {}
// 2. std::pair<DenseTensorMeta, DenseTensorMeta> [OpName]InferMeta(const
// DenseTensorMeta&
// x_meta, ...) {}
// 3. std::tuple<DenseTensorMeta, DenseTensorMeta, DenseTensorMeta>
// [OpName]InferMeta(const
// DenseTensorMeta& x_meta, ...)
// NOTE: The name "InferMeta" may be not appropriate. "InferMeta" may be good.
// Because functions in this file
// not only can infer shape, but alse need infer lod or other useful data.
// void [OpName]InferMeta(const MetaTensor& x, ..., MetaTensor* out) {}
//
// NOTE: The name "InferShape" may be not appropriate. "InferMeta" may be good.
// Because functions in this file not only can infer shape, but also need
// infer lod or other useful data.
// TODO(chenweihang): update all InferMeta function format in next pr,
// now add UnchangedInferMetaNew for test new format
void UnchangedInferMetaNew(MetaConfig config,
const MetaTensor& x,
MetaTensor* out);
DenseTensorMeta UnchangedInferMeta(const DenseTensorMeta& x_meta);
......
set(op_utils_header ${PADDLE_BINARY_DIR}/paddle/pten/ops/compat/signatures.h.tmp CACHE INTERNAL "op_args_fns.cc file")
set(op_utils_header_final ${PADDLE_BINARY_DIR}/paddle/pten/ops/compat/signatures.h)
file(WRITE ${op_utils_header} "// Generated by the paddle/pten/ops/compat/CMakeLists.txt. DO NOT EDIT!\n\n")
file(APPEND ${op_utils_header} "#include \"paddle/pten/core/compat/op_utils.h\"\n\n")
# Automatically generate the registration code of all arg map functions
# and compile the corresponding target to avoid frequent code conflicts
# when writing to same file
register_op_utils(op_compat_infos DEPS op_utils)
copy_if_different(${op_utils_header} ${op_utils_header_final})
......@@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
......@@ -22,15 +20,18 @@ KernelSignature ScaleOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) {
std::string scale_attr;
if (ctx.HasInput("ScaleTensor")) {
scale_attr = "ScaleTensor";
return KernelSignature(
"scale", {"X"}, {"ScaleTensor", "bias", "bias_after_scale"}, {"Out"});
} else {
scale_attr = "scale";
return KernelSignature(
"scale", {"X"}, {"scale", "bias", "bias_after_scale"}, {"Out"});
}
return KernelSignature(
"scale", {"X"}, {scale_attr, "bias", "bias_after_scale"}, {"Out"});
}
// TODO(chenweihang): support other cases after selected rows added
return KernelSignature("scale.unregistered", {}, {}, {});
}
} // namespace pten
// op_type, api_name, arg_mapping_fn
PT_REGISTER_ARG_MAPPING_FN(scale, pten::ScaleOpArgumentMapping);
......@@ -3,6 +3,7 @@ cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc)
cc_test(test_type_info SRCS test_type_info.cc)
cc_test(test_convert_utils SRCS test_convert_utils.cc DEPS convert_utils)
cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory scale_kernel)
cc_test(test_op_utils SRCS test_op_utils.cc DEPS op_compat_infos)
cc_test(test_pten_device_context SRCS test_device_context.cc DEPS pten_context cpu_context)
cc_test(test_ddim SRCS test_ddim.cc DEPS ddim)
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "gtest/gtest.h"
#include "paddle/pten/core/compat/op_utils.h"
#include "paddle/pten/ops/compat/signatures.h"
namespace pten {
namespace tests {
TEST(OpUtilsMap, ArgMappingFnExists) {
std::cout << "enter ArgMappingFnExists";
auto scale_arg_mapping_fn =
pten::OpUtilsMap::Instance().GetArgumentMappingFn("scale");
EXPECT_NE(scale_arg_mapping_fn, nullptr);
}
} // namespace tests
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册