提交 446df2b9 编写于 作者: E eclipsycn 提交者: GitHub

Merge pull request #287 from smilejames/develop

update op registry
...@@ -33,21 +33,15 @@ struct OpInfo { ...@@ -33,21 +33,15 @@ struct OpInfo {
} }
}; };
template <typename Dtype>
class OpInfoMap;
template <typename Dtype>
static OpInfoMap<Dtype> *g_op_info_map = nullptr;
template <typename Dtype> template <typename Dtype>
class OpInfoMap { class OpInfoMap {
public: public:
static OpInfoMap &Instance() { static OpInfoMap<Dtype> *Instance() {
LOG(paddle_mobile::kLOG_DEBUG1) << " TODO: fix bug"; static OpInfoMap<Dtype> *s_instance = nullptr;
if (g_op_info_map<Dtype> == nullptr) { if (s_instance == nullptr) {
g_op_info_map<Dtype> = new OpInfoMap(); s_instance = new OpInfoMap();
} }
return *g_op_info_map<Dtype>; return s_instance;
} }
bool Has(const std::string &op_type) const { bool Has(const std::string &op_type) const {
......
...@@ -35,7 +35,7 @@ class OperatorRegistrarRecursive; ...@@ -35,7 +35,7 @@ class OperatorRegistrarRecursive;
template <typename Dtype, typename... ARGS> template <typename Dtype, typename... ARGS>
struct OperatorRegistrar : public Registrar { struct OperatorRegistrar : public Registrar {
explicit OperatorRegistrar(const std::string& op_type) { explicit OperatorRegistrar(const std::string& op_type) {
if (OpInfoMap<Dtype>::Instance().Has(op_type)) { if (OpInfoMap<Dtype>::Instance()->Has(op_type)) {
LOG(paddle_mobile::kLOG_DEBUG1) LOG(paddle_mobile::kLOG_DEBUG1)
<< op_type << " is registered more than once."; << op_type << " is registered more than once.";
return; return;
...@@ -47,7 +47,7 @@ struct OperatorRegistrar : public Registrar { ...@@ -47,7 +47,7 @@ struct OperatorRegistrar : public Registrar {
} }
OpInfo<Dtype> info; OpInfo<Dtype> info;
OperatorRegistrarRecursive<Dtype, 0, false, ARGS...>(op_type, &info); OperatorRegistrarRecursive<Dtype, 0, false, ARGS...>(op_type, &info);
OpInfoMap<Dtype>::Instance().Insert(op_type, info); OpInfoMap<Dtype>::Instance()->Insert(op_type, info);
} }
}; };
...@@ -95,10 +95,10 @@ class OpRegistry { ...@@ -95,10 +95,10 @@ class OpRegistry {
LOG(paddle_mobile::kLOG_DEBUG1) << " output size: " << outputs.size(); LOG(paddle_mobile::kLOG_DEBUG1) << " output size: " << outputs.size();
LOG(paddle_mobile::kLOG_DEBUG1) << " attr size: " << attrs.size(); LOG(paddle_mobile::kLOG_DEBUG1) << " attr size: " << attrs.size();
LOG(paddle_mobile::kLOG_DEBUG1) LOG(paddle_mobile::kLOG_DEBUG1)
<< " OpInfoMap size: " << OpInfoMap<Dtype>::Instance().map().size(); << " OpInfoMap size: " << OpInfoMap<Dtype>::Instance()->map().size();
LOG(paddle_mobile::kLOG_DEBUG1) << " has type: " << type << " " LOG(paddle_mobile::kLOG_DEBUG1) << " has type: " << type << " "
<< OpInfoMap<Dtype>::Instance().Has(type); << OpInfoMap<Dtype>::Instance()->Has(type);
auto& info = OpInfoMap<Dtype>::Instance().Get(type); auto& info = OpInfoMap<Dtype>::Instance()->Get(type);
auto op = info.Creator()(type, inputs, outputs, attrs, scope); auto op = info.Creator()(type, inputs, outputs, attrs, scope);
return std::shared_ptr<OperatorBase<Dtype>>(op); return std::shared_ptr<OperatorBase<Dtype>>(op);
} }
......
...@@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define __ARM_NEON true
#include "operators/math/softmax.h" #include "operators/math/softmax.h"
#include "common/types.h" #include "common/types.h"
#if __ARM_NEON #if __ARM_NEON
......
...@@ -17,8 +17,9 @@ limitations under the License. */ ...@@ -17,8 +17,9 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "./io.h"
#include "common/log.h" #include "common/log.h"
#include "io.h" #include "framework/op_registry.h"
#include "operators/conv_op.h" #include "operators/conv_op.h"
#include "operators/pool_op.h" #include "operators/pool_op.h"
#include "operators/reshape_op.h" #include "operators/reshape_op.h"
...@@ -57,9 +58,12 @@ class Executor4Test : public Executor<DeviceType> { ...@@ -57,9 +58,12 @@ class Executor4Test : public Executor<DeviceType> {
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops(); std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
for (std::shared_ptr<OpDesc> op : ops) { for (std::shared_ptr<OpDesc> op : ops) {
if (op->Type() == op_type) { if (op->Type() == op_type) {
std::shared_ptr<OpType> op_ptr = std::make_shared<OpType>( std::shared_ptr<paddle_mobile::framework::OperatorBase<DeviceType>>
op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(), op_ptr = paddle_mobile::framework::OpRegistry<
this->program_.scope); paddle_mobile::CPU>::CreateOp(op->Type(), op->GetInputs(),
op->GetOutputs(),
op->GetAttrMap(),
this->program_.scope);
this->ops_of_block_[*block_desc.get()].push_back(op_ptr); this->ops_of_block_[*block_desc.get()].push_back(op_ptr);
break; break;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册