提交 b2e3824e 编写于 作者: Q qiaolongfei

change operator

上级 81f5f861
...@@ -120,8 +120,10 @@ class OpProtoAndCheckerMaker { ...@@ -120,8 +120,10 @@ class OpProtoAndCheckerMaker {
}; };
class OpRegistry { class OpRegistry {
using OpCreator = std::function<OperatorBase*()>;
using VarNameMap = OperatorBase::VarNameMap; using VarNameMap = OperatorBase::VarNameMap;
using OpCreator = std::function<OperatorBase*(
const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)>;
public: public:
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
...@@ -153,14 +155,9 @@ class OpRegistry { ...@@ -153,14 +155,9 @@ class OpRegistry {
PADDLE_ENFORCE(op_create_it != op_creators().end(), PADDLE_ENFORCE(op_create_it != op_creators().end(),
"Operator %s cannot be found.", type); "Operator %s cannot be found.", type);
auto op = op_create_it->second(); auto attrMap = attrs;
op->type_ = type; op_checkers().at(type).Check(attrMap);
op->inputs_ = inputs; auto op = op_create_it->second(type, inputs, outputs, attrMap);
op->outputs_ = outputs;
op->attrs_ = attrs;
op_checkers().at(type).Check(op->attrs_);
GenerateTempVariableName(op); GenerateTempVariableName(op);
op->Init(); op->Init();
...@@ -217,12 +214,14 @@ class OpRegistry { ...@@ -217,12 +214,14 @@ class OpRegistry {
static void GenerateTempVariableName(OperatorBase* op) { static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL); static std::atomic<size_t> gUniqId(0UL);
for (auto& output : op->outputs_) { for (auto& output : op->Outputs()) {
for (auto& output_name : output.second) { for (auto& output_name : output.second) {
if (output_name == kTempVarName) { if (output_name == kTempVarName) {
output_name += op->type_; auto new_name = output_name;
output_name += "@"; new_name += op->Type();
output_name += std::to_string(gUniqId.fetch_add(1)); new_name += "@";
new_name += std::to_string(gUniqId.fetch_add(1));
op->Rename(output_name, new_name);
} }
} }
} }
......
...@@ -105,6 +105,8 @@ class OperatorBase { ...@@ -105,6 +105,8 @@ class OperatorBase {
/// rename inputs outputs name /// rename inputs outputs name
void Rename(const std::string& old_name, const std::string& new_name); void Rename(const std::string& old_name, const std::string& new_name);
const VarNameMap& Inputs() const { return inputs_; }
const VarNameMap& Outputs() const { return outputs_; }
//! Get a input with argument's name described in `op_proto` //! Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const; const std::string& Input(const std::string& name) const;
//! Get a input which has multiple variables. //! Get a input which has multiple variables.
...@@ -118,10 +120,10 @@ class OperatorBase { ...@@ -118,10 +120,10 @@ class OperatorBase {
virtual std::vector<std::string> OutputVars(bool has_intermediate) const; virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
std::string Type() const { return type_; } const std::string& Type() const { return type_; }
const AttributeMap& Attrs() const { return attrs_; } const AttributeMap& Attrs() const { return attrs_; }
public: protected:
std::string type_; std::string type_;
// NOTE: in case of OpGrad, inputs_ contains: // NOTE: in case of OpGrad, inputs_ contains:
// I (Inputs) // I (Inputs)
......
...@@ -29,7 +29,7 @@ void NetOp::CompleteAddOp(bool calc) { ...@@ -29,7 +29,7 @@ void NetOp::CompleteAddOp(bool calc) {
std::set<std::string> input_set; std::set<std::string> input_set;
std::set<std::string> output_set; std::set<std::string> output_set;
for (auto& op : ops_) { for (auto& op : ops_) {
for (auto& ipt : op->inputs_) { for (auto& ipt : op->Inputs()) {
for (auto& var_name : ipt.second) { for (auto& var_name : ipt.second) {
if (!Contains(output_set, var_name)) { // Not other op's output if (!Contains(output_set, var_name)) { // Not other op's output
input_set.insert(var_name); input_set.insert(var_name);
...@@ -39,7 +39,7 @@ void NetOp::CompleteAddOp(bool calc) { ...@@ -39,7 +39,7 @@ void NetOp::CompleteAddOp(bool calc) {
} }
} }
for (auto& opt : op->outputs_) { for (auto& opt : op->Outputs()) {
for (auto& var_name : opt.second) { for (auto& var_name : opt.second) {
output_set.insert(var_name); output_set.insert(var_name);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册