提交 fb48cb12 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #2936 from reyoung/feature/create_op_in_cpp_params

Make CreateOp in Plain C++ params
...@@ -199,6 +199,7 @@ Add a mark to which output is temporary is helpful for future optimization. ...@@ -199,6 +199,7 @@ Add a mark to which output is temporary is helpful for future optimization.
class OpRegistry { class OpRegistry {
using OpCreator = std::function<OperatorBase*()>; using OpCreator = std::function<OperatorBase*()>;
using VarIndexMap = std::unordered_map<std::string, int>; using VarIndexMap = std::unordered_map<std::string, int>;
using VarNameList = std::vector<std::string>;
public: public:
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
...@@ -226,42 +227,51 @@ class OpRegistry { ...@@ -226,42 +227,51 @@ class OpRegistry {
} }
} }
static OperatorPtr CreateOp(const OpDesc& op_desc) { static OperatorPtr CreateOp(const std::string& type,
//! Create a OpPtr by type. const VarNameList& inputs,
std::string op_type = op_desc.type(); const VarNameList& outputs,
OperatorPtr op(creators().at(op_type)()); const AttributeMap& attrs) {
//! Fill op's data member. Not use constructor because it will be noising auto op_create_it = creators().find(type);
//! for Op developer. PADDLE_ENFORCE(op_create_it != creators().end(),
op->type_ = op_desc.type(); "Operator %s cannot be found", type);
// set op's inputs_ from desc.
op->inputs_.reserve((size_t)op_desc.inputs_size());
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
std::back_inserter(op->inputs_));
// set op's outputs_ from desc.
op->outputs_.reserve((size_t)op_desc.outputs_size());
std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
std::back_inserter(op->outputs_));
//! Fill attrs, and validate attrs. auto op = op_create_it->second();
for (auto& attr : op_desc.attrs()) { op->type_ = type;
op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr); op->inputs_ = inputs;
} op->outputs_ = outputs;
op_checkers().at(op_type).Check(op->attrs_); op->attrs_ = attrs;
op_checkers().at(type).Check(op->attrs_);
//! Convert Temporary variable name to an unique variable name. GenerateTempVariableName(op);
GenerateTempVariableName(op.get());
//! set argument offsets stored in op.
{ {
auto var_index_it = VarIndexMaps().find(op_type); auto var_index_it = VarIndexMaps().find(type);
if (var_index_it != VarIndexMaps().end()) { if (var_index_it != VarIndexMaps().end()) {
op->in_out_idxs_ = var_index_it->second; op->in_out_idxs_ = var_index_it->second;
} }
} }
//! Other op's custom Init for a complex Op. For simple Op, the Init
//! method do nothing.
op->Init(); op->Init();
return op; return OperatorPtr(op);
}
static OperatorPtr CreateOp(const OpDesc& op_desc) {
std::vector<std::string> inputs;
inputs.reserve((size_t)op_desc.inputs_size());
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
std::back_inserter(inputs));
std::vector<std::string> outputs;
outputs.reserve((size_t)op_desc.outputs_size());
std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
std::back_inserter(outputs));
AttributeMap attrs;
for (auto& attr : op_desc.attrs()) {
attrs[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
}
return CreateOp(op_desc.type(), inputs, outputs, attrs);
} }
static std::unordered_map<std::string, OpProto>& protos() { static std::unordered_map<std::string, OpProto>& protos() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册