提交 b2e3824e 编写于 作者: Q qiaolongfei

change operator

上级 81f5f861
......@@ -120,8 +120,10 @@ class OpProtoAndCheckerMaker {
};
class OpRegistry {
using OpCreator = std::function<OperatorBase*()>;
using VarNameMap = OperatorBase::VarNameMap;
using OpCreator = std::function<OperatorBase*(
const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)>;
public:
template <typename OpType, typename ProtoMakerType>
......@@ -153,14 +155,9 @@ class OpRegistry {
PADDLE_ENFORCE(op_create_it != op_creators().end(),
"Operator %s cannot be found.", type);
auto op = op_create_it->second();
op->type_ = type;
op->inputs_ = inputs;
op->outputs_ = outputs;
op->attrs_ = attrs;
op_checkers().at(type).Check(op->attrs_);
auto attrMap = attrs;
op_checkers().at(type).Check(attrMap);
auto op = op_create_it->second(type, inputs, outputs, attrMap);
GenerateTempVariableName(op);
op->Init();
......@@ -217,12 +214,14 @@ class OpRegistry {
static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : op->outputs_) {
for (auto& output : op->Outputs()) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += op->type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
auto new_name = output_name;
new_name += op->Type();
new_name += "@";
new_name += std::to_string(gUniqId.fetch_add(1));
op->Rename(output_name, new_name);
}
}
}
......
......@@ -105,6 +105,8 @@ class OperatorBase {
/// rename inputs outputs name
void Rename(const std::string& old_name, const std::string& new_name);
const VarNameMap& Inputs() const { return inputs_; }
const VarNameMap& Outputs() const { return outputs_; }
//! Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const;
//! Get a input which has multiple variables.
......@@ -118,10 +120,10 @@ class OperatorBase {
virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
std::string Type() const { return type_; }
const std::string& Type() const { return type_; }
const AttributeMap& Attrs() const { return attrs_; }
public:
protected:
std::string type_;
// NOTE: in case of OpGrad, inputs_ contains:
// I (Inputs)
......
......@@ -29,7 +29,7 @@ void NetOp::CompleteAddOp(bool calc) {
std::set<std::string> input_set;
std::set<std::string> output_set;
for (auto& op : ops_) {
for (auto& ipt : op->inputs_) {
for (auto& ipt : op->Inputs()) {
for (auto& var_name : ipt.second) {
if (!Contains(output_set, var_name)) { // Not other op's output
input_set.insert(var_name);
......@@ -39,7 +39,7 @@ void NetOp::CompleteAddOp(bool calc) {
}
}
for (auto& opt : op->outputs_) {
for (auto& opt : op->Outputs()) {
for (auto& var_name : opt.second) {
output_set.insert(var_name);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册