提交 8532bb4a 编写于 作者: S Superjomn

add io_copy op and kernel for cuda

上级 25990d29
...@@ -50,6 +50,22 @@ static void TargetFree(TargetType target, void* data) { ...@@ -50,6 +50,22 @@ static void TargetFree(TargetType target, void* data) {
} }
} }
static void TargetCopy(TargetType target, void* dst, const void* src,
size_t size) {
switch (static_cast<int>(target)) {
case static_cast<int>(TargetType::kX86):
case static_cast<int>(TargetType::kHost):
TargetWrapper<TARGET(kHost)>::MemcpySync(dst, src, size,
IoDirection::DtoD);
break;
case static_cast<int>(TargetType::kCUDA):
TargetWrapper<TARGET(kCUDA)>::MemcpySync(dst, src, size,
IoDirection::DtoD);
break;
}
}
// Memory buffer manager. // Memory buffer manager.
class Buffer { class Buffer {
public: public:
...@@ -57,6 +73,8 @@ class Buffer { ...@@ -57,6 +73,8 @@ class Buffer {
Buffer(TargetType target, size_t size) : space_(size), target_(target) {} Buffer(TargetType target, size_t size) : space_(size), target_(target) {}
void* data() const { return data_; } void* data() const { return data_; }
TargetType target() const { return target_; }
size_t space() const { return space_; }
void ResetLazy(TargetType target, size_t size) { void ResetLazy(TargetType target, size_t size) {
if (target != target_ || space_ < size) { if (target != target_ || space_ < size) {
...@@ -64,8 +82,8 @@ class Buffer { ...@@ -64,8 +82,8 @@ class Buffer {
} }
if (size < space_) return; if (size < space_) return;
data_ = TargetMalloc(target, size);
target_ = target; target_ = target;
data_ = TargetMalloc(target, size);
space_ = size; space_ = size;
} }
...@@ -83,10 +101,11 @@ class Buffer { ...@@ -83,10 +101,11 @@ class Buffer {
target_ = other.target_; target_ = other.target_;
ResizeLazy(nbytes); ResizeLazy(nbytes);
// TODO(Superjomn) support copy between different targets. // TODO(Superjomn) support copy between different targets.
memcpy(data_, other.data_, nbytes); TargetCopy(target_, data_, other.data_, nbytes);
} }
private: private:
// memory it actually malloced.
size_t space_{0}; size_t space_{0};
void* data_{nullptr}; void* data_{nullptr};
TargetType target_{TargetType::kHost}; TargetType target_{TargetType::kHost};
......
...@@ -45,4 +45,4 @@ TEST(Optimizer, test) { ...@@ -45,4 +45,4 @@ TEST(Optimizer, test) {
} // namespace paddle } // namespace paddle
USE_LITE_OP(fc); USE_LITE_OP(fc);
USE_LITE_KERNEL(fc, kHost, kFloat); USE_LITE_KERNEL(fc, kHost, kFloat, def);
...@@ -27,5 +27,20 @@ size_t Place::hash() const { ...@@ -27,5 +27,20 @@ size_t Place::hash() const {
return hash; return hash;
} }
bool operator<(const Place &a, const Place &b) {
if (a.target != b.target) return a.target < b.target;
if (a.precision != b.precision) return a.precision < b.precision;
if (a.layout != b.layout) return a.layout < b.layout;
if (a.device != b.device) return a.device < b.device;
return true;
}
std::string Place::DebugString() const {
std::stringstream os;
os << TargetToStr(target) << "/" << PrecisionToStr(precision) << "/"
<< DataLayoutToStr(layout);
return os.str();
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
\ No newline at end of file
...@@ -24,10 +24,22 @@ enum class TargetType : int { ...@@ -24,10 +24,22 @@ enum class TargetType : int {
kHost, kHost,
kX86, kX86,
kCUDA, kCUDA,
kLastAsPlaceHolder kAny, // any target
kLastAsPlaceHolder,
};
enum class PrecisionType : int {
kUnk = 0,
kFloat,
kInt8,
kAny, // any precision
kLastAsPlaceHolder,
};
enum class DataLayoutType : int {
kUnk = 0,
kNCHW,
kAny, // any data layout
kLastAsPlaceHolder,
}; };
enum class PrecisionType : int { kUnk = 0, kFloat, kInt8, kLastAsPlaceHolder };
enum class DataLayoutType : int { kUnk = 0, kNCHW, kLastAsPlaceHolder };
// Some helper macro to get a specific TargetType. // Some helper macro to get a specific TargetType.
#define TARGET(item__) paddle::lite::TargetType::item__ #define TARGET(item__) paddle::lite::TargetType::item__
...@@ -42,17 +54,18 @@ constexpr const int kNumPrecisions = ...@@ -42,17 +54,18 @@ constexpr const int kNumPrecisions =
constexpr const int kNumTargets = constexpr const int kNumTargets =
TARGET_VAL(kLastAsPlaceHolder) - TARGET_VAL(kHost); TARGET_VAL(kLastAsPlaceHolder) - TARGET_VAL(kHost);
static const std::string target2string[] = {"unk", "host", "x86", "cuda"}; static const std::string target2string[] = {"unk", "host", "x86", "cuda",
"any"};
static const std::string& TargetToStr(TargetType target) { static const std::string& TargetToStr(TargetType target) {
return target2string[static_cast<int>(target)]; return target2string[static_cast<int>(target)];
} }
static const std::string precision2string[] = {"unk", "float", "int8"}; static const std::string precision2string[] = {"unk", "float", "int8", "any"};
static const std::string& PrecisionToStr(PrecisionType precision) { static const std::string& PrecisionToStr(PrecisionType precision) {
return precision2string[static_cast<int>(precision)]; return precision2string[static_cast<int>(precision)];
} }
static const std::string datalayout2string[] = {"unk", "NCHW"}; static const std::string datalayout2string[] = {"unk", "NCHW", "any"};
static const std::string& DataLayoutToStr(DataLayoutType x) { static const std::string& DataLayoutToStr(DataLayoutType x) {
return datalayout2string[static_cast<int>(x)]; return datalayout2string[static_cast<int>(x)];
} }
...@@ -86,45 +99,30 @@ struct Place { ...@@ -86,45 +99,30 @@ struct Place {
bool operator!=(const Place& other) const { return !(*this == other); } bool operator!=(const Place& other) const { return !(*this == other); }
friend bool operator<(const Place& a, const Place& b) { friend bool operator<(const Place& a, const Place& b);
if (a.target != b.target) return a.target < b.target;
if (a.precision != b.precision) return a.precision < b.precision;
if (a.layout != b.layout) return a.layout < b.layout;
if (a.device != b.device) return a.device < b.device;
return true;
}
friend std::ostream& operator<<(std::ostream& os, const Place& other) { friend std::ostream& operator<<(std::ostream& os, const Place& other) {
os << other.DebugString(); os << other.DebugString();
return os; return os;
} }
std::string DebugString() const { std::string DebugString() const;
std::stringstream os;
os << TargetToStr(target) << "/" << PrecisionToStr(precision) << "/"
<< DataLayoutToStr(layout);
return os.str();
}
}; };
// Event sync for multi-stream devices like CUDA and OpenCL.
// For the devices without support of stream, leave it empty.
template <TargetType Target>
class Event {};
// Memory copy directions. // Memory copy directions.
enum class IoDirection { enum class IoDirection {
HtoH = 0, // Host to host HtoH = 0, // Host to host
HtoD, // Host to device HtoD, // Host to device
DtoH, // Device to host DtoH, // Device to host
DtoD, // Device to device
}; };
// This interface should be specified by each kind of target. // This interface should be specified by each kind of target.
template <TargetType Target> template <TargetType Target, typename StreamTy = int, typename EventTy = int>
class TargetWrapper { class TargetWrapper {
public: public:
using stream_t = int; using stream_t = StreamTy;
using event_t = Event<Target>; using event_t = EventTy;
static size_t num_devices() { return 0; } static size_t num_devices() { return 0; }
static size_t maximum_stream() { return 0; } static size_t maximum_stream() { return 0; }
...@@ -143,9 +141,10 @@ class TargetWrapper { ...@@ -143,9 +141,10 @@ class TargetWrapper {
static void* Malloc(size_t size) { return new char[size]; } static void* Malloc(size_t size) { return new char[size]; }
static void Free(void* ptr) { delete[] static_cast<char*>(ptr); } static void Free(void* ptr) { delete[] static_cast<char*>(ptr); }
static void MemcpySync(void* dst, void* src, size_t size, IoDirection dir) {} static void MemcpySync(void* dst, const void* src, size_t size,
static void MemcpyAsync(void* dst, void* src, size_t size, IoDirection dir) {}
const stream_t& stream, IoDirection dir) { static void MemcpyAsync(void* dst, const void* src, size_t size,
IoDirection dir, const stream_t& stream) {
MemcpySync(dst, src, size, dir); MemcpySync(dst, src, size, dir);
} }
}; };
......
...@@ -23,23 +23,6 @@ ...@@ -23,23 +23,6 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
template <TargetType Target>
class EventTree {
public:
using event_t = Event<Target>;
void AddChild(const event_t& event) { children_.push_back(event); }
void Sync() {
for (auto& event : children_) {
TargetWrapper<Target>::SyncEvent(event);
}
}
private:
std::vector<event_t> children_;
};
using DDim = std::vector<int64_t>; using DDim = std::vector<int64_t>;
static DDim SliceDims(const DDim& dims, int begin, int end) { static DDim SliceDims(const DDim& dims, int begin, int end) {
return DDim(dims.begin() + begin, dims.begin() + end - 1); return DDim(dims.begin() + begin, dims.begin() + end - 1);
...@@ -80,10 +63,30 @@ class Tensor { ...@@ -80,10 +63,30 @@ class Tensor {
template <typename T> template <typename T>
T* mutable_data() { T* mutable_data() {
buffer_->ResetLazy(target_, product(dims_) * sizeof(T)); memory_size_ = product(dims_) * sizeof(T);
buffer_->ResetLazy(target_, memory_size_);
return static_cast<T*>(buffer_->data()); return static_cast<T*>(buffer_->data());
} }
template <typename T>
T* mutable_data(TargetType target) {
target_ = target;
buffer_->ResetLazy(target, memory_size());
return static_cast<T*>(buffer_->data());
}
void* mutable_data(size_t memory_size) {
buffer_->ResetLazy(target_, memory_size);
return buffer_->data();
}
void* mutable_data(TargetType target, size_t memory_size) {
target_ = target;
return mutable_data(memory_size);
}
size_t memory_size() const { return memory_size_; }
bool IsInitialized() const { return buffer_->data(); } bool IsInitialized() const { return buffer_->data(); }
// Other share data to this. // Other share data to this.
...@@ -101,11 +104,14 @@ class Tensor { ...@@ -101,11 +104,14 @@ class Tensor {
*buffer_ = *other.buffer_; *buffer_ = *other.buffer_;
} }
TargetType target() const { return target_; }
private: private:
TargetType target_{TargetType::kHost}; TargetType target_{TargetType::kHost};
DDim dims_; DDim dims_;
std::shared_ptr<Buffer> buffer_; std::shared_ptr<Buffer> buffer_;
LoD lod_; LoD lod_;
size_t memory_size_{};
}; };
std::ostream& operator<<(std::ostream& os, const DDim& dims); std::ostream& operator<<(std::ostream& os, const DDim& dims);
......
...@@ -57,6 +57,9 @@ class DataTypeBase { ...@@ -57,6 +57,9 @@ class DataTypeBase {
Tensor_Fp32_NCHW, Tensor_Fp32_NCHW,
Tensor_Int8_NCHW, Tensor_Int8_NCHW,
Tensor_Int64_NCHW, Tensor_Int64_NCHW,
// Tensor_Any represents a Tensor with any place, data, layout. It is used
// in some IO kernels those doesn't care the data.
Tensor_Any,
NumTypes, // Must remains as last defined ID. NumTypes, // Must remains as last defined ID.
}; };
...@@ -137,6 +140,12 @@ class UnsupportedTy : public Type { ...@@ -137,6 +140,12 @@ class UnsupportedTy : public Type {
public: public:
UnsupportedTy() : Type(ID::Unsupported, "Unsupported", false /*is_tensor*/) {} UnsupportedTy() : Type(ID::Unsupported, "Unsupported", false /*is_tensor*/) {}
}; };
class TensorAnyTy : public Type {
public:
TensorAnyTy(TargetType target)
: Type(ID::Tensor_Any, "TensorAny", true, target, PRECISION(kAny),
DATALAYOUT(kAny)) {}
};
class TensorFp32NCHWTy : public Type { class TensorFp32NCHWTy : public Type {
public: public:
TensorFp32NCHWTy(TargetType target) TensorFp32NCHWTy(TargetType target)
......
...@@ -16,4 +16,63 @@ ...@@ -16,4 +16,63 @@
// Created by chunwei on 19-2-23. // Created by chunwei on 19-2-23.
// //
#include "target_wrapper.h" #include "paddle/fluid/lite/cuda/target_wrapper.h"
#include <glog/logging.h>
namespace paddle {
namespace lite {
using TargetW = TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t>;
template <>
void* TargetW::Malloc(size_t size) {
return new char[size];
}
template <>
void TargetW::Free(void* ptr) {
delete[] static_cast<char*>(ptr);
}
template <>
void TargetW::MemcpySync(void* dst, const void* src, size_t size,
IoDirection dir) {
switch (dir) {
case IoDirection::DtoD:
CHECK(cudaSuccess ==
cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice));
break;
case IoDirection::HtoD:
CHECK(cudaSuccess == cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice));
break;
case IoDirection::DtoH:
CHECK(cudaSuccess == cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost));
break;
default:
LOG(FATAL) << "Unsupported IoDirection " << static_cast<int>(dir);
}
}
template <>
void TargetW::MemcpyAsync(void* dst, const void* src, size_t size,
IoDirection dir, const stream_t& stream) {
switch (dir) {
case IoDirection::DtoD:
CHECK(cudaSuccess ==
cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream));
break;
case IoDirection::HtoD:
CHECK(cudaSuccess ==
cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, stream));
break;
case IoDirection::DtoH:
CHECK(cudaSuccess ==
cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, stream));
break;
default:
LOG(FATAL) << "Unsupported IoDirection " << static_cast<int>(dir);
}
}
} // namespace lite
} // namespace paddle
...@@ -12,10 +12,17 @@ ...@@ -12,10 +12,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cuda.h>
#include <cuda_runtime.h>
#include "paddle/fluid/lite/core/target_wrapper.h"
namespace paddle { namespace paddle {
namespace framework {
namespace lite { namespace lite {
namespace cuda {} // namespace cuda namespace cuda {
using TargetWrap = TargetWrapper<TARGET(kHost)>;
using TargetWrapAsync = TargetWrapper<TARGET(kHost), cudaStream_t, cudaEvent_t>;
} // namespace cuda
} // namespace lite } // namespace lite
} // namespace framework
} // namespace paddle } // namespace paddle
cc_library(mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite) cc_library(mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite)
cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS tensor_lite)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/cuda/target_wrapper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using TargetW = TargetWrapper<TARGET(kHost), cudaStream_t, cudaEvent_t>;
// Host to CUDA memory.
void CopyFromHostSync(void* target, const void* source, size_t size) {
TargetW::MemcpySync(target, source, size, IoDirection::HtoD);
}
void CopyFromHostAsync(void* target, const void* source, size_t size,
TargetW::stream_t stream) {
TargetW::MemcpyAsync(target, source, size, IoDirection::HtoD, stream);
}
// Host to Host memory.
void CopyToHostSync(void* target, const void* source, size_t size) {
TargetW::MemcpySync(target, source, size, IoDirection::DtoH);
}
/*
* This kernel copies a tensor from host to CUDA space.
*/
class IoCopyHostToCudaCompute
: public OpKernel<TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
void Run() override {
auto& param = Param<operators::IoCopyParam>();
CHECK(param.x->target() == TARGET(kHost) ||
param.x->target() == TARGET(kX86));
auto* data = param.y->mutable_data(target(), param.x->memory_size());
CopyFromHostSync(data, param.x->data<void>(), param.x->memory_size());
}
};
/*
* This kernel copies a tensor from CUDA to host space.
*/
class IoCopyCudaToHostCompute
: public OpKernel<TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
void Run() override {
auto& param = Param<operators::IoCopyParam>();
CHECK(param.x->target() == TARGET(kCUDA));
auto* data = param.y->mutable_data(TARGET(kHost), param.x->memory_size());
CopyToHostSync(data, param.x->data<void>(), param.x->memory_size());
}
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny,
paddle::lite::kernels::cuda::IoCopyHostToCudaCompute,
host_to_device)
.BindInput("Input", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny,
paddle::lite::kernels::cuda::IoCopyCudaToHostCompute,
device_to_host)
.BindInput("Input", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kCUDA))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kHost))})
.Finalize();
...@@ -3,6 +3,7 @@ cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite) ...@@ -3,6 +3,7 @@ cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite)
cc_library(mul_op_lite SRCS mul_op.cc DEPS op_lite) cc_library(mul_op_lite SRCS mul_op.cc DEPS op_lite)
cc_library(scale_op_lite SRCS scale_op.cc DEPS op_lite) cc_library(scale_op_lite SRCS scale_op.cc DEPS op_lite)
cc_library(feed_op_lite SRCS feed_op.cc DEPS op_lite) cc_library(feed_op_lite SRCS feed_op.cc DEPS op_lite)
cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS op_lite)
cc_library(op_params_lite SRCS op_params.cc DEPS tensor_lite) cc_library(op_params_lite SRCS op_params.cc DEPS tensor_lite)
cc_library(ops_lite DEPS cc_library(ops_lite DEPS
...@@ -11,6 +12,7 @@ cc_library(ops_lite DEPS ...@@ -11,6 +12,7 @@ cc_library(ops_lite DEPS
mul_op_lite mul_op_lite
scale_op_lite scale_op_lite
feed_op_lite feed_op_lite
io_copy_op_lite
) )
cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite fc_compute_host) cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite fc_compute_host)
...@@ -64,7 +64,13 @@ struct ScaleParam { ...@@ -64,7 +64,13 @@ struct ScaleParam {
bool bias_after_scale{true}; bool bias_after_scale{true};
}; };
using param_t = variant<FeedParam, FcParam, ReluParam, MulParam, ScaleParam>; struct IoCopyParam {
const Tensor* x{};
Tensor* y{};
};
using param_t =
variant<FeedParam, FcParam, ReluParam, MulParam, ScaleParam, IoCopyParam>;
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
......
...@@ -20,8 +20,8 @@ namespace paddle { ...@@ -20,8 +20,8 @@ namespace paddle {
namespace lite { namespace lite {
template <> template <>
void TargetWrapper<TARGET(kX86)>::MemcpySync(void *dst, void *src, size_t size, void TargetWrapper<TARGET(kX86)>::MemcpySync(void *dst, const void *src,
IoDirection dir) { size_t size, IoDirection dir) {
std::copy_n(reinterpret_cast<uint8_t *>(src), size, std::copy_n(reinterpret_cast<uint8_t *>(src), size,
reinterpret_cast<uint8_t *>(dst)); reinterpret_cast<uint8_t *>(dst));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册