提交 a256d413 编写于 作者: Z zhaojiaying01

update op registry

上级 fbd23d7f
......@@ -33,21 +33,15 @@ struct OpInfo {
}
};
template <typename Dtype>
class OpInfoMap;
template <typename Dtype>
static OpInfoMap<Dtype> *g_op_info_map = nullptr;
template <typename Dtype>
class OpInfoMap {
public:
static OpInfoMap &Instance() {
LOG(paddle_mobile::kLOG_DEBUG1) << " TODO: fix bug";
if (g_op_info_map<Dtype> == nullptr) {
g_op_info_map<Dtype> = new OpInfoMap();
static OpInfoMap<Dtype> *Instance() {
static OpInfoMap<Dtype> *s_instance = nullptr;
if (s_instance == nullptr) {
s_instance = new OpInfoMap();
}
return *g_op_info_map<Dtype>;
return s_instance;
}
bool Has(const std::string &op_type) const {
......
......@@ -35,7 +35,7 @@ class OperatorRegistrarRecursive;
template <typename Dtype, typename... ARGS>
struct OperatorRegistrar : public Registrar {
explicit OperatorRegistrar(const std::string& op_type) {
if (OpInfoMap<Dtype>::Instance().Has(op_type)) {
if (OpInfoMap<Dtype>::Instance()->Has(op_type)) {
LOG(paddle_mobile::kLOG_DEBUG1)
<< op_type << " is registered more than once.";
return;
......@@ -47,7 +47,7 @@ struct OperatorRegistrar : public Registrar {
}
OpInfo<Dtype> info;
OperatorRegistrarRecursive<Dtype, 0, false, ARGS...>(op_type, &info);
OpInfoMap<Dtype>::Instance().Insert(op_type, info);
OpInfoMap<Dtype>::Instance()->Insert(op_type, info);
}
};
......@@ -95,10 +95,10 @@ class OpRegistry {
LOG(paddle_mobile::kLOG_DEBUG1) << " output size: " << outputs.size();
LOG(paddle_mobile::kLOG_DEBUG1) << " attr size: " << attrs.size();
LOG(paddle_mobile::kLOG_DEBUG1)
<< " OpInfoMap size: " << OpInfoMap<Dtype>::Instance().map().size();
<< " OpInfoMap size: " << OpInfoMap<Dtype>::Instance()->map().size();
LOG(paddle_mobile::kLOG_DEBUG1) << " has type: " << type << " "
<< OpInfoMap<Dtype>::Instance().Has(type);
auto& info = OpInfoMap<Dtype>::Instance().Get(type);
<< OpInfoMap<Dtype>::Instance()->Has(type);
auto& info = OpInfoMap<Dtype>::Instance()->Get(type);
auto op = info.Creator()(type, inputs, outputs, attrs, scope);
return std::shared_ptr<OperatorBase<Dtype>>(op);
}
......
......@@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define __ARM_NEON true
#include "operators/math/softmax.h"
#include "common/types.h"
#if __ARM_NEON
......
......@@ -17,8 +17,9 @@ limitations under the License. */
#include <string>
#include <vector>
#include "./io.h"
#include "common/log.h"
#include "io.h"
#include "framework/op_registry.h"
#include "operators/conv_op.h"
#include "operators/pool_op.h"
#include "operators/reshape_op.h"
......@@ -57,9 +58,12 @@ class Executor4Test : public Executor<DeviceType> {
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
for (std::shared_ptr<OpDesc> op : ops) {
if (op->Type() == op_type) {
std::shared_ptr<OpType> op_ptr = std::make_shared<OpType>(
op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(),
this->program_.scope);
std::shared_ptr<paddle_mobile::framework::OperatorBase<DeviceType>>
op_ptr = paddle_mobile::framework::OpRegistry<
paddle_mobile::CPU>::CreateOp(op->Type(), op->GetInputs(),
op->GetOutputs(),
op->GetAttrMap(),
this->program_.scope);
this->ops_of_block_[*block_desc.get()].push_back(op_ptr);
break;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册