提交 4eedd20f 编写于 作者: S superjomn

make kernel implementation works

上级 f3d1fac2
......@@ -7,3 +7,4 @@ cc_library(op_registry_lite SRCS op_registry.cc)
cc_library(scope_lite SRCS scope.cc)
cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite)
cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86)
......@@ -14,46 +14,54 @@
#pragma once
#include <glog/logging.h>
#include <boost/variant.hpp>
#include <map>
#include <string>
#include "context.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
#include "target_wrapper.h"
namespace paddle {
namespace lite {
// Light-weight kernel implementation.
// The OpKernel is designed to implement the specific algorithm on a target
// device.
template <TargetType Target, PrecisionType Precision>
class OpKernel {
class KernelBase {
public:
using context_t = Context<Target>;
using context_ptr_t = std::unique_ptr<context_t>;
OpKernel() = default;
virtual void Run() = 0;
void SetContext(context_ptr_t&& ctx) { context_ = std::move(ctx); }
template <TargetType Target>
void SetContext(std::unique_ptr<Context<Target>>&& ctx) {
context_.set<std::unique_ptr<Context<Target>>>(std::move(ctx));
}
void SetParam(operators::param_t param) { param_ = param; }
template <typename T>
void SetParam(T param) {
param_.set<T>(param);
}
template <typename Param>
Param& param() const {
return param_.get<Param>();
}
protected:
virtual ~KernelBase() = default;
core::any_context_t context_;
mutable operators::param_t param_;
};
// Light-weight kernel implementation.
// The OpKernel is designed to implement the specific algorithm on a target
// device.
template <TargetType Target, PrecisionType Precision>
class OpKernel : public KernelBase {
public:
virtual void Run() { CHECK(false) << "Not Implemented"; }
virtual ~OpKernel() = default;
OpKernel() = default;
protected:
context_ptr_t context_;
mutable operators::param_t param_;
virtual ~OpKernel() = default;
};
} // namespace lite
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/kernel.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace core {
int test_code{-1};
class SomeKernel : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
public:
void Run() override {
LOG(INFO) << "SomeKernel executed";
LOG(INFO) << param<operators::FcParam>().in_num_col_dims;
test_code = param<operators::FcParam>().in_num_col_dims;
}
};
TEST(Kernel, test) {
SomeKernel kernel;
operators::FcParam param;
param.in_num_col_dims = 100;
kernel.SetParam<operators::FcParam>(param);
kernel.Run();
ASSERT_EQ(test_code, 100);
}
} // namespace core
} // namespace lite
} // namespace paddle
......@@ -57,7 +57,12 @@ class OpLite : public Registry {
kRuntime,
};
OpLite() {}
struct Place {
TargetType target{TARGET(kHost)};
PrecisionType precision{PRECISION(kFloat)};
};
OpLite() = default;
OpLite(std::unique_ptr<OpContext> &&x) : op_context_(std::move(x)) {}
// Check the shape.
......@@ -71,12 +76,14 @@ class OpLite : public Registry {
// Human-readable information.
virtual std::string DebugString() const = 0;
const Place &kernel_place() const { return kernel_place_; }
protected:
// Specify the kernel to run by default.
virtual void StaticPickKernel(
const std::vector<TargetType> &valid_targets) = 0;
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
virtual void StaticPickKernel(const std::vector<Place> &valid_targets) = 0;
void PickKernel(const std::vector<TargetType> &valid_places,
void PickKernel(const std::vector<Place> &valid_places,
KernelStrategy kernel_strategy = KernelStrategy::kStatic);
// Create all the kernels for the valid targets.
......@@ -86,6 +93,7 @@ class OpLite : public Registry {
protected:
std::unique_ptr<OpContext> op_context_;
Place kernel_place_;
};
} // namespace lite
......
......@@ -54,14 +54,14 @@ using KernelRegistryForTarget = Factory<OpKernel<Target, Precision>>;
class KernelRegistry final {
public:
using any_kernel_registor_t = variant<
KernelRegistryForTarget<TargetType::kCUDA, PrecisionType::kFloat> *, //
KernelRegistryForTarget<TargetType::kCUDA, PrecisionType::kInt8> *, //
KernelRegistryForTarget<TargetType::kX86, PrecisionType::kFloat> *, //
KernelRegistryForTarget<TargetType::kX86, PrecisionType::kInt8> *, //
KernelRegistryForTarget<TargetType::kARM, PrecisionType::kFloat> *, //
KernelRegistryForTarget<TargetType::kHost, PrecisionType::kFloat> * //
>;
using any_kernel_registor_t =
variant<KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kFloat)> *, //
KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kInt8)> *, //
KernelRegistryForTarget<TARGET(kX86), PRECISION(kFloat)> *, //
KernelRegistryForTarget<TARGET(kX86), PRECISION(kInt8)> *, //
KernelRegistryForTarget<TARGET(kARM), PRECISION(kFloat)> *, //
KernelRegistryForTarget<TARGET(kHost), PRECISION(kFloat)> * //
>;
KernelRegistry() {
/*
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace core {
using any_context_t = variant<Context<TARGET(kX86)>, //
Context<TARGET(kCUDA)>, //
Context<TARGET(kARM)> //
>;
} // namespace core
} // namespace lite
} // namespace paddle
......@@ -14,6 +14,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/core/tensor.h"
......@@ -53,8 +54,7 @@ class FcOpLite : public OpLite {
std::string DebugString() const override { return "fc"; }
void StaticPickKernel(const std::vector<TargetType>& valid_targets) override {
}
void StaticPickKernel(const std::vector<Place>& valid_targets) override {}
private:
mutable FcParam param_;
......
......@@ -19,6 +19,18 @@
namespace paddle {
namespace lite {
/*
* Factor for any Type creator.
*
* Usage:
*
* struct SomeType;
* // Register a creator.
* Factory<SomeType>::Global().Register("some_key", [] ->
* std::unique_ptr<SomeType> { ... });
* // Retrive a creator.
* auto some_type_instance = Factory<SomeType>::Global().Create("some_key");
*/
template <typename ItemType>
class Factory {
public:
......@@ -55,6 +67,7 @@ class Registor {
public:
Registor(std::function<void()>&& functor) { functor(); }
// Touch will do nothing.
int Touch() { return 0; }
};
......
......@@ -12,22 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "target_wrapper.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
#include <algorithm>
namespace paddle {
namespace framework {
namespace lite {
template <>
void TargetWrapper<X86>::MemcpySync(void* dst, void* src, size_t size,
IoDirection dir) {
std::copy_n(reinterpret_cast<uint8_t*>(src), size,
reinterpret_cast<uint8_t*>(dst));
void TargetWrapper<TARGET(kX86)>::MemcpySync(void *dst, void *src, size_t size,
IoDirection dir) {
std::copy_n(reinterpret_cast<uint8_t *>(src), size,
reinterpret_cast<uint8_t *>(dst));
}
template class TargetWrapper<X86>;
template class TargetWrapper<TARGET(kX86)>;
} // namespace lite
} // namespace framework
} // namespace paddle
......@@ -16,9 +16,7 @@
#include "paddle/fluid/lite/core/target_wrapper.h"
namespace paddle {
namespace framework {
namespace lite {
namespace x86 {} // namespace x86
} // namespace lite
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册