未验证 提交 3a23c1a2 编写于 作者: C Chen Weihang 提交者: GitHub

move get expected kernel args into pten (#38825)

上级 657b6742
......@@ -1287,7 +1287,7 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
void OperatorWithKernel::ChoosePtenKernel(const ExecutionContext& ctx) const {
pt_kernel_signature_.reset(
new KernelSignature(std::move(this->GetExpectedPtenKernelArgs(ctx))));
VLOG(6) << KernelSignatureToString(*pt_kernel_signature_.get());
VLOG(6) << *pt_kernel_signature_.get();
kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(ctx))));
......
......@@ -40,6 +40,7 @@ limitations under the License. */
#include "paddle/fluid/platform/variant.h"
#include "paddle/utils/flat_hash_map.h"
#include "paddle/pten/core/arg_map_context.h"
#include "paddle/pten/include/core.h"
namespace paddle {
......@@ -438,6 +439,45 @@ class ExecutionContext {
const RuntimeContext& ctx_;
};
// TODO(chenweihang): split impl based OpProto or Dygraph if needed
class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext {
public:
explicit ExecutionArgumentMappingContext(const ExecutionContext& ctx)
: ctx_(ctx) {}
bool HasInput(const std::string& name) const override {
return ctx_.HasInput(name);
}
bool HasOutput(const std::string& name) const override {
return ctx_.HasOutput(name);
}
bool HasAttr(const std::string& name) const override {
return ctx_.HasAttr(name);
}
size_t InputSize(const std::string& name) const override {
return ctx_.InputSize(name);
}
size_t OutputSize(const std::string& name) const override {
return ctx_.OutputSize(name);
}
bool IsDenseTensorInput(const std::string& name) const override {
return ctx_.InputVar(name)->IsType<framework::Tensor>() ||
ctx_.InputVar(name)->IsType<framework::LoDTensor>();
}
bool IsSelectedRowsInput(const std::string& name) const override {
return ctx_.InputVar(name)->IsType<framework::SelectedRows>();
}
private:
const ExecutionContext& ctx_;
};
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;
......
......@@ -196,15 +196,5 @@ KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
GetOutputArgsNames());
}
std::string KernelSignatureToString(const KernelSignature& signature) {
std::stringstream os;
os << "Kernel Signature - name: " << signature.name
<< "; inputs: " << string::join_strings(std::get<0>(signature.args), ", ")
<< "; attributes: "
<< string::join_strings(std::get<1>(signature.args), ", ") << "; outputs: "
<< string::join_strings(std::get<2>(signature.args), ", ");
return os.str();
}
} // namespace framework
} // namespace paddle
......@@ -22,17 +22,19 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/core/arg_map_context.h"
#include "paddle/pten/core/kernel_factory.h"
#include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h"
namespace paddle {
namespace framework {
using KernelSignature = pten::KernelSignature;
/* Kernel Key translate */
OpKernelType TransPtenKernelKeyToOpKernelType(
......@@ -42,24 +44,6 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey(
/* Kernel Args parse */
struct KernelSignature {
std::string name;
KernelArgsTuple args;
KernelSignature() = default;
KernelSignature(std::string&& kernel_name,
paddle::SmallVector<std::string>&& inputs,
paddle::SmallVector<std::string>&& attrs,
paddle::SmallVector<std::string>&& outputs)
: name(std::move(kernel_name)),
args(std::make_tuple(inputs, attrs, outputs)) {}
KernelSignature(const std::string& kernel_name,
const paddle::SmallVector<std::string>& inputs,
const paddle::SmallVector<std::string>& attrs,
const paddle::SmallVector<std::string>& outputs)
: name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {}
};
// TODO(chenweihang): we can generate this map by proto info in compile time
class KernelSignatureMap {
public:
......@@ -88,7 +72,5 @@ class KernelArgsNameMaker {
virtual const paddle::SmallVector<std::string>& GetAttrsArgsNames() = 0;
};
std::string KernelSignatureToString(const KernelSignature& signature);
} // namespace framework
} // namespace paddle
......@@ -84,10 +84,5 @@ using InferShapeFN = std::function<void(InferShapeContext*)>;
using InplacePair = std::unordered_map<std::string, std::string>;
using InferInplaceOpFN = std::function<InplacePair(bool /*use_cuda*/)>;
// tuple(input_names, attr_names, output_names)
using KernelArgsTuple = std::tuple<paddle::SmallVector<std::string>,
paddle::SmallVector<std::string>,
paddle::SmallVector<std::string>>;
} // namespace framework
} // namespace paddle
......@@ -164,7 +164,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
if (FLAGS_run_pten_kernel &&
pten::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) {
auto pt_kernel_signature = op.GetExpectedPtenKernelArgs(dygraph_exe_ctx);
VLOG(6) << framework::KernelSignatureToString(pt_kernel_signature);
VLOG(6) << pt_kernel_signature;
auto pt_kernel_name = pt_kernel_signature.name;
auto pt_kernel_key = TransOpKernelTypeToPtenKernelKey(expected_kernel_key);
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/scale_op.h"
#include <string>
#include "paddle/fluid/platform/float16.h"
#include "paddle/pten/ops/compat/scale_args_fn.h"
namespace paddle {
namespace framework {
......@@ -73,19 +74,8 @@ class ScaleOp : public framework::OperatorWithKernel {
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>() ||
ctx.InputVar("X")->IsType<framework::Tensor>()) {
std::string scale_attr;
if (ctx.HasInput("ScaleTensor")) {
scale_attr = "ScaleTensor";
} else {
scale_attr = "scale";
}
return framework::KernelSignature(
"scale", {"X"}, {scale_attr, "bias", "bias_after_scale"}, {"Out"});
}
// TODO(chenweihang): support other cases after selected rows added
return framework::KernelSignature("scale.unregistered", {}, {}, {});
framework::ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
return pten::ScaleOpArgumentMapping(arg_mapping_ctx);
}
};
......
......@@ -23,7 +23,7 @@ add_subdirectory(ops)
add_subdirectory(tests)
# make an unity target for compile deps
set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context infermeta)
set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context arg_map_context infermeta)
get_property(pten_kernels GLOBAL PROPERTY PTEN_KERNELS)
# keep this message for debug, remove it later if needless
message(STATUS "All standard pten kernels: ${pten_kernels}")
......
......@@ -8,8 +8,9 @@ endif()
cc_library(kernel_factory SRCS kernel_factory.cc DEPS enforce convert_utils)
cc_library(kernel_context SRCS kernel_context.cc DEPS enforce pten_context)
cc_library(tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce)
cc_library(arg_map_context SRCS arg_map_context.cc DEPS enforce)
cc_library(tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce)
cc_library(tensor_meta SRCS tensor_meta.cc DEPS enforce mixed_vector)
cc_library(dense_tensor SRCS dense_tensor.cc DEPS convert_utils tensor_meta tensor_base)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/arg_map_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/string_helper.h"
namespace pten {
OpArgumentMappingFnMap& OpArgumentMappingFnMap::Instance() {
static OpArgumentMappingFnMap g_op_arg_mapping_fn_map;
return g_op_arg_mapping_fn_map;
}
bool OpArgumentMappingFnMap::Has(const std::string& op_type) const {
return fn_map_.find(op_type) != fn_map_.end();
}
const ArgumentMappingFn& OpArgumentMappingFnMap::Get(
const std::string& op_type) const {
auto it = fn_map_.find(op_type);
PADDLE_ENFORCE_NE(
it,
fn_map_.end(),
paddle::platform::errors::NotFound(
"Operator `%s`'s argument mapping funciton is not registered.",
op_type));
return it->second;
}
void OpArgumentMappingFnMap::Emplace(const std::string& op_type,
const std::string api_name,
ArgumentMappingFn fn) {
name_map_.emplace(op_type, api_name);
fn_map_.emplace(op_type, fn);
}
std::ostream& operator<<(std::ostream& os, KernelSignature signature) {
os << "Kernel Signature - name: " << signature.name << "; inputs: "
<< paddle::string::join_strings(std::get<0>(signature.args), ", ")
<< "; attributes: "
<< paddle::string::join_strings(std::get<1>(signature.args), ", ")
<< "; outputs: "
<< paddle::string::join_strings(std::get<2>(signature.args), ", ");
return os;
}
} // namespace pten
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <ostream>
#include <string>
#include <tuple>
#include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h"
namespace pten {
// tuple(input_names, attr_names, output_names)
using KernelArgsTuple = std::tuple<paddle::SmallVector<std::string>,
paddle::SmallVector<std::string>,
paddle::SmallVector<std::string>>;
// TODO(chenweihang): Add more methods if needed in future
class ArgumentMappingContext {
public:
virtual ~ArgumentMappingContext() = default;
virtual bool HasInput(const std::string& name) const = 0;
virtual bool HasOutput(const std::string& name) const = 0;
virtual bool HasAttr(const std::string& name) const = 0;
virtual size_t InputSize(const std::string& name) const = 0;
virtual size_t OutputSize(const std::string& name) const = 0;
virtual bool IsDenseTensorInput(const std::string& name) const = 0;
virtual bool IsSelectedRowsInput(const std::string& name) const = 0;
};
struct KernelSignature {
std::string name;
KernelArgsTuple args;
KernelSignature() = default;
KernelSignature(std::string&& kernel_name,
paddle::SmallVector<std::string>&& inputs,
paddle::SmallVector<std::string>&& attrs,
paddle::SmallVector<std::string>&& outputs)
: name(std::move(kernel_name)),
args(std::make_tuple(inputs, attrs, outputs)) {}
KernelSignature(const std::string& kernel_name,
const paddle::SmallVector<std::string>& inputs,
const paddle::SmallVector<std::string>& attrs,
const paddle::SmallVector<std::string>& outputs)
: name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {}
};
std::ostream& operator<<(std::ostream& os, KernelSignature signature);
using ArgumentMappingFn = KernelSignature (*)(const ArgumentMappingContext&);
class OpArgumentMappingFnMap {
public:
static OpArgumentMappingFnMap& Instance();
bool Has(const std::string& op_type) const;
const ArgumentMappingFn& Get(const std::string& op_type) const;
void Emplace(const std::string& op_type,
const std::string api_name,
ArgumentMappingFn fn);
private:
paddle::flat_hash_map<std::string, std::string> name_map_;
paddle::flat_hash_map<std::string, ArgumentMappingFn> fn_map_;
};
} // namespace pten
......@@ -26,17 +26,4 @@ using KernelArgsDefFn = void (*)(Kernel* kernel);
using KernelArgsParseFn = void (*)(const KernelKey& default_key,
KernelArgsDef* args_def);
// Multiple kernels of the same operation are distinguished by the difference
// of the overload name. For the convenience of reuse, we define some overload
// naming strings for the naming of the kernel
// For kernels that contains dynamic tensor attribute and it need to be always
// on host device, such as `ScaleTensor`
constexpr char kContainHostTensorSuffix[] = "host";
// For kernels with SelectedRowsTensor input and output
constexpr char kContainSelectedRowsSuffix[] = "sr";
// For kernels with intermediate output
constexpr char kContainMidOutputTensorSuffix[] = "mid";
} // namespace pten
......@@ -24,7 +24,7 @@ endif()
# pten depends all pten kernel targets
set_property(GLOBAL PROPERTY PTEN_KERNELS "")
set(COMMON_KERNEL_DEPS dense_tensor kernel_context kernel_factory convert_utils)
set(COMMON_KERNEL_DEPS dense_tensor kernel_context kernel_factory arg_map_context convert_utils)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas)
# remove this dep after removing fluid deps on tensor creation
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} pten_api_utils)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/arg_map_context.h"
namespace pten {
KernelSignature ScaleOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) {
std::string scale_attr;
if (ctx.HasInput("ScaleTensor")) {
scale_attr = "ScaleTensor";
} else {
scale_attr = "scale";
}
return KernelSignature(
"scale", {"X"}, {scale_attr, "bias", "bias_after_scale"}, {"Out"});
}
// TODO(chenweihang): support other cases after selected rows added
return KernelSignature("scale.unregistered", {}, {}, {});
}
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册