未验证 提交 81dfc0cf 编写于 作者: Y Yang Yang(Tony) 提交者: GitHub

Clean up unused code in operator class (#10035)

* delete unused IsNetOp() and Rename()

* rm OperatorBase::Rename implementation

* delete Operator::InputVars()

* remove unused OperatorBase::ShareLoD; ShareLoD has been implemented in infershape

* organize operatorbase; remove unused set_type

* add comments

* fix comment
上级 f09aed04
...@@ -171,17 +171,6 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { ...@@ -171,17 +171,6 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
return ss.str(); return ss.str();
} }
void OperatorBase::Rename(const std::string& old_name,
const std::string& new_name) {
for (auto& input : inputs_) {
std::replace(input.second.begin(), input.second.end(), old_name, new_name);
}
for (auto& output : outputs_) {
std::replace(output.second.begin(), output.second.end(), old_name,
new_name);
}
}
OperatorBase::OperatorBase(const std::string& type, OperatorBase::OperatorBase(const std::string& type,
const VariableNameMap& inputs, const VariableNameMap& inputs,
const VariableNameMap& outputs, const VariableNameMap& outputs,
...@@ -327,7 +316,6 @@ bool OpSupportGPU(const std::string& op_type) { ...@@ -327,7 +316,6 @@ bool OpSupportGPU(const std::string& op_type) {
auto it = all_kernels.find(op_type); auto it = all_kernels.find(op_type);
if (it == all_kernels.end()) { if (it == all_kernels.end()) {
// All control operator must support GPU // All control operator must support GPU
return true; return true;
} }
for (auto& kern_pair : it->second) { for (auto& kern_pair : it->second) {
......
...@@ -79,31 +79,28 @@ class OperatorBase { ...@@ -79,31 +79,28 @@ class OperatorBase {
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
template <typename T> /// Executor will call this interface function to Run an op.
inline const T& Attr(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name);
return boost::get<T>(attrs_.at(name));
}
/// if scope is not null, also show dimensions of arguments
virtual std::string DebugStringEx(const Scope* scope) const;
std::string DebugString() const { return DebugStringEx(nullptr); }
/// Net will call this interface function to Run an op.
// The implementation should be written at RunImpl // The implementation should be written at RunImpl
void Run(const Scope& scope, const platform::Place& place); void Run(const Scope& scope, const platform::Place& place);
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop. // FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
virtual void Stop() {} virtual void Stop() {}
virtual bool IsNetOp() const { return false; } /// if scope is not null, also show dimensions of arguments
virtual std::string DebugStringEx(const Scope* scope) const;
std::string DebugString() const { return DebugStringEx(nullptr); }
virtual bool SupportGPU() const { return false; } virtual bool SupportGPU() const { return false; }
/// rename inputs outputs name const std::string& Type() const { return type_; }
void Rename(const std::string& old_name, const std::string& new_name);
template <typename T>
inline const T& Attr(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name);
return boost::get<T>(attrs_.at(name));
}
const AttributeMap& Attrs() const { return attrs_; }
const VariableNameMap& Inputs() const { return inputs_; } const VariableNameMap& Inputs() const { return inputs_; }
const VariableNameMap& Outputs() const { return outputs_; } const VariableNameMap& Outputs() const { return outputs_; }
...@@ -112,7 +109,7 @@ class OperatorBase { ...@@ -112,7 +109,7 @@ class OperatorBase {
std::string Input(const std::string& name) const; std::string Input(const std::string& name) const;
//! Get a input which has multiple variables. //! Get a input which has multiple variables.
const std::vector<std::string>& Inputs(const std::string& name) const; const std::vector<std::string>& Inputs(const std::string& name) const;
//! Get all inputs variable names
std::vector<std::string> InputVars() const; std::vector<std::string> InputVars() const;
//! Get a output with argument's name described in `op_proto` //! Get a output with argument's name described in `op_proto`
...@@ -120,13 +117,9 @@ class OperatorBase { ...@@ -120,13 +117,9 @@ class OperatorBase {
//! Get an output which has multiple variables. //! Get an output which has multiple variables.
//! TODO add a vector_view to prevent memory copy. //! TODO add a vector_view to prevent memory copy.
const std::vector<std::string>& Outputs(const std::string& name) const; const std::vector<std::string>& Outputs(const std::string& name) const;
//! Get all outputs variable names
virtual std::vector<std::string> OutputVars(bool has_intermediate) const; virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
const std::string& Type() const { return type_; }
void SetType(const std::string& type) { type_ = type; }
const AttributeMap& Attrs() const { return attrs_; }
// Return a new operator instance, which is as same as this. // Return a new operator instance, which is as same as this.
// Use unique_ptr to prevent caller forget to delete this pointer. // Use unique_ptr to prevent caller forget to delete this pointer.
virtual std::unique_ptr<OperatorBase> Clone() const = 0; virtual std::unique_ptr<OperatorBase> Clone() const = 0;
...@@ -278,20 +271,6 @@ class ExecutionContext { ...@@ -278,20 +271,6 @@ class ExecutionContext {
return res; return res;
} }
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const {
PADDLE_ENFORCE_LT(i, InputSize(in));
PADDLE_ENFORCE_LT(j, OutputSize(out));
auto* in_var = MultiInputVar(in)[i];
auto* out_var = MultiOutputVar(out)[j];
if (!in_var->IsType<LoDTensor>()) return;
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
auto in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_lod(in_tensor.lod());
}
platform::Place GetPlace() const { return device_context_.GetPlace(); } platform::Place GetPlace() const { return device_context_.GetPlace(); }
template <typename DeviceContextType> template <typename DeviceContextType>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册