提交 97149f31 编写于 作者: S superjomn

update

上级 4eedd20f
......@@ -8,3 +8,4 @@ cc_library(scope_lite SRCS scope.cc)
cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite)
cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86)
cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite)
......@@ -26,6 +26,8 @@
namespace paddle {
namespace lite {
// An base with virtual functions to unify all the kernel implementation on
// different targets.
class KernelBase {
public:
virtual void Run() = 0;
......@@ -45,8 +47,12 @@ class KernelBase {
return param_.get<Param>();
}
protected:
virtual TargetType target() const = 0;
virtual PrecisionType precision() const = 0;
virtual ~KernelBase() = default;
protected:
core::any_context_t context_;
mutable operators::param_t param_;
};
......
......@@ -28,6 +28,10 @@ class SomeKernel : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
LOG(INFO) << param<operators::FcParam>().in_num_col_dims;
test_code = param<operators::FcParam>().in_num_col_dims;
}
TargetType target() const override { return TARGET(kHost); }
PrecisionType precision() const override { return PRECISION(kFloat); }
};
TEST(Kernel, test) {
......
......@@ -28,9 +28,6 @@ static void* TargetMalloc(TargetType target, size_t size) {
case static_cast<int>(TargetType::kCUDA):
data = TargetWrapper<TARGET(kCUDA)>::Malloc(size);
break;
case static_cast<int>(TargetType::kARM):
data = TargetWrapper<TARGET(kARM)>::Malloc(size);
break;
case static_cast<int>(TargetType::kHost):
data = TargetWrapper<TARGET(kHost)>::Malloc(size);
break;
......@@ -48,9 +45,6 @@ static void TargetFree(TargetType target, void* data) {
case static_cast<int>(TargetType::kCUDA):
TargetWrapper<TARGET(kX86)>::Free(data);
break;
case static_cast<int>(TargetType::kARM):
TargetWrapper<TARGET(kX86)>::Free(data);
break;
default:
LOG(FATAL) << "Unknown type";
}
......
......@@ -12,4 +12,36 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/op_lite.h"
#include "op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
const std::vector<OpLite::Place> &places) {
std::vector<std::unique_ptr<KernelBase>> kernels;
CHECK(!op_type_.empty()) << "op_type_ should be set first";
for (auto place : places) {
kernels.emplace_back(KernelRegistry::Global().Create(op_type_, place.target,
place.precision));
}
return kernels;
}
void OpLite::PickKernel(const std::vector<OpLite::Place> &valid_places,
OpLite::KernelStrategy kernel_strategy) {
switch (kernel_strategy) {
case KernelStrategy::kStatic:
StaticPickKernel(valid_places);
break;
default:
LOG(FATAL) << "unsupported kernel strategy";
}
}
} // namespace lite
} // namespace paddle
......@@ -70,7 +70,15 @@ class OpLite : public Registry {
// Inference the outputs' shape.
virtual bool InferShape() const { return true; }
// Run this operator.
virtual bool Run() = 0;
virtual bool Run() {
CHECK(kernel_);
SyncInputEvents();
kernel_->Run();
RecordOutputEvents();
return true;
}
// Build the operator, attach it with the runtime environment.
virtual bool Build(const framework::OpDesc &opdesc, lite::Scope *scope) = 0;
// Human-readable information.
......@@ -79,21 +87,31 @@ class OpLite : public Registry {
const Place &kernel_place() const { return kernel_place_; }
protected:
void PickKernel(const std::vector<Place> &valid_places,
KernelStrategy kernel_strategy = KernelStrategy::kStatic);
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
virtual void StaticPickKernel(const std::vector<Place> &valid_targets) = 0;
void PickKernel(const std::vector<Place> &valid_places,
KernelStrategy kernel_strategy = KernelStrategy::kStatic);
// Wait until all the inputs' events are ready.
void SyncInputEvents() {}
// Record the output events, and that will tell all the dependent operators
// some inputs are ready.
void RecordOutputEvents() {}
// Create all the kernels for the valid targets.
void CreateKernels();
std::vector<std::unique_ptr<KernelBase>> CreateKernels(
const std::vector<Place> &places);
virtual ~OpLite() = default;
protected:
std::unique_ptr<OpContext> op_context_;
Place kernel_place_;
std::unique_ptr<KernelBase> kernel_;
std::string op_type_;
};
} // namespace lite
......
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_lite.h"
namespace paddle {
namespace lite {
TEST(OpLite, test) {
}
} // namespace lite
} // namespace paddle
......@@ -59,7 +59,6 @@ class KernelRegistry final {
KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kInt8)> *, //
KernelRegistryForTarget<TARGET(kX86), PRECISION(kFloat)> *, //
KernelRegistryForTarget<TARGET(kX86), PRECISION(kInt8)> *, //
KernelRegistryForTarget<TARGET(kARM), PRECISION(kFloat)> *, //
KernelRegistryForTarget<TARGET(kHost), PRECISION(kFloat)> * //
>;
......@@ -77,7 +76,6 @@ registries_[0].set<kernel_target_t *>(
*>(&KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__)>::Global());
// Currently, just register 2 kernel targets.
INIT_FOR(kARM, kFloat);
INIT_FOR(kHost, kFloat);
#undef INIT_FOR
}
......@@ -97,6 +95,42 @@ registries_[0].set<kernel_target_t *>(
->Register(name, std::move(creator));
}
template <TargetType Target, PrecisionType Precision>
std::unique_ptr<KernelBase> Create(const std::string &op_type) {
using kernel_registor_t = KernelRegistryForTarget<Target, Precision>;
return registries_[GetKernelOffset<Target, Precision>()]
.template get<kernel_registor_t *>()
->Create(op_type);
}
std::unique_ptr<KernelBase> Create(const std::string &op_type,
TargetType target,
PrecisionType precision) {
#define CREATE_KERNEL(target__) \
switch (precision) { \
case PRECISION(kFloat): \
return Create<TARGET(target__), PRECISION(kFloat)>(op_type); \
default: \
CHECK(false) << "not supported kernel place yet"; \
}
switch (target) {
case TARGET(kHost): {
CREATE_KERNEL(kHost);
} break;
case TARGET(kX86): {
CREATE_KERNEL(kX86);
} break;
case TARGET(kCUDA): {
CREATE_KERNEL(kCUDA);
} break;
default:
CHECK(false) << "not supported kernel place";
}
#undef CREATE_KERNEL
}
// Get a kernel registry offset in all the registries.
template <TargetType Target, PrecisionType Precision>
static constexpr int GetKernelOffset() {
......
......@@ -18,7 +18,7 @@
namespace paddle {
namespace lite {
enum class TargetType { kHost = 0, kX86, kCUDA, kARM, kLastAsPlaceHolder };
enum class TargetType { kHost = 0, kX86, kCUDA, kLastAsPlaceHolder };
// Some helper macro to get a specific TargetType.
#define TARGET(item__) paddle::lite::TargetType::item__
#define TARGET_VAL(item__) static_cast<int>(TARGET(item__))
......
......@@ -21,9 +21,8 @@ namespace paddle {
namespace lite {
namespace core {
using any_context_t = variant<Context<TARGET(kX86)>, //
Context<TARGET(kCUDA)>, //
Context<TARGET(kARM)> //
using any_context_t = variant<Context<TARGET(kX86)>, //
Context<TARGET(kCUDA)> //
>;
} // namespace core
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册