提交 f3d1fac2 编写于 作者: S superjomn

fix kernel registry

上级 027cbe83
......@@ -4,3 +4,4 @@ add_subdirectory(cuda)
add_subdirectory(operators)
add_subdirectory(kernels)
add_subdirectory(model_parser)
add_subdirectory(utils)
......@@ -18,10 +18,11 @@
#include <boost/variant.hpp>
#include <map>
#include <string>
#include "paddle/fluid/framework/op_desc.h"
#include "context.h"
#include "target_wrapper.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
#include "target_wrapper.h"
namespace paddle {
namespace lite {
......@@ -39,11 +40,11 @@ class OpKernel {
void SetContext(context_ptr_t&& ctx) { context_ = std::move(ctx); }
void SetParam(any param) { param_ = param; }
void SetParam(operators::param_t param) { param_ = param; }
template <typename Param>
Param& param() const {
return *any_cast<Param>(&param_);
return param_.get<Param>();
}
virtual void Run() { CHECK(false) << "Not Implemented"; }
......@@ -52,7 +53,7 @@ class OpKernel {
protected:
context_ptr_t context_;
mutable any param_;
mutable operators::param_t param_;
};
} // namespace lite
......
......@@ -18,11 +18,12 @@
#include <boost/variant.hpp>
#include <map>
#include <string>
#include "context.h"
#include "kernel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/scope.h"
namespace paddle {
namespace lite {
......@@ -66,8 +67,7 @@ class OpLite : public Registry {
// Run this operator.
virtual bool Run() = 0;
// Build the operator, attach it with the runtime environment.
virtual bool Build(const framework::OpDesc &opdesc,
framework::Scope *scope) = 0;
virtual bool Build(const framework::OpDesc &opdesc, lite::Scope *scope) = 0;
// Human-readable information.
virtual std::string DebugString() const = 0;
......
......@@ -12,4 +12,4 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "op_registry.h"
\ No newline at end of file
#include "paddle/fluid/lite/core/op_registry.h"
\ No newline at end of file
......@@ -15,9 +15,9 @@
#include <memory>
#include <string>
#include <unordered_map>
#include "kernel.h"
#include "op_lite.h"
#include "target_wrapper.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
......@@ -50,16 +50,32 @@ class OpLiteRegistor : public Registor<OpClass> {
};
template <TargetType Target, PrecisionType Precision>
class KernelRegistryForTarget : public Factory<OpKernel<Target, Precision>> {};
using KernelRegistryForTarget = Factory<OpKernel<Target, Precision>>;
class KernelRegistry final {
public:
using any_kernel_registor_t = variant<
KernelRegistryForTarget<TargetType::kCUDA, PrecisionType::kFloat> *, //
KernelRegistryForTarget<TargetType::kCUDA, PrecisionType::kInt8> *, //
KernelRegistryForTarget<TargetType::kX86, PrecisionType::kFloat> *, //
KernelRegistryForTarget<TargetType::kX86, PrecisionType::kInt8> *, //
KernelRegistryForTarget<TargetType::kARM, PrecisionType::kFloat> *, //
KernelRegistryForTarget<TargetType::kHost, PrecisionType::kFloat> * //
>;
KernelRegistry() {
/*
using kernel_target_t =
KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kFloat)>;
registries_[0].set<kernel_target_t *>(
&KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kFloat)>::Global());
*/
#define INIT_FOR(target__, precision__) \
registries_[KernelRegistry::GetKernelOffset<TARGET(target__), \
PRECISION(precision__)>()] = \
&KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__)>::Global();
PRECISION(precision__)>()] \
.set<KernelRegistryForTarget<TARGET(target__), PRECISION(precision__)> \
*>(&KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__)>::Global());
// Currently, just register 2 kernel targets.
INIT_FOR(kARM, kFloat);
INIT_FOR(kHost, kFloat);
......@@ -76,8 +92,8 @@ class KernelRegistry final {
typename KernelRegistryForTarget<Target, Precision>::creator_t
&&creator) {
using kernel_registor_t = KernelRegistryForTarget<Target, Precision>;
any_cast<kernel_registor_t *>(
registries_[GetKernelOffset<Target, Precision>()])
registries_[GetKernelOffset<Target, Precision>()]
.template get<kernel_registor_t *>()
->Register(name, std::move(creator));
}
......@@ -88,7 +104,7 @@ class KernelRegistry final {
}
private:
std::array<any, kNumTargets * kNumPrecisions> registries_;
std::array<any_kernel_registor_t, kNumTargets * kNumPrecisions> registries_;
};
template <TargetType target, PrecisionType precision, typename KernelType>
......
......@@ -12,8 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "scope.h"
#include "scope.h"
#include "paddle/fluid/lite/core/scope.h"
namespace paddle {
namespace lite {
......
......@@ -19,7 +19,7 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include "variable.h"
#include "paddle/fluid/lite/core/variable.h"
namespace paddle {
namespace lite {
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "scope.h"
#include "paddle/fluid/lite/core/scope.h"
#include <gtest/gtest.h>
namespace paddle {
......
......@@ -12,4 +12,4 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tensor.h"
#include "paddle/fluid/lite/core/tensor.h"
......@@ -57,7 +57,6 @@ using LoD = std::vector<std::vector<size_t>>;
// A light-weight tensor implementation.
class Tensor {
public:
void SyncEventTree();
Tensor() = default;
template <typename T>
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "variable.h"
#include "paddle/fluid/lite/core/variable.h"
namespace paddle {
namespace lite {} // namespace lite
......
cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite)
cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite op_params_lite tensor_lite)
cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite)
cc_library(op_params_lite SRCS op_params.cc DEPS tensor_lite)
cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite)
......@@ -53,7 +53,7 @@ bool FcOpLite::InferShape() const {
const auto w_dims = param_.w->dims();
// Set output dims
std::vector<int> output_dims(param_.in_num_col_dims + 1, 0);
std::vector<int64_t> output_dims(param_.in_num_col_dims + 1, 0);
for (int i = 0; i < param_.in_num_col_dims; ++i) {
output_dims[i] = input_dims[i];
}
......
......@@ -15,23 +15,15 @@
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/core/tensor.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
struct FcParam {
Tensor* input{nullptr};
Tensor* w{nullptr};
Tensor* bias{nullptr};
Tensor* output{nullptr};
// the input matrix dimentions.
lite::DDim in_mat_dims;
int in_num_col_dims{0};
};
class FcOpLite : public OpLite {
public:
FcOpLite() {}
......@@ -42,9 +34,21 @@ class FcOpLite : public OpLite {
bool Run() override { return false; }
bool Build(const framework::OpDesc& opdesc,
framework::Scope* scope) override {
return false;
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool Build(const framework::OpDesc& op_desc, lite::Scope* scope) override {
auto input = op_desc.Input("Input").front();
auto W = op_desc.Input("W").front();
auto bias = op_desc.Input("bias").front();
auto out = op_desc.Output("bias").front();
param_.input = scope->FindVar(input)->GetMutable<Tensor>();
param_.w = scope->FindVar(W)->GetMutable<Tensor>();
param_.bias = scope->FindVar(bias)->GetMutable<Tensor>();
param_.output = scope->FindVar(out)->GetMutable<Tensor>();
param_.in_num_col_dims =
boost::any_cast<int>(op_desc.GetAttr("in_num_col_dims"));
return true;
}
std::string DebugString() const override { return "fc"; }
......
......@@ -14,7 +14,7 @@
#pragma once
#include "paddle/fluid/lite/utils/varient.h"
#include "paddle/fluid/lite/utils/check.h"
#include "paddle/fluid/lite/utils/factory.h"
#include "paddle/fluid/lite/utils/macros.h"
#include "paddle/fluid/lite/utils/varient.h"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册