提交 f3d1fac2 编写于 作者: S superjomn

fix kernel registry

上级 027cbe83
...@@ -4,3 +4,4 @@ add_subdirectory(cuda) ...@@ -4,3 +4,4 @@ add_subdirectory(cuda)
add_subdirectory(operators) add_subdirectory(operators)
add_subdirectory(kernels) add_subdirectory(kernels)
add_subdirectory(model_parser) add_subdirectory(model_parser)
add_subdirectory(utils)
...@@ -18,10 +18,11 @@ ...@@ -18,10 +18,11 @@
#include <boost/variant.hpp> #include <boost/variant.hpp>
#include <map> #include <map>
#include <string> #include <string>
#include "paddle/fluid/framework/op_desc.h"
#include "context.h" #include "context.h"
#include "target_wrapper.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h" #include "paddle/fluid/lite/utils/all.h"
#include "target_wrapper.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -39,11 +40,11 @@ class OpKernel { ...@@ -39,11 +40,11 @@ class OpKernel {
void SetContext(context_ptr_t&& ctx) { context_ = std::move(ctx); } void SetContext(context_ptr_t&& ctx) { context_ = std::move(ctx); }
void SetParam(any param) { param_ = param; } void SetParam(operators::param_t param) { param_ = param; }
template <typename Param> template <typename Param>
Param& param() const { Param& param() const {
return *any_cast<Param>(&param_); return param_.get<Param>();
} }
virtual void Run() { CHECK(false) << "Not Implemented"; } virtual void Run() { CHECK(false) << "Not Implemented"; }
...@@ -52,7 +53,7 @@ class OpKernel { ...@@ -52,7 +53,7 @@ class OpKernel {
protected: protected:
context_ptr_t context_; context_ptr_t context_;
mutable any param_; mutable operators::param_t param_;
}; };
} // namespace lite } // namespace lite
......
...@@ -18,11 +18,12 @@ ...@@ -18,11 +18,12 @@
#include <boost/variant.hpp> #include <boost/variant.hpp>
#include <map> #include <map>
#include <string> #include <string>
#include "context.h"
#include "kernel.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/scope.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -66,8 +67,7 @@ class OpLite : public Registry { ...@@ -66,8 +67,7 @@ class OpLite : public Registry {
// Run this operator. // Run this operator.
virtual bool Run() = 0; virtual bool Run() = 0;
// Build the operator, attach it with the runtime environment. // Build the operator, attach it with the runtime environment.
virtual bool Build(const framework::OpDesc &opdesc, virtual bool Build(const framework::OpDesc &opdesc, lite::Scope *scope) = 0;
framework::Scope *scope) = 0;
// Human-readable information. // Human-readable information.
virtual std::string DebugString() const = 0; virtual std::string DebugString() const = 0;
......
...@@ -12,4 +12,4 @@ ...@@ -12,4 +12,4 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
\ No newline at end of file \ No newline at end of file
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "kernel.h" #include "paddle/fluid/lite/core/kernel.h"
#include "op_lite.h" #include "paddle/fluid/lite/core/op_lite.h"
#include "target_wrapper.h" #include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/utils/all.h" #include "paddle/fluid/lite/utils/all.h"
namespace paddle { namespace paddle {
...@@ -50,16 +50,32 @@ class OpLiteRegistor : public Registor<OpClass> { ...@@ -50,16 +50,32 @@ class OpLiteRegistor : public Registor<OpClass> {
}; };
template <TargetType Target, PrecisionType Precision> template <TargetType Target, PrecisionType Precision>
class KernelRegistryForTarget : public Factory<OpKernel<Target, Precision>> {}; using KernelRegistryForTarget = Factory<OpKernel<Target, Precision>>;
class KernelRegistry final { class KernelRegistry final {
public: public:
using any_kernel_registor_t = variant<
KernelRegistryForTarget<TargetType::kCUDA, PrecisionType::kFloat> *, //
KernelRegistryForTarget<TargetType::kCUDA, PrecisionType::kInt8> *, //
KernelRegistryForTarget<TargetType::kX86, PrecisionType::kFloat> *, //
KernelRegistryForTarget<TargetType::kX86, PrecisionType::kInt8> *, //
KernelRegistryForTarget<TargetType::kARM, PrecisionType::kFloat> *, //
KernelRegistryForTarget<TargetType::kHost, PrecisionType::kFloat> * //
>;
KernelRegistry() { KernelRegistry() {
/*
using kernel_target_t =
KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kFloat)>;
registries_[0].set<kernel_target_t *>(
&KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kFloat)>::Global());
*/
#define INIT_FOR(target__, precision__) \ #define INIT_FOR(target__, precision__) \
registries_[KernelRegistry::GetKernelOffset<TARGET(target__), \ registries_[KernelRegistry::GetKernelOffset<TARGET(target__), \
PRECISION(precision__)>()] = \ PRECISION(precision__)>()] \
&KernelRegistryForTarget<TARGET(target__), \ .set<KernelRegistryForTarget<TARGET(target__), PRECISION(precision__)> \
PRECISION(precision__)>::Global(); *>(&KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__)>::Global());
// Currently, just register 2 kernel targets. // Currently, just register 2 kernel targets.
INIT_FOR(kARM, kFloat); INIT_FOR(kARM, kFloat);
INIT_FOR(kHost, kFloat); INIT_FOR(kHost, kFloat);
...@@ -76,8 +92,8 @@ class KernelRegistry final { ...@@ -76,8 +92,8 @@ class KernelRegistry final {
typename KernelRegistryForTarget<Target, Precision>::creator_t typename KernelRegistryForTarget<Target, Precision>::creator_t
&&creator) { &&creator) {
using kernel_registor_t = KernelRegistryForTarget<Target, Precision>; using kernel_registor_t = KernelRegistryForTarget<Target, Precision>;
any_cast<kernel_registor_t *>( registries_[GetKernelOffset<Target, Precision>()]
registries_[GetKernelOffset<Target, Precision>()]) .template get<kernel_registor_t *>()
->Register(name, std::move(creator)); ->Register(name, std::move(creator));
} }
...@@ -88,7 +104,7 @@ class KernelRegistry final { ...@@ -88,7 +104,7 @@ class KernelRegistry final {
} }
private: private:
std::array<any, kNumTargets * kNumPrecisions> registries_; std::array<any_kernel_registor_t, kNumTargets * kNumPrecisions> registries_;
}; };
template <TargetType target, PrecisionType precision, typename KernelType> template <TargetType target, PrecisionType precision, typename KernelType>
......
...@@ -12,8 +12,7 @@ ...@@ -12,8 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "scope.h" #include "paddle/fluid/lite/core/scope.h"
#include "scope.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "variable.h" #include "paddle/fluid/lite/core/variable.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "scope.h" #include "paddle/fluid/lite/core/scope.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
namespace paddle { namespace paddle {
......
...@@ -12,4 +12,4 @@ ...@@ -12,4 +12,4 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "tensor.h" #include "paddle/fluid/lite/core/tensor.h"
...@@ -57,7 +57,6 @@ using LoD = std::vector<std::vector<size_t>>; ...@@ -57,7 +57,6 @@ using LoD = std::vector<std::vector<size_t>>;
// A light-weight tensor implementation. // A light-weight tensor implementation.
class Tensor { class Tensor {
public: public:
void SyncEventTree();
Tensor() = default; Tensor() = default;
template <typename T> template <typename T>
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "variable.h" #include "paddle/fluid/lite/core/variable.h"
namespace paddle { namespace paddle {
namespace lite {} // namespace lite namespace lite {} // namespace lite
......
cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite) cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite op_params_lite tensor_lite)
cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite) cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite)
cc_library(op_params_lite SRCS op_params.cc DEPS tensor_lite)
cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite)
...@@ -53,7 +53,7 @@ bool FcOpLite::InferShape() const { ...@@ -53,7 +53,7 @@ bool FcOpLite::InferShape() const {
const auto w_dims = param_.w->dims(); const auto w_dims = param_.w->dims();
// Set output dims // Set output dims
std::vector<int> output_dims(param_.in_num_col_dims + 1, 0); std::vector<int64_t> output_dims(param_.in_num_col_dims + 1, 0);
for (int i = 0; i < param_.in_num_col_dims; ++i) { for (int i = 0; i < param_.in_num_col_dims; ++i) {
output_dims[i] = input_dims[i]; output_dims[i] = input_dims[i];
} }
......
...@@ -15,23 +15,15 @@ ...@@ -15,23 +15,15 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/core/tensor.h" #include "paddle/fluid/lite/core/tensor.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h" #include "paddle/fluid/lite/utils/all.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
struct FcParam {
Tensor* input{nullptr};
Tensor* w{nullptr};
Tensor* bias{nullptr};
Tensor* output{nullptr};
// the input matrix dimentions.
lite::DDim in_mat_dims;
int in_num_col_dims{0};
};
class FcOpLite : public OpLite { class FcOpLite : public OpLite {
public: public:
FcOpLite() {} FcOpLite() {}
...@@ -42,9 +34,21 @@ class FcOpLite : public OpLite { ...@@ -42,9 +34,21 @@ class FcOpLite : public OpLite {
bool Run() override { return false; } bool Run() override { return false; }
bool Build(const framework::OpDesc& opdesc, // TODO(Superjomn) replace framework::OpDesc with a lite one.
framework::Scope* scope) override { bool Build(const framework::OpDesc& op_desc, lite::Scope* scope) override {
return false; auto input = op_desc.Input("Input").front();
auto W = op_desc.Input("W").front();
auto bias = op_desc.Input("bias").front();
auto out = op_desc.Output("bias").front();
param_.input = scope->FindVar(input)->GetMutable<Tensor>();
param_.w = scope->FindVar(W)->GetMutable<Tensor>();
param_.bias = scope->FindVar(bias)->GetMutable<Tensor>();
param_.output = scope->FindVar(out)->GetMutable<Tensor>();
param_.in_num_col_dims =
boost::any_cast<int>(op_desc.GetAttr("in_num_col_dims"));
return true;
} }
std::string DebugString() const override { return "fc"; } std::string DebugString() const override { return "fc"; }
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#pragma once #pragma once
#include "paddle/fluid/lite/utils/varient.h"
#include "paddle/fluid/lite/utils/check.h" #include "paddle/fluid/lite/utils/check.h"
#include "paddle/fluid/lite/utils/factory.h" #include "paddle/fluid/lite/utils/factory.h"
#include "paddle/fluid/lite/utils/macros.h" #include "paddle/fluid/lite/utils/macros.h"
#include "paddle/fluid/lite/utils/varient.h"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册