diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt
index c96727cfb8852ba8f024195051ca0be3e74b681c..b656c494ddfb49f709f21cf12cfccf71a8a42246 100644
--- a/paddle/fluid/lite/core/CMakeLists.txt
+++ b/paddle/fluid/lite/core/CMakeLists.txt
@@ -8,3 +8,4 @@ cc_library(scope_lite SRCS scope.cc)
cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite)
cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86)
+cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite)
diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h
index e9eeb71a302aad90f29e5686fbb18acb64d8910d..af1adb9338cb7ede4391bd27172a82580418d91d 100644
--- a/paddle/fluid/lite/core/kernel.h
+++ b/paddle/fluid/lite/core/kernel.h
@@ -26,6 +26,8 @@
namespace paddle {
namespace lite {
+// An base with virtual functions to unify all the kernel implementation on
+// different targets.
class KernelBase {
public:
virtual void Run() = 0;
@@ -45,8 +47,12 @@ class KernelBase {
return param_.get();
}
- protected:
+ virtual TargetType target() const = 0;
+ virtual PrecisionType precision() const = 0;
+
virtual ~KernelBase() = default;
+
+ protected:
core::any_context_t context_;
mutable operators::param_t param_;
};
diff --git a/paddle/fluid/lite/core/kernel_test.cc b/paddle/fluid/lite/core/kernel_test.cc
index 42ebe2ed29d042ad3f7ba2d01fca501c6c1f9864..325744aeee25d635db6079fdcd0e557b2295ffe7 100644
--- a/paddle/fluid/lite/core/kernel_test.cc
+++ b/paddle/fluid/lite/core/kernel_test.cc
@@ -28,6 +28,10 @@ class SomeKernel : public OpKernel {
LOG(INFO) << param().in_num_col_dims;
test_code = param().in_num_col_dims;
}
+
+ TargetType target() const override { return TARGET(kHost); }
+ PrecisionType precision() const override { return PRECISION(kFloat); }
+
};
TEST(Kernel, test) {
diff --git a/paddle/fluid/lite/core/memory.h b/paddle/fluid/lite/core/memory.h
index 36f247c20937a85b57ce998b4d07cf6474418783..0267dd43ed1fb636148c018262edc2781baaaa83 100644
--- a/paddle/fluid/lite/core/memory.h
+++ b/paddle/fluid/lite/core/memory.h
@@ -28,9 +28,6 @@ static void* TargetMalloc(TargetType target, size_t size) {
case static_cast(TargetType::kCUDA):
data = TargetWrapper::Malloc(size);
break;
- case static_cast(TargetType::kARM):
- data = TargetWrapper::Malloc(size);
- break;
case static_cast(TargetType::kHost):
data = TargetWrapper::Malloc(size);
break;
@@ -48,9 +45,6 @@ static void TargetFree(TargetType target, void* data) {
case static_cast(TargetType::kCUDA):
TargetWrapper::Free(data);
break;
- case static_cast(TargetType::kARM):
- TargetWrapper::Free(data);
- break;
default:
LOG(FATAL) << "Unknown type";
}
diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc
index f277ecb580015448f60c939c400be4c10b83547b..53098a2a957dc7333395305101fb602447417f96 100644
--- a/paddle/fluid/lite/core/op_lite.cc
+++ b/paddle/fluid/lite/core/op_lite.cc
@@ -12,4 +12,36 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "paddle/fluid/lite/core/op_lite.h"
#include "op_lite.h"
+#include "paddle/fluid/lite/core/op_registry.h"
+
+namespace paddle {
+namespace lite {
+
+std::vector> OpLite::CreateKernels(
+ const std::vector &places) {
+ std::vector> kernels;
+ CHECK(!op_type_.empty()) << "op_type_ should be set first";
+
+ for (auto place : places) {
+ kernels.emplace_back(KernelRegistry::Global().Create(op_type_, place.target,
+ place.precision));
+ }
+
+ return kernels;
+}
+
+void OpLite::PickKernel(const std::vector &valid_places,
+ OpLite::KernelStrategy kernel_strategy) {
+ switch (kernel_strategy) {
+ case KernelStrategy::kStatic:
+ StaticPickKernel(valid_places);
+ break;
+ default:
+ LOG(FATAL) << "unsupported kernel strategy";
+ }
+}
+
+} // namespace lite
+} // namespace paddle
diff --git a/paddle/fluid/lite/core/op_lite.h b/paddle/fluid/lite/core/op_lite.h
index 76f5851e43f4d25192a46fdb61b432f0a9f2e94f..fa72d95051a64daa9c8a92f13b7291ac767620fe 100644
--- a/paddle/fluid/lite/core/op_lite.h
+++ b/paddle/fluid/lite/core/op_lite.h
@@ -70,7 +70,15 @@ class OpLite : public Registry {
// Inference the outputs' shape.
virtual bool InferShape() const { return true; }
// Run this operator.
- virtual bool Run() = 0;
+ virtual bool Run() {
+ CHECK(kernel_);
+ SyncInputEvents();
+
+ kernel_->Run();
+
+ RecordOutputEvents();
+ return true;
+ }
// Build the operator, attach it with the runtime environment.
virtual bool Build(const framework::OpDesc &opdesc, lite::Scope *scope) = 0;
// Human-readable information.
@@ -79,21 +87,31 @@ class OpLite : public Registry {
const Place &kernel_place() const { return kernel_place_; }
protected:
+ void PickKernel(const std::vector &valid_places,
+ KernelStrategy kernel_strategy = KernelStrategy::kStatic);
+
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
virtual void StaticPickKernel(const std::vector &valid_targets) = 0;
- void PickKernel(const std::vector &valid_places,
- KernelStrategy kernel_strategy = KernelStrategy::kStatic);
+ // Wait until all the inputs' events are ready.
+ void SyncInputEvents() {}
+
+ // Record the output events, and that will tell all the dependent operators
+ // some inputs are ready.
+ void RecordOutputEvents() {}
// Create all the kernels for the valid targets.
- void CreateKernels();
+ std::vector> CreateKernels(
+ const std::vector &places);
virtual ~OpLite() = default;
protected:
std::unique_ptr op_context_;
Place kernel_place_;
+ std::unique_ptr kernel_;
+ std::string op_type_;
};
} // namespace lite
diff --git a/paddle/fluid/lite/core/op_lite_test.cc b/paddle/fluid/lite/core/op_lite_test.cc
new file mode 100644
index 0000000000000000000000000000000000000000..0892da86606d5f9fe1be636f8287fa8b867a4116
--- /dev/null
+++ b/paddle/fluid/lite/core/op_lite_test.cc
@@ -0,0 +1,13 @@
+#include
+#include "paddle/fluid/lite/core/op_lite.h"
+
+namespace paddle {
+namespace lite {
+
+TEST(OpLite, test) {
+
+
+}
+
+} // namespace lite
+} // namespace paddle
diff --git a/paddle/fluid/lite/core/op_registry.h b/paddle/fluid/lite/core/op_registry.h
index deb1654bc35a409020d92ab69df52b2d29927868..bf3bacb346c5301b4c5ae7cd900b40034cf0f914 100644
--- a/paddle/fluid/lite/core/op_registry.h
+++ b/paddle/fluid/lite/core/op_registry.h
@@ -59,7 +59,6 @@ class KernelRegistry final {
KernelRegistryForTarget *, //
KernelRegistryForTarget *, //
KernelRegistryForTarget *, //
- KernelRegistryForTarget *, //
KernelRegistryForTarget * //
>;
@@ -77,7 +76,6 @@ registries_[0].set(
*>(&KernelRegistryForTarget::Global());
// Currently, just register 2 kernel targets.
- INIT_FOR(kARM, kFloat);
INIT_FOR(kHost, kFloat);
#undef INIT_FOR
}
@@ -97,6 +95,42 @@ registries_[0].set(
->Register(name, std::move(creator));
}
+ template
+ std::unique_ptr Create(const std::string &op_type) {
+ using kernel_registor_t = KernelRegistryForTarget;
+ return registries_[GetKernelOffset()]
+ .template get()
+ ->Create(op_type);
+ }
+
+ std::unique_ptr Create(const std::string &op_type,
+ TargetType target,
+ PrecisionType precision) {
+#define CREATE_KERNEL(target__) \
+ switch (precision) { \
+ case PRECISION(kFloat): \
+ return Create(op_type); \
+ default: \
+ CHECK(false) << "not supported kernel place yet"; \
+ }
+
+ switch (target) {
+ case TARGET(kHost): {
+ CREATE_KERNEL(kHost);
+ } break;
+ case TARGET(kX86): {
+ CREATE_KERNEL(kX86);
+ } break;
+ case TARGET(kCUDA): {
+ CREATE_KERNEL(kCUDA);
+ } break;
+ default:
+ CHECK(false) << "not supported kernel place";
+ }
+
+#undef CREATE_KERNEL
+ }
+
// Get a kernel registry offset in all the registries.
template
static constexpr int GetKernelOffset() {
diff --git a/paddle/fluid/lite/core/target_wrapper.h b/paddle/fluid/lite/core/target_wrapper.h
index 60de8c224db223727f62fff8f8220493af301ec7..c7bd2ce2172eb0e1ce5063128b6e933094687a1a 100644
--- a/paddle/fluid/lite/core/target_wrapper.h
+++ b/paddle/fluid/lite/core/target_wrapper.h
@@ -18,7 +18,7 @@
namespace paddle {
namespace lite {
-enum class TargetType { kHost = 0, kX86, kCUDA, kARM, kLastAsPlaceHolder };
+enum class TargetType { kHost = 0, kX86, kCUDA, kLastAsPlaceHolder };
// Some helper macro to get a specific TargetType.
#define TARGET(item__) paddle::lite::TargetType::item__
#define TARGET_VAL(item__) static_cast(TARGET(item__))
diff --git a/paddle/fluid/lite/core/types.h b/paddle/fluid/lite/core/types.h
index 566d407f8508f59103dfc05850cd72bd6d73517f..0e95882c56056e93d02e65fb34152e2818da0972 100644
--- a/paddle/fluid/lite/core/types.h
+++ b/paddle/fluid/lite/core/types.h
@@ -21,9 +21,8 @@ namespace paddle {
namespace lite {
namespace core {
-using any_context_t = variant, //
- Context, //
- Context //
+using any_context_t = variant, //
+ Context //
>;
} // namespace core