提交 446df2b9 编写于 作者: E eclipsycn 提交者: GitHub

Merge pull request #287 from smilejames/develop

update op registry
......@@ -33,21 +33,15 @@ struct OpInfo {
}
};
template <typename Dtype>
class OpInfoMap;
template <typename Dtype>
static OpInfoMap<Dtype> *g_op_info_map = nullptr;
template <typename Dtype>
class OpInfoMap {
public:
static OpInfoMap &Instance() {
LOG(paddle_mobile::kLOG_DEBUG1) << " TODO: fix bug";
if (g_op_info_map<Dtype> == nullptr) {
g_op_info_map<Dtype> = new OpInfoMap();
static OpInfoMap<Dtype> *Instance() {
static OpInfoMap<Dtype> *s_instance = nullptr;
if (s_instance == nullptr) {
s_instance = new OpInfoMap();
}
return *g_op_info_map<Dtype>;
return s_instance;
}
bool Has(const std::string &op_type) const {
......
......@@ -35,7 +35,7 @@ class OperatorRegistrarRecursive;
template <typename Dtype, typename... ARGS>
struct OperatorRegistrar : public Registrar {
explicit OperatorRegistrar(const std::string& op_type) {
if (OpInfoMap<Dtype>::Instance().Has(op_type)) {
if (OpInfoMap<Dtype>::Instance()->Has(op_type)) {
LOG(paddle_mobile::kLOG_DEBUG1)
<< op_type << " is registered more than once.";
return;
......@@ -47,7 +47,7 @@ struct OperatorRegistrar : public Registrar {
}
OpInfo<Dtype> info;
OperatorRegistrarRecursive<Dtype, 0, false, ARGS...>(op_type, &info);
OpInfoMap<Dtype>::Instance().Insert(op_type, info);
OpInfoMap<Dtype>::Instance()->Insert(op_type, info);
}
};
......@@ -95,10 +95,10 @@ class OpRegistry {
LOG(paddle_mobile::kLOG_DEBUG1) << " output size: " << outputs.size();
LOG(paddle_mobile::kLOG_DEBUG1) << " attr size: " << attrs.size();
LOG(paddle_mobile::kLOG_DEBUG1)
<< " OpInfoMap size: " << OpInfoMap<Dtype>::Instance().map().size();
<< " OpInfoMap size: " << OpInfoMap<Dtype>::Instance()->map().size();
LOG(paddle_mobile::kLOG_DEBUG1) << " has type: " << type << " "
<< OpInfoMap<Dtype>::Instance().Has(type);
auto& info = OpInfoMap<Dtype>::Instance().Get(type);
<< OpInfoMap<Dtype>::Instance()->Has(type);
auto& info = OpInfoMap<Dtype>::Instance()->Get(type);
auto op = info.Creator()(type, inputs, outputs, attrs, scope);
return std::shared_ptr<OperatorBase<Dtype>>(op);
}
......
......@@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define __ARM_NEON true
#include "operators/math/softmax.h"
#include "common/types.h"
#if __ARM_NEON
......
......@@ -17,8 +17,9 @@ limitations under the License. */
#include <string>
#include <vector>
#include "./io.h"
#include "common/log.h"
#include "io.h"
#include "framework/op_registry.h"
#include "operators/conv_op.h"
#include "operators/pool_op.h"
#include "operators/reshape_op.h"
......@@ -57,9 +58,12 @@ class Executor4Test : public Executor<DeviceType> {
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
for (std::shared_ptr<OpDesc> op : ops) {
if (op->Type() == op_type) {
std::shared_ptr<OpType> op_ptr = std::make_shared<OpType>(
op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(),
this->program_.scope);
std::shared_ptr<paddle_mobile::framework::OperatorBase<DeviceType>>
op_ptr = paddle_mobile::framework::OpRegistry<
paddle_mobile::CPU>::CreateOp(op->Type(), op->GetInputs(),
op->GetOutputs(),
op->GetAttrMap(),
this->program_.scope);
this->ops_of_block_[*block_desc.get()].push_back(op_ptr);
break;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册