提交 a256d413 编写于 作者: Z zhaojiaying01

update op registry

上级 fbd23d7f
...@@ -33,21 +33,15 @@ struct OpInfo { ...@@ -33,21 +33,15 @@ struct OpInfo {
} }
}; };
template <typename Dtype>
class OpInfoMap;
template <typename Dtype>
static OpInfoMap<Dtype> *g_op_info_map = nullptr;
template <typename Dtype> template <typename Dtype>
class OpInfoMap { class OpInfoMap {
public: public:
static OpInfoMap &Instance() { static OpInfoMap<Dtype> *Instance() {
LOG(paddle_mobile::kLOG_DEBUG1) << " TODO: fix bug"; static OpInfoMap<Dtype> *s_instance = nullptr;
if (g_op_info_map<Dtype> == nullptr) { if (s_instance == nullptr) {
g_op_info_map<Dtype> = new OpInfoMap(); s_instance = new OpInfoMap();
} }
return *g_op_info_map<Dtype>; return s_instance;
} }
bool Has(const std::string &op_type) const { bool Has(const std::string &op_type) const {
......
...@@ -35,7 +35,7 @@ class OperatorRegistrarRecursive; ...@@ -35,7 +35,7 @@ class OperatorRegistrarRecursive;
template <typename Dtype, typename... ARGS> template <typename Dtype, typename... ARGS>
struct OperatorRegistrar : public Registrar { struct OperatorRegistrar : public Registrar {
explicit OperatorRegistrar(const std::string& op_type) { explicit OperatorRegistrar(const std::string& op_type) {
if (OpInfoMap<Dtype>::Instance().Has(op_type)) { if (OpInfoMap<Dtype>::Instance()->Has(op_type)) {
LOG(paddle_mobile::kLOG_DEBUG1) LOG(paddle_mobile::kLOG_DEBUG1)
<< op_type << " is registered more than once."; << op_type << " is registered more than once.";
return; return;
...@@ -47,7 +47,7 @@ struct OperatorRegistrar : public Registrar { ...@@ -47,7 +47,7 @@ struct OperatorRegistrar : public Registrar {
} }
OpInfo<Dtype> info; OpInfo<Dtype> info;
OperatorRegistrarRecursive<Dtype, 0, false, ARGS...>(op_type, &info); OperatorRegistrarRecursive<Dtype, 0, false, ARGS...>(op_type, &info);
OpInfoMap<Dtype>::Instance().Insert(op_type, info); OpInfoMap<Dtype>::Instance()->Insert(op_type, info);
} }
}; };
...@@ -95,10 +95,10 @@ class OpRegistry { ...@@ -95,10 +95,10 @@ class OpRegistry {
LOG(paddle_mobile::kLOG_DEBUG1) << " output size: " << outputs.size(); LOG(paddle_mobile::kLOG_DEBUG1) << " output size: " << outputs.size();
LOG(paddle_mobile::kLOG_DEBUG1) << " attr size: " << attrs.size(); LOG(paddle_mobile::kLOG_DEBUG1) << " attr size: " << attrs.size();
LOG(paddle_mobile::kLOG_DEBUG1) LOG(paddle_mobile::kLOG_DEBUG1)
<< " OpInfoMap size: " << OpInfoMap<Dtype>::Instance().map().size(); << " OpInfoMap size: " << OpInfoMap<Dtype>::Instance()->map().size();
LOG(paddle_mobile::kLOG_DEBUG1) << " has type: " << type << " " LOG(paddle_mobile::kLOG_DEBUG1) << " has type: " << type << " "
<< OpInfoMap<Dtype>::Instance().Has(type); << OpInfoMap<Dtype>::Instance()->Has(type);
auto& info = OpInfoMap<Dtype>::Instance().Get(type); auto& info = OpInfoMap<Dtype>::Instance()->Get(type);
auto op = info.Creator()(type, inputs, outputs, attrs, scope); auto op = info.Creator()(type, inputs, outputs, attrs, scope);
return std::shared_ptr<OperatorBase<Dtype>>(op); return std::shared_ptr<OperatorBase<Dtype>>(op);
} }
......
...@@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define __ARM_NEON true
#include "operators/math/softmax.h" #include "operators/math/softmax.h"
#include "common/types.h" #include "common/types.h"
#if __ARM_NEON #if __ARM_NEON
......
...@@ -17,8 +17,9 @@ limitations under the License. */ ...@@ -17,8 +17,9 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "./io.h"
#include "common/log.h" #include "common/log.h"
#include "io.h" #include "framework/op_registry.h"
#include "operators/conv_op.h" #include "operators/conv_op.h"
#include "operators/pool_op.h" #include "operators/pool_op.h"
#include "operators/reshape_op.h" #include "operators/reshape_op.h"
...@@ -57,9 +58,12 @@ class Executor4Test : public Executor<DeviceType> { ...@@ -57,9 +58,12 @@ class Executor4Test : public Executor<DeviceType> {
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops(); std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
for (std::shared_ptr<OpDesc> op : ops) { for (std::shared_ptr<OpDesc> op : ops) {
if (op->Type() == op_type) { if (op->Type() == op_type) {
std::shared_ptr<OpType> op_ptr = std::make_shared<OpType>( std::shared_ptr<paddle_mobile::framework::OperatorBase<DeviceType>>
op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(), op_ptr = paddle_mobile::framework::OpRegistry<
this->program_.scope); paddle_mobile::CPU>::CreateOp(op->Type(), op->GetInputs(),
op->GetOutputs(),
op->GetAttrMap(),
this->program_.scope);
this->ops_of_block_[*block_desc.get()].push_back(op_ptr); this->ops_of_block_[*block_desc.get()].push_back(op_ptr);
break; break;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册