提交 ec27aa46 编写于 作者: S superjomn

code clean

上级 0fb566b7
...@@ -24,6 +24,31 @@ std::string KernelBase::summary() const { ...@@ -24,6 +24,31 @@ std::string KernelBase::summary() const {
return ss.str(); return ss.str();
} }
const Type *KernelBase::GetInputDeclType(const std::string &arg_name) {
CHECK(!op_type_.empty()) << "op_type should be set first";
const auto *type = ParamTypeRegistry::Global().RetrieveInArgument(
place(), GenParamTypeKey(), arg_name);
CHECK(type) << "no type registered for kernel [" << op_type_
<< "] input argument [" << arg_name << "]"
<< " with key " << GenParamTypeKey();
return type->type;
}
const Type *KernelBase::GetOutputDeclType(const std::string &arg_name) {
CHECK(!op_type_.empty()) << "op_type should be set first";
const auto *type = ParamTypeRegistry::Global().RetrieveOutArgument(
place(), GenParamTypeKey(), arg_name);
CHECK(type) << "no type registered for kernel [" << op_type_
<< "] output argument [" << arg_name << "]";
return type->type;
}
std::string KernelBase::GenParamTypeKey() const {
std::stringstream ss;
ss << op_type() << "/" << alias_;
return ss.str();
}
bool ParamTypeRegistry::KeyCmp::operator()( bool ParamTypeRegistry::KeyCmp::operator()(
const ParamTypeRegistry::key_t &a, const ParamTypeRegistry::key_t &a,
const ParamTypeRegistry::key_t &b) const { const ParamTypeRegistry::key_t &b) const {
...@@ -37,6 +62,5 @@ std::ostream &operator<<(std::ostream &os, ...@@ -37,6 +62,5 @@ std::ostream &operator<<(std::ostream &os,
<< other.place.DebugString(); << other.place.DebugString();
return os; return os;
} }
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
\ No newline at end of file
...@@ -73,28 +73,11 @@ class KernelBase { ...@@ -73,28 +73,11 @@ class KernelBase {
void set_op_type(const std::string& type) { op_type_ = type; } void set_op_type(const std::string& type) { op_type_ = type; }
const std::string& op_type() const { return op_type_; } const std::string& op_type() const { return op_type_; }
void Torch() {} // Get input declaration Type.
const Type* GetInputDeclType(const std::string& arg_name);
// Get input declaration type. // Get output declaration Type.
const Type* GetInputDeclType(const std::string& arg_name) { const Type* GetOutputDeclType(const std::string& arg_name);
CHECK(!op_type_.empty()) << "op_type should be set first";
const auto* type = ParamTypeRegistry::Global().RetrieveInArgument(
place(), GenParamTypeKey(), arg_name);
CHECK(type) << "no type registered for kernel [" << op_type_
<< "] input argument [" << arg_name << "]"
<< " with key " << GenParamTypeKey();
return type->type;
}
// Get output declaration type.
const Type* GetOutputDeclType(const std::string& arg_name) {
CHECK(!op_type_.empty()) << "op_type should be set first";
const auto* type = ParamTypeRegistry::Global().RetrieveOutArgument(
place(), GenParamTypeKey(), arg_name);
CHECK(type) << "no type registered for kernel [" << op_type_
<< "] output argument [" << arg_name << "]";
return type->type;
}
void set_alias(const std::string& x) { alias_ = x; } void set_alias(const std::string& x) { alias_ = x; }
const std::string& alias() const { return alias_; } const std::string& alias() const { return alias_; }
...@@ -110,14 +93,11 @@ class KernelBase { ...@@ -110,14 +93,11 @@ class KernelBase {
std::string summary() const; std::string summary() const;
// Long human-readable document. // Long human-readable document.
virtual std::string doc() const { return ""; } virtual std::string doc() const { return ""; }
// Generate the key of the parameter type.
std::string GenParamTypeKey() const { std::string GenParamTypeKey() const;
std::stringstream ss;
ss << op_type() << "/" << alias_;
return ss.str();
}
virtual ~KernelBase() = default; virtual ~KernelBase() = default;
void Torch() {}
protected: protected:
std::unique_ptr<KernelContext> context_; std::unique_ptr<KernelContext> context_;
...@@ -144,10 +124,7 @@ class OpKernel : public KernelBase { ...@@ -144,10 +124,7 @@ class OpKernel : public KernelBase {
PrecisionType precision() const override { return Precision; } PrecisionType precision() const override { return Precision; }
DataLayoutType layout() const override { return DataLayout; } DataLayoutType layout() const override { return DataLayout; }
Place place() const override { return Place{Target, Precision, DataLayout}; } Place place() const override { return Place{Target, Precision, DataLayout}; }
std::string name() const override { std::string name() const override;
return op_type() + ":" + TargetToStr(Target) + "/" +
PrecisionToStr(Precision) + "/" + DataLayoutToStr(DataLayout);
}
void Touch() {} void Touch() {}
...@@ -158,5 +135,11 @@ class OpKernel : public KernelBase { ...@@ -158,5 +135,11 @@ class OpKernel : public KernelBase {
std::unique_ptr<KernelContext> ctx_; std::unique_ptr<KernelContext> ctx_;
}; };
template <TargetType Target, PrecisionType Precision, DataLayoutType DataLayout>
std::string OpKernel<Target, Precision, DataLayout>::name() const {
return op_type() + ":" + TargetToStr(Target) + "/" +
PrecisionToStr(Precision) + "/" + DataLayoutToStr(DataLayout);
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -15,7 +15,5 @@ ...@@ -15,7 +15,5 @@
#include "paddle/fluid/lite/core/memory.h" #include "paddle/fluid/lite/core/memory.h"
namespace paddle { namespace paddle {
namespace framework {
namespace lite {} // namespace lite namespace lite {} // namespace lite
} // namespace framework
} // namespace paddle } // namespace paddle
...@@ -55,17 +55,6 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels( ...@@ -55,17 +55,6 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
return kernels; return kernels;
} }
void OpLite::PickKernel(const std::vector<Place> &valid_places,
OpLite::KernelStrategy kernel_strategy) {
switch (kernel_strategy) {
case KernelStrategy::kStatic:
StaticPickKernel(valid_places);
break;
default:
LOG(FATAL) << "unsupported kernel strategy";
}
}
bool OpLite::Run() { bool OpLite::Run() {
CHECK(kernel_); CHECK(kernel_);
SyncInputEvents(); SyncInputEvents();
...@@ -120,5 +109,72 @@ bool OpInfo::GetOutputArgname(const std::string &value_name, ...@@ -120,5 +109,72 @@ bool OpInfo::GetOutputArgname(const std::string &value_name,
} }
return false; return false;
} }
void OpInfo::ExtractInputsAndOutputs(const framework::proto::OpDesc &opdesc) {
for (const auto &item : opdesc.inputs()) {
for (const auto &x : item.arguments()) {
input_names_.push_back(x);
}
}
for (const auto &item : opdesc.outputs()) {
for (const auto &x : item.arguments()) {
output_names_.push_back(x);
}
}
}
void OpInfo::CollectInputAndOutputArgnames(
const framework::proto::OpDesc &opdesc) {
for (const auto &item : opdesc.inputs()) {
input_argnames_.push_back(item.parameter());
}
for (const auto &item : opdesc.outputs()) {
output_argnames_.push_back(item.parameter());
}
}
void OpInfo::CollectArguments(const framework::proto::OpDesc &opdesc) {
for (const auto &item : opdesc.inputs()) {
for (auto &x : item.arguments()) {
input_argument_[item.parameter()].push_back(x);
}
}
for (const auto &item : opdesc.outputs()) {
for (auto &x : item.arguments()) {
output_argument_[item.parameter()].push_back(x);
}
}
}
void OpInfo::Build(const framework::proto::OpDesc &desc) {
ExtractInputsAndOutputs(desc);
CollectInputAndOutputArgnames(desc);
CollectArguments(desc);
desc_.reset(new framework::proto::OpDesc(desc));
}
const std::map<std::string, std::list<std::string>> &OpInfo::input_argument()
const {
return input_argument_;
}
const std::map<std::string, std::list<std::string>> &OpInfo::output_argument()
const {
return output_argument_;
}
const std::list<std::string> &OpInfo::input_argnames() const {
return input_argnames_;
}
const std::list<std::string> &OpInfo::output_argnames() const {
return output_argnames_;
}
const framework::proto::OpDesc &OpInfo::desc() const {
CHECK(desc_) << "desc has't set";
return *desc_;
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -54,16 +54,6 @@ class OpInfo; ...@@ -54,16 +54,6 @@ class OpInfo;
*/ */
class OpLite : public Registry { class OpLite : public Registry {
public: public:
// The strategies to pick a kernel from candidates.
enum class KernelStrategy {
// Return the user specified one.
kStatic = 0,
// Specify the expected kernel externally.
kSpecified,
// Run each kernel to evaluate and get the best kernel.
kRuntime,
};
OpLite() = default; OpLite() = default;
OpLite(const std::string &type) : op_type_(type) {} OpLite(const std::string &type) : op_type_(type) {}
OpLite(const std::vector<Place> &valid_places) OpLite(const std::vector<Place> &valid_places)
...@@ -91,10 +81,6 @@ class OpLite : public Registry { ...@@ -91,10 +81,6 @@ class OpLite : public Registry {
const Place &kernel_place() const { return kernel_place_; } const Place &kernel_place() const { return kernel_place_; }
// NOTE This might be discarded.
void PickKernel(const std::vector<Place> &valid_places,
KernelStrategy kernel_strategy = KernelStrategy::kStatic);
// Create all the kernels for the valid targets. // Create all the kernels for the valid targets.
std::vector<std::unique_ptr<KernelBase>> CreateKernels( std::vector<std::unique_ptr<KernelBase>> CreateKernels(
const std::vector<Place> &places, const std::string &kernel_type = ""); const std::vector<Place> &places, const std::string &kernel_type = "");
...@@ -147,71 +133,26 @@ class OpInfo { ...@@ -147,71 +133,26 @@ class OpInfo {
public: public:
// To avoid the bugs from legancy framework::OpDesc, we use the ProtoBuf // To avoid the bugs from legancy framework::OpDesc, we use the ProtoBuf
// message instead. // message instead.
void Build(const framework::proto::OpDesc &desc) { void Build(const framework::proto::OpDesc &desc);
ExtractInputsAndOutputs(desc);
CollectInputAndOutputArgnames(desc);
CollectArguments(desc);
desc_.reset(new framework::proto::OpDesc(desc));
}
const framework::proto::OpDesc &desc() const { const framework::proto::OpDesc &desc() const;
CHECK(desc_) << "desc has't set";
return *desc_;
}
framework::proto::OpDesc *mutable_desc() { return desc_.get(); } framework::proto::OpDesc *mutable_desc() { return desc_.get(); }
const std::list<std::string> &input_names() const { return input_names_; } const std::list<std::string> &input_names() const { return input_names_; }
const std::list<std::string> &output_names() const { return output_names_; } const std::list<std::string> &output_names() const { return output_names_; }
const std::map<std::string, std::list<std::string>> &input_argument() const { const std::map<std::string, std::list<std::string>> &input_argument() const;
return input_argument_; const std::map<std::string, std::list<std::string>> &output_argument() const;
}
const std::map<std::string, std::list<std::string>> &output_argument() const {
return output_argument_;
}
bool GetInputArgname(const std::string &value_name, std::string *out) const; bool GetInputArgname(const std::string &value_name, std::string *out) const;
bool GetOutputArgname(const std::string &value_name, std::string *out) const; bool GetOutputArgname(const std::string &value_name, std::string *out) const;
const std::list<std::string> &input_argnames() const { const std::list<std::string> &input_argnames() const;
return input_argnames_; const std::list<std::string> &output_argnames() const;
}
const std::list<std::string> &output_argnames() const {
return output_argnames_;
}
private: private:
void ExtractInputsAndOutputs(const framework::proto::OpDesc &opdesc) { void ExtractInputsAndOutputs(const framework::proto::OpDesc &opdesc);
for (const auto &item : opdesc.inputs()) {
for (const auto &x : item.arguments()) {
input_names_.push_back(x);
}
}
for (const auto &item : opdesc.outputs()) {
for (const auto &x : item.arguments()) {
output_names_.push_back(x);
}
}
}
void CollectInputAndOutputArgnames(const framework::proto::OpDesc &opdesc) { void CollectInputAndOutputArgnames(const framework::proto::OpDesc &opdesc);
for (const auto &item : opdesc.inputs()) {
input_argnames_.push_back(item.parameter());
}
for (const auto &item : opdesc.outputs()) {
output_argnames_.push_back(item.parameter());
}
}
void CollectArguments(const framework::proto::OpDesc &opdesc) { void CollectArguments(const framework::proto::OpDesc &opdesc);
for (const auto &item : opdesc.inputs()) {
for (auto &x : item.arguments()) {
input_argument_[item.parameter()].push_back(x);
}
}
for (const auto &item : opdesc.outputs()) {
for (auto &x : item.arguments()) {
output_argument_[item.parameter()].push_back(x);
}
}
}
private: private:
std::list<std::string> input_names_; std::list<std::string> input_names_;
......
...@@ -41,7 +41,7 @@ class Optimizer { ...@@ -41,7 +41,7 @@ class Optimizer {
graph_.reset(new mir::SSAGraph); graph_.reset(new mir::SSAGraph);
graph_->Build(program, valid_places); graph_->Build(program, valid_places);
SpecifyKernelPickTactic(kernel_pick_factor); SpecifyKernelPickTactic(kernel_pick_factor);
InitIoComplement(); InitTargetTypeTransformPass();
if (passes.empty()) { if (passes.empty()) {
RunPasses(std::vector<std::string>{{ RunPasses(std::vector<std::string>{{
...@@ -82,7 +82,7 @@ class Optimizer { ...@@ -82,7 +82,7 @@ class Optimizer {
return program; return program;
} }
void InitIoComplement() { void InitTargetTypeTransformPass() {
auto* pass = auto* pass =
mir::PassManager::Global().LookUp<mir::TypeTargetTransformPass>( mir::PassManager::Global().LookUp<mir::TypeTargetTransformPass>(
"type_target_transform_pass"); "type_target_transform_pass");
......
...@@ -33,7 +33,7 @@ struct Program { ...@@ -33,7 +33,7 @@ struct Program {
std::list<std::string> tmp_vars; std::list<std::string> tmp_vars;
std::list<std::string> weights; std::list<std::string> weights;
std::list<std::shared_ptr<OpLite>> ops; std::list<std::shared_ptr<OpLite>> ops;
// the scope to run the kernels, NOTE not the root scope. // the scope to run the kernels, NOTE this is the execution scope.
std::shared_ptr<lite::Scope> scope; std::shared_ptr<lite::Scope> scope;
std::vector<Place> valid_places; std::vector<Place> valid_places;
// Runtime scope. // Runtime scope.
...@@ -67,8 +67,6 @@ struct Program { ...@@ -67,8 +67,6 @@ struct Program {
// if (op_type == "feed" || op_type == "fetch") continue; // if (op_type == "feed" || op_type == "fetch") continue;
VLOG(4) << "create Op [" << op_type << "]"; VLOG(4) << "create Op [" << op_type << "]";
ops.emplace_back(LiteOpRegistry::Global().Create(op_type)); ops.emplace_back(LiteOpRegistry::Global().Create(op_type));
// pick initial kernel
ops.back()->PickKernel(valid_places);
ops.back()->Attach(op_desc, exec_scope); ops.back()->Attach(op_desc, exec_scope);
} }
} }
......
...@@ -54,5 +54,13 @@ Variable *Scope::FindLocalVar(const std::string &name) const { ...@@ -54,5 +54,13 @@ Variable *Scope::FindLocalVar(const std::string &name) const {
return nullptr; return nullptr;
} }
std::vector<std::string> Scope::LocalVarNames() const {
std::vector<std::string> keys;
for (const auto &item : vars_) {
keys.push_back(item.first);
}
return keys;
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -40,13 +40,7 @@ class Scope final { ...@@ -40,13 +40,7 @@ class Scope final {
const Scope* parent() const { return parent_; } const Scope* parent() const { return parent_; }
// Following the legacy scope interface. // Following the legacy scope interface.
std::vector<std::string> LocalVarNames() const { std::vector<std::string> LocalVarNames() const;
std::vector<std::string> keys;
for (const auto& item : vars_) {
keys.push_back(item.first);
}
return keys;
}
private: private:
// Scope in `kids_` are owned by this class. // Scope in `kids_` are owned by this class.
......
...@@ -41,5 +41,31 @@ std::ostream &operator<<(std::ostream &os, const Tensor &tensor) { ...@@ -41,5 +41,31 @@ std::ostream &operator<<(std::ostream &os, const Tensor &tensor) {
return os; return os;
} }
void Tensor::ShareDataWith(const Tensor &other) {
buffer_ = other.buffer_;
dims_ = other.dims_;
target_ = other.target_;
lod_ = other.lod_;
memory_size_ = other.memory_size_;
}
void *Tensor::mutable_data(size_t memory_size) {
buffer_->ResetLazy(target_, memory_size);
return buffer_->data();
}
void *Tensor::mutable_data(TargetType target, size_t memory_size) {
target_ = target;
return mutable_data(memory_size);
}
void Tensor::CopyDataFrom(const Tensor &other) {
dims_ = other.dims_;
target_ = other.target_;
lod_ = other.lod_;
memory_size_ = other.memory_size_;
buffer_->CopyDataFrom(*other.buffer_, memory_size_);
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -62,50 +62,20 @@ class Tensor { ...@@ -62,50 +62,20 @@ class Tensor {
LoD* mutable_lod() { return &lod_; } LoD* mutable_lod() { return &lod_; }
template <typename T> template <typename T>
T* mutable_data() { T* mutable_data();
memory_size_ = product(dims_) * sizeof(T);
buffer_->ResetLazy(target_, memory_size_);
return static_cast<T*>(buffer_->data());
}
template <typename T> template <typename T>
T* mutable_data(TargetType target) { T* mutable_data(TargetType target);
target_ = target; void* mutable_data(size_t memory_size);
memory_size_ = product(dims_) * sizeof(T); void* mutable_data(TargetType target, size_t memory_size);
buffer_->ResetLazy(target, memory_size());
return static_cast<T*>(buffer_->data());
}
void* mutable_data(size_t memory_size) {
buffer_->ResetLazy(target_, memory_size);
return buffer_->data();
}
void* mutable_data(TargetType target, size_t memory_size) {
target_ = target;
return mutable_data(memory_size);
}
size_t memory_size() const { return memory_size_; } size_t memory_size() const { return memory_size_; }
bool IsInitialized() const { return buffer_->data(); } bool IsInitialized() const { return buffer_->data(); }
// Other share data to this. // Other share data to this.
void ShareDataWith(const Tensor& other) { void ShareDataWith(const Tensor& other);
buffer_ = other.buffer_;
dims_ = other.dims_;
target_ = other.target_;
lod_ = other.lod_;
memory_size_ = other.memory_size_;
}
void CopyDataFrom(const Tensor& other) { void CopyDataFrom(const Tensor& other);
dims_ = other.dims_;
target_ = other.target_;
lod_ = other.lod_;
memory_size_ = other.memory_size_;
buffer_->CopyDataFrom(*other.buffer_, memory_size_);
}
TargetType target() const { return target_; } TargetType target() const { return target_; }
...@@ -117,6 +87,21 @@ class Tensor { ...@@ -117,6 +87,21 @@ class Tensor {
size_t memory_size_{}; size_t memory_size_{};
}; };
template <typename T>
T* Tensor::mutable_data() {
memory_size_ = product(dims_) * sizeof(T);
buffer_->ResetLazy(target_, memory_size_);
return static_cast<T*>(buffer_->data());
}
template <typename T>
T* Tensor::mutable_data(TargetType target) {
target_ = target;
memory_size_ = product(dims_) * sizeof(T);
buffer_->ResetLazy(target, memory_size());
return static_cast<T*>(buffer_->data());
}
std::ostream& operator<<(std::ostream& os, const DDim& dims); std::ostream& operator<<(std::ostream& os, const DDim& dims);
std::ostream& operator<<(std::ostream& os, const Tensor& tensor); std::ostream& operator<<(std::ostream& os, const Tensor& tensor);
......
...@@ -126,5 +126,14 @@ const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown, ...@@ -126,5 +126,14 @@ const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown,
// ------------------------- end GetType specification ------------------------ // ------------------------- end GetType specification ------------------------
size_t ParamTypeRegistry::KernelIdTy::hash() const {
std::hash<std::string> h;
size_t hash = h(kernel_type);
hash = hash_combine(hash, place.hash());
hash = hash_combine(hash, std::hash<int>()(static_cast<int>(io)));
hash = hash_combine(hash, std::hash<std::string>()(arg_name));
return hash;
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -274,19 +274,11 @@ const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown, ...@@ -274,19 +274,11 @@ const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown,
* registered in the `TypeSystem`. * registered in the `TypeSystem`.
*/ */
struct ParamType { struct ParamType {
// For unsupported types.
size_t element_type_hash{};
Place tensor_place{}; Place tensor_place{};
const Type* type; const Type* type;
ParamType() = default; ParamType() = default;
explicit ParamType(size_t element_type_hash) ParamType(const Type* type) : type(type) { tensor_place = type->place(); }
: element_type_hash(element_type_hash) {}
ParamType(size_t element_type_hash, const Place& place)
: element_type_hash(element_type_hash), tensor_place(place) {}
explicit ParamType(const Type* type) : type(type) {
tensor_place = type->place();
}
std::string DebugString() const { return tensor_place.DebugString(); } std::string DebugString() const { return tensor_place.DebugString(); }
}; };
...@@ -416,14 +408,7 @@ class ParamTypeRegistry { ...@@ -416,14 +408,7 @@ class ParamTypeRegistry {
IO io; IO io;
std::string arg_name; std::string arg_name;
size_t hash() const { size_t hash() const;
std::hash<std::string> h;
size_t hash = h(kernel_type);
hash = hash_combine(hash, place.hash());
hash = hash_combine(hash, std::hash<int>()(static_cast<int>(io)));
hash = hash_combine(hash, std::hash<std::string>()(arg_name));
return hash;
}
friend std::ostream& operator<<(std::ostream& os, const KernelIdTy& other); friend std::ostream& operator<<(std::ostream& os, const KernelIdTy& other);
}; };
......
...@@ -47,6 +47,24 @@ bool KernelPickFactor::IsDeviceConsidered() const { ...@@ -47,6 +47,24 @@ bool KernelPickFactor::IsDeviceConsidered() const {
return data_ & static_cast<int>(Factor::DeviceFirst); return data_ & static_cast<int>(Factor::DeviceFirst);
} }
std::ostream& operator<<(std::ostream& os, const KernelPickFactor& k) {
std::stack<bool> bits;
auto data = k.data_;
while (data) {
bits.push(data % 2);
data /= 2;
}
int nbits = bits.size();
for (size_t i = 0; i < sizeof(data) * 8 - nbits; i++) {
os << 0;
}
while (!bits.empty()) {
os << bits.top();
bits.pop();
}
return os;
}
} // namespace core } // namespace core
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -49,23 +49,7 @@ class KernelPickFactor { ...@@ -49,23 +49,7 @@ class KernelPickFactor {
bool IsDataLayoutConsidered() const; bool IsDataLayoutConsidered() const;
bool IsDeviceConsidered() const; bool IsDeviceConsidered() const;
friend std::ostream& operator<<(std::ostream& os, const KernelPickFactor& k) { friend std::ostream& operator<<(std::ostream& os, const KernelPickFactor& k);
std::stack<bool> bits;
auto data = k.data_;
while (data) {
bits.push(data % 2);
data /= 2;
}
int nbits = bits.size();
for (size_t i = 0; i < sizeof(data) * 8 - nbits; i++) {
os << 0;
}
while (!bits.empty()) {
os << bits.top();
bits.pop();
}
return os;
}
private: private:
unsigned char data_{}; unsigned char data_{};
......
...@@ -56,10 +56,6 @@ class ScaleOp : public OpLite { ...@@ -56,10 +56,6 @@ class ScaleOp : public OpLite {
param_.scale = op_desc.GetAttr("scale").get<float>(); param_.scale = op_desc.GetAttr("scale").get<float>();
param_.bias = op_desc.GetAttr("bias").get<float>(); param_.bias = op_desc.GetAttr("bias").get<float>();
param_.bias_after_scale = op_desc.GetAttr("bias_after_scale").get<bool>(); param_.bias_after_scale = op_desc.GetAttr("bias_after_scale").get<bool>();
CHECK(kernel_);
kernel_->SetParam(param_);
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册