提交 cf6b65a1 编写于 作者: S superjomn

init lite

上级 a67f7179
cc_library(executor_lite SRCS executor.cc)
cc_library(op_lite SRCS op_lite.cc)
cc_library(memory_lite SRCS memory.cc)
cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite)
cc_library(op_registry_lite SRCS op_registry.cc)
add_subdirectory(x86)
add_subdirectory(cuda)
add_subdirectory(operators)
add_subdirectory(kernels)
add_subdirectory(model_parser)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Created by chunwei on 19-2-22.
//
#include "context.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "target_wrapper.h"
namespace paddle {
namespace lite {
template <TargetType Target>
class Context {
public:
using target_wrapper_t = TargetWrapper<Target>;
using stream_t = typename TargetWrapper<Target>::stream_t;
Context() = default;
Context(int device_id, stream_t compute_stream, stream_t data_stream)
: device_id_(device_id),
compute_stream_(compute_stream),
data_stream_(data_stream) {}
void SetDeviceId(int device_id) { device_id_ = device_id; }
void SetComputeStream(stream_t x) { compute_stream_ = x; }
void SetDataStream(stream_t x) { data_stream_ = x; }
int device_id() const { return device_id_; }
stream_t compute_stream() const { return compute_stream_; }
stream_t data_stream() const { return data_stream_; }
private:
int device_id_;
stream_t compute_stream_;
stream_t data_stream_;
};
class OpContext final {
public:
template <TargetType Target>
using target_ptr_t = std::unique_ptr<Context<Target>>;
// @param target valid target.
explicit OpContext(TargetType target)
: targets_(std::vector<TargetType>({target})) {}
// @param target valid target.
explicit OpContext(const std::vector<TargetType>& target) : targets_(target) {}
const std::vector<TargetType>& target() const { return targets_; }
template <TargetType Target>
target_ptr_t<Target> CreateContext() {
return target_ptr_t<Target>(new Context<Target>);
}
private:
std::vector<TargetType> targets_;
};
} // namespace lite
} // namespace paddle
nv_library(target_wrapper_cuda SRCS target_wrapper.cc)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Created by chunwei on 19-2-23.
//
#include "target_wrapper.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
namespace paddle {
namespace framework {
namespace lite {
namespace cuda {} // namespace cuda
} // namespace lite
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
add_subdirectory(host)
add_subdirectory(arm)
cc_library(fc_compute_host SRCS fc_compute.cc DEPS tensor_lite)
cc_library(relu_compute_host SRCS relu_compute.cc DEPS tensor_lite)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/host/fc_compute.h"
#include <Eigen/Core>
#include "paddle/fluid/lite/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
// NOTE should use pure std C++ implementation.
void FcCompute::Run() {
using matrix_t = Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>;
using matrix_map_t = Eigen::Map<matrix_t>;
auto& param = this->param<param_t>();
CHECK_EQ(param.in_mat_dims.size(), 2UL);
CHECK_EQ(param.output->dims().size(), 2UL);
Eigen::Map<const matrix_t> input(param.input->data<float>(),
param.in_mat_dims[0], param.in_mat_dims[1]);
Eigen::Map<const matrix_t> weight(param.w->data<float>(), param.w->dims()[0],
param.w->dims()[1]);
matrix_map_t output(param.output->mutable_data<float>(),
param.output->dims()[0], param.output->dims()[1]);
output = weight.transpose() * input;
if (param.bias) {
Eigen::Map<const matrix_t> bias(param.bias->data<float>(),
param.bias->dims()[0],
param.bias->dims()[1]);
output += bias;
}
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/op_kernel.h"
#include "paddle/fluid/lite/operators/fc_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class FcCompute final : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
public:
using param_t = operators::FcParam;
void Run() override;
virtual ~FcCompute() = default;
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/relu_compute.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/op_kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class ReluCompute final : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
public:
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/memory.h"
namespace paddle {
namespace framework {
namespace lite {} // namespace lite
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include "target_wrapper.h"
namespace paddle {
namespace lite {
void* TargetMalloc(TargetType target, size_t size) {
void* data{nullptr};
switch (static_cast<int>(target)) {
case static_cast<int>(TargetType::kX86):
data = TargetWrapper<TARGET(kX86)>::Malloc(size);
break;
case static_cast<int>(TargetType::kCUDA):
data = TargetWrapper<TARGET(kCUDA)>::Malloc(size);
break;
case static_cast<int>(TargetType::kARM):
data = TargetWrapper<TARGET(kARM)>::Malloc(size);
break;
case static_cast<int>(TargetType::kHost):
data = TargetWrapper<TARGET(kHost)>::Malloc(size);
break;
default:
LOG(FATAL) << "Unknown type";
}
return data;
}
void TargetFree(TargetType target, void* data) {
switch (static_cast<int>(target)) {
case static_cast<int>(TargetType::kX86):
TargetWrapper<TARGET(kX86)>::Free(data);
break;
case static_cast<int>(TargetType::kCUDA):
TargetWrapper<TARGET(kX86)>::Free(data);
break;
case static_cast<int>(TargetType::kARM):
TargetWrapper<TARGET(kX86)>::Free(data);
break;
default:
LOG(FATAL) << "Unknown type";
}
}
// Memory buffer manager.
class Buffer {
public:
Buffer(TargetType target, size_t size) : space_(size), target_(target) {}
void* data() const { return data_; }
void ResetLazy(TargetType target, size_t size) {
if (target != target_ || space_ < size) {
Free();
}
if (size < space_) return;
data_ = TargetMalloc(target, size);
target_ = target;
space_ = size;
}
void ResizeLazy(size_t size) { ResetLazy(target_, size); }
void Free() {
if (space_ > 0) {
TargetFree(target_, data_);
}
target_ = TargetType::kHost;
space_ = 0;
}
private:
size_t space_{0};
void* data_{nullptr};
TargetType target_{TargetType::kHost};
};
} // namespace lite
} // namespace paddle
cc_library(model_parser SRCS model_parser.cc)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Created by chunwei on 19-2-25.
//
#include "model_parser.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This file contains model format related operations, such as load a model,
// parse an operator definitions and so on.
#include <string>
#include <vector>
namespace paddle {
namespace lite {
void LoadProgram(const std::string& path);
void LoadParams(const std::string& path);
void LoadModel(const std::string& model_dir);
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <boost/variant.hpp>
#include <map>
#include <string>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/lite/context.h"
#include "paddle/fluid/lite/target_wrapper.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
// Light-weight kernel implementation.
// The OpKernel is designed to implement the specific algorithm on a target
// device.
template <TargetType Target, PrecisionType Precision>
class OpKernel {
public:
using context_t = Context<Target>;
using context_ptr_t = std::unique_ptr<context_t>;
OpKernel() = default;
void SetContext(context_ptr_t&& ctx) { context_ = std::move(ctx); }
void SetParam(any param) { param_ = param; }
template <typename Param>
Param& param() const {
return *any_cast<Param>(&param_);
}
virtual void Run() { CHECK(false) << "Not Implemented"; }
virtual ~OpKernel() = default;
protected:
context_ptr_t context_;
mutable any param_;
};
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "op_lite.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <boost/variant.hpp>
#include <map>
#include <string>
#include "context.h"
#include "op_kernel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace lite {
using any_t = boost::variant<int, float, framework::Variable *>;
using anys_t = std::map<std::string, any_t>;
// For registry factory.
struct Registry {
void Touch() {}
};
/**
* The base class of an light-weight operators, currently just used in inference
* to eliminate overhead of some operations in current framework.
*
* The Operator are designed as follows:
* - it can has some members to hold the argument addresses,
* - it should act just like a function call, no more logic should included.
*/
class OpLite : public Registry {
public:
enum class KernelStrategy {
// Return the user specified one.
kStatic = 0,
// Specify the expected kernel externally.
kSpecified,
// Run each kernel to evaluate and get the best kernel.
kRuntime,
};
OpLite() {}
OpLite(std::unique_ptr<OpContext> &&x) : op_context_(std::move(x)) {}
virtual bool CheckShape() const { return true; }
virtual bool InferShape() const { return true; }
virtual bool Run() = 0;
virtual bool Build(const framework::OpDesc &opdesc,
framework::Scope *scope) = 0;
virtual std::string DebugString() const = 0;
virtual void StaticPickKernel(const std::vector<OpTarget> &valid_targets) = 0;
void PickBestKernel(const std::vector<OpTarget> &valid_places,
KernelStrategy kernel_strategy = KernelStrategy::kStatic);
// Create all the kernels for the valid targets.
void CreateKernels();
virtual ~OpLite() = default;
protected:
std::unique_ptr<OpContext> op_context_;
};
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "op_registry.h"
\ No newline at end of file
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/lite/op_kernel.h"
#include "paddle/fluid/lite/op_lite.h"
#include "paddle/fluid/lite/target_wrapper.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
using KernelFunc = std::function<void()>;
using KernelFuncCreator = std::function<std::unique_ptr<KernelFunc>()>;
class LiteOpRegistry final : public Factory<OpLite> {
public:
static LiteOpRegistry &Global() {
static auto *x = new LiteOpRegistry;
return *x;
}
private:
LiteOpRegistry() = default;
};
template <typename OpClass>
class OpLiteRegistor : public Registor<OpClass> {
public:
OpLiteRegistor(const std::string &op_type)
: Registor<OpClass>([&] {
LiteOpRegistry::Global().Register(
op_type, []() -> std::unique_ptr<OpLite> {
return std::unique_ptr<OpLite>(new OpClass);
});
}) {}
};
template <TargetType Target, PrecisionType Precision>
class KernelRegistryForTarget : public Factory<OpKernel<Target, Precision>> {};
class KernelRegistry final {
public:
KernelRegistry() {
#define INIT_FOR(target__, precision__) \
registries_[KernelRegistry::GetKernelOffset<TARGET(target__), \
PRECISION(precision__)>()] = \
&KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__)>::Global();
// Currently, just register 2 kernel targets.
INIT_FOR(kARM, kFloat);
INIT_FOR(kHost, kFloat);
#undef INIT_FOR
}
static KernelRegistry &Global() {
static auto *x = new KernelRegistry;
return *x;
}
template <TargetType Target, PrecisionType Precision>
void Register(const std::string &name,
typename KernelRegistryForTarget<Target, Precision>::creator_t
&&creator) {
using kernel_registor_t = KernelRegistryForTarget<Target, Precision>;
any_cast<kernel_registor_t *>(
registries_[GetKernelOffset<Target, Precision>()])
->Register(name, std::move(creator));
}
// Get a kernel registry offset in all the registries.
template <TargetType Target, PrecisionType Precision>
static constexpr int GetKernelOffset() {
return kNumTargets * static_cast<int>(Target) + static_cast<int>(Precision);
}
private:
std::array<any, kNumTargets * kNumPrecisions> registries_;
};
template <TargetType target, PrecisionType precision, typename KernelType>
class KernelRegistor : public lite::Registor<KernelType> {
public:
KernelRegistor(const std::string op_type)
: Registor<KernelType>([&] {
KernelRegistry::Global().Register<target, precision>(
op_type, [&]() -> std::unique_ptr<KernelType> {
return std::unique_ptr<KernelType>(new KernelType);
});
}) {}
};
} // namespace lite
} // namespace paddle
// Operator registry
#define LITE_OP_REGISTER_INSTANCE(op_type__) op_type__##__registry__instance__
#define LITE_OP_REGISTER_FAKE(op_type__) op_type__##__registry__
#define REGISTER_LITE_OP(op_type__, OpClass) \
static paddle::lite::OpLiteRegistor<OpClass> LITE_OP_REGISTER_INSTANCE( \
op_type__)(#op_type__);
#define USE_LITE_OP(op_type__) \
int LITE_OP_REGISTER_FAKE(op_type__)((unused)) = \
LITE_OP_REGISTER_INSTANCE(op_type__).Touch();
// Kernel registry
#define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \
op_type__##target__##precision__##__registor__
#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__) \
op_type__##target__##precision__##__registor__instance__
#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__) \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__)##__fake__
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass) \
static paddle::lite::KernelRegistor<TARGET(target__), \
PRECISION(precision__), KernelClass> \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, \
precision__)(#op_type__);
#define USE_LITE_KERNEL(op_type__, target__, precision__) \
int LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__)((unused)) = \
LITE_KERNEL_REGISTER(op_type__, target__, precision__).Touch();
cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite)
cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fc_op.h"
#include "paddle/fluid/lite/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool FcOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.input);
CHECK_OR_FALSE(param_.output);
CHECK_OR_FALSE(param_.w);
// bias is optional.
const auto input_dims = param_.input->dims();
const auto w_dims = param_.w->dims();
if (param_.bias) {
const auto bias_dims = param_.bias->dims();
if (bias_dims.size() == 2) {
CHECK_EQ_OR_FALSE(bias_dims[0], 1);
CHECK_EQ_OR_FALSE(bias_dims[1], w_dims[1]);
} else if (bias_dims.size() == 1) {
CHECK_EQ_OR_FALSE(bias_dims[0], w_dims[1]);
}
}
CHECK_EQ_OR_FALSE(w_dims.size(), 2UL);
CHECK_GT_OR_FALSE(input_dims.size(),
static_cast<size_t>(param_.in_num_col_dims));
param_.in_mat_dims = lite::flatten_to_2d(input_dims, param_.in_num_col_dims);
CHECK_EQ_OR_FALSE(param_.in_mat_dims[1], w_dims[0]);
return true;
}
bool FcOpLite::InferShape() const {
const auto input_dims = param_.input->dims();
const auto w_dims = param_.w->dims();
// Set output dims
std::vector<int> output_dims(param_.in_num_col_dims + 1, 0);
for (int i = 0; i < param_.in_num_col_dims; ++i) {
output_dims[i] = input_dims[i];
}
output_dims.back() = w_dims[1];
param_.output->Resize(output_dims);
// share LoD
// param_.output->set_lod(param_.input->lod());
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(fc, paddle::lite::operators::FcOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/lite/op_lite.h"
#include "paddle/fluid/lite/tensor.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
struct FcParam {
Tensor* input{nullptr};
Tensor* w{nullptr};
Tensor* bias{nullptr};
Tensor* output{nullptr};
// the input matrix dimentions.
lite::DDim in_mat_dims;
int in_num_col_dims{0};
};
class FcOpLite : public OpLite {
public:
FcOpLite() {}
bool CheckShape() const override;
bool InferShape() const override;
bool Run() override { return false; }
bool Build(const framework::OpDesc& opdesc,
framework::Scope* scope) override {
return false;
}
std::string DebugString() const override { return "fc"; }
void StaticPickKernel(const std::vector<OpTarget>& valid_targets) override {}
private:
mutable FcParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/relu_op.h"
#include "paddle/fluid/lite/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool ReluOp::CheckShape() const { return true; }
bool ReluOp::InferShape() const {
CHECK_OR_FALSE(param_.input);
CHECK_OR_FALSE(param_.output);
// TODO(Superjomn) Enable data sharing.
param_.output->Resize(param_.input->dims());
// param_.output->ShareDataWith(*param_.input);
// share lod
// param_.output->set_lod(param_.input->lod());
return true;
}
bool ReluOp::Run() { return false; }
bool ReluOp::Build(const framework::OpDesc &opdesc, framework::Scope *scope) {
return false;
}
REGISTER_LITE_OP(relu, ReluOp);
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/lite/op_lite.h"
#include "paddle/fluid/lite/tensor.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
struct ReluParam {
Tensor* input{nullptr};
Tensor* output{nullptr};
};
class ReluOp : public OpLite {
public:
ReluOp() {}
bool CheckShape() const override;
bool InferShape() const override;
bool Run() override;
bool Build(const framework::OpDesc& opdesc, framework::Scope* scope) override;
std::string DebugString() const override { return "tanh"; }
void StaticPickKernel(const std::vector<OpTarget>& valid_targets) override {}
private:
mutable ReluParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iostream>
namespace paddle {
namespace lite {
enum class TargetType { kHost = 0, kX86, kCUDA, kARM, kLastAsPlaceHolder };
#define TARGET(item__) paddle::lite::TargetType::item__
#define TARGET_VAL(item__) static_cast<int>(TARGET(item__))
constexpr int kNumTargets = TARGET_VAL(kLastAsPlaceHolder) - TARGET_VAL(kHost);
/*
template <TargetType target>
struct Target {};
using Host = Target<TargetType::kHost>;
using X86 = Target<TargetType::kX86>;
using CUDA = Target<TargetType::kCUDA>;
using ARM = Target<TargetType::kARM>;
*/
enum class PrecisionType { kFloat = 0, kInt8, kLastAsPlaceHolder };
#define PRECISION(item__) paddle::lite::PrecisionType::item__
#define PRECISION_VAL(item__) static_cast<int>(PRECISION(item__))
constexpr int kNumPrecisions =
PRECISION_VAL(kLastAsPlaceHolder) - PRECISION_VAL(kFloat);
// Event sync for multi-stream devices like CUDA and OpenCL.
template <TargetType Target>
class Event {};
// Memory copy directions.
enum class IoDirection {
HtoH = 0,
HtoD,
DtoH,
};
// This interface should be specified by each kind of target.
template <TargetType Target>
class TargetWrapper {
public:
using stream_t = int;
using event_t = Event<Target>;
static size_t num_devices() { return 0; }
static size_t maximum_stream() { return 0; }
static void CreateStream(stream_t* stream) {}
static void DestroyStream(const stream_t& stream) {}
static void CreateEvent(event_t* event) {}
static void DestroyEvent(const event_t& event) {}
static void RecordEvent(const event_t& event) {}
static void SyncEvent(const event_t& event) {}
static void StreamSync(const stream_t& stream) {}
static void* Malloc(size_t size) { return nullptr; }
static void Free(void* ptr) {}
static void MemcpySync(void* dst, void* src, size_t size, IoDirection dir) {}
static void MemcpyAsync(void* dst, void* src, size_t size,
const stream_t& stream, IoDirection dir) {
MemcpySync(dst, src, size, dir);
}
};
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tensor.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <vector>
#include "memory.h"
namespace paddle {
namespace lite {
template <TargetType Target>
class EventTree {
public:
using event_t = Event<Target>;
void AddChild(const event_t& event) { children_.push_back(event); }
void Sync() {
for (auto& event : children_) {
TargetWrapper<Target>::SyncEvent(event);
}
}
private:
std::vector<event_t> children_;
};
using DDim = std::vector<int>;
DDim SliceDims(const DDim& dims, int begin, int end) {
return DDim(dims.begin() + begin, dims.begin() + end - 1);
}
int product(const DDim& dims) {
return std::accumulate(dims.begin(), dims.end(), 1,
[](int a, int b) { return a * b; });
}
DDim flatten_to_2d(const DDim& dims, int col) {
return DDim({product(SliceDims(dims, 0, col)),
product(SliceDims(dims, col, dims.size()))});
}
// A light-weight tensor implementation.
class Tensor {
public:
void SyncEventTree();
template <typename T>
const T* data() const {
return static_cast<const T*>(buffer_.data());
}
void Resize(const DDim& ddim) { dims_ = ddim; }
const DDim& dims() const { return dims_; }
template <typename T>
T* mutable_data() {
buffer_.ResetLazy(target_, product(dims_));
return static_cast<T*>(buffer_.data());
}
bool IsInitialized() const { return buffer_.data(); }
private:
TargetType target_{TargetType::kHost};
DDim dims_;
Buffer buffer_;
};
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/utils/any.h"
#include "paddle/fluid/lite/utils/check.h"
#include "paddle/fluid/lite/utils/factory.h"
#include "paddle/fluid/lite/utils/macros.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <exception>
#include <memory>
#include <type_traits>
#include <typeinfo>
// This is an equivalent implementation of boost::any. We implement this to
// avoid including the whole boost library and keep the inference library small.
// These code references https://gist.github.com/shoooe/9202235
namespace paddle {
namespace lite {
class any;
template <class Type>
Type any_cast(any&);
template <class Type>
Type any_cast(const any&);
template <class Type>
Type* any_cast(any*);
template <class Type>
const Type* any_cast(const any*);
struct bad_any_cast : public std::bad_cast {};
class any {
public:
template <class Type>
friend Type any_cast(any&);
template <class Type>
friend Type any_cast(const any&);
template <class Type>
friend Type* any_cast(any*);
template <class Type>
friend const Type* any_cast(const any*);
any() : ptr(nullptr) {}
explicit any(any&& x) : ptr(std::move(x.ptr)) {}
explicit any(const any& x) {
if (x.ptr) ptr = x.ptr->clone();
}
template <class Type>
explicit any(const Type& x)
: ptr(new concrete<typename std::decay<const Type>::type>(x)) {}
any& operator=(any&& rhs) {
ptr = std::move(rhs.ptr);
return (*this);
}
any& operator=(const any& rhs) {
ptr = std::move(any(rhs).ptr);
return (*this);
}
template <class T>
any& operator=(T&& x) {
ptr.reset(new concrete<typename std::decay<T>::type>(
typename std::decay<T>::type(x)));
return (*this);
}
template <class T>
any& operator=(const T& x) {
ptr.reset(new concrete<typename std::decay<T>::type>(
typename std::decay<T>::type(x)));
return (*this);
}
void clear() { ptr.reset(nullptr); }
bool empty() const { return ptr == nullptr; }
const std::type_info& type() const {
return (!empty()) ? ptr->type() : typeid(void);
}
private:
struct placeholder {
virtual std::unique_ptr<placeholder> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual ~placeholder() {}
};
template <class T>
struct concrete : public placeholder {
explicit concrete(T&& x) : value(std::move(x)) {}
explicit concrete(const T& x) : value(x) {}
virtual std::unique_ptr<placeholder> clone() const override {
return std::unique_ptr<placeholder>(new concrete<T>(value));
}
virtual const std::type_info& type() const override { return typeid(T); }
T value;
};
std::unique_ptr<placeholder> ptr;
};
template <class Type>
Type any_cast(any& val) {
if (val.ptr->type() != typeid(Type)) throw bad_any_cast();
return static_cast<any::concrete<Type>*>(val.ptr.get())->value;
}
template <class Type>
Type any_cast(const any& val) {
return any_cast<Type>(any(val));
}
template <class Type>
Type* any_cast(any* ptr) {
return dynamic_cast<Type*>(ptr->ptr.get());
}
template <class Type>
const Type* any_cast(const any* ptr) {
return dynamic_cast<const Type*>(ptr->ptr.get());
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#define CHECK_OR_FALSE(cond) \
if (!(cond)) { \
LOG(ERROR) << #cond << " test error!"; \
return false; \
}
#define CHECK_EQ_OR_FALSE(a__, b__) \
if ((a__) != (b__)) { \
LOG(ERROR) << #a__ << " == " << #b__ << " check failed!"; \
LOG(ERROR) << a__ << " != " << b__; \
return false; \
}
#define CHECK_GT_OR_FALSE(a__, b__) \
if (!((a__) > (b__))) { \
LOG(ERROR) << #a__ << " > " << #b__ << " check failed!"; \
LOG(ERROR) << a__ << " <= " << b__; \
return false; \
}
#define CHECK_GE_OR_FALSE(a__, b__) \
if (!((a__) >= (b__))) { \
LOG(ERROR) << #a__ << " >= " << #b__ << " check failed!"; \
LOG(ERROR) << a__ << " < " << b__; \
return false; \
}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
namespace paddle {
namespace lite {
template <typename ItemType>
class Factory {
public:
using item_t = ItemType;
using self_t = Factory<item_t>;
using item_ptr_t = std::unique_ptr<item_t>;
using creator_t = std::function<item_ptr_t()>;
static Factory& Global() {
static Factory* x = new self_t;
return *x;
}
void Register(const std::string& op_type, creator_t&& creator) {
CHECK(!creators_.count(op_type)) << "The op " << op_type
<< " has already registered";
creators_.emplace(op_type, std::move(creator));
}
item_ptr_t Create(const std::string& op_type) const {
auto it = creators_.find(op_type);
CHECK(it != creators_.end());
return it->second();
}
protected:
std::unordered_map<std::string, creator_t> creators_;
};
/* A helper function to help run a lambda at the start.
*/
template <typename Type>
class Registor {
public:
Registor(std::function<void()>&& functor) { functor(); }
int Touch() { return 0; }
};
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifndef DISALLOW_COPY_AND_ASSIGN
#define DISALLOW_COPY_AND_ASSIGN(class__) \
class__(const class__&) = delete; \
class__& operator=(const class__&) = delete;
#endif
cc_library(target_wrapper_x86 SRCS target_wrapper.cc)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "target_wrapper.h"
#include <algorithm>
namespace paddle {
namespace framework {
namespace lite {
template <>
void TargetWrapper<X86>::MemcpySync(void* dst, void* src, size_t size,
IoDirection dir) {
std::copy_n(reinterpret_cast<uint8_t*>(src), size,
reinterpret_cast<uint8_t*>(dst));
}
template class TargetWrapper<X86>;
} // namespace lite
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/target_wrapper.h"
namespace paddle {
namespace framework {
namespace lite {
namespace x86 {} // namespace x86
} // namespace lite
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册