提交 ec27aa46 编写于 作者: S superjomn

code clean

上级 0fb566b7
......@@ -24,6 +24,31 @@ std::string KernelBase::summary() const {
return ss.str();
}
const Type *KernelBase::GetInputDeclType(const std::string &arg_name) {
CHECK(!op_type_.empty()) << "op_type should be set first";
const auto *type = ParamTypeRegistry::Global().RetrieveInArgument(
place(), GenParamTypeKey(), arg_name);
CHECK(type) << "no type registered for kernel [" << op_type_
<< "] input argument [" << arg_name << "]"
<< " with key " << GenParamTypeKey();
return type->type;
}
const Type *KernelBase::GetOutputDeclType(const std::string &arg_name) {
CHECK(!op_type_.empty()) << "op_type should be set first";
const auto *type = ParamTypeRegistry::Global().RetrieveOutArgument(
place(), GenParamTypeKey(), arg_name);
CHECK(type) << "no type registered for kernel [" << op_type_
<< "] output argument [" << arg_name << "]";
return type->type;
}
std::string KernelBase::GenParamTypeKey() const {
std::stringstream ss;
ss << op_type() << "/" << alias_;
return ss.str();
}
bool ParamTypeRegistry::KeyCmp::operator()(
const ParamTypeRegistry::key_t &a,
const ParamTypeRegistry::key_t &b) const {
......@@ -37,6 +62,5 @@ std::ostream &operator<<(std::ostream &os,
<< other.place.DebugString();
return os;
}
} // namespace lite
} // namespace paddle
\ No newline at end of file
......@@ -73,28 +73,11 @@ class KernelBase {
void set_op_type(const std::string& type) { op_type_ = type; }
const std::string& op_type() const { return op_type_; }
void Torch() {}
// Get input declaration Type.
const Type* GetInputDeclType(const std::string& arg_name);
// Get input declaration type.
const Type* GetInputDeclType(const std::string& arg_name) {
CHECK(!op_type_.empty()) << "op_type should be set first";
const auto* type = ParamTypeRegistry::Global().RetrieveInArgument(
place(), GenParamTypeKey(), arg_name);
CHECK(type) << "no type registered for kernel [" << op_type_
<< "] input argument [" << arg_name << "]"
<< " with key " << GenParamTypeKey();
return type->type;
}
// Get output declaration type.
const Type* GetOutputDeclType(const std::string& arg_name) {
CHECK(!op_type_.empty()) << "op_type should be set first";
const auto* type = ParamTypeRegistry::Global().RetrieveOutArgument(
place(), GenParamTypeKey(), arg_name);
CHECK(type) << "no type registered for kernel [" << op_type_
<< "] output argument [" << arg_name << "]";
return type->type;
}
// Get output declaration Type.
const Type* GetOutputDeclType(const std::string& arg_name);
void set_alias(const std::string& x) { alias_ = x; }
const std::string& alias() const { return alias_; }
......@@ -110,14 +93,11 @@ class KernelBase {
std::string summary() const;
// Long human-readable document.
virtual std::string doc() const { return ""; }
std::string GenParamTypeKey() const {
std::stringstream ss;
ss << op_type() << "/" << alias_;
return ss.str();
}
// Generate the key of the parameter type.
std::string GenParamTypeKey() const;
virtual ~KernelBase() = default;
void Torch() {}
protected:
std::unique_ptr<KernelContext> context_;
......@@ -144,10 +124,7 @@ class OpKernel : public KernelBase {
PrecisionType precision() const override { return Precision; }
DataLayoutType layout() const override { return DataLayout; }
Place place() const override { return Place{Target, Precision, DataLayout}; }
std::string name() const override {
return op_type() + ":" + TargetToStr(Target) + "/" +
PrecisionToStr(Precision) + "/" + DataLayoutToStr(DataLayout);
}
std::string name() const override;
void Touch() {}
......@@ -158,5 +135,11 @@ class OpKernel : public KernelBase {
std::unique_ptr<KernelContext> ctx_;
};
template <TargetType Target, PrecisionType Precision, DataLayoutType DataLayout>
std::string OpKernel<Target, Precision, DataLayout>::name() const {
return op_type() + ":" + TargetToStr(Target) + "/" +
PrecisionToStr(Precision) + "/" + DataLayoutToStr(DataLayout);
}
} // namespace lite
} // namespace paddle
......@@ -15,7 +15,5 @@
#include "paddle/fluid/lite/core/memory.h"
namespace paddle {
namespace framework {
namespace lite {} // namespace lite
} // namespace framework
} // namespace paddle
......@@ -55,17 +55,6 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
return kernels;
}
void OpLite::PickKernel(const std::vector<Place> &valid_places,
OpLite::KernelStrategy kernel_strategy) {
switch (kernel_strategy) {
case KernelStrategy::kStatic:
StaticPickKernel(valid_places);
break;
default:
LOG(FATAL) << "unsupported kernel strategy";
}
}
bool OpLite::Run() {
CHECK(kernel_);
SyncInputEvents();
......@@ -120,5 +109,72 @@ bool OpInfo::GetOutputArgname(const std::string &value_name,
}
return false;
}
void OpInfo::ExtractInputsAndOutputs(const framework::proto::OpDesc &opdesc) {
for (const auto &item : opdesc.inputs()) {
for (const auto &x : item.arguments()) {
input_names_.push_back(x);
}
}
for (const auto &item : opdesc.outputs()) {
for (const auto &x : item.arguments()) {
output_names_.push_back(x);
}
}
}
void OpInfo::CollectInputAndOutputArgnames(
const framework::proto::OpDesc &opdesc) {
for (const auto &item : opdesc.inputs()) {
input_argnames_.push_back(item.parameter());
}
for (const auto &item : opdesc.outputs()) {
output_argnames_.push_back(item.parameter());
}
}
void OpInfo::CollectArguments(const framework::proto::OpDesc &opdesc) {
for (const auto &item : opdesc.inputs()) {
for (auto &x : item.arguments()) {
input_argument_[item.parameter()].push_back(x);
}
}
for (const auto &item : opdesc.outputs()) {
for (auto &x : item.arguments()) {
output_argument_[item.parameter()].push_back(x);
}
}
}
void OpInfo::Build(const framework::proto::OpDesc &desc) {
ExtractInputsAndOutputs(desc);
CollectInputAndOutputArgnames(desc);
CollectArguments(desc);
desc_.reset(new framework::proto::OpDesc(desc));
}
const std::map<std::string, std::list<std::string>> &OpInfo::input_argument()
const {
return input_argument_;
}
const std::map<std::string, std::list<std::string>> &OpInfo::output_argument()
const {
return output_argument_;
}
const std::list<std::string> &OpInfo::input_argnames() const {
return input_argnames_;
}
const std::list<std::string> &OpInfo::output_argnames() const {
return output_argnames_;
}
const framework::proto::OpDesc &OpInfo::desc() const {
CHECK(desc_) << "desc has't set";
return *desc_;
}
} // namespace lite
} // namespace paddle
......@@ -54,16 +54,6 @@ class OpInfo;
*/
class OpLite : public Registry {
public:
// The strategies to pick a kernel from candidates.
enum class KernelStrategy {
// Return the user specified one.
kStatic = 0,
// Specify the expected kernel externally.
kSpecified,
// Run each kernel to evaluate and get the best kernel.
kRuntime,
};
OpLite() = default;
OpLite(const std::string &type) : op_type_(type) {}
OpLite(const std::vector<Place> &valid_places)
......@@ -91,10 +81,6 @@ class OpLite : public Registry {
const Place &kernel_place() const { return kernel_place_; }
// NOTE This might be discarded.
void PickKernel(const std::vector<Place> &valid_places,
KernelStrategy kernel_strategy = KernelStrategy::kStatic);
// Create all the kernels for the valid targets.
std::vector<std::unique_ptr<KernelBase>> CreateKernels(
const std::vector<Place> &places, const std::string &kernel_type = "");
......@@ -147,71 +133,26 @@ class OpInfo {
public:
// To avoid the bugs from legancy framework::OpDesc, we use the ProtoBuf
// message instead.
void Build(const framework::proto::OpDesc &desc) {
ExtractInputsAndOutputs(desc);
CollectInputAndOutputArgnames(desc);
CollectArguments(desc);
desc_.reset(new framework::proto::OpDesc(desc));
}
void Build(const framework::proto::OpDesc &desc);
const framework::proto::OpDesc &desc() const {
CHECK(desc_) << "desc has't set";
return *desc_;
}
const framework::proto::OpDesc &desc() const;
framework::proto::OpDesc *mutable_desc() { return desc_.get(); }
const std::list<std::string> &input_names() const { return input_names_; }
const std::list<std::string> &output_names() const { return output_names_; }
const std::map<std::string, std::list<std::string>> &input_argument() const {
return input_argument_;
}
const std::map<std::string, std::list<std::string>> &output_argument() const {
return output_argument_;
}
const std::map<std::string, std::list<std::string>> &input_argument() const;
const std::map<std::string, std::list<std::string>> &output_argument() const;
bool GetInputArgname(const std::string &value_name, std::string *out) const;
bool GetOutputArgname(const std::string &value_name, std::string *out) const;
const std::list<std::string> &input_argnames() const {
return input_argnames_;
}
const std::list<std::string> &output_argnames() const {
return output_argnames_;
}
const std::list<std::string> &input_argnames() const;
const std::list<std::string> &output_argnames() const;
private:
void ExtractInputsAndOutputs(const framework::proto::OpDesc &opdesc) {
for (const auto &item : opdesc.inputs()) {
for (const auto &x : item.arguments()) {
input_names_.push_back(x);
}
}
for (const auto &item : opdesc.outputs()) {
for (const auto &x : item.arguments()) {
output_names_.push_back(x);
}
}
}
void ExtractInputsAndOutputs(const framework::proto::OpDesc &opdesc);
void CollectInputAndOutputArgnames(const framework::proto::OpDesc &opdesc) {
for (const auto &item : opdesc.inputs()) {
input_argnames_.push_back(item.parameter());
}
for (const auto &item : opdesc.outputs()) {
output_argnames_.push_back(item.parameter());
}
}
void CollectInputAndOutputArgnames(const framework::proto::OpDesc &opdesc);
void CollectArguments(const framework::proto::OpDesc &opdesc) {
for (const auto &item : opdesc.inputs()) {
for (auto &x : item.arguments()) {
input_argument_[item.parameter()].push_back(x);
}
}
for (const auto &item : opdesc.outputs()) {
for (auto &x : item.arguments()) {
output_argument_[item.parameter()].push_back(x);
}
}
}
void CollectArguments(const framework::proto::OpDesc &opdesc);
private:
std::list<std::string> input_names_;
......
......@@ -41,7 +41,7 @@ class Optimizer {
graph_.reset(new mir::SSAGraph);
graph_->Build(program, valid_places);
SpecifyKernelPickTactic(kernel_pick_factor);
InitIoComplement();
InitTargetTypeTransformPass();
if (passes.empty()) {
RunPasses(std::vector<std::string>{{
......@@ -82,7 +82,7 @@ class Optimizer {
return program;
}
void InitIoComplement() {
void InitTargetTypeTransformPass() {
auto* pass =
mir::PassManager::Global().LookUp<mir::TypeTargetTransformPass>(
"type_target_transform_pass");
......
......@@ -33,7 +33,7 @@ struct Program {
std::list<std::string> tmp_vars;
std::list<std::string> weights;
std::list<std::shared_ptr<OpLite>> ops;
// the scope to run the kernels, NOTE not the root scope.
// the scope to run the kernels, NOTE this is the execution scope.
std::shared_ptr<lite::Scope> scope;
std::vector<Place> valid_places;
// Runtime scope.
......@@ -67,8 +67,6 @@ struct Program {
// if (op_type == "feed" || op_type == "fetch") continue;
VLOG(4) << "create Op [" << op_type << "]";
ops.emplace_back(LiteOpRegistry::Global().Create(op_type));
// pick initial kernel
ops.back()->PickKernel(valid_places);
ops.back()->Attach(op_desc, exec_scope);
}
}
......
......@@ -54,5 +54,13 @@ Variable *Scope::FindLocalVar(const std::string &name) const {
return nullptr;
}
std::vector<std::string> Scope::LocalVarNames() const {
std::vector<std::string> keys;
for (const auto &item : vars_) {
keys.push_back(item.first);
}
return keys;
}
} // namespace lite
} // namespace paddle
......@@ -40,13 +40,7 @@ class Scope final {
const Scope* parent() const { return parent_; }
// Following the legacy scope interface.
std::vector<std::string> LocalVarNames() const {
std::vector<std::string> keys;
for (const auto& item : vars_) {
keys.push_back(item.first);
}
return keys;
}
std::vector<std::string> LocalVarNames() const;
private:
// Scope in `kids_` are owned by this class.
......
......@@ -41,5 +41,31 @@ std::ostream &operator<<(std::ostream &os, const Tensor &tensor) {
return os;
}
void Tensor::ShareDataWith(const Tensor &other) {
buffer_ = other.buffer_;
dims_ = other.dims_;
target_ = other.target_;
lod_ = other.lod_;
memory_size_ = other.memory_size_;
}
void *Tensor::mutable_data(size_t memory_size) {
buffer_->ResetLazy(target_, memory_size);
return buffer_->data();
}
void *Tensor::mutable_data(TargetType target, size_t memory_size) {
target_ = target;
return mutable_data(memory_size);
}
void Tensor::CopyDataFrom(const Tensor &other) {
dims_ = other.dims_;
target_ = other.target_;
lod_ = other.lod_;
memory_size_ = other.memory_size_;
buffer_->CopyDataFrom(*other.buffer_, memory_size_);
}
} // namespace lite
} // namespace paddle
......@@ -62,50 +62,20 @@ class Tensor {
LoD* mutable_lod() { return &lod_; }
template <typename T>
T* mutable_data() {
memory_size_ = product(dims_) * sizeof(T);
buffer_->ResetLazy(target_, memory_size_);
return static_cast<T*>(buffer_->data());
}
T* mutable_data();
template <typename T>
T* mutable_data(TargetType target) {
target_ = target;
memory_size_ = product(dims_) * sizeof(T);
buffer_->ResetLazy(target, memory_size());
return static_cast<T*>(buffer_->data());
}
void* mutable_data(size_t memory_size) {
buffer_->ResetLazy(target_, memory_size);
return buffer_->data();
}
void* mutable_data(TargetType target, size_t memory_size) {
target_ = target;
return mutable_data(memory_size);
}
T* mutable_data(TargetType target);
void* mutable_data(size_t memory_size);
void* mutable_data(TargetType target, size_t memory_size);
size_t memory_size() const { return memory_size_; }
bool IsInitialized() const { return buffer_->data(); }
// Other share data to this.
void ShareDataWith(const Tensor& other) {
buffer_ = other.buffer_;
dims_ = other.dims_;
target_ = other.target_;
lod_ = other.lod_;
memory_size_ = other.memory_size_;
}
void ShareDataWith(const Tensor& other);
void CopyDataFrom(const Tensor& other) {
dims_ = other.dims_;
target_ = other.target_;
lod_ = other.lod_;
memory_size_ = other.memory_size_;
buffer_->CopyDataFrom(*other.buffer_, memory_size_);
}
void CopyDataFrom(const Tensor& other);
TargetType target() const { return target_; }
......@@ -117,6 +87,21 @@ class Tensor {
size_t memory_size_{};
};
template <typename T>
T* Tensor::mutable_data() {
memory_size_ = product(dims_) * sizeof(T);
buffer_->ResetLazy(target_, memory_size_);
return static_cast<T*>(buffer_->data());
}
template <typename T>
T* Tensor::mutable_data(TargetType target) {
target_ = target;
memory_size_ = product(dims_) * sizeof(T);
buffer_->ResetLazy(target, memory_size());
return static_cast<T*>(buffer_->data());
}
std::ostream& operator<<(std::ostream& os, const DDim& dims);
std::ostream& operator<<(std::ostream& os, const Tensor& tensor);
......
......@@ -126,5 +126,14 @@ const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown,
// ------------------------- end GetType specification ------------------------
size_t ParamTypeRegistry::KernelIdTy::hash() const {
std::hash<std::string> h;
size_t hash = h(kernel_type);
hash = hash_combine(hash, place.hash());
hash = hash_combine(hash, std::hash<int>()(static_cast<int>(io)));
hash = hash_combine(hash, std::hash<std::string>()(arg_name));
return hash;
}
} // namespace lite
} // namespace paddle
......@@ -274,19 +274,11 @@ const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown,
* registered in the `TypeSystem`.
*/
struct ParamType {
// For unsupported types.
size_t element_type_hash{};
Place tensor_place{};
const Type* type;
ParamType() = default;
explicit ParamType(size_t element_type_hash)
: element_type_hash(element_type_hash) {}
ParamType(size_t element_type_hash, const Place& place)
: element_type_hash(element_type_hash), tensor_place(place) {}
explicit ParamType(const Type* type) : type(type) {
tensor_place = type->place();
}
ParamType(const Type* type) : type(type) { tensor_place = type->place(); }
std::string DebugString() const { return tensor_place.DebugString(); }
};
......@@ -416,14 +408,7 @@ class ParamTypeRegistry {
IO io;
std::string arg_name;
size_t hash() const {
std::hash<std::string> h;
size_t hash = h(kernel_type);
hash = hash_combine(hash, place.hash());
hash = hash_combine(hash, std::hash<int>()(static_cast<int>(io)));
hash = hash_combine(hash, std::hash<std::string>()(arg_name));
return hash;
}
size_t hash() const;
friend std::ostream& operator<<(std::ostream& os, const KernelIdTy& other);
};
......
......@@ -47,6 +47,24 @@ bool KernelPickFactor::IsDeviceConsidered() const {
return data_ & static_cast<int>(Factor::DeviceFirst);
}
std::ostream& operator<<(std::ostream& os, const KernelPickFactor& k) {
std::stack<bool> bits;
auto data = k.data_;
while (data) {
bits.push(data % 2);
data /= 2;
}
int nbits = bits.size();
for (size_t i = 0; i < sizeof(data) * 8 - nbits; i++) {
os << 0;
}
while (!bits.empty()) {
os << bits.top();
bits.pop();
}
return os;
}
} // namespace core
} // namespace lite
} // namespace paddle
......@@ -49,23 +49,7 @@ class KernelPickFactor {
bool IsDataLayoutConsidered() const;
bool IsDeviceConsidered() const;
friend std::ostream& operator<<(std::ostream& os, const KernelPickFactor& k) {
std::stack<bool> bits;
auto data = k.data_;
while (data) {
bits.push(data % 2);
data /= 2;
}
int nbits = bits.size();
for (size_t i = 0; i < sizeof(data) * 8 - nbits; i++) {
os << 0;
}
while (!bits.empty()) {
os << bits.top();
bits.pop();
}
return os;
}
friend std::ostream& operator<<(std::ostream& os, const KernelPickFactor& k);
private:
unsigned char data_{};
......
......@@ -56,10 +56,6 @@ class ScaleOp : public OpLite {
param_.scale = op_desc.GetAttr("scale").get<float>();
param_.bias = op_desc.GetAttr("bias").get<float>();
param_.bias_after_scale = op_desc.GetAttr("bias_after_scale").get<bool>();
CHECK(kernel_);
kernel_->SetParam(param_);
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册