提交 1fb93746 编写于 作者: S superjomn

correct the running logic of host model

- make the target wrapper for host works
- code clean
上级 ca629eb4
add_subdirectory(core) add_subdirectory(core)
add_subdirectory(x86) add_subdirectory(x86)
add_subdirectory(host)
add_subdirectory(cuda) add_subdirectory(cuda)
add_subdirectory(operators) add_subdirectory(operators)
add_subdirectory(kernels) add_subdirectory(kernels)
......
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite) cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host)
cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite) cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite)
...@@ -58,7 +58,7 @@ class Predictor { ...@@ -58,7 +58,7 @@ class Predictor {
const Tensor* GetOutput(size_t offset) { const Tensor* GetOutput(size_t offset) {
auto* _fetch_list = program_->exec_scope()->FindVar("fetch"); auto* _fetch_list = program_->exec_scope()->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope"; CHECK(_fetch_list) << "no fatch variable in exec_scope";
auto fetch_list = _fetch_list->Get<std::vector<Tensor>>(); auto& fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow";
return &fetch_list.at(offset); return &fetch_list.at(offset);
} }
......
...@@ -28,8 +28,22 @@ TEST(CXXApi, test) { ...@@ -28,8 +28,22 @@ TEST(CXXApi, test) {
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize({100, 100}); input_tensor->Resize({100, 100});
input_tensor->mutable_data<float>(); auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 100 * 100; i++) {
data[i] = i;
}
LOG(INFO) << "input " << input_tensor;
LOG(INFO) << "input " << *input_tensor;
predictor.Run(); predictor.Run();
auto* out = predictor.GetOutput(0);
LOG(INFO) << out << " memory size " << out->memory_size();
LOG(INFO) << "out " << out->data<float>()[0];
LOG(INFO) << "out " << out->data<float>()[1];
LOG(INFO) << "dims " << out->dims();
LOG(INFO) << "out " << *out;
} }
} // namespace lite } // namespace lite
......
...@@ -65,6 +65,13 @@ class KernelBase { ...@@ -65,6 +65,13 @@ class KernelBase {
virtual ~KernelBase() = default; virtual ~KernelBase() = default;
std::string DebugString() const {
std::stringstream ss;
ss << op_type() << ":" << TargetToStr(target()) << "/"
<< PrecisionToStr(precision()) << "/" << DataLayoutToStr(layout());
return ss.str();
}
protected: protected:
std::unique_ptr<KernelContext> context_; std::unique_ptr<KernelContext> context_;
mutable operators::param_t param_; mutable operators::param_t param_;
......
...@@ -21,18 +21,16 @@ namespace lite { ...@@ -21,18 +21,16 @@ namespace lite {
static void* TargetMalloc(TargetType target, size_t size) { static void* TargetMalloc(TargetType target, size_t size) {
void* data{nullptr}; void* data{nullptr};
switch (static_cast<int>(target)) { switch (target) {
case static_cast<int>(TargetType::kX86): case TargetType::kHost:
data = TargetWrapper<TARGET(kX86)>::Malloc(size); case TargetType::kX86:
data = TargetWrapper<TARGET(kHost)>::Malloc(size);
break; break;
case static_cast<int>(TargetType::kCUDA): case TargetType::kCUDA:
data = TargetWrapper<TARGET(kCUDA)>::Malloc(size); data = TargetWrapper<TARGET(kCUDA)>::Malloc(size);
break; break;
case static_cast<int>(TargetType::kHost):
data = TargetWrapper<TARGET(kHost)>::Malloc(size);
break;
default: default:
LOG(FATAL) << "Unknown type"; LOG(FATAL) << "Unknown supported target " << TargetToStr(target);
} }
return data; return data;
} }
...@@ -52,17 +50,19 @@ static void TargetFree(TargetType target, void* data) { ...@@ -52,17 +50,19 @@ static void TargetFree(TargetType target, void* data) {
static void TargetCopy(TargetType target, void* dst, const void* src, static void TargetCopy(TargetType target, void* dst, const void* src,
size_t size) { size_t size) {
switch (static_cast<int>(target)) { switch (target) {
case static_cast<int>(TargetType::kX86): case TargetType::kX86:
case static_cast<int>(TargetType::kHost): case TargetType::kHost:
TargetWrapper<TARGET(kHost)>::MemcpySync(dst, src, size, TargetWrapper<TARGET(kHost)>::MemcpySync(dst, src, size,
IoDirection::DtoD); IoDirection::DtoD);
break; break;
case static_cast<int>(TargetType::kCUDA): case TargetType::kCUDA:
TargetWrapper<TARGET(kCUDA)>::MemcpySync(dst, src, size, TargetWrapper<TARGET(kCUDA)>::MemcpySync(dst, src, size,
IoDirection::DtoD); IoDirection::DtoD);
break; break;
default:
LOG(FATAL) << "unsupported type";
} }
} }
...@@ -79,12 +79,10 @@ class Buffer { ...@@ -79,12 +79,10 @@ class Buffer {
void ResetLazy(TargetType target, size_t size) { void ResetLazy(TargetType target, size_t size) {
if (target != target_ || space_ < size) { if (target != target_ || space_ < size) {
Free(); Free();
data_ = TargetMalloc(target, size);
target_ = target;
space_ = size;
} }
if (size < space_) return;
target_ = target;
data_ = TargetMalloc(target, size);
space_ = size;
} }
void ResizeLazy(size_t size) { ResetLazy(target_, size); } void ResizeLazy(size_t size) { ResetLazy(target_, size); }
......
...@@ -60,9 +60,6 @@ class SSAGraph : GraphBase { ...@@ -60,9 +60,6 @@ class SSAGraph : GraphBase {
op->SetValidPlaces(valid_places); op->SetValidPlaces(valid_places);
auto &new_node = node_storage_.back(); auto &new_node = node_storage_.back();
auto kernels = op->CreateKernels(valid_places); auto kernels = op->CreateKernels(valid_places);
for (auto &kernel : kernels) {
op->AttachKernel(kernel.get());
}
node_storage_.back().AsInstruct(op->op_type_, std::move(kernels), op, node_storage_.back().AsInstruct(op->op_type_, std::move(kernels), op,
op->op_info()); op->op_info());
......
...@@ -29,6 +29,7 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels( ...@@ -29,6 +29,7 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
(kernel_type.empty() ? op_type_ : kernel_type), place.target, (kernel_type.empty() ? op_type_ : kernel_type), place.target,
place.precision); place.precision);
for (auto &&it : ks) { for (auto &&it : ks) {
AttachKernel(it.get());
kernels.emplace_back(std::move(it)); kernels.emplace_back(std::move(it));
} }
} }
......
...@@ -105,6 +105,11 @@ struct Instruction { ...@@ -105,6 +105,11 @@ struct Instruction {
void Run() { void Run() {
CHECK(op_); CHECK(op_);
CHECK(kernel_); CHECK(kernel_);
LOG(INFO) << "running kernel> " << kernel_->DebugString();
if (UNLIKELY(first_epoch_)) {
first_epoch_ = false;
op_->CheckShape();
}
op_->InferShape(); op_->InferShape();
kernel_->Run(); kernel_->Run();
} }
...@@ -112,6 +117,7 @@ struct Instruction { ...@@ -112,6 +117,7 @@ struct Instruction {
private: private:
std::shared_ptr<OpLite> op_; std::shared_ptr<OpLite> op_;
std::unique_ptr<KernelBase> kernel_; std::unique_ptr<KernelBase> kernel_;
bool first_epoch_{true};
}; };
/* /*
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <glog/logging.h>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -138,11 +139,48 @@ class TargetWrapper { ...@@ -138,11 +139,48 @@ class TargetWrapper {
static void StreamSync(const stream_t& stream) {} static void StreamSync(const stream_t& stream) {}
static void* Malloc(size_t size) { return new char[size]; } static void* Malloc(size_t size) {
static void Free(void* ptr) { delete[] static_cast<char*>(ptr); } LOG(FATAL) << "Unimplemented malloc for " << TargetToStr(Target);
return nullptr;
}
static void Free(void* ptr) { LOG(FATAL) << "Unimplemented"; }
static void MemcpySync(void* dst, const void* src, size_t size,
IoDirection dir) {
LOG(FATAL) << "Unimplemented";
}
static void MemcpyAsync(void* dst, const void* src, size_t size,
IoDirection dir, const stream_t& stream) {
MemcpySync(dst, src, size, dir);
}
};
// This interface should be specified by each kind of target.
template <>
class TargetWrapper<TARGET(kHost)> {
public:
using stream_t = int;
using event_t = int;
static size_t num_devices() { return 0; }
static size_t maximum_stream() { return 0; }
static void CreateStream(stream_t* stream) {}
static void DestroyStream(const stream_t& stream) {}
static void CreateEvent(event_t* event) {}
static void DestroyEvent(const event_t& event) {}
static void RecordEvent(const event_t& event) {}
static void SyncEvent(const event_t& event) {}
static void StreamSync(const stream_t& stream) {}
static void* Malloc(size_t size);
static void Free(void* ptr);
static void MemcpySync(void* dst, const void* src, size_t size, static void MemcpySync(void* dst, const void* src, size_t size,
IoDirection dir) {} IoDirection dir);
static void MemcpyAsync(void* dst, const void* src, size_t size, static void MemcpyAsync(void* dst, const void* src, size_t size,
IoDirection dir, const stream_t& stream) { IoDirection dir, const stream_t& stream) {
MemcpySync(dst, src, size, dir); MemcpySync(dst, src, size, dir);
......
...@@ -95,13 +95,15 @@ class Tensor { ...@@ -95,13 +95,15 @@ class Tensor {
dims_ = other.dims_; dims_ = other.dims_;
target_ = other.target_; target_ = other.target_;
lod_ = other.lod_; lod_ = other.lod_;
memory_size_ = other.memory_size_;
} }
void CopyDataFrom(const Tensor& other) { void CopyDataFrom(const Tensor& other) {
dims_ = other.dims_; dims_ = other.dims_;
target_ = other.target_; target_ = other.target_;
lod_ = other.lod_; lod_ = other.lod_;
*buffer_ = *other.buffer_; memory_size_ = other.memory_size_;
buffer_->CopyDataFrom(*other.buffer_, memory_size_);
} }
TargetType target() const { return target_; } TargetType target() const { return target_; }
......
...@@ -38,9 +38,39 @@ namespace lite { ...@@ -38,9 +38,39 @@ namespace lite {
// The DNN system is simple, and the architecture can not process that many data // The DNN system is simple, and the architecture can not process that many data
// types as a compiler, or that will turn out to a chaos. // types as a compiler, or that will turn out to a chaos.
// //
// We should make sure that supported data types should be registered here, and // We should make sure that the supported data types be registered here, and
// keep the quantity small. And avoid using some special data types as op's IO, // keep the quantity small and avoid using some special data types as op's
// such as some runtime cache, that need to be avoided. // inputs or outputs, such as some runtime cache, those types can't be processed
// by the MIR.
//
// A tensor with different places(target, precision, data layout or device)
// should be treated as different types. Different types might be compatible
// with each other, for example, the `VoidTy` means any type, so any other types
// can be treated as a `VoidTy`.
//
// The Different Types can transform to others by adding some special
// transforming operators, for example, a DataLayoutTransformOp can convert a
// `TensorFp32NCHWTy` to a `TensorFp32NHWCTy`; a IoCopyOp can convert a
// `TensorFp32NCHWTy(kHost)` to `TensorFp32NCHWTy(kCUDA)`. There are many other
// convertions between different Types, but there are some unsupportted type
// convertions, for example, there is noway to convert a `UnsupportedTy` to a
// `TensorAnyTy`.
//
// We use Types to declare the definition of a kernel, each inputs' and outputs'
// arguments have a specific Types.
//
// REGISTER_LITE_KERNEL(mul, kHost, kFloat,
// paddle::lite::kernels::host::MulCompute, def)
// .BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
// TARGET(kHost))})
// .BindInput("Y", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
// TARGET(kHost))})
// .BindOutput("Out",
// {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(TARGET(kHost))})
// .Finalize();
//
// The above definition will be used in MIR by Type inference and uncompatible
// types check.
// //
// TODO(Superjomn) Add operator/kernel-wise static checking to avoid unsupported // TODO(Superjomn) Add operator/kernel-wise static checking to avoid unsupported
// type mixed in the system. // type mixed in the system.
......
...@@ -26,12 +26,14 @@ using TargetW = TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t>; ...@@ -26,12 +26,14 @@ using TargetW = TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t>;
template <> template <>
void* TargetW::Malloc(size_t size) { void* TargetW::Malloc(size_t size) {
return new char[size]; void* ptr{};
CHECK_EQ(cudaSuccess, cudaMalloc(&ptr, size));
return ptr;
} }
template <> template <>
void TargetW::Free(void* ptr) { void TargetW::Free(void* ptr) {
delete[] static_cast<char*>(ptr); CHECK_EQ(cudaSuccess, cudaFree(ptr));
} }
template <> template <>
......
cc_library(target_wrapper_host SRCS target_wrapper.cc)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/target_wrapper.h"
#include <cstring>
namespace paddle {
namespace lite {
void* TargetWrapper<TARGET(kHost)>::Malloc(size_t size) {
return new char[size];
}
void TargetWrapper<TARGET(kHost)>::Free(void* ptr) {
delete[] static_cast<char*>(ptr);
}
void TargetWrapper<TARGET(kHost)>::MemcpySync(void* dst, const void* src,
size_t size, IoDirection dir) {
memcpy(dst, src, size);
}
} // namespace lite
} // namespace paddle
...@@ -28,6 +28,8 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -28,6 +28,8 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
auto &param = Param<operators::FeedParam>(); auto &param = Param<operators::FeedParam>();
const Tensor &feed_item = param.feed_list->at(param.col); const Tensor &feed_item = param.feed_list->at(param.col);
param.out->CopyDataFrom(feed_item); param.out->CopyDataFrom(feed_item);
LOG(INFO) << "FEED input " << feed_item << " col " << param.col;
LOG(INFO) << "FEED output " << *param.out;
} }
}; };
...@@ -40,6 +42,6 @@ REGISTER_LITE_KERNEL(feed, kHost, kFloat, ...@@ -40,6 +42,6 @@ REGISTER_LITE_KERNEL(feed, kHost, kFloat,
paddle::lite::kernels::host::FeedCompute, def) paddle::lite::kernels::host::FeedCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>( .BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kHost))}) TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( .BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kHost))}) TARGET(kHost))})
.Finalize(); .Finalize();
...@@ -32,7 +32,7 @@ class FetchCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -32,7 +32,7 @@ class FetchCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} }
auto& dst = fetch_list->at(param.col); auto& dst = fetch_list->at(param.col);
dst.CopyDataFrom(*param.input); dst.ShareDataWith(*param.input);
} }
}; };
......
...@@ -40,22 +40,23 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -40,22 +40,23 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using param_t = operators::MulParam; using param_t = operators::MulParam;
void Run() override { void Run() override {
auto& theparam = Param<operators::MulParam>(); auto& param = Param<operators::MulParam>();
core::dim2 x_shape( core::dim2 x_shape({product(param.x->dims().begin(),
{product(theparam.x->dims().begin(), param.x->dims().begin() + param.x_num_col_dims),
theparam.x->dims().begin() + theparam.x_num_col_dims), product(param.x->dims().begin() + param.x_num_col_dims,
product(theparam.x->dims().begin() + theparam.x_num_col_dims, param.x->dims().end())});
theparam.x->dims().end())});
core::dim2 y_shape( core::dim2 y_shape({product(param.y->dims().begin(),
{product(theparam.y->dims().begin(), param.y->dims().begin() + param.x_num_col_dims),
theparam.y->dims().begin() + theparam.x_num_col_dims), product(param.y->dims().begin() + param.x_num_col_dims,
product(theparam.y->dims().begin() + theparam.x_num_col_dims, param.y->dims().end())});
theparam.y->dims().end())});
mul_compute_eigen(theparam.x->data<float>(), x_shape.x, x_shape.y, // mul_compute_eigen(param.x->data<float>(), x_shape.x, x_shape.y, //
theparam.y->data<float>(), y_shape.x, y_shape.y, // param.y->data<float>(), y_shape.x, y_shape.y, //
theparam.output->mutable_data<float>()); param.output->mutable_data<float>());
LOG(INFO) << "MUL x " << *param.x;
LOG(INFO) << "MUL W " << *param.y;
LOG(INFO) << "MUL out " << *param.output;
} }
virtual ~MulCompute() = default; virtual ~MulCompute() = default;
......
...@@ -36,10 +36,10 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -36,10 +36,10 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using param_t = operators::MulParam; using param_t = operators::MulParam;
void Run() override { void Run() override {
auto& theparam = Param<operators::ScaleParam>(); auto& param = Param<operators::ScaleParam>();
scale_compute(theparam.x->data<float>(), theparam.x->mutable_data<float>(), scale_compute(param.x->data<float>(), param.output->mutable_data<float>(),
product(theparam.x->dims()), theparam.scale, theparam.bias, product(param.x->dims()), param.scale, param.bias,
theparam.bias_after_scale); param.bias_after_scale);
} }
virtual ~ScaleCompute() = default; virtual ~ScaleCompute() = default;
......
...@@ -77,7 +77,7 @@ void TensorFromStream(std::istream &is, lite::Tensor *tensor) { ...@@ -77,7 +77,7 @@ void TensorFromStream(std::istream &is, lite::Tensor *tensor) {
DO(INT64, int64_t); DO(INT64, int64_t);
#undef DO #undef DO
default: default:
LOG(FATAL) << "unknown type"; LOG(FATAL) << "unknown type " << desc.data_type();
} }
is.read(static_cast<char *>(buf), size); is.read(static_cast<char *>(buf), size);
......
...@@ -51,7 +51,6 @@ class FeedOp : public OpLite { ...@@ -51,7 +51,6 @@ class FeedOp : public OpLite {
// NOTE need boost here // NOTE need boost here
// TODO(Superjomn) drop the need of framework::op_desc // TODO(Superjomn) drop the need of framework::op_desc
param_.col = boost::get<int>(opdesc.GetAttr("col")); param_.col = boost::get<int>(opdesc.GetAttr("col"));
kernel_->SetParam(param_);
return true; return true;
} }
......
...@@ -21,3 +21,10 @@ ...@@ -21,3 +21,10 @@
#endif #endif
#define LITE_UNIMPLEMENTED CHECK(false) << "Not Implemented"; #define LITE_UNIMPLEMENTED CHECK(false) << "Not Implemented";
#ifndef LIKELY
#define LIKELY(x) __builtin_expect(!!(x), 1)
#endif
#ifndef UNLIKELY
#define UNLIKELY(x) __built_expect(!!(x), 0)
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册