提交 380de6da 编写于 作者: S superjomn

make kernel works

上级 97149f31
cc_library(executor_lite SRCS executor.cc)
cc_library(op_lite SRCS op_lite.cc)
cc_library(memory_lite SRCS memory.cc)
cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite)
cc_library(variable_lite SRCS variable.cc)
cc_library(op_registry_lite SRCS op_registry.cc)
cc_library(scope_lite SRCS scope.cc)
cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite)
cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite)
cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86)
cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite)
cc_test(test_tensor_lite SRCS tensor_test.cc)
......@@ -15,7 +15,7 @@
#pragma once
#include <memory>
#include <vector>
#include "target_wrapper.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
namespace paddle {
namespace lite {
......
......@@ -65,6 +65,8 @@ class OpKernel : public KernelBase {
public:
virtual void Run() { CHECK(false) << "Not Implemented"; }
void Touch() {}
OpKernel() = default;
virtual ~OpKernel() = default;
......
......@@ -31,7 +31,6 @@ class SomeKernel : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
TargetType target() const override { return TARGET(kHost); }
PrecisionType precision() const override { return PRECISION(kFloat); }
};
TEST(Kernel, test) {
......
......@@ -60,11 +60,20 @@ class OpLite : public Registry {
struct Place {
TargetType target{TARGET(kHost)};
PrecisionType precision{PRECISION(kFloat)};
Place(TargetType target, PrecisionType precision)
: target(target), precision(precision) {}
};
OpLite() = default;
OpLite(std::unique_ptr<OpContext> &&x) : op_context_(std::move(x)) {}
OpLite(const std::string &type) : op_type_(type) {}
OpLite(std::unique_ptr<OpContext> &&x, const std::vector<Place> &valid_places)
: op_context_(std::move(x)), valid_places_(valid_places) {}
void SetValidPlaces(const std::vector<Place> &places) {
valid_places_ = places;
}
const std::vector<Place> &valid_places() const { return valid_places_; }
// Check the shape.
virtual bool CheckShape() const { return true; }
// Inference the outputs' shape.
......@@ -79,20 +88,27 @@ class OpLite : public Registry {
RecordOutputEvents();
return true;
}
// Build the operator, attach it with the runtime environment.
virtual bool Build(const framework::OpDesc &opdesc, lite::Scope *scope) = 0;
// Attach it with the runtime environment.
virtual bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope) = 0;
// Human-readable information.
virtual std::string DebugString() const = 0;
const Place &kernel_place() const { return kernel_place_; }
protected:
void PickKernel(const std::vector<Place> &valid_places,
KernelStrategy kernel_strategy = KernelStrategy::kStatic);
virtual ~OpLite() = default;
protected:
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
virtual void StaticPickKernel(const std::vector<Place> &valid_targets) = 0;
virtual void StaticPickKernel(const std::vector<Place> &valid_targets) {
auto kernels = CreateKernels(valid_targets);
kernel_ = std::move(kernels.front());
}
// Wait until all the inputs' events are ready.
void SyncInputEvents() {}
......@@ -105,13 +121,12 @@ class OpLite : public Registry {
std::vector<std::unique_ptr<KernelBase>> CreateKernels(
const std::vector<Place> &places);
virtual ~OpLite() = default;
protected:
std::unique_ptr<OpContext> op_context_;
Place kernel_place_;
std::unique_ptr<KernelBase> kernel_;
std::string op_type_;
std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
};
} // namespace lite
......
#include <gtest/gtest.h>
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/op_lite.h"
#include <gtest/gtest.h>
namespace paddle {
namespace lite {
TEST(OpLite, test) {
}
TEST(OpLite, test) {}
} // namespace lite
} // namespace paddle
......@@ -12,4 +12,55 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/op_registry.h"
\ No newline at end of file
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
std::unique_ptr<KernelBase> KernelRegistry::Create(const std::string &op_type,
TargetType target,
PrecisionType precision) {
#define CREATE_KERNEL(target__) \
switch (precision) { \
case PRECISION(kFloat): \
return Create<TARGET(target__), PRECISION(kFloat)>(op_type); \
default: \
CHECK(false) << "not supported kernel place yet"; \
}
switch (target) {
case TARGET(kHost): {
CREATE_KERNEL(kHost);
} break;
case TARGET(kX86): {
CREATE_KERNEL(kX86);
} break;
case TARGET(kCUDA): {
CREATE_KERNEL(kCUDA);
} break;
default:
CHECK(false) << "not supported kernel place";
}
#undef CREATE_KERNEL
}
KernelRegistry::KernelRegistry() {
#define INIT_FOR(target__, precision__) \
registries_[KernelRegistry::GetKernelOffset<TARGET(target__), \
PRECISION(precision__)>()] \
.set<KernelRegistryForTarget<TARGET(target__), PRECISION(precision__)> \
*>(&KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__)>::Global());
// Currently, just register 2 kernel targets.
INIT_FOR(kHost, kFloat);
#undef INIT_FOR
}
KernelRegistry &KernelRegistry::Global() {
static auto *x = new KernelRegistry;
return *x;
}
} // namespace lite
} // namespace paddle
\ No newline at end of file
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
......@@ -62,28 +63,9 @@ class KernelRegistry final {
KernelRegistryForTarget<TARGET(kHost), PRECISION(kFloat)> * //
>;
KernelRegistry() {
/*
using kernel_target_t =
KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kFloat)>;
registries_[0].set<kernel_target_t *>(
&KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kFloat)>::Global());
*/
#define INIT_FOR(target__, precision__) \
registries_[KernelRegistry::GetKernelOffset<TARGET(target__), \
PRECISION(precision__)>()] \
.set<KernelRegistryForTarget<TARGET(target__), PRECISION(precision__)> \
*>(&KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__)>::Global());
// Currently, just register 2 kernel targets.
INIT_FOR(kHost, kFloat);
#undef INIT_FOR
}
KernelRegistry();
static KernelRegistry &Global() {
static auto *x = new KernelRegistry;
return *x;
}
static KernelRegistry &Global();
template <TargetType Target, PrecisionType Precision>
void Register(const std::string &name,
......@@ -105,31 +87,7 @@ registries_[0].set<kernel_target_t *>(
std::unique_ptr<KernelBase> Create(const std::string &op_type,
TargetType target,
PrecisionType precision) {
#define CREATE_KERNEL(target__) \
switch (precision) { \
case PRECISION(kFloat): \
return Create<TARGET(target__), PRECISION(kFloat)>(op_type); \
default: \
CHECK(false) << "not supported kernel place yet"; \
}
switch (target) {
case TARGET(kHost): {
CREATE_KERNEL(kHost);
} break;
case TARGET(kX86): {
CREATE_KERNEL(kX86);
} break;
case TARGET(kCUDA): {
CREATE_KERNEL(kCUDA);
} break;
default:
CHECK(false) << "not supported kernel place";
}
#undef CREATE_KERNEL
}
PrecisionType precision);
// Get a kernel registry offset in all the registries.
template <TargetType Target, PrecisionType Precision>
......@@ -137,8 +95,21 @@ registries_[0].set<kernel_target_t *>(
return kNumTargets * static_cast<int>(Target) + static_cast<int>(Precision);
}
std::string DebugString() const {
std::stringstream ss;
ss << "KernelCreator<host, float>:" << std::endl;
ss << registries_[GetKernelOffset<TARGET(kHost), PRECISION(kFloat)>()]
.get<
KernelRegistryForTarget<TARGET(kHost), PRECISION(kFloat)> *>()
->DebugString();
ss << std::endl;
return ss.str();
}
private:
std::array<any_kernel_registor_t, kNumTargets * kNumPrecisions> registries_;
mutable std::array<any_kernel_registor_t, kNumTargets * kNumPrecisions>
registries_;
};
template <TargetType target, PrecisionType precision, typename KernelType>
......@@ -146,6 +117,8 @@ class KernelRegistor : public lite::Registor<KernelType> {
public:
KernelRegistor(const std::string op_type)
: Registor<KernelType>([&] {
LOG(INFO) << "Register kernel " << op_type << " for "
<< TargetToStr(target) << " " << PrecisionToStr(precision);
KernelRegistry::Global().Register<target, precision>(
op_type, [&]() -> std::unique_ptr<KernelType> {
return std::unique_ptr<KernelType>(new KernelType);
......@@ -169,18 +142,27 @@ class KernelRegistor : public lite::Registor<KernelType> {
// Kernel registry
#define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \
op_type__##target__##precision__##__registor__
op_type__##__##target__##__##precision__##__registor__
#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__) \
op_type__##target__##precision__##__registor__instance__
op_type__##__##target__##__##precision__##__registor__instance__
#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__) \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__)##__fake__
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__)
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass) \
static paddle::lite::KernelRegistor<TARGET(target__), \
PRECISION(precision__), KernelClass> \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, \
precision__)(#op_type__); \
static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__); \
int touch_##op_type__##target__##precision__() { \
LITE_KERNEL_INSTANCE(op_type__, target__, precision__).Touch(); \
return 0; \
}
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass) \
static paddle::lite::KernelRegistor<TARGET(target__), \
PRECISION(precision__), KernelClass> \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, \
precision__)(#op_type__);
#define USE_LITE_KERNEL(op_type__, target__, precision__) \
extern int touch_##op_type__##target__##precision__(); \
int LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__) \
__attribute__((unused)) = touch_##op_type__##target__##precision__();
#define USE_LITE_KERNEL(op_type__, target__, precision__) \
int LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__)((unused)) = \
LITE_KERNEL_REGISTER(op_type__, target__, precision__).Touch();
#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__) \
op_type__##target__##precision__
......@@ -43,6 +43,16 @@ enum class PrecisionType { kFloat = 0, kInt8, kLastAsPlaceHolder };
constexpr int kNumPrecisions =
PRECISION_VAL(kLastAsPlaceHolder) - PRECISION_VAL(kFloat);
static const std::string target2string[] = {"host", "x86", "cuda"};
static const std::string& TargetToStr(TargetType target) {
return target2string[static_cast<int>(target)];
}
static const std::string precision2string[] = {"float, int8"};
static const std::string& PrecisionToStr(PrecisionType precision) {
return precision2string[static_cast<int>(precision)];
}
// Event sync for multi-stream devices like CUDA and OpenCL.
// For the devices without support of stream, leave it empty.
template <TargetType Target>
......@@ -76,8 +86,8 @@ class TargetWrapper {
static void StreamSync(const stream_t& stream) {}
static void* Malloc(size_t size) { return nullptr; }
static void Free(void* ptr) {}
static void* Malloc(size_t size) { return new char[size]; }
static void Free(void* ptr) { delete[] static_cast<char*>(ptr); }
static void MemcpySync(void* dst, void* src, size_t size, IoDirection dir) {}
static void MemcpyAsync(void* dst, void* src, size_t size,
......
......@@ -47,6 +47,10 @@ static int product(const DDim& dims) {
[](int a, int b) { return a * b; });
}
static int product(DDim::const_iterator begin, DDim::const_iterator end) {
return std::accumulate(begin, end, 1, [](int a, int b) { return a * b; });
}
static DDim flatten_to_2d(const DDim& dims, int col) {
return DDim({product(SliceDims(dims, 0, col)),
product(SliceDims(dims, col, dims.size()))});
......@@ -73,7 +77,7 @@ class Tensor {
template <typename T>
T* mutable_data() {
buffer_.ResetLazy(target_, product(dims_));
buffer_.ResetLazy(target_, product(dims_) * sizeof(T));
return static_cast<T*>(buffer_.data());
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/tensor.h"
#include <gtest/gtest.h>
namespace paddle {
namespace lite {
TEST(tensor, test) {
Tensor tensor;
tensor.Resize({1, 8});
for (int i = 0; i < 8; i++) {
tensor.mutable_data<int>()[i] = i;
}
}
} // namespace lite
} // namespace paddle
......@@ -28,7 +28,7 @@ class Variable {
template <typename T>
T* GetMutable() {
blob_.set<T>();
if (!blob_.is<T>()) blob_.set<T>();
return &blob_.get<T>();
}
......
cc_library(fc_compute_host SRCS fc_compute.cc DEPS tensor_lite)
cc_library(relu_compute_host SRCS relu_compute.cc DEPS tensor_lite)
cc_test(test_fc_compute SRCS fc_compute_test.cc DEPS fc_compute_host fc_op_lite)
......@@ -26,12 +26,16 @@ void FcCompute::Run() {
using matrix_t = Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>;
using matrix_map_t = Eigen::Map<matrix_t>;
auto& param = this->param<param_t>();
auto& param = this->param<operators::FcParam>();
CHECK_EQ(param.in_mat_dims.size(), 2UL);
CHECK_GE(param.input->dims().size(), 2UL);
CHECK_EQ(param.output->dims().size(), 2UL);
Eigen::Map<const matrix_t> input(param.input->data<float>(),
param.in_mat_dims[0], param.in_mat_dims[1]);
Eigen::Map<const matrix_t> input(
param.input->data<float>(),
product(param.input->dims().begin(),
param.input->dims().begin() + param.in_num_col_dims),
product(param.input->dims().begin() + param.in_num_col_dims,
param.input->dims().end()));
Eigen::Map<const matrix_t> weight(param.w->data<float>(), param.w->dims()[0],
param.w->dims()[1]);
matrix_map_t output(param.output->mutable_data<float>(),
......@@ -47,6 +51,10 @@ void FcCompute::Run() {
}
}
TargetType FcCompute::target() const { return TARGET(kHost); }
PrecisionType FcCompute::precision() const { return PRECISION(kFloat); }
} // namespace host
} // namespace kernels
} // namespace lite
......
......@@ -14,6 +14,7 @@
#pragma once
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/kernels/fc_compute.h"
#include "paddle/fluid/lite/operators/fc_op.h"
namespace paddle {
......@@ -21,12 +22,15 @@ namespace lite {
namespace kernels {
namespace host {
class FcCompute final : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
class FcCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
public:
using param_t = operators::FcParam;
void Run() override;
TargetType target() const override;
PrecisionType precision() const override;
virtual ~FcCompute() = default;
};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/host/fc_compute.h"
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
TEST(fc_host, init) {
FcCompute fc;
ASSERT_EQ(fc.precision(), PRECISION(kFloat));
ASSERT_EQ(fc.target(), TARGET(kHost));
}
TEST(fc_host, algorithm) {
using matrix_t = Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>;
using matrix_map_t = Eigen::Map<matrix_t>;
// dim 10, 20
std::vector<float> input(10 * 20);
std::vector<float> w(20 * 20);
std::vector<float> output(10 * 20);
Eigen::Map<const matrix_t> input_mat(input.data(), 10, 20);
Eigen::Map<const matrix_t> weight_mat(w.data(), 20, 20);
matrix_map_t output_mat(output.data(), 10, 20);
output_mat = weight_mat.transpose() * input_mat;
}
TEST(fc_host, compute) {
FcCompute fc;
operators::FcParam param;
Tensor x;
Tensor w;
Tensor bias;
Tensor output;
x.Resize({1, 10, 20});
w.Resize({20, 20});
bias.Resize({1, 10});
output.Resize({10, 20});
auto* x_data = x.mutable_data<float>();
auto* w_data = w.mutable_data<float>();
auto* bias_data = bias.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
for (int i = 0; i < 10 * 20; i++) x_data[i] = i;
for (int i = 0; i < 20 * 20; i++) w_data[i] = i;
for (int i = 0; i < 10; i++) bias_data[i] = i;
for (int i = 0; i < 10 * 20; i++) output_data[i] = 0;
param.in_num_col_dims = 2;
param.input = &x;
param.w = &w;
param.bias = &bias;
param.output = &output;
param.in_mat_dims = x.dims();
fc.SetParam(param);
fc.Run();
LOG(INFO) << "x";
for (int i = 0; i < 10 * 20; i++) LOG(INFO) << x_data[i];
LOG(INFO) << "output:";
for (int i = 0; i < 10 * 20; i++) LOG(INFO) << output.data<float>()[i];
}
TEST(fc, retrive_op) {
auto fc =
KernelRegistry::Global().Create<TARGET(kHost), PRECISION(kFloat)>("fc");
ASSERT_TRUE(fc.get());
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(fc, kHost, kFloat);
......@@ -74,7 +74,7 @@ class LoDTensorArrayDesc {
class VarType {
public:
framework::proto::VarType::Type type;
any desc;
variant<LoDTensorDesc, TensorDesc> desc;
void Parse(const framework::proto::VarType& proto);
};
......@@ -95,7 +95,7 @@ class OpDesc {
std::string op_type;
std::map<std::string, std::vector<std::string>> inputs;
std::map<std::string, std::vector<std::string>> outputs;
std::map<std::string, any> attrs;
std::map<std::string, variant<int, std::string>> attrs;
};
class BlockDesc {
......@@ -109,5 +109,10 @@ class BlockDesc {
std::vector<OpDesc> ops;
};
class ProgramDesc {
public:
void Parse(const framework::proto::ProgramDesc& desc);
};
} // namespace lite
} // namespace paddle
cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite op_params_lite tensor_lite)
cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite op_params_lite tensor_lite proto_desc)
cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite)
cc_library(op_params_lite SRCS op_params.cc DEPS tensor_lite)
cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite)
cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite fc_compute_host)
......@@ -29,32 +29,43 @@ class FcOpLite : public OpLite {
public:
FcOpLite() {}
FcOpLite(const std::string &type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool Run() override { return false; }
bool Run() override {
CHECK(kernel_);
kernel_->Run();
return true;
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool Build(const framework::OpDesc& op_desc, lite::Scope* scope) override {
bool Attach(const framework::OpDesc &op_desc, lite::Scope *scope) override {
auto input = op_desc.Input("Input").front();
auto W = op_desc.Input("W").front();
auto bias = op_desc.Input("bias").front();
auto out = op_desc.Output("bias").front();
auto bias = op_desc.Input("Bias").front();
auto out = op_desc.Output("Out").front();
param_.input = scope->FindVar(input)->GetMutable<Tensor>();
param_.w = scope->FindVar(W)->GetMutable<Tensor>();
param_.bias = scope->FindVar(bias)->GetMutable<Tensor>();
param_.output = scope->FindVar(out)->GetMutable<Tensor>();
param_.in_num_col_dims =
boost::any_cast<int>(op_desc.GetAttr("in_num_col_dims"));
boost::get<int>(op_desc.GetAttr("in_num_col_dims"));
kernel_->SetParam(param_);
return true;
}
std::string DebugString() const override { return "fc"; }
void StaticPickKernel(const std::vector<Place>& valid_targets) override {}
void StaticPickKernel(const std::vector<Place> &valid_targets) override {
auto kernels = CreateKernels(valid_targets);
kernel_ = std::move(kernels.front());
}
private:
mutable FcParam param_;
......
......@@ -33,7 +33,7 @@ bool ReluOp::InferShape() const {
bool ReluOp::Run() { return false; }
bool ReluOp::Build(const framework::OpDesc &opdesc, framework::Scope *scope) {
bool ReluOp::Attach(const framework::OpDesc &opdesc, framework::Scope *scope) {
return false;
}
......
......@@ -37,7 +37,8 @@ class ReluOp : public OpLite {
bool Run() override;
bool Build(const framework::OpDesc& opdesc, framework::Scope* scope) override;
bool Attach(const framework::OpDesc& opdesc,
framework::Scope* scope) override;
std::string DebugString() const override { return "tanh"; }
......
......@@ -13,7 +13,9 @@
// limitations under the License.
#pragma once
#include <iostream>
#include <memory>
#include <sstream>
#include <unordered_map>
namespace paddle {
......@@ -52,10 +54,18 @@ class Factory {
item_ptr_t Create(const std::string& op_type) const {
auto it = creators_.find(op_type);
CHECK(it != creators_.end());
CHECK(it != creators_.end()) << "no item called " << op_type;
return it->second();
}
std::string DebugString() const {
std::stringstream ss;
for (const auto& item : creators_) {
ss << " - " << item.first << std::endl;
}
return ss.str();
}
protected:
std::unordered_map<std::string, creator_t> creators_;
};
......
......@@ -93,7 +93,7 @@ struct variant {
return *this;
}
template <typename T>
void is() {
bool is() {
return (type_id == typeid(T).hash_code());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册