提交 1fb93746 编写于 作者: S superjomn

correct the running logic of host model

- make the target wrapper for host works
- code clean
上级 ca629eb4
add_subdirectory(core)
add_subdirectory(x86)
add_subdirectory(host)
add_subdirectory(cuda)
add_subdirectory(operators)
add_subdirectory(kernels)
......
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite)
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host)
cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite)
......@@ -58,7 +58,7 @@ class Predictor {
const Tensor* GetOutput(size_t offset) {
auto* _fetch_list = program_->exec_scope()->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope";
auto fetch_list = _fetch_list->Get<std::vector<Tensor>>();
auto& fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow";
return &fetch_list.at(offset);
}
......
......@@ -28,8 +28,22 @@ TEST(CXXApi, test) {
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize({100, 100});
input_tensor->mutable_data<float>();
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 100 * 100; i++) {
data[i] = i;
}
LOG(INFO) << "input " << input_tensor;
LOG(INFO) << "input " << *input_tensor;
predictor.Run();
auto* out = predictor.GetOutput(0);
LOG(INFO) << out << " memory size " << out->memory_size();
LOG(INFO) << "out " << out->data<float>()[0];
LOG(INFO) << "out " << out->data<float>()[1];
LOG(INFO) << "dims " << out->dims();
LOG(INFO) << "out " << *out;
}
} // namespace lite
......
......@@ -65,6 +65,13 @@ class KernelBase {
virtual ~KernelBase() = default;
std::string DebugString() const {
std::stringstream ss;
ss << op_type() << ":" << TargetToStr(target()) << "/"
<< PrecisionToStr(precision()) << "/" << DataLayoutToStr(layout());
return ss.str();
}
protected:
std::unique_ptr<KernelContext> context_;
mutable operators::param_t param_;
......
......@@ -21,18 +21,16 @@ namespace lite {
static void* TargetMalloc(TargetType target, size_t size) {
void* data{nullptr};
switch (static_cast<int>(target)) {
case static_cast<int>(TargetType::kX86):
data = TargetWrapper<TARGET(kX86)>::Malloc(size);
switch (target) {
case TargetType::kHost:
case TargetType::kX86:
data = TargetWrapper<TARGET(kHost)>::Malloc(size);
break;
case static_cast<int>(TargetType::kCUDA):
case TargetType::kCUDA:
data = TargetWrapper<TARGET(kCUDA)>::Malloc(size);
break;
case static_cast<int>(TargetType::kHost):
data = TargetWrapper<TARGET(kHost)>::Malloc(size);
break;
default:
LOG(FATAL) << "Unknown type";
LOG(FATAL) << "Unknown supported target " << TargetToStr(target);
}
return data;
}
......@@ -52,17 +50,19 @@ static void TargetFree(TargetType target, void* data) {
static void TargetCopy(TargetType target, void* dst, const void* src,
size_t size) {
switch (static_cast<int>(target)) {
case static_cast<int>(TargetType::kX86):
case static_cast<int>(TargetType::kHost):
switch (target) {
case TargetType::kX86:
case TargetType::kHost:
TargetWrapper<TARGET(kHost)>::MemcpySync(dst, src, size,
IoDirection::DtoD);
break;
case static_cast<int>(TargetType::kCUDA):
case TargetType::kCUDA:
TargetWrapper<TARGET(kCUDA)>::MemcpySync(dst, src, size,
IoDirection::DtoD);
break;
default:
LOG(FATAL) << "unsupported type";
}
}
......@@ -79,12 +79,10 @@ class Buffer {
void ResetLazy(TargetType target, size_t size) {
if (target != target_ || space_ < size) {
Free();
data_ = TargetMalloc(target, size);
target_ = target;
space_ = size;
}
if (size < space_) return;
target_ = target;
data_ = TargetMalloc(target, size);
space_ = size;
}
void ResizeLazy(size_t size) { ResetLazy(target_, size); }
......
......@@ -60,9 +60,6 @@ class SSAGraph : GraphBase {
op->SetValidPlaces(valid_places);
auto &new_node = node_storage_.back();
auto kernels = op->CreateKernels(valid_places);
for (auto &kernel : kernels) {
op->AttachKernel(kernel.get());
}
node_storage_.back().AsInstruct(op->op_type_, std::move(kernels), op,
op->op_info());
......
......@@ -29,6 +29,7 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
(kernel_type.empty() ? op_type_ : kernel_type), place.target,
place.precision);
for (auto &&it : ks) {
AttachKernel(it.get());
kernels.emplace_back(std::move(it));
}
}
......
......@@ -105,6 +105,11 @@ struct Instruction {
void Run() {
CHECK(op_);
CHECK(kernel_);
LOG(INFO) << "running kernel> " << kernel_->DebugString();
if (UNLIKELY(first_epoch_)) {
first_epoch_ = false;
op_->CheckShape();
}
op_->InferShape();
kernel_->Run();
}
......@@ -112,6 +117,7 @@ struct Instruction {
private:
std::shared_ptr<OpLite> op_;
std::unique_ptr<KernelBase> kernel_;
bool first_epoch_{true};
};
/*
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <iostream>
#include <sstream>
......@@ -138,11 +139,48 @@ class TargetWrapper {
static void StreamSync(const stream_t& stream) {}
static void* Malloc(size_t size) { return new char[size]; }
static void Free(void* ptr) { delete[] static_cast<char*>(ptr); }
static void* Malloc(size_t size) {
LOG(FATAL) << "Unimplemented malloc for " << TargetToStr(Target);
return nullptr;
}
static void Free(void* ptr) { LOG(FATAL) << "Unimplemented"; }
static void MemcpySync(void* dst, const void* src, size_t size,
IoDirection dir) {
LOG(FATAL) << "Unimplemented";
}
static void MemcpyAsync(void* dst, const void* src, size_t size,
IoDirection dir, const stream_t& stream) {
MemcpySync(dst, src, size, dir);
}
};
// This interface should be specified by each kind of target.
template <>
class TargetWrapper<TARGET(kHost)> {
public:
using stream_t = int;
using event_t = int;
static size_t num_devices() { return 0; }
static size_t maximum_stream() { return 0; }
static void CreateStream(stream_t* stream) {}
static void DestroyStream(const stream_t& stream) {}
static void CreateEvent(event_t* event) {}
static void DestroyEvent(const event_t& event) {}
static void RecordEvent(const event_t& event) {}
static void SyncEvent(const event_t& event) {}
static void StreamSync(const stream_t& stream) {}
static void* Malloc(size_t size);
static void Free(void* ptr);
static void MemcpySync(void* dst, const void* src, size_t size,
IoDirection dir) {}
IoDirection dir);
static void MemcpyAsync(void* dst, const void* src, size_t size,
IoDirection dir, const stream_t& stream) {
MemcpySync(dst, src, size, dir);
......
......@@ -95,13 +95,15 @@ class Tensor {
dims_ = other.dims_;
target_ = other.target_;
lod_ = other.lod_;
memory_size_ = other.memory_size_;
}
void CopyDataFrom(const Tensor& other) {
dims_ = other.dims_;
target_ = other.target_;
lod_ = other.lod_;
*buffer_ = *other.buffer_;
memory_size_ = other.memory_size_;
buffer_->CopyDataFrom(*other.buffer_, memory_size_);
}
TargetType target() const { return target_; }
......
......@@ -38,9 +38,39 @@ namespace lite {
// The DNN system is simple, and the architecture can not process that many data
// types as a compiler, or that will turn out to a chaos.
//
// We should make sure that supported data types should be registered here, and
// keep the quantity small. And avoid using some special data types as op's IO,
// such as some runtime cache, that need to be avoided.
// We should make sure that the supported data types be registered here, and
// keep the quantity small and avoid using some special data types as op's
// inputs or outputs, such as some runtime cache, those types can't be processed
// by the MIR.
//
// A tensor with different places(target, precision, data layout or device)
// should be treated as different types. Different types might be compatible
// with each other, for example, the `VoidTy` means any type, so any other types
// can be treated as a `VoidTy`.
//
// The Different Types can transform to others by adding some special
// transforming operators, for example, a DataLayoutTransformOp can convert a
// `TensorFp32NCHWTy` to a `TensorFp32NHWCTy`; a IoCopyOp can convert a
// `TensorFp32NCHWTy(kHost)` to `TensorFp32NCHWTy(kCUDA)`. There are many other
// convertions between different Types, but there are some unsupportted type
// convertions, for example, there is noway to convert a `UnsupportedTy` to a
// `TensorAnyTy`.
//
// We use Types to declare the definition of a kernel, each inputs' and outputs'
// arguments have a specific Types.
//
// REGISTER_LITE_KERNEL(mul, kHost, kFloat,
// paddle::lite::kernels::host::MulCompute, def)
// .BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
// TARGET(kHost))})
// .BindInput("Y", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
// TARGET(kHost))})
// .BindOutput("Out",
// {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(TARGET(kHost))})
// .Finalize();
//
// The above definition will be used in MIR by Type inference and uncompatible
// types check.
//
// TODO(Superjomn) Add operator/kernel-wise static checking to avoid unsupported
// type mixed in the system.
......
......@@ -26,12 +26,14 @@ using TargetW = TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t>;
template <>
void* TargetW::Malloc(size_t size) {
return new char[size];
void* ptr{};
CHECK_EQ(cudaSuccess, cudaMalloc(&ptr, size));
return ptr;
}
template <>
void TargetW::Free(void* ptr) {
delete[] static_cast<char*>(ptr);
CHECK_EQ(cudaSuccess, cudaFree(ptr));
}
template <>
......
cc_library(target_wrapper_host SRCS target_wrapper.cc)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/target_wrapper.h"
#include <cstring>
namespace paddle {
namespace lite {
void* TargetWrapper<TARGET(kHost)>::Malloc(size_t size) {
return new char[size];
}
void TargetWrapper<TARGET(kHost)>::Free(void* ptr) {
delete[] static_cast<char*>(ptr);
}
void TargetWrapper<TARGET(kHost)>::MemcpySync(void* dst, const void* src,
size_t size, IoDirection dir) {
memcpy(dst, src, size);
}
} // namespace lite
} // namespace paddle
......@@ -28,6 +28,8 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
auto &param = Param<operators::FeedParam>();
const Tensor &feed_item = param.feed_list->at(param.col);
param.out->CopyDataFrom(feed_item);
LOG(INFO) << "FEED input " << feed_item << " col " << param.col;
LOG(INFO) << "FEED output " << *param.out;
}
};
......@@ -40,6 +42,6 @@ REGISTER_LITE_KERNEL(feed, kHost, kFloat,
paddle::lite::kernels::host::FeedCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kHost))})
.Finalize();
......@@ -32,7 +32,7 @@ class FetchCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
auto& dst = fetch_list->at(param.col);
dst.CopyDataFrom(*param.input);
dst.ShareDataWith(*param.input);
}
};
......
......@@ -40,22 +40,23 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using param_t = operators::MulParam;
void Run() override {
auto& theparam = Param<operators::MulParam>();
core::dim2 x_shape(
{product(theparam.x->dims().begin(),
theparam.x->dims().begin() + theparam.x_num_col_dims),
product(theparam.x->dims().begin() + theparam.x_num_col_dims,
theparam.x->dims().end())});
auto& param = Param<operators::MulParam>();
core::dim2 x_shape({product(param.x->dims().begin(),
param.x->dims().begin() + param.x_num_col_dims),
product(param.x->dims().begin() + param.x_num_col_dims,
param.x->dims().end())});
core::dim2 y_shape(
{product(theparam.y->dims().begin(),
theparam.y->dims().begin() + theparam.x_num_col_dims),
product(theparam.y->dims().begin() + theparam.x_num_col_dims,
theparam.y->dims().end())});
core::dim2 y_shape({product(param.y->dims().begin(),
param.y->dims().begin() + param.x_num_col_dims),
product(param.y->dims().begin() + param.x_num_col_dims,
param.y->dims().end())});
mul_compute_eigen(theparam.x->data<float>(), x_shape.x, x_shape.y, //
theparam.y->data<float>(), y_shape.x, y_shape.y, //
theparam.output->mutable_data<float>());
mul_compute_eigen(param.x->data<float>(), x_shape.x, x_shape.y, //
param.y->data<float>(), y_shape.x, y_shape.y, //
param.output->mutable_data<float>());
LOG(INFO) << "MUL x " << *param.x;
LOG(INFO) << "MUL W " << *param.y;
LOG(INFO) << "MUL out " << *param.output;
}
virtual ~MulCompute() = default;
......
......@@ -36,10 +36,10 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using param_t = operators::MulParam;
void Run() override {
auto& theparam = Param<operators::ScaleParam>();
scale_compute(theparam.x->data<float>(), theparam.x->mutable_data<float>(),
product(theparam.x->dims()), theparam.scale, theparam.bias,
theparam.bias_after_scale);
auto& param = Param<operators::ScaleParam>();
scale_compute(param.x->data<float>(), param.output->mutable_data<float>(),
product(param.x->dims()), param.scale, param.bias,
param.bias_after_scale);
}
virtual ~ScaleCompute() = default;
......
......@@ -77,7 +77,7 @@ void TensorFromStream(std::istream &is, lite::Tensor *tensor) {
DO(INT64, int64_t);
#undef DO
default:
LOG(FATAL) << "unknown type";
LOG(FATAL) << "unknown type " << desc.data_type();
}
is.read(static_cast<char *>(buf), size);
......
......@@ -51,7 +51,6 @@ class FeedOp : public OpLite {
// NOTE need boost here
// TODO(Superjomn) drop the need of framework::op_desc
param_.col = boost::get<int>(opdesc.GetAttr("col"));
kernel_->SetParam(param_);
return true;
}
......
......@@ -21,3 +21,10 @@
#endif
#define LITE_UNIMPLEMENTED CHECK(false) << "Not Implemented";
#ifndef LIKELY
#define LIKELY(x) __builtin_expect(!!(x), 1)
#endif
#ifndef UNLIKELY
#define UNLIKELY(x) __built_expect(!!(x), 0)
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册