From 8532bb4afa95e1a945a82ba77a9f7816425b32da Mon Sep 17 00:00:00 2001 From: Superjomn Date: Sun, 21 Apr 2019 16:45:19 +0800 Subject: [PATCH] add io_copy op and kernel for cuda --- paddle/fluid/lite/core/memory.h | 23 ++++- paddle/fluid/lite/core/optimizer_test.cc | 2 +- paddle/fluid/lite/core/target_wrapper.cc | 15 +++ paddle/fluid/lite/core/target_wrapper.h | 59 ++++++------ paddle/fluid/lite/core/tensor.h | 42 +++++---- paddle/fluid/lite/core/type_system.h | 9 ++ paddle/fluid/lite/cuda/target_wrapper.cc | 61 ++++++++++++- paddle/fluid/lite/cuda/target_wrapper.h | 13 ++- paddle/fluid/lite/kernels/cuda/CMakeLists.txt | 1 + .../lite/kernels/cuda/io_copy_compute.cc | 91 +++++++++++++++++++ paddle/fluid/lite/operators/CMakeLists.txt | 2 + paddle/fluid/lite/operators/op_params.h | 8 +- paddle/fluid/lite/x86/target_wrapper.cc | 4 +- 13 files changed, 272 insertions(+), 58 deletions(-) create mode 100644 paddle/fluid/lite/kernels/cuda/io_copy_compute.cc diff --git a/paddle/fluid/lite/core/memory.h b/paddle/fluid/lite/core/memory.h index 4cb46607f..ccfe412ed 100644 --- a/paddle/fluid/lite/core/memory.h +++ b/paddle/fluid/lite/core/memory.h @@ -50,6 +50,22 @@ static void TargetFree(TargetType target, void* data) { } } +static void TargetCopy(TargetType target, void* dst, const void* src, + size_t size) { + switch (static_cast(target)) { + case static_cast(TargetType::kX86): + case static_cast(TargetType::kHost): + TargetWrapper::MemcpySync(dst, src, size, + IoDirection::DtoD); + break; + + case static_cast(TargetType::kCUDA): + TargetWrapper::MemcpySync(dst, src, size, + IoDirection::DtoD); + break; + } +} + // Memory buffer manager. class Buffer { public: @@ -57,6 +73,8 @@ class Buffer { Buffer(TargetType target, size_t size) : space_(size), target_(target) {} void* data() const { return data_; } + TargetType target() const { return target_; } + size_t space() const { return space_; } void ResetLazy(TargetType target, size_t size) { if (target != target_ || space_ < size) { @@ -64,8 +82,8 @@ class Buffer { } if (size < space_) return; - data_ = TargetMalloc(target, size); target_ = target; + data_ = TargetMalloc(target, size); space_ = size; } @@ -83,10 +101,11 @@ class Buffer { target_ = other.target_; ResizeLazy(nbytes); // TODO(Superjomn) support copy between different targets. - memcpy(data_, other.data_, nbytes); + TargetCopy(target_, data_, other.data_, nbytes); } private: + // memory it actually malloced. size_t space_{0}; void* data_{nullptr}; TargetType target_{TargetType::kHost}; diff --git a/paddle/fluid/lite/core/optimizer_test.cc b/paddle/fluid/lite/core/optimizer_test.cc index b1acd6975..a301f996a 100644 --- a/paddle/fluid/lite/core/optimizer_test.cc +++ b/paddle/fluid/lite/core/optimizer_test.cc @@ -45,4 +45,4 @@ TEST(Optimizer, test) { } // namespace paddle USE_LITE_OP(fc); -USE_LITE_KERNEL(fc, kHost, kFloat); +USE_LITE_KERNEL(fc, kHost, kFloat, def); diff --git a/paddle/fluid/lite/core/target_wrapper.cc b/paddle/fluid/lite/core/target_wrapper.cc index 176a2cb24..4978f622f 100644 --- a/paddle/fluid/lite/core/target_wrapper.cc +++ b/paddle/fluid/lite/core/target_wrapper.cc @@ -27,5 +27,20 @@ size_t Place::hash() const { return hash; } +bool operator<(const Place &a, const Place &b) { + if (a.target != b.target) return a.target < b.target; + if (a.precision != b.precision) return a.precision < b.precision; + if (a.layout != b.layout) return a.layout < b.layout; + if (a.device != b.device) return a.device < b.device; + return true; +} + +std::string Place::DebugString() const { + std::stringstream os; + os << TargetToStr(target) << "/" << PrecisionToStr(precision) << "/" + << DataLayoutToStr(layout); + return os.str(); +} + } // namespace lite } // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/lite/core/target_wrapper.h b/paddle/fluid/lite/core/target_wrapper.h index a60986c26..c8b7aa18c 100644 --- a/paddle/fluid/lite/core/target_wrapper.h +++ b/paddle/fluid/lite/core/target_wrapper.h @@ -24,10 +24,22 @@ enum class TargetType : int { kHost, kX86, kCUDA, - kLastAsPlaceHolder + kAny, // any target + kLastAsPlaceHolder, +}; +enum class PrecisionType : int { + kUnk = 0, + kFloat, + kInt8, + kAny, // any precision + kLastAsPlaceHolder, +}; +enum class DataLayoutType : int { + kUnk = 0, + kNCHW, + kAny, // any data layout + kLastAsPlaceHolder, }; -enum class PrecisionType : int { kUnk = 0, kFloat, kInt8, kLastAsPlaceHolder }; -enum class DataLayoutType : int { kUnk = 0, kNCHW, kLastAsPlaceHolder }; // Some helper macro to get a specific TargetType. #define TARGET(item__) paddle::lite::TargetType::item__ @@ -42,17 +54,18 @@ constexpr const int kNumPrecisions = constexpr const int kNumTargets = TARGET_VAL(kLastAsPlaceHolder) - TARGET_VAL(kHost); -static const std::string target2string[] = {"unk", "host", "x86", "cuda"}; +static const std::string target2string[] = {"unk", "host", "x86", "cuda", + "any"}; static const std::string& TargetToStr(TargetType target) { return target2string[static_cast(target)]; } -static const std::string precision2string[] = {"unk", "float", "int8"}; +static const std::string precision2string[] = {"unk", "float", "int8", "any"}; static const std::string& PrecisionToStr(PrecisionType precision) { return precision2string[static_cast(precision)]; } -static const std::string datalayout2string[] = {"unk", "NCHW"}; +static const std::string datalayout2string[] = {"unk", "NCHW", "any"}; static const std::string& DataLayoutToStr(DataLayoutType x) { return datalayout2string[static_cast(x)]; } @@ -86,45 +99,30 @@ struct Place { bool operator!=(const Place& other) const { return !(*this == other); } - friend bool operator<(const Place& a, const Place& b) { - if (a.target != b.target) return a.target < b.target; - if (a.precision != b.precision) return a.precision < b.precision; - if (a.layout != b.layout) return a.layout < b.layout; - if (a.device != b.device) return a.device < b.device; - return true; - } + friend bool operator<(const Place& a, const Place& b); friend std::ostream& operator<<(std::ostream& os, const Place& other) { os << other.DebugString(); return os; } - std::string DebugString() const { - std::stringstream os; - os << TargetToStr(target) << "/" << PrecisionToStr(precision) << "/" - << DataLayoutToStr(layout); - return os.str(); - } + std::string DebugString() const; }; -// Event sync for multi-stream devices like CUDA and OpenCL. -// For the devices without support of stream, leave it empty. -template -class Event {}; - // Memory copy directions. enum class IoDirection { HtoH = 0, // Host to host HtoD, // Host to device DtoH, // Device to host + DtoD, // Device to device }; // This interface should be specified by each kind of target. -template +template class TargetWrapper { public: - using stream_t = int; - using event_t = Event; + using stream_t = StreamTy; + using event_t = EventTy; static size_t num_devices() { return 0; } static size_t maximum_stream() { return 0; } @@ -143,9 +141,10 @@ class TargetWrapper { static void* Malloc(size_t size) { return new char[size]; } static void Free(void* ptr) { delete[] static_cast(ptr); } - static void MemcpySync(void* dst, void* src, size_t size, IoDirection dir) {} - static void MemcpyAsync(void* dst, void* src, size_t size, - const stream_t& stream, IoDirection dir) { + static void MemcpySync(void* dst, const void* src, size_t size, + IoDirection dir) {} + static void MemcpyAsync(void* dst, const void* src, size_t size, + IoDirection dir, const stream_t& stream) { MemcpySync(dst, src, size, dir); } }; diff --git a/paddle/fluid/lite/core/tensor.h b/paddle/fluid/lite/core/tensor.h index a9a129cda..864e5db9e 100644 --- a/paddle/fluid/lite/core/tensor.h +++ b/paddle/fluid/lite/core/tensor.h @@ -23,23 +23,6 @@ namespace paddle { namespace lite { -template -class EventTree { - public: - using event_t = Event; - - void AddChild(const event_t& event) { children_.push_back(event); } - - void Sync() { - for (auto& event : children_) { - TargetWrapper::SyncEvent(event); - } - } - - private: - std::vector children_; -}; - using DDim = std::vector; static DDim SliceDims(const DDim& dims, int begin, int end) { return DDim(dims.begin() + begin, dims.begin() + end - 1); @@ -80,10 +63,30 @@ class Tensor { template T* mutable_data() { - buffer_->ResetLazy(target_, product(dims_) * sizeof(T)); + memory_size_ = product(dims_) * sizeof(T); + buffer_->ResetLazy(target_, memory_size_); + return static_cast(buffer_->data()); + } + + template + T* mutable_data(TargetType target) { + target_ = target; + buffer_->ResetLazy(target, memory_size()); return static_cast(buffer_->data()); } + void* mutable_data(size_t memory_size) { + buffer_->ResetLazy(target_, memory_size); + return buffer_->data(); + } + + void* mutable_data(TargetType target, size_t memory_size) { + target_ = target; + return mutable_data(memory_size); + } + + size_t memory_size() const { return memory_size_; } + bool IsInitialized() const { return buffer_->data(); } // Other share data to this. @@ -101,11 +104,14 @@ class Tensor { *buffer_ = *other.buffer_; } + TargetType target() const { return target_; } + private: TargetType target_{TargetType::kHost}; DDim dims_; std::shared_ptr buffer_; LoD lod_; + size_t memory_size_{}; }; std::ostream& operator<<(std::ostream& os, const DDim& dims); diff --git a/paddle/fluid/lite/core/type_system.h b/paddle/fluid/lite/core/type_system.h index 05240e294..5fa1df0ca 100644 --- a/paddle/fluid/lite/core/type_system.h +++ b/paddle/fluid/lite/core/type_system.h @@ -57,6 +57,9 @@ class DataTypeBase { Tensor_Fp32_NCHW, Tensor_Int8_NCHW, Tensor_Int64_NCHW, + // Tensor_Any represents a Tensor with any place, data, layout. It is used + // in some IO kernels those doesn't care the data. + Tensor_Any, NumTypes, // Must remains as last defined ID. }; @@ -137,6 +140,12 @@ class UnsupportedTy : public Type { public: UnsupportedTy() : Type(ID::Unsupported, "Unsupported", false /*is_tensor*/) {} }; +class TensorAnyTy : public Type { + public: + TensorAnyTy(TargetType target) + : Type(ID::Tensor_Any, "TensorAny", true, target, PRECISION(kAny), + DATALAYOUT(kAny)) {} +}; class TensorFp32NCHWTy : public Type { public: TensorFp32NCHWTy(TargetType target) diff --git a/paddle/fluid/lite/cuda/target_wrapper.cc b/paddle/fluid/lite/cuda/target_wrapper.cc index 3376a5596..21df004aa 100644 --- a/paddle/fluid/lite/cuda/target_wrapper.cc +++ b/paddle/fluid/lite/cuda/target_wrapper.cc @@ -16,4 +16,63 @@ // Created by chunwei on 19-2-23. // -#include "target_wrapper.h" +#include "paddle/fluid/lite/cuda/target_wrapper.h" +#include + +namespace paddle { +namespace lite { + +using TargetW = TargetWrapper; + +template <> +void* TargetW::Malloc(size_t size) { + return new char[size]; +} + +template <> +void TargetW::Free(void* ptr) { + delete[] static_cast(ptr); +} + +template <> +void TargetW::MemcpySync(void* dst, const void* src, size_t size, + IoDirection dir) { + switch (dir) { + case IoDirection::DtoD: + CHECK(cudaSuccess == + cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice)); + break; + case IoDirection::HtoD: + CHECK(cudaSuccess == cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice)); + break; + case IoDirection::DtoH: + CHECK(cudaSuccess == cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost)); + break; + default: + LOG(FATAL) << "Unsupported IoDirection " << static_cast(dir); + } +} + +template <> +void TargetW::MemcpyAsync(void* dst, const void* src, size_t size, + IoDirection dir, const stream_t& stream) { + switch (dir) { + case IoDirection::DtoD: + CHECK(cudaSuccess == + cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream)); + break; + case IoDirection::HtoD: + CHECK(cudaSuccess == + cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, stream)); + break; + case IoDirection::DtoH: + CHECK(cudaSuccess == + cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, stream)); + break; + default: + LOG(FATAL) << "Unsupported IoDirection " << static_cast(dir); + } +} + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/cuda/target_wrapper.h b/paddle/fluid/lite/cuda/target_wrapper.h index b1f8bab3b..6040f444f 100644 --- a/paddle/fluid/lite/cuda/target_wrapper.h +++ b/paddle/fluid/lite/cuda/target_wrapper.h @@ -12,10 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include +#include "paddle/fluid/lite/core/target_wrapper.h" + namespace paddle { -namespace framework { namespace lite { -namespace cuda {} // namespace cuda +namespace cuda { + +using TargetWrap = TargetWrapper; +using TargetWrapAsync = TargetWrapper; + +} // namespace cuda } // namespace lite -} // namespace framework } // namespace paddle diff --git a/paddle/fluid/lite/kernels/cuda/CMakeLists.txt b/paddle/fluid/lite/kernels/cuda/CMakeLists.txt index 6814f3f51..9a435e45e 100644 --- a/paddle/fluid/lite/kernels/cuda/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/cuda/CMakeLists.txt @@ -1 +1,2 @@ cc_library(mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite) +cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS tensor_lite) diff --git a/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc b/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc new file mode 100644 index 000000000..9503d41f7 --- /dev/null +++ b/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/cuda/target_wrapper.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +using TargetW = TargetWrapper; + +// Host to CUDA memory. +void CopyFromHostSync(void* target, const void* source, size_t size) { + TargetW::MemcpySync(target, source, size, IoDirection::HtoD); +} + +void CopyFromHostAsync(void* target, const void* source, size_t size, + TargetW::stream_t stream) { + TargetW::MemcpyAsync(target, source, size, IoDirection::HtoD, stream); +} + +// Host to Host memory. +void CopyToHostSync(void* target, const void* source, size_t size) { + TargetW::MemcpySync(target, source, size, IoDirection::DtoH); +} + +/* + * This kernel copies a tensor from host to CUDA space. + */ +class IoCopyHostToCudaCompute + : public OpKernel { + public: + void Run() override { + auto& param = Param(); + CHECK(param.x->target() == TARGET(kHost) || + param.x->target() == TARGET(kX86)); + auto* data = param.y->mutable_data(target(), param.x->memory_size()); + CopyFromHostSync(data, param.x->data(), param.x->memory_size()); + } +}; + +/* + * This kernel copies a tensor from CUDA to host space. + */ +class IoCopyCudaToHostCompute + : public OpKernel { + public: + void Run() override { + auto& param = Param(); + CHECK(param.x->target() == TARGET(kCUDA)); + auto* data = param.y->mutable_data(TARGET(kHost), param.x->memory_size()); + CopyToHostSync(data, param.x->data(), param.x->memory_size()); + } +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, + paddle::lite::kernels::cuda::IoCopyHostToCudaCompute, + host_to_device) + .BindInput("Input", {paddle::lite::Type::Get( + TARGET(kHost))}) + .BindOutput("Out", {paddle::lite::Type::Get( + TARGET(kCUDA))}) + .Finalize(); + +REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, + paddle::lite::kernels::cuda::IoCopyCudaToHostCompute, + device_to_host) + .BindInput("Input", {paddle::lite::Type::Get( + TARGET(kCUDA))}) + .BindOutput("Out", {paddle::lite::Type::Get( + TARGET(kHost))}) + .Finalize(); diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 3a80f3b02..101224106 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -3,6 +3,7 @@ cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite) cc_library(mul_op_lite SRCS mul_op.cc DEPS op_lite) cc_library(scale_op_lite SRCS scale_op.cc DEPS op_lite) cc_library(feed_op_lite SRCS feed_op.cc DEPS op_lite) +cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS op_lite) cc_library(op_params_lite SRCS op_params.cc DEPS tensor_lite) cc_library(ops_lite DEPS @@ -11,6 +12,7 @@ cc_library(ops_lite DEPS mul_op_lite scale_op_lite feed_op_lite + io_copy_op_lite ) cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite fc_compute_host) diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index bf57af0a4..2207e2a09 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -64,7 +64,13 @@ struct ScaleParam { bool bias_after_scale{true}; }; -using param_t = variant; +struct IoCopyParam { + const Tensor* x{}; + Tensor* y{}; +}; + +using param_t = + variant; } // namespace operators } // namespace lite diff --git a/paddle/fluid/lite/x86/target_wrapper.cc b/paddle/fluid/lite/x86/target_wrapper.cc index 3374cdd73..55cdf91b6 100644 --- a/paddle/fluid/lite/x86/target_wrapper.cc +++ b/paddle/fluid/lite/x86/target_wrapper.cc @@ -20,8 +20,8 @@ namespace paddle { namespace lite { template <> -void TargetWrapper::MemcpySync(void *dst, void *src, size_t size, - IoDirection dir) { +void TargetWrapper::MemcpySync(void *dst, const void *src, + size_t size, IoDirection dir) { std::copy_n(reinterpret_cast(src), size, reinterpret_cast(dst)); } -- GitLab