未验证 提交 b75507d3 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Unify InferMeta(Shape) Function in pten and fluid op (#38976)

* infermeta context init design

* support infermeta called in fluid op

* add hasattr and attr methods

* add dygraah GetVarPtrs support

* rename arg_map_context to arg_map_utils

* add registry for arg map func

* resolve conflit

* refactor op utils design

* polish meta config

* fix details

* remove hasattr method

* resolve conflit

* revert cmake order change

* revert some change

* change init pos

* fix compile faileed

* fix typo

* fix inference failed

* fix windows ccompile failed

* polish format
Co-authored-by: NWang Huan <wanghuan29@baidu.com>
上级 2bf9b844
...@@ -243,3 +243,29 @@ function(register_kernels) ...@@ -243,3 +243,29 @@ function(register_kernels)
endif() endif()
endforeach() endforeach()
endfunction() endfunction()
function(append_op_util_declare TARGET)
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET} target_content)
string(REGEX MATCH "(PT_REGISTER_API_NAME|PT_REGISTER_ARG_MAPPING_FN)\\([ \t\r\n]*[a-z0-9_]*" util_registrar "${target_content}")
string(REPLACE "PT_REGISTER_ARG_MAPPING_FN" "PT_DECLARE_ARG_MAPPING_FN" util_declare "${util_registrar}")
string(REPLACE "PT_REGISTER_API_NAME" "PT_REGISTER_API_NAME" util_declare "${util_declare}")
string(APPEND util_declare ");")
file(APPEND ${op_utils_header} "${util_declare}")
endfunction()
function(register_op_utils TARGET_NAME)
set(utils_srcs)
set(options "")
set(oneValueArgs "")
set(multiValueArgs EXCLUDES DEPS)
cmake_parse_arguments(register_op_utils "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
file(GLOB SIGNATURES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_sig.cc")
foreach(target ${SIGNATURES})
append_op_util_declare(${target})
list(APPEND utils_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${target})
endforeach()
cc_library(${TARGET_NAME} SRCS ${utils_srcs} DEPS ${register_op_utils_DEPS})
endfunction()
...@@ -192,11 +192,11 @@ cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_va ...@@ -192,11 +192,11 @@ cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_va
IF(WITH_XPU) IF(WITH_XPU)
cc_library(operator SRCS operator.cc DEPS xpu_op_list op_info device_context tensor scope glog trainer_desc_proto data_feed_proto cc_library(operator SRCS operator.cc DEPS xpu_op_list op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
pten pten_utils kernel_factory infershape_utils) pten pten_utils kernel_factory infershape_utils op_utils)
ELSE() ELSE()
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
pten pten_utils kernel_factory infershape_utils) pten pten_utils kernel_factory infershape_utils op_utils)
ENDIF() ENDIF()
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
...@@ -404,7 +404,7 @@ cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tens ...@@ -404,7 +404,7 @@ cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tens
cc_library(generator SRCS generator.cc DEPS enforce place) cc_library(generator SRCS generator.cc DEPS enforce place)
cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows_utils place pten var_type_traits pten_api_utils op_info) cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows_utils place pten var_type_traits pten_api_utils op_info)
cc_library(infershape_utils SRCS infershape_utils.cc DEPS lod_tensor selected_rows_utils attribute place pten var_type_traits pten pten_api_utils op_info shape_inference) cc_library(infershape_utils SRCS infershape_utils.cc DEPS pten_utils attribute shape_inference op_utils)
# Get the current working branch # Get the current working branch
execute_process( execute_process(
......
...@@ -275,10 +275,8 @@ struct OpInfoFiller<T, kVarTypeInference> { ...@@ -275,10 +275,8 @@ struct OpInfoFiller<T, kVarTypeInference> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kShapeInference> { struct OpInfoFiller<T, kShapeInference> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ( // Note: if fill InferShapeFN by this Filler, the infershape here
info->infer_shape_, nullptr, // will overwrite the op->InferShape func registered in kOperator Filler
platform::errors::AlreadyExists(
"Duplicate InferShapeFN of %s has been registered", op_type));
info->infer_shape_ = [](InferShapeContext* ctx) { info->infer_shape_ = [](InferShapeContext* ctx) {
T inference; T inference;
inference(ctx); inference(ctx);
......
...@@ -15,11 +15,14 @@ limitations under the License. */ ...@@ -15,11 +15,14 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/pten/core/compat/arg_map_context.h" #include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/pten/core/compat/op_utils.h"
#include "paddle/pten/core/compat_utils.h" #include "paddle/pten/core/compat_utils.h"
#include "paddle/pten/core/convert_utils.h" #include "paddle/pten/core/convert_utils.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/core/meta_tensor.h" #include "paddle/pten/core/meta_tensor.h"
namespace paddle { namespace paddle {
...@@ -186,5 +189,40 @@ class CompatMetaTensor : public pten::MetaTensor { ...@@ -186,5 +189,40 @@ class CompatMetaTensor : public pten::MetaTensor {
bool is_runtime_; bool is_runtime_;
}; };
pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type) {
// 1. get kernel args
InitDefaultKernelSignatureMap();
auto arg_map_fn = pten::OpUtilsMap::Instance().GetArgumentMappingFn(op_type);
PADDLE_ENFORCE_NOT_NULL(
arg_map_fn, platform::errors::NotFound(
"The ArgumentMappingFn of %s op is not found.", op_type));
InferShapeArgumentMappingContext arg_map_context(*ctx);
auto signature = arg_map_fn(arg_map_context);
VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;
// 2. build infermeta context
pten::InferMetaContext infer_meta_context(ctx->IsRuntime());
auto& input_names = std::get<0>(signature.args);
auto& output_names = std::get<2>(signature.args);
// TODO(chenweihang): support attrs in next pr
// auto& attr_names = std::get<1>(signature.args);
// TODO(chenweihang): support multiple inputs and outputs
pten::InferMetaContext infer_mete_context;
for (auto& in_name : input_names) {
infer_meta_context.EmplaceBackInput(std::make_shared<CompatMetaTensor>(
ctx->GetInputVarPtrs(in_name)[0], ctx->IsRuntime()));
}
for (auto& out_name : output_names) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime()));
}
// TODO(chenweihang): support attrs later
return infer_meta_context;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -26,7 +26,6 @@ class InferMetaContext; ...@@ -26,7 +26,6 @@ class InferMetaContext;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// TODO(chenweihang): impl this function in next PR
pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type); const std::string& op_type);
......
...@@ -31,6 +31,7 @@ limitations under the License. */ ...@@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/pten/common/scalar.h" #include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/ops/compat/signatures.h"
namespace pten { namespace pten {
class DenseTensor; class DenseTensor;
...@@ -1086,6 +1087,13 @@ bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, ...@@ -1086,6 +1087,13 @@ bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
return use_mkldnn_ctx && this->SupportsMKLDNN(data_type); return use_mkldnn_ctx && this->SupportsMKLDNN(data_type);
} }
void OperatorWithKernel::InferShape(InferShapeContext* ctx) const {
PADDLE_THROW(platform::errors::PermissionDenied(
"The default InferShape function of OperatorWithKernel is not allowed to "
"be called, please override corresponding InferShape function in the "
"specific operator."));
}
void OperatorWithKernel::RuntimeInferShape(const Scope& scope, void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
const platform::Place& place, const platform::Place& place,
const RuntimeContext& ctx) const { const RuntimeContext& ctx) const {
...@@ -1784,8 +1792,10 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar( ...@@ -1784,8 +1792,10 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs( KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
return KernelSignatureMap::Instance().Get( InitDefaultKernelSignatureMap();
pten::TransToPtenKernelName(Type())); ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
return pten::OpUtilsMap::Instance().GetArgumentMappingFn(Type())(
arg_mapping_ctx);
} }
Scope* OperatorWithKernel::PreparePtenData( Scope* OperatorWithKernel::PreparePtenData(
......
...@@ -41,6 +41,7 @@ limitations under the License. */ ...@@ -41,6 +41,7 @@ limitations under the License. */
#include "paddle/utils/flat_hash_map.h" #include "paddle/utils/flat_hash_map.h"
#include "paddle/pten/core/compat/arg_map_context.h" #include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/pten/core/compat/op_utils.h"
#include "paddle/pten/core/kernel_context.h" #include "paddle/pten/core/kernel_context.h"
#include "paddle/pten/core/kernel_factory.h" #include "paddle/pten/core/kernel_factory.h"
...@@ -468,8 +469,7 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext { ...@@ -468,8 +469,7 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext {
} }
bool IsDenseTensorInput(const std::string& name) const override { bool IsDenseTensorInput(const std::string& name) const override {
return ctx_.InputVar(name)->IsType<framework::Tensor>() || return ctx_.InputVar(name)->IsType<framework::LoDTensor>();
ctx_.InputVar(name)->IsType<framework::LoDTensor>();
} }
bool IsSelectedRowsInput(const std::string& name) const override { bool IsSelectedRowsInput(const std::string& name) const override {
...@@ -550,7 +550,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -550,7 +550,7 @@ class OperatorWithKernel : public OperatorBase {
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const; proto::VarType::Type data_type) const;
virtual void InferShape(InferShapeContext* ctx) const = 0; virtual void InferShape(InferShapeContext* ctx) const;
void RuntimeInferShape(const Scope& scope, const platform::Place& place, void RuntimeInferShape(const Scope& scope, const platform::Place& place,
const RuntimeContext& ctx) const override; const RuntimeContext& ctx) const override;
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <sstream> #include <sstream>
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/pten_utils.h"
#include "paddle/pten/core/compat/op_utils.h"
#include "paddle/pten/core/convert_utils.h" #include "paddle/pten/core/convert_utils.h"
#include "paddle/pten/core/kernel_factory.h" #include "paddle/pten/core/kernel_factory.h"
...@@ -89,48 +90,6 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey( ...@@ -89,48 +90,6 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey(
return pten::KernelKey(backend, layout, dtype); return pten::KernelKey(backend, layout, dtype);
} }
KernelSignatureMap* KernelSignatureMap::kernel_signature_map_ = nullptr;
std::once_flag KernelSignatureMap::init_flag_;
KernelSignatureMap& KernelSignatureMap::Instance() {
std::call_once(init_flag_, [] {
kernel_signature_map_ = new KernelSignatureMap();
for (const auto& pair : OpInfoMap::Instance().map()) {
const auto& op_type = pair.first;
const auto* op_proto = pair.second.proto_;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) &&
op_proto) {
KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type;
auto success = kernel_signature_map_->map_
.emplace(pten::TransToPtenKernelName(op_type),
std::move(maker.GetKernelSignature()))
.second;
PADDLE_ENFORCE_EQ(
success, true,
platform::errors::PermissionDenied(
"Kernel signature of the operator %s has been registered.",
op_type));
}
}
});
return *kernel_signature_map_;
}
bool KernelSignatureMap::Has(const std::string& op_type) const {
return map_.find(op_type) != map_.end();
}
const KernelSignature& KernelSignatureMap::Get(
const std::string& op_type) const {
auto it = map_.find(op_type);
PADDLE_ENFORCE_NE(
it, map_.end(),
platform::errors::NotFound(
"Operator `%s`'s kernel signature is not registered.", op_type));
return it->second;
}
const paddle::SmallVector<std::string>& const paddle::SmallVector<std::string>&
KernelArgsNameMakerByOpProto::GetInputArgsNames() { KernelArgsNameMakerByOpProto::GetInputArgsNames() {
for (int i = 0; i < op_proto_->inputs_size(); ++i) { for (int i = 0; i < op_proto_->inputs_size(); ++i) {
...@@ -196,6 +155,24 @@ KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() { ...@@ -196,6 +155,24 @@ KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
GetOutputArgsNames()); GetOutputArgsNames());
} }
std::once_flag kernel_sig_map_init_flag;
void InitDefaultKernelSignatureMap() {
std::call_once(kernel_sig_map_init_flag, [] {
for (const auto& pair : paddle::framework::OpInfoMap::Instance().map()) {
const auto& op_type = pair.first;
const auto* op_proto = pair.second.proto_;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) &&
op_proto) {
paddle::framework::KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type;
pten::DefaultKernelSignatureMap::Instance().Insert(
op_type, std::move(maker.GetKernelSignature()));
}
}
});
}
void SetAllocationForOutputTenosr(pten::DenseTensor* tensor, void SetAllocationForOutputTenosr(pten::DenseTensor* tensor,
const platform::Place& place) { const platform::Place& place) {
if (!tensor->IsInitialized() || !(tensor->place() == place)) { if (!tensor->IsInitialized() || !(tensor->place() == place)) {
......
...@@ -44,26 +44,6 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey( ...@@ -44,26 +44,6 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey(
/* Kernel Args parse */ /* Kernel Args parse */
// TODO(chenweihang): we can generate this map by proto info in compile time
class KernelSignatureMap {
public:
static KernelSignatureMap& Instance();
bool Has(const std::string& op_type) const;
const KernelSignature& Get(const std::string& op_type) const;
private:
KernelSignatureMap() = default;
DISABLE_COPY_AND_ASSIGN(KernelSignatureMap);
private:
static KernelSignatureMap* kernel_signature_map_;
static std::once_flag init_flag_;
paddle::flat_hash_map<std::string, KernelSignature> map_;
};
class KernelArgsNameMaker { class KernelArgsNameMaker {
public: public:
virtual ~KernelArgsNameMaker() {} virtual ~KernelArgsNameMaker() {}
...@@ -72,6 +52,8 @@ class KernelArgsNameMaker { ...@@ -72,6 +52,8 @@ class KernelArgsNameMaker {
virtual const paddle::SmallVector<std::string>& GetAttrsArgsNames() = 0; virtual const paddle::SmallVector<std::string>& GetAttrsArgsNames() = 0;
}; };
void InitDefaultKernelSignatureMap();
void SetAllocationForOutputTenosr(pten::DenseTensor* tensor, void SetAllocationForOutputTenosr(pten::DenseTensor* tensor,
const platform::Place& place); const platform::Place& place);
......
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/operators/scale_op.h" #include "paddle/fluid/operators/scale_op.h"
#include <string> #include <string>
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/pten/ops/compat/scale_args_fn.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -71,12 +70,6 @@ class ScaleOp : public framework::OperatorWithKernel { ...@@ -71,12 +70,6 @@ class ScaleOp : public framework::OperatorWithKernel {
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
framework::ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
return pten::ScaleOpArgumentMapping(arg_mapping_ctx);
}
}; };
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -14,7 +14,10 @@ limitations under the License. */ ...@@ -14,7 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/sign_op.h" #include "paddle/fluid/operators/sign_op.h"
#include <memory> #include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -22,14 +25,6 @@ namespace operators { ...@@ -22,14 +25,6 @@ namespace operators {
class SignOp : public framework::OperatorWithKernel { class SignOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "sign");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "sign");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
}; };
template <typename AttrType> template <typename AttrType>
...@@ -64,9 +59,12 @@ class SignGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -64,9 +59,12 @@ class SignGradMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(sign, SignInferShapeFunctor,
PT_INFER_META(pten::UnchangedInferMetaNew));
REGISTER_OPERATOR(sign, ops::SignOp, ops::SignOpMaker<float>, REGISTER_OPERATOR(sign, ops::SignOp, ops::SignOpMaker<float>,
ops::SignGradMaker<paddle::framework::OpDesc>, ops::SignGradMaker<paddle::framework::OpDesc>,
ops::SignGradMaker<paddle::imperative::OpBase>); ops::SignGradMaker<paddle::imperative::OpBase>,
SignInferShapeFunctor);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sign, ops::SignKernel<paddle::platform::CPUDeviceContext, float>, sign, ops::SignKernel<paddle::platform::CPUDeviceContext, float>,
ops::SignKernel<paddle::platform::CPUDeviceContext, double>); ops::SignKernel<paddle::platform::CPUDeviceContext, double>);
......
...@@ -2,7 +2,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapp ...@@ -2,7 +2,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapp
feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator
cost_model cuda_graph_with_memory_pool fleet_executor global_utils) cost_model cuda_graph_with_memory_pool fleet_executor global_utils pten_utils)
if (WITH_PSCORE) if (WITH_PSCORE)
set(PYBIND_DEPS ${PYBIND_DEPS} ps_service) set(PYBIND_DEPS ${PYBIND_DEPS} ps_service)
......
...@@ -50,6 +50,7 @@ limitations under the License. */ ...@@ -50,6 +50,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/prune.h" #include "paddle/fluid/framework/prune.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/save_load_util.h" #include "paddle/fluid/framework/save_load_util.h"
#include "paddle/fluid/framework/scope_pool.h" #include "paddle/fluid/framework/scope_pool.h"
......
...@@ -21,7 +21,7 @@ add_subdirectory(ops) ...@@ -21,7 +21,7 @@ add_subdirectory(ops)
add_subdirectory(tests) add_subdirectory(tests)
# make an unity target for compile deps # make an unity target for compile deps
set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context arg_map_context infermeta lod_utils) set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context arg_map_context infermeta lod_utils op_compat_infos)
get_property(pten_kernels GLOBAL PROPERTY PTEN_KERNELS) get_property(pten_kernels GLOBAL PROPERTY PTEN_KERNELS)
# keep this message for debug, remove it later if needless # keep this message for debug, remove it later if needless
message(STATUS "All standard pten kernels: ${pten_kernels}") message(STATUS "All standard pten kernels: ${pten_kernels}")
......
...@@ -21,13 +21,14 @@ cc_library(tensor_meta SRCS tensor_meta.cc DEPS pten_enforce mixed_vector) ...@@ -21,13 +21,14 @@ cc_library(tensor_meta SRCS tensor_meta.cc DEPS pten_enforce mixed_vector)
cc_library(lod_utils SRCS lod_utils.cc DEPS pten_enforce mixed_vector) cc_library(lod_utils SRCS lod_utils.cc DEPS pten_enforce mixed_vector)
cc_library(dense_tensor SRCS dense_tensor.cc DEPS convert_utils tensor_meta tensor_base) cc_library(dense_tensor SRCS dense_tensor.cc DEPS convert_utils tensor_meta tensor_base)
cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base ) cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base )
cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor) cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor)
cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor) cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor)
cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector enforce ddim)
cc_test(unroll_array_ops_test SRCS unroll_array_ops_test.cc) cc_test(unroll_array_ops_test SRCS unroll_array_ops_test.cc)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector enforce ddim)
# Will remove once we implemented MKLDNN_Tensor # Will remove once we implemented MKLDNN_Tensor
if(WITH_MKLDNN) if(WITH_MKLDNN)
add_dependencies(dense_tensor mkldnn) add_dependencies(dense_tensor mkldnn)
......
cc_library(arg_map_context SRCS arg_map_context.cc DEPS pten_enforce) cc_library(arg_map_context SRCS arg_map_context.cc DEPS pten_enforce)
cc_library(op_utils SRCS op_utils.cc DEPS arg_map_context enforce convert_utils)
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
DefaultKernelSignatureMap& DefaultKernelSignatureMap::Instance() {
static DefaultKernelSignatureMap g_default_kernel_sig_map;
return g_default_kernel_sig_map;
}
OpUtilsMap& OpUtilsMap::Instance() {
static OpUtilsMap g_op_utils_map;
return g_op_utils_map;
}
} // namespace pten
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <mutex>
#include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/core/kernel_def.h"
#include "paddle/pten/core/macros.h"
#include "paddle/utils/flat_hash_map.h"
#include "paddle/fluid/platform/enforce.h"
namespace pten {
class DefaultKernelSignatureMap {
public:
static DefaultKernelSignatureMap& Instance();
bool Has(const std::string& op_type) const { return map_.count(op_type) > 0; }
const KernelSignature& Get(const std::string& op_type) const {
auto it = map_.find(op_type);
PADDLE_ENFORCE_NE(
it,
map_.end(),
paddle::platform::errors::NotFound(
"Operator `%s`'s kernel signature is not registered.", op_type));
return it->second;
}
void Insert(std::string op_type, KernelSignature signature) {
PADDLE_ENFORCE_NE(
Has(op_type),
true,
paddle::platform::errors::AlreadyExists(
"Operator (%s)'s Kernel Siginature has been registered.", op_type));
map_.insert({std::move(op_type), std::move(signature)});
}
private:
DefaultKernelSignatureMap() = default;
paddle::flat_hash_map<std::string, KernelSignature> map_;
DISABLE_COPY_AND_ASSIGN(DefaultKernelSignatureMap);
};
class OpUtilsMap {
public:
static OpUtilsMap& Instance();
bool Contains(const std::string& op_type) const {
return name_map_.count(op_type) || arg_mapping_fn_map_.count(op_type);
}
void InsertApiName(std::string op_type, std::string api_name) {
PADDLE_ENFORCE_EQ(
name_map_.count(op_type),
0UL,
paddle::platform::errors::AlreadyExists(
"Operator (%s)'s api name has been registered.", op_type));
name_map_.insert({std::move(op_type), std::move(api_name)});
}
void InsertArgumentMappingFn(std::string op_type, ArgumentMappingFn fn) {
PADDLE_ENFORCE_EQ(
arg_mapping_fn_map_.count(op_type),
0UL,
paddle::platform::errors::AlreadyExists(
"Operator (%s)'s argu,emt mapping function has been registered.",
op_type));
arg_mapping_fn_map_.insert({std::move(op_type), std::move(fn)});
}
std::string GetApiName(const std::string& op_type) const {
auto it = name_map_.find(op_type);
if (it == name_map_.end()) {
return "deprecated";
} else {
return it->second;
}
}
ArgumentMappingFn GetArgumentMappingFn(const std::string& op_type) const {
auto it = arg_mapping_fn_map_.find(op_type);
if (it == arg_mapping_fn_map_.end()) {
auto func =
[op_type](const ArgumentMappingContext& ctx) -> KernelSignature {
return DefaultKernelSignatureMap::Instance().Get(op_type);
};
return func;
} else {
return it->second;
}
}
private:
OpUtilsMap() = default;
paddle::flat_hash_map<std::string, std::string> name_map_;
paddle::flat_hash_map<std::string, ArgumentMappingFn> arg_mapping_fn_map_;
DISABLE_COPY_AND_ASSIGN(OpUtilsMap);
};
struct ApiNameRegistrar {
ApiNameRegistrar(const char* op_type, const char* api_name) {
OpUtilsMap::Instance().InsertApiName(op_type, api_name);
}
};
struct ArgumentMappingFnRegistrar {
ArgumentMappingFnRegistrar(const char* op_type,
ArgumentMappingFn arg_mapping_fn) {
OpUtilsMap::Instance().InsertArgumentMappingFn(op_type,
std::move(arg_mapping_fn));
}
};
#define PT_REGISTER_API_NAME(op_type, api_name) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_api_name_ns_check_##op_type, \
"PT_REGISTER_API_NAME must be called in global namespace."); \
static const ::pten::ApiNameRegistrar __registrar_api_name_for_##op_type( \
#op_type, #api_name); \
int TouchApiNameSymbol_##op_type() { return 0; }
#define PT_DECLARE_API_NAME(op_type) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_declare_ai_name_ns_check_##op_type, \
"PT_DECLARE_API_NAME must be called in global namespace."); \
extern int TouchApiNameSymbol_##op_type(); \
UNUSED static int __declare_api_name_symbol_for_##op_type = \
TouchApiNameSymbol_##op_type()
#define PT_REGISTER_ARG_MAPPING_FN(op_type, arg_mapping_fn) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_arg_map_fn_ns_check_##op_type, \
"PT_REGISTER_ARG_MAPPING_FN must be called in global namespace."); \
static const ::pten::ArgumentMappingFnRegistrar \
__registrar_arg_map_fn_for_##op_type(#op_type, arg_mapping_fn); \
int TouchArgumentMappingFnSymbol_##op_type() { return 0; }
#define PT_DECLARE_ARG_MAPPING_FN(op_type) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_declare_arg_map_fn_ns_check_##op_type, \
"PT_DECLARE_ARG_MAPPING_FN must be called in global namespace."); \
extern int TouchArgumentMappingFnSymbol_##op_type(); \
UNUSED static int __declare_arg_map_fn_symbol_for_##op_type = \
TouchArgumentMappingFnSymbol_##op_type()
} // namespace pten
...@@ -151,7 +151,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> { ...@@ -151,7 +151,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
struct InferMetaFnCallHelper<MetaConfig, Tail...> { struct InferMetaFnCallHelper<MetaConfig, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
const MetaConfig& arg = ctx->GetMetaConfig(); MetaConfig arg = ctx->GetMetaConfig();
InferMetaFnCallHelper<Tail...>::template Call<in_idx, attr_idx, out_idx>( InferMetaFnCallHelper<Tail...>::template Call<in_idx, attr_idx, out_idx>(
ctx, pargs..., arg); ctx, pargs..., arg);
} }
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
......
cc_library(infermeta SRCS nullary.cc unary.cc binary.cc multiary.cc DEPS convert_utils) cc_library(infermeta SRCS nullary.cc unary.cc binary.cc multiary.cc DEPS convert_utils infermeta_utils)
cc_library(backward_infermeta SRCS backward.cc DEPS convert_utils) cc_library(backward_infermeta SRCS backward.cc DEPS convert_utils)
...@@ -12,12 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
// See Note [ Why still include the fluid headers? ]
#include "paddle/pten/infermeta/unary.h" #include "paddle/pten/infermeta/unary.h"
#include <set> #include <set>
#include "paddle/pten/core/infermeta_utils.h"
namespace pten { namespace pten {
void UnchangedInferMetaNew(MetaConfig config,
const MetaTensor& x,
MetaTensor* out) {
out->set_dims(x.dims());
out->share_lod(x);
}
DenseTensorMeta UnchangedInferMeta(const DenseTensorMeta& x_meta) { DenseTensorMeta UnchangedInferMeta(const DenseTensorMeta& x_meta) {
return x_meta; return x_meta;
} }
......
...@@ -16,23 +16,27 @@ limitations under the License. */ ...@@ -16,23 +16,27 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/core/meta_tensor.h"
#include "paddle/pten/core/tensor_meta.h" #include "paddle/pten/core/tensor_meta.h"
namespace pten { namespace pten {
class MetaConfig;
// Common InferMeta Functions for unary operators, The format like: // Common InferMeta Functions for unary operators, The format like:
// //
// 1. DenseTensorMeta [OpName]InferMeta(const DenseTensorMeta& x_meta, ...) // void [OpName]InferMeta(const MetaTensor& x, ..., MetaTensor* out) {}
// {} //
// 2. std::pair<DenseTensorMeta, DenseTensorMeta> [OpName]InferMeta(const // NOTE: The name "InferShape" may be not appropriate. "InferMeta" may be good.
// DenseTensorMeta& // Because functions in this file not only can infer shape, but also need
// x_meta, ...) {} // infer lod or other useful data.
// 3. std::tuple<DenseTensorMeta, DenseTensorMeta, DenseTensorMeta>
// [OpName]InferMeta(const // TODO(chenweihang): update all InferMeta function format in next pr,
// DenseTensorMeta& x_meta, ...) // now add UnchangedInferMetaNew for test new format
// NOTE: The name "InferMeta" may be not appropriate. "InferMeta" may be good. void UnchangedInferMetaNew(MetaConfig config,
// Because functions in this file const MetaTensor& x,
// not only can infer shape, but alse need infer lod or other useful data. MetaTensor* out);
DenseTensorMeta UnchangedInferMeta(const DenseTensorMeta& x_meta); DenseTensorMeta UnchangedInferMeta(const DenseTensorMeta& x_meta);
......
set(op_utils_header ${PADDLE_BINARY_DIR}/paddle/pten/ops/compat/signatures.h.tmp CACHE INTERNAL "op_args_fns.cc file")
set(op_utils_header_final ${PADDLE_BINARY_DIR}/paddle/pten/ops/compat/signatures.h)
file(WRITE ${op_utils_header} "// Generated by the paddle/pten/ops/compat/CMakeLists.txt. DO NOT EDIT!\n\n")
file(APPEND ${op_utils_header} "#include \"paddle/pten/core/compat/op_utils.h\"\n\n")
# Automatically generate the registration code of all arg map functions
# and compile the corresponding target to avoid frequent code conflicts
# when writing to same file
register_op_utils(op_compat_infos DEPS op_utils)
copy_if_different(${op_utils_header} ${op_utils_header_final})
...@@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #include "paddle/pten/core/compat/op_utils.h"
#include "paddle/pten/core/compat/arg_map_context.h"
namespace pten { namespace pten {
...@@ -22,15 +20,18 @@ KernelSignature ScaleOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -22,15 +20,18 @@ KernelSignature ScaleOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) { if (ctx.IsDenseTensorInput("X")) {
std::string scale_attr; std::string scale_attr;
if (ctx.HasInput("ScaleTensor")) { if (ctx.HasInput("ScaleTensor")) {
scale_attr = "ScaleTensor"; return KernelSignature(
"scale", {"X"}, {"ScaleTensor", "bias", "bias_after_scale"}, {"Out"});
} else { } else {
scale_attr = "scale"; return KernelSignature(
"scale", {"X"}, {"scale", "bias", "bias_after_scale"}, {"Out"});
} }
return KernelSignature(
"scale", {"X"}, {scale_attr, "bias", "bias_after_scale"}, {"Out"});
} }
// TODO(chenweihang): support other cases after selected rows added // TODO(chenweihang): support other cases after selected rows added
return KernelSignature("scale.unregistered", {}, {}, {}); return KernelSignature("scale.unregistered", {}, {}, {});
} }
} // namespace pten } // namespace pten
// op_type, api_name, arg_mapping_fn
PT_REGISTER_ARG_MAPPING_FN(scale, pten::ScaleOpArgumentMapping);
...@@ -3,6 +3,7 @@ cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc) ...@@ -3,6 +3,7 @@ cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc)
cc_test(test_type_info SRCS test_type_info.cc) cc_test(test_type_info SRCS test_type_info.cc)
cc_test(test_convert_utils SRCS test_convert_utils.cc DEPS convert_utils) cc_test(test_convert_utils SRCS test_convert_utils.cc DEPS convert_utils)
cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory scale_kernel) cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory scale_kernel)
cc_test(test_op_utils SRCS test_op_utils.cc DEPS op_compat_infos)
cc_test(test_pten_device_context SRCS test_device_context.cc DEPS pten_context cpu_context) cc_test(test_pten_device_context SRCS test_device_context.cc DEPS pten_context cpu_context)
cc_test(test_ddim SRCS test_ddim.cc DEPS ddim) cc_test(test_ddim SRCS test_ddim.cc DEPS ddim)
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "gtest/gtest.h"
#include "paddle/pten/core/compat/op_utils.h"
#include "paddle/pten/ops/compat/signatures.h"
namespace pten {
namespace tests {
TEST(OpUtilsMap, ArgMappingFnExists) {
std::cout << "enter ArgMappingFnExists";
auto scale_arg_mapping_fn =
pten::OpUtilsMap::Instance().GetArgumentMappingFn("scale");
EXPECT_NE(scale_arg_mapping_fn, nullptr);
}
} // namespace tests
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册