From 77cdbdce0c04927f3f64b73ef6e6f26d0d32b33b Mon Sep 17 00:00:00 2001 From: Yuan Shuai Date: Fri, 11 Oct 2019 10:39:21 +0800 Subject: [PATCH] [LITE][OPENCL] support image2d type (#2158) * [LITE][OPENCL] support image2d. test=develop * add context changed with consider image*. test=develop * add layout, relu image kernels. test=develop * replace image_data with data, mutable_image_data with mutable_data, test=develop * comment unused var. test=develop * remove unused var. test=develop --- lite/backends/opencl/cl_functions_test.cc | 6 +- .../opencl/cl_kernel/buffer/layout_kernel.cl | 116 +++++++ .../opencl/cl_kernel/image/relu_kernel.cl | 30 ++ lite/backends/opencl/target_wrapper.cc | 56 +++- lite/backends/opencl/target_wrapper.h | 4 +- lite/core/context.h | 2 +- lite/core/memory.cc | 11 + lite/core/memory.h | 38 +++ lite/core/mir/static_kernel_pick_pass.cc | 15 +- lite/core/mir/static_kernel_pick_pass.h | 15 +- lite/core/mir/type_layout_cast_pass.cc | 39 ++- lite/core/mir/type_precision_cast_pass.cc | 2 +- lite/core/mir/type_target_cast_pass.cc | 18 +- lite/core/mir/variable_place_inference_pass.h | 22 +- lite/core/op_registry.cc | 4 + lite/core/op_registry.h | 21 +- lite/core/optimizer.h | 30 +- lite/core/tensor.cc | 33 +- lite/core/tensor.h | 47 +-- lite/kernels/opencl/CMakeLists.txt | 19 +- lite/kernels/opencl/image_helper.h | 47 +++ lite/kernels/opencl/io_copy_compute.cc | 82 +++-- lite/kernels/opencl/io_copy_compute_test.cc | 4 +- lite/kernels/opencl/layout_compute.cc | 295 ++++++++++++++++++ lite/kernels/opencl/layout_compute_test.cc | 154 +++++++++ lite/kernels/opencl/relu_compute.cc | 90 +++++- 26 files changed, 1090 insertions(+), 110 deletions(-) create mode 100644 lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl create mode 100644 lite/backends/opencl/cl_kernel/image/relu_kernel.cl create mode 100644 lite/kernels/opencl/image_helper.h create mode 100644 lite/kernels/opencl/layout_compute.cc create mode 100644 lite/kernels/opencl/layout_compute_test.cc diff --git a/lite/backends/opencl/cl_functions_test.cc b/lite/backends/opencl/cl_functions_test.cc index b041952b34..b2f31ad0f8 100644 --- a/lite/backends/opencl/cl_functions_test.cc +++ b/lite/backends/opencl/cl_functions_test.cc @@ -396,9 +396,9 @@ TEST(cl_test, target_wrapper_buffer_test) { TEST(cl_test, target_wrapper_image_test) { const std::array image_shape{28, 32}; - auto *d_image = static_cast( - TargetWrapperCL::MallocImage(image_shape, PRECISION(kFloat))); std::array image_pitch; + auto *d_image = static_cast( + TargetWrapperCL::MallocImage(image_shape)); // Map/Unmap test auto *h_image = static_cast( TargetWrapperCL::MapImage(d_image, image_shape, &image_pitch)); @@ -430,7 +430,7 @@ TEST(cl_test, target_wrapper_image_test) { TargetWrapperCL::ImgcpySync( d_image, h_image_cpy.data(), image_shape, image_pitch, IoDirection::HtoD); auto *d_image_cpy = static_cast( - TargetWrapperCL::MallocImage(image_shape, PRECISION(kFloat))); + TargetWrapperCL::MallocImage(image_shape)); TargetWrapperCL::ImgcpySync( d_image_cpy, d_image, image_shape, image_pitch, IoDirection::DtoD); std::fill(h_image_cpy.begin(), h_image_cpy.end(), 0); diff --git a/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl b/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl new file mode 100644 index 0000000000..c9c16581d6 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl @@ -0,0 +1,116 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +// buffer -> image2d +__kernel void buffer_to_image2d(__global CL_DTYPE *in, + __write_only image2d_t output_image, + __private const int out_H, + __private const int out_W, + __private const int out_C, + __private const int Stride0, + __private const int Stride1, + __private const int Stride2) { + + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + const int out_n = out_nh / out_H; + const int out_h = out_nh % out_H; + + const int in_n = out_n; + const int in_c0 = out_c * 4 + 0; + const int in_c1 = out_c * 4 + 1; + const int in_c2 = out_c * 4 + 2; + const int in_c3 = out_c * 4 + 3; + const int in_h = out_h; + const int in_w = out_w; + + int input_pos0 = in_n * Stride2 + in_c0 * Stride1 + in_h * Stride0 + in_w; + int input_pos1 = in_n * Stride2 + in_c1 * Stride1 + in_h * Stride0 + in_w; + int input_pos2 = in_n * Stride2 + in_c2 * Stride1 + in_h * Stride0 + in_w; + int input_pos3 = in_n * Stride2 + in_c3 * Stride1 + in_h * Stride0 + in_w; + + int2 output_pos; + output_pos.x = out_c * out_W + out_w; + output_pos.y = out_nh; + + CL_DTYPE4 output = (CL_DTYPE4)0.0f; + output.x = convert_float(in[input_pos0]); + if(out_C - 4 * out_c >= 2){ + output.y = convert_float(in[input_pos1]); + } + if(out_C - 4 * out_c >= 3){ + output.z = convert_float(in[input_pos2]); + } + if(out_C - 4 * out_c >= 4){ + output.w = convert_float(in[input_pos3]); + } + write_imagef(output_image, output_pos, output); +} + +// image2d -> buffer +__kernel void image2d_to_buffer(__read_only image2d_t input, + __private const int in_width, + __private const int in_height, + __global CL_DTYPE* out, + __private const int size_ch, + __private const int size_block, + __private const int size_batch, + __private const int C) { + const int in_c = get_global_id(0); + const int in_w = get_global_id(1); + const int in_nh = get_global_id(2); + const int in_n = in_nh / in_height; + const int in_h = in_nh % in_height; + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + const int pos_x = mad24(in_c, in_width, in_w); + CL_DTYPE4 in = read_imagef(input, sampler, (int2)(pos_x, in_nh)); + + const int index = in_n * size_batch + in_c * size_block + in_h * in_width + in_w; + out[index] = convert_float(in.x); + if (C - 4 * in_c >= 2) { + out[index + size_ch] = convert_float(in.y); + } + if(C - 4 * in_c >= 3) { + out[index + size_ch * 2] = convert_float(in.z); + } + if(C - 4 * in_c >= 4) { + out[index + size_ch * 3] = convert_float(in.w); + } +} + +// image2d -> buffer +__kernel void image2d_to_buffer_2d(__private const int in_height, + __private const int in_width, + __read_only image2d_t input, + __global CL_DTYPE* out) { + const int in_w = get_global_id(1); + const int in_h = get_global_id(2); + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + CL_DTYPE4 in = read_imagef(input, sampler, (int2)(in_w, in_h)); + + const int index = (in_h * in_width + in_w) * 4; + out[index] = convert_float(in.x); + out[index + 1] = convert_float(in.y); + out[index + 2] = convert_float(in.z); + out[index + 3] = convert_float(in.w); +} diff --git a/lite/backends/opencl/cl_kernel/image/relu_kernel.cl b/lite/backends/opencl/cl_kernel/image/relu_kernel.cl new file mode 100644 index 0000000000..a99ac79d32 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/relu_kernel.cl @@ -0,0 +1,30 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void relu(__read_only image2d_t input, + __write_only image2d_t output) { + + const int x = get_global_id(0); // image_width + const int y = get_global_id(1); // image_height + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + CL_DTYPE4 in = read_imagef(input, sampler, (int2)(x, y)); + in = max((CL_DTYPE4)(0.0f), in); + write_imagef(output, (int2)(x, y), in); +} diff --git a/lite/backends/opencl/target_wrapper.cc b/lite/backends/opencl/target_wrapper.cc index eb324fcb0f..51bab721bf 100644 --- a/lite/backends/opencl/target_wrapper.cc +++ b/lite/backends/opencl/target_wrapper.cc @@ -18,7 +18,6 @@ #include "lite/backends/opencl/cl_include.h" #include "lite/backends/opencl/cl_runtime.h" #include "lite/backends/opencl/cl_utility.h" - namespace paddle { namespace lite { @@ -58,9 +57,58 @@ void TargetWrapperCL::Free(void *ptr) { } } -void *TargetWrapperCL::MallocImage(const std::array &image_shape, - PrecisionType data_type) { - cl::ImageFormat img_format(CL_RGBA, GetCLChannelType(data_type)); +template <> +void *TargetWrapperCL::MallocImage( + const std::array &image_shape) { + cl::ImageFormat img_format(CL_RGBA, GetCLChannelType(PRECISION(kFloat))); + cl_int status; + size_t width = image_shape[0]; + size_t height = image_shape[1]; + cl::Image2D *cl_image = + new cl::Image2D(CLRuntime::Global()->context(), + CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, + img_format, + width, + height, + 0, + nullptr, + &status); + if (status != CL_SUCCESS) { + delete cl_image; + cl_image = nullptr; + } + CL_CHECK_FATAL(status); + return cl_image; +} + +template <> +void *TargetWrapperCL::MallocImage( + const std::array &image_shape) { + cl::ImageFormat img_format(CL_RGBA, GetCLChannelType(PRECISION(kInt8))); + cl_int status; + size_t width = image_shape[0]; + size_t height = image_shape[1]; + cl::Image2D *cl_image = + new cl::Image2D(CLRuntime::Global()->context(), + CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, + img_format, + width, + height, + 0, + nullptr, + &status); + if (status != CL_SUCCESS) { + delete cl_image; + cl_image = nullptr; + } + CL_CHECK_FATAL(status); + return cl_image; +} + +template <> +void *TargetWrapperCL::MallocImage( + const std::array &image_shape) { + cl::ImageFormat img_format(CL_RGBA, GetCLChannelType(PRECISION(kInt32))); cl_int status; size_t width = image_shape[0]; size_t height = image_shape[1]; diff --git a/lite/backends/opencl/target_wrapper.h b/lite/backends/opencl/target_wrapper.h index 8ff8e6fd40..cf2c835bff 100644 --- a/lite/backends/opencl/target_wrapper.h +++ b/lite/backends/opencl/target_wrapper.h @@ -47,8 +47,8 @@ class TargetWrapper { static void* Malloc(size_t size); static void Free(void* ptr); - static void* MallocImage(const std::array& image_shape, - PrecisionType data_type); + template + static void* MallocImage(const std::array& image_shape); static void FreeImage(void* image); static void* Map(void* buffer, size_t offset, size_t size); diff --git a/lite/core/context.h b/lite/core/context.h index a707dfc375..4f4ed31559 100644 --- a/lite/core/context.h +++ b/lite/core/context.h @@ -261,7 +261,7 @@ template <> class Context { std::shared_ptr cl_context_; using WaitListType = - std::unordered_map(nullptr)), + std::unordered_map(nullptr)), std::shared_ptr>; std::shared_ptr cl_wait_list_; diff --git a/lite/core/memory.cc b/lite/core/memory.cc index 463e10b9f9..6d24a785fd 100644 --- a/lite/core/memory.cc +++ b/lite/core/memory.cc @@ -105,5 +105,16 @@ void TargetCopy(TargetType target, void* dst, const void* src, size_t size) { } } +#ifdef LITE_WITH_OPENCL +void TargetCopyImage2D(TargetType target, + void* dst, + const void* src, + const std::array& image_shape, + const std::array& image_pitch) { + TargetWrapperCL::ImgcpySync( + dst, src, image_shape, image_pitch, IoDirection::DtoD); +} +#endif + } // namespace lite } // namespace paddle diff --git a/lite/core/memory.h b/lite/core/memory.h index 31d7fd34e1..0033fa4606 100644 --- a/lite/core/memory.h +++ b/lite/core/memory.h @@ -38,6 +38,13 @@ void LITE_API TargetFree(TargetType target, void* data); // Copy a buffer from host to another target. void TargetCopy(TargetType target, void* dst, const void* src, size_t size); +#ifdef LITE_WITH_OPENCL +void TargetCopyImage2D(TargetType target, + void* dst, + const void* src, + const std::array& image_shape, + const std::array& image_pitch); +#endif // LITE_WITH_OPENCL template void CopySync(void* dst, const void* src, size_t size, IoDirection dir) { @@ -87,6 +94,37 @@ class Buffer { void ResizeLazy(size_t size) { ResetLazy(target_, size); } +#ifdef LITE_WITH_OPENCL + template + void ResetLazyImage2D(TargetType target, + const std::array& image2d_shape) { + size_t size = + sizeof(T) * image2d_shape[0] * image2d_shape[1] * 4; // 4 for RGBA + VLOG(4) << "image2d_shape:" << image2d_shape[0] << " " << image2d_shape[1]; + if (target != target_) { + Free(); + data_ = TargetWrapperCL::MallocImage(image2d_shape); + target_ = target; + space_ = size; + } + } + + template + void ResizeLazyImage2D(const std::array& image2d_shape) { + ResetLazyImage2D(target_, image2d_shape); + } + + template + void CopyImage2DFrom(const Buffer& other, + const std::array& image2d_shape, + const std::array& image2d_pitch) { + target_ = other.target_; + ResizeLazyImage2D(image2d_shape, image2d_pitch); + TargetCopyImage2D( + target_, data_, other.data_, image2d_shape, image2d_pitch); + } +#endif + void Free() { if (space_ > 0) { TargetFree(target_, data_); diff --git a/lite/core/mir/static_kernel_pick_pass.cc b/lite/core/mir/static_kernel_pick_pass.cc index 10e4f6c1b2..adadbd6d98 100644 --- a/lite/core/mir/static_kernel_pick_pass.cc +++ b/lite/core/mir/static_kernel_pick_pass.cc @@ -17,6 +17,7 @@ #include #include #include +#include "lite/core/mir/graph_visualize_pass.h" #include "lite/core/mir/pass_registry.h" namespace paddle { @@ -29,11 +30,15 @@ bool KernelScoreCmp(const std::pair>& a, } void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { + kernel_pick_factors_.ConsiderTarget(); + kernel_pick_factors_.ConsiderPrecision(); + kernel_pick_factors_.ConsiderDataLayout(); CHECK(kernel_pick_factors_.any_factor_considered()) << "kernel_pick_factors should be specified first"; CHECK(graph) << "graph not valid"; - // sort kernels by the factors. + // sort kernels by the factors. + VLOG(4) << "graph->mutable_nodes().size():" << graph->mutable_nodes().size(); for (auto& node : graph->mutable_nodes()) { if (!node.IsStmt()) continue; auto& instruct = node.AsStmt(); @@ -42,8 +47,11 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { std::vector>> scored; CHECK(!instruct.kernels().empty()) << "No kernels found for " << instruct.op_type(); + VLOG(4) << "instruct.kernels().size():" << instruct.kernels().size(); for (auto&& kernel : instruct.kernels()) { size_t score = KernelGrade(*kernel); + VLOG(4) << "kernel->summary():" << kernel->summary() + << " score:" << score; scored.emplace_back(score, std::move(kernel)); } std::sort(scored.begin(), scored.end(), KernelScoreCmp); @@ -54,7 +62,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { // Just keep a single best kernel. // TODO(Superjomn) reconsider this. instruct.kernels().emplace_back(std::move(scored.front().second)); - VLOG(2) << "pick " << instruct.kernels().front()->name(); + VLOG(2) << "pick " << instruct.kernels().front()->name() << "\n\n"; } else { bool out_type_int8 = true; @@ -117,7 +125,8 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { if (all_output_type_match) { instruct.kernels().emplace_back(std::move(candidate.second)); - VLOG(2) << "pick " << instruct.kernels().front()->name(); + VLOG(2) << "instruct.kernels.emplace_back " + << instruct.kernels().front()->name(); break; } } diff --git a/lite/core/mir/static_kernel_pick_pass.h b/lite/core/mir/static_kernel_pick_pass.h index 3412278229..4e8707aa49 100644 --- a/lite/core/mir/static_kernel_pick_pass.h +++ b/lite/core/mir/static_kernel_pick_pass.h @@ -53,7 +53,7 @@ class StaticKernelPickPass : public mir::StmtPass { size_t score{}; const int kMax = std::numeric_limits::max(); - + VLOG(4) << "[score s1]:" << score; // The more important factor comes first if (kernel_pick_factors_.IsTargetConsidered() && (place().target == kernel.target() || kernel.target() == TARGET(kAny) || @@ -61,6 +61,7 @@ class StaticKernelPickPass : public mir::StmtPass { score += kMax / static_cast(core::KernelPickFactor::Factor::TargetFirst); } + VLOG(4) << "[score s2]:" << score; if (kernel_pick_factors_.IsPrecisionConsidered() && (place().precision == kernel.precision() || kernel.precision() == PRECISION(kAny) || @@ -68,6 +69,7 @@ class StaticKernelPickPass : public mir::StmtPass { score += kMax / static_cast(core::KernelPickFactor::Factor::PrecisionFirst); } + VLOG(4) << "[score s3]:" << score; if (kernel_pick_factors_.IsDataLayoutConsidered() && (place().layout == kernel.layout() || kernel.layout() == DATALAYOUT(kAny) || @@ -75,10 +77,21 @@ class StaticKernelPickPass : public mir::StmtPass { score += kMax / static_cast( core::KernelPickFactor::Factor::DataLayoutFirst); } + VLOG(4) << "[score s4(final)]:" << score; + VLOG(4) << "-------- pick summary --------"; + VLOG(4) << " ===> place():" << PrecisionToStr(place().precision) << " " + << DataLayoutToStr(place().layout) << " " + << TargetToStr(place().target); + VLOG(4) << " ===> kernel.place():" + << PrecisionToStr(kernel.place().precision) << " " + << DataLayoutToStr(kernel.place().layout) << " " + << TargetToStr(kernel.place().target); + VLOG(4) << "kernel.op_type():" << kernel.op_type(); VLOG(4) << "picker tactic " << kernel_pick_factors_; VLOG(4) << "kernel place " << kernel.place().DebugString(); VLOG(4) << "picker place " << place().DebugString(); VLOG(4) << "score " << score; + VLOG(4) << "------------------------------"; // The data layout is not considered, for the input and output arguments // might have different data layout. diff --git a/lite/core/mir/type_layout_cast_pass.cc b/lite/core/mir/type_layout_cast_pass.cc index 6103ec6b3f..afd3a80ca6 100644 --- a/lite/core/mir/type_layout_cast_pass.cc +++ b/lite/core/mir/type_layout_cast_pass.cc @@ -28,19 +28,24 @@ namespace mir { void TypeLayoutTransformPass::Apply(const std::unique_ptr& graph) { // Start from inputs of the graph, those should have place set. + VLOG(4) << "\n" << Visualize(graph.get()); std::list nodes; for (auto& node : graph->mutable_nodes()) { nodes.push_back(&node); } + LOG(INFO) << "nodes.size():" << nodes.size(); for (auto& node : nodes) { + LOG(INFO) << "!node->IsStmt():" << !node->IsStmt(); if (!node->IsStmt()) continue; auto inlinks = node->inlinks; + LOG(INFO) << "node->AsStmt().desc:" << node->AsStmt().desc + << " inlinks.size():" << inlinks.size(); for (auto* in : inlinks) { ComplementInputs(graph.get(), node, in); } } - VLOG(3) << "\n" << Visualize(graph.get()); + VLOG(4) << "\n" << Visualize(graph.get()); } void TypeLayoutTransformPass::ComplementInputs(SSAGraph* graph, @@ -53,6 +58,7 @@ void TypeLayoutTransformPass::ComplementInputs(SSAGraph* graph, CHECK(inst_node->IsStmt()); auto& inst = inst_node->AsStmt(); + LOG(INFO) << "found Target tensor: " << in->AsArg().name; CHECK(in->IsRoleSet()); CHECK(in->IsArg()); auto in_arg_name = in->AsArg().name; @@ -60,10 +66,15 @@ void TypeLayoutTransformPass::ComplementInputs(SSAGraph* graph, CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp)); auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp); CHECK(in->AsArg().type); + LOG(INFO) << "\n tmp:" << tmp << "\n in->AsArg().name:" << in->AsArg().name + << "\n *in->AsArg().type:" << *in->AsArg().type + << "\n *decl_arg_type:" << *decl_arg_type + << "\n inst.op()->DebugString():" << inst.op()->DebugString(); + if (!DataLayoutCompatible(*in->AsArg().type, *decl_arg_type)) { - VLOG(4) << "found Layout unmatched tensor: " << in->AsArg().name - << " for kernel " << inst.op()->DebugString() << " " - << *in->AsArg().type << " -> " << *decl_arg_type; + LOG(INFO) << "found Layout unmatched tensor: " << in->AsArg().name + << " for kernel " << inst.op()->DebugString() << " " + << *in->AsArg().type << " -> " << *decl_arg_type; AddLayoutInst(*in->AsArg().type, *decl_arg_type, in, @@ -85,7 +96,7 @@ void TypeLayoutTransformPass::AddLayoutInst( CHECK(in->IsArg()); auto node_id = [&] { return graph->nodes().size(); }; auto layout_output_name = - string_format("%s/trans/%d", in->AsArg().name.c_str(), node_id()); + string_format("%s/layout_trans/%d", in->AsArg().name.c_str(), node_id()); auto* layout_output_arg = graph->NewArgumentNode(layout_output_name); layout_output_arg->AsArg().type = LiteType::GetTensorTy(from.target(), from.precision(), to.layout()); @@ -113,19 +124,29 @@ void TypeLayoutTransformPass::AddLayoutInst( bool is_found = false; for (auto& kernel : kernels) { const Type* in_arg_ty = kernel->GetInputDeclType("Input"); - const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); - if (TypeCompatible(*in_arg_ty, from) && - out_arg_ty->layout() == to.layout()) { +// const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); // unused variable +#ifdef LITE_WITH_OPENCL + // ignore [layout check] for layout trans from image2d to buffer + if (TargetCompatibleTo(*in_arg_ty, from) && + PrecisionCompatibleTo(*in_arg_ty, from) && + DeviceCompatibleTo(*in_arg_ty, from)) { +#else + if (TypeCompatible(*in_arg_ty, from)) { +#endif is_found = true; selected_kernels.emplace_back(std::move(kernel)); // we pick the kernel - layout_inst->AsStmt(layout_type, std::move(kernels), layout_op); + layout_inst->AsStmt(layout_type, std::move(selected_kernels), layout_op); break; } } CHECK(is_found) << "Can't find a layout kernel for layout op: " << from << ":" << in->AsArg().name << "->" << to << ":" << inst_node->AsStmt().op_info()->Type(); + LOG(INFO) << "========= final picked kernel [info]:" + << layout_inst->AsStmt().picked_kernel().name() + << " [summary]:" << layout_inst->AsStmt().picked_kernel().summary() + << "\n"; // Remove the old link RemoveDirectedLink(in, inst_node); diff --git a/lite/core/mir/type_precision_cast_pass.cc b/lite/core/mir/type_precision_cast_pass.cc index c9837241d2..c44f4cd0ea 100644 --- a/lite/core/mir/type_precision_cast_pass.cc +++ b/lite/core/mir/type_precision_cast_pass.cc @@ -88,7 +88,7 @@ void PrecisionCastPass::AddCastInst(const Type& from, CHECK(in->IsArg()); auto node_id = [&] { return graph->nodes().size(); }; auto cast_op_output_name = - in->AsArg().name + "/trans/" + std::to_string(node_id()); + in->AsArg().name + "/precision_trans/" + std::to_string(node_id()); auto* cast_op_output_arg = graph->NewArgumentNode(cast_op_output_name); cast_op_output_arg->AsArg().type = LiteType::GetTensorTy(from.target(), to.precision(), from.layout()); diff --git a/lite/core/mir/type_target_cast_pass.cc b/lite/core/mir/type_target_cast_pass.cc index eca6444938..d32767e7c1 100644 --- a/lite/core/mir/type_target_cast_pass.cc +++ b/lite/core/mir/type_target_cast_pass.cc @@ -60,6 +60,7 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, auto in_arg_name = in->AsArg().name; std::string tmp; CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp)); + LOG(INFO) << "tmp:" << tmp; auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp); CHECK(in->AsArg().type); if (!TargetCompatibleTo(*in->AsArg().type, *decl_arg_type)) { @@ -86,7 +87,8 @@ void TypeTargetTransformPass::AddIoCopyInst( CHECK(in->IsArg()); auto node_id = [&] { return graph->nodes().size(); }; auto io_copy_output_name = - string_format("%s/trans/%d", in->AsArg().name.c_str(), node_id()); + string_format("%s/target_trans/%d", in->AsArg().name.c_str(), node_id()); + // TODO(MyPandaShaoxiang) should set same place with input? auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name); // Set the place for io_copy_output_arg node, the target should be equal to // to.target() @@ -118,9 +120,14 @@ void TypeTargetTransformPass::AddIoCopyInst( std::vector> selected_kernels; for (auto& kernel : kernels) { const Type* in_arg_ty = kernel->GetInputDeclType("Input"); - const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); - if (TypeCompatible(*in_arg_ty, from) && - out_arg_ty->target() == to.target()) { +#ifdef LITE_WITH_OPENCL + // ignore [layout check] for layout trans from buffer to image2d + if (TargetCompatibleTo(*in_arg_ty, from) && + PrecisionCompatibleTo(*in_arg_ty, from) && + DeviceCompatibleTo(*in_arg_ty, from)) { +#else + if (TypeCompatible(*in_arg_ty, from)) { +#endif is_found = true; selected_kernels.emplace_back(std::move(kernel)); // we pick the kernel @@ -130,9 +137,8 @@ void TypeTargetTransformPass::AddIoCopyInst( } } CHECK(is_found) << "Can't find a io_copy kernel for io_copy op: " << from - << ":" << in->AsArg().name << "->" << to << ":" + << ":" << in->AsArg().name << " -> " << to << ":" << inst_node->AsStmt().op_info()->Type(); - // Remove the old link RemoveDirectedLink(in, inst_node); diff --git a/lite/core/mir/variable_place_inference_pass.h b/lite/core/mir/variable_place_inference_pass.h index ccdea38203..255641018a 100644 --- a/lite/core/mir/variable_place_inference_pass.h +++ b/lite/core/mir/variable_place_inference_pass.h @@ -57,12 +57,21 @@ class VariablePlaceInferencePass : public DebugPass { // Set the tye of the weight void SetWeightType(Node* w, const LiteType& type) { // TODO(xg) to optimize this -#ifndef LITE_WITH_FPGA +#ifdef LITE_WITH_FPGA w->AsArg().type = LiteType::GetTensorTy( - TARGET(kHost), type.precision(), DATALAYOUT(kNCHW)); -#else + TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); +#endif + +#ifdef LITE_WITH_OPENCL w->AsArg().type = LiteType::GetTensorTy( TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); +#endif + +#ifndef LITE_WITH_FPGA +#ifndef LITE_WITH_OPENCL + w->AsArg().type = + LiteType::GetTensorTy(TARGET(kHost), type.precision(), type.layout()); +#endif #endif } @@ -74,7 +83,10 @@ class VariablePlaceInferencePass : public DebugPass { // in fpga, we has io_copy+cali+layout tool ops, so we need type inference for // tool operator #ifndef LITE_WITH_FPGA +#ifndef LITE_WITH_OPENCL + VLOG(3) << "inst.op_type() == 'io_copy', continue"; if (inst.op_type() == "io_copy") continue; +#endif #endif // deal with inputs VLOG(4) << "Infering op " << inst.op_info()->Repr(); @@ -97,8 +109,8 @@ class VariablePlaceInferencePass : public DebugPass { std::string arg_name = get_argname(node_name, inst.op_info()->inputs()); CHECK(arg_name.size() > 0) << "can not found op arguments for node " << node_name; - VLOG(4) << "-- input arg_name " << arg_name - << "-- node name :" << node_name; + VLOG(4) << "-- input arg_name:" << arg_name << " " + << "-- node name:" << node_name; auto type = inst.picked_kernel().GetInputDeclType(arg_name); if (!x_in->AsArg().type) { VLOG(4) << "set type " << *type << " " << x_in->AsArg().name; diff --git a/lite/core/op_registry.cc b/lite/core/op_registry.cc index 9d4042435f..0fdce27e3b 100644 --- a/lite/core/op_registry.cc +++ b/lite/core/op_registry.cc @@ -130,7 +130,11 @@ KernelRegistry::KernelRegistry() INIT_FOR(kARM, kAny, kAny); INIT_FOR(kOpenCL, kFloat, kNCHW); + INIT_FOR(kOpenCL, kFloat, kNHWC); INIT_FOR(kOpenCL, kAny, kNCHW); + INIT_FOR(kOpenCL, kAny, kNHWC); + INIT_FOR(kOpenCL, kFloat, kAny); + INIT_FOR(kOpenCL, kInt8, kNCHW); INIT_FOR(kOpenCL, kAny, kAny); INIT_FOR(kNPU, kFloat, kNCHW); diff --git a/lite/core/op_registry.h b/lite/core/op_registry.h index 81e921f0a4..13f83c5346 100644 --- a/lite/core/op_registry.h +++ b/lite/core/op_registry.h @@ -109,12 +109,29 @@ class KernelRegistry final { KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // @@ -124,6 +141,7 @@ class KernelRegistry final { KernelRegistryForTarget *, // + KernelRegistryForTarget *, // @@ -277,7 +295,8 @@ class KernelRegistor : public lite::Registor { op_type__##__##target__##__##precision__##__registor__ #define LITE_KERNEL_REGISTER_INSTANCE( \ op_type__, target__, precision__, layout__, alias__) \ - op_type__##__##target__##__##precision__##__registor__instance__##alias__ + op_type__##__##target__##__##precision__##__##layout__##registor__instance__##alias__ // NOLINT + #define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \ LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__) diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 5b6a32447c..a96f4fec07 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -73,15 +73,23 @@ class Optimizer { #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "lite_elementwise_add_activation_fuse_pass", // #endif - "static_kernel_pick_pass", // + "static_kernel_pick_pass", // pick original kernel from graph + "variable_place_inference_pass", // inference arg/var's + // info(target/precision/layout/device) + // using kernel info + "argument_type_display_pass", // debug pass: show arg-type-node's + // info + // (target/precision/layout/device) + + "type_target_cast_pass", // add io_copy/io_copy_once if meet + // different targets when last and next + // node "variable_place_inference_pass", // "argument_type_display_pass", // - "type_target_cast_pass", // - "variable_place_inference_pass", // - "argument_type_display_pass", // + "io_copy_kernel_pick_pass", // + "argument_type_display_pass", // - "io_copy_kernel_pick_pass", // "variable_place_inference_pass", // "argument_type_display_pass", // @@ -89,12 +97,20 @@ class Optimizer { "variable_place_inference_pass", // "argument_type_display_pass", // - "type_layout_cast_pass", // + "type_layout_cast_pass", // add layout/layout_once op if meet + // different layout when last and next node + "argument_type_display_pass", // + "variable_place_inference_pass", // "argument_type_display_pass", // "runtime_context_assign_pass", - "memory_optimize_pass"}}); + "argument_type_display_pass", // +#ifndef LITE_WITH_OPENCL + // TODO(ysh329): cause CL_INVALID_MEM_OBJECT when setArg in kernel + "memory_optimize_pass", +#endif + "argument_type_display_pass"}}); } else { RunPasses(passes); } diff --git a/lite/core/tensor.cc b/lite/core/tensor.cc index 4dd4f5319d..1c7db871c7 100644 --- a/lite/core/tensor.cc +++ b/lite/core/tensor.cc @@ -79,6 +79,14 @@ void TensorLite::ShareDataWith(const TensorLite &other) { memory_size_ = other.memory_size_; } +void TensorLite::CopyDataFrom(const TensorLite &other) { + dims_ = other.dims_; + target_ = other.target_; + lod_ = other.lod_; + memory_size_ = other.memory_size_; + buffer_->CopyDataFrom(*other.buffer_, memory_size_); +} + void *TensorLite::mutable_data(size_t memory_size) { memory_size_ = memory_size; buffer_->ResetLazy(target_, memory_size_); @@ -90,26 +98,15 @@ void *TensorLite::mutable_data(TargetType target, size_t memory_size) { return mutable_data(memory_size); } -void TensorLite::CopyDataFrom(const TensorLite &other) { - dims_ = other.dims_; - target_ = other.target_; - lod_ = other.lod_; - memory_size_ = other.memory_size_; - buffer_->CopyDataFrom(*other.buffer_, memory_size_); +#ifdef LITE_WITH_OPENCL +template <> +const cl::Image2D *TensorLite::data() const { + if (nullptr == buffer_->data()) return nullptr; + return static_cast(buffer_->data()); } - -// static LoD TensorLite::ToAbsOffset(const LoD &lod) { -// if (lod.empty() || lod.size() == 1) return lod; -// LoD ret = lod; -// for (int level = static_cast(lod.size()) - 2; level >= 0; --level) { -// for (size_t i = 0; i < lod[level].size(); ++i) { -// size_t index = lod[level][i]; -// result[level][i] = result[level + 1][index]; -// } -// } -//} +#endif } // namespace lite } // namespace paddle -#endif +#endif // #ifndef LITE_WITH_FPGA diff --git a/lite/core/tensor.h b/lite/core/tensor.h index aa4cb1b3c5..f5468cacb1 100644 --- a/lite/core/tensor.h +++ b/lite/core/tensor.h @@ -138,14 +138,35 @@ class TensorLite { // and the data type can be float/int8_t. // For other devices, T and R may be the same type. template - R *mutable_data(); + R *mutable_data() { + memory_size_ = dims_.production() * sizeof(T); + buffer_->ResetLazy(target_, memory_size_); + return reinterpret_cast(static_cast(buffer_->data()) + + offset_); + } + +#ifdef LITE_WITH_OPENCL + template + R *mutable_data(const size_t img_w, const size_t img_h) { + target_ = TARGET(kOpenCL); + std::array image2d_shape{img_w, img_h}; + buffer_->ResetLazyImage2D(target_, image2d_shape); + return static_cast(buffer_->data()); + } +#endif // T is the data type and R is the return type // For OpenCL, the return type can be cl::Buffer // and the data type can be float/int8_t. // For other devices, T and R may be the same type. template - R *mutable_data(TargetType target); + R *mutable_data(TargetType target) { + target_ = target; + memory_size_ = dims_.production() * sizeof(T); + buffer_->ResetLazy(target, memory_size()); + return reinterpret_cast(static_cast(buffer_->data()) + + offset_); + } void *mutable_data(size_t memory_size); void *mutable_data(TargetType target, size_t memory_size); @@ -201,21 +222,6 @@ class TensorLite { size_t offset_{0}; }; -template -R *TensorLite::mutable_data() { - memory_size_ = dims_.production() * sizeof(T); - buffer_->ResetLazy(target_, memory_size_); - return reinterpret_cast(static_cast(buffer_->data()) + offset_); -} - -template -R *TensorLite::mutable_data(TargetType target) { - target_ = target; - memory_size_ = dims_.production() * sizeof(T); - buffer_->ResetLazy(target, memory_size()); - return reinterpret_cast(static_cast(buffer_->data()) + offset_); -} - template TensorLite TensorLite::Slice(int64_t begin, int64_t end) const { CHECK_GE(begin, 0); @@ -243,7 +249,12 @@ bool TensorCompareWith(const TensorT &a, const TensorT &b) { return true; } +#ifdef LITE_WITH_OPENCL +template <> +const cl::Image2D *TensorLite::data() const; +#endif + } // namespace lite } // namespace paddle -#endif +#endif // #ifndef LITE_WITH_FPGA diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index 65145c40b8..d070eb84c5 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -15,6 +15,7 @@ add_kernel(io_copy_compute_opencl OPENCL basic SRCS io_copy_compute.cc DEPS ${te add_kernel(relu_opencl OPENCL basic SRCS relu_compute.cc DEPS ${cl_kernel_deps}) add_kernel(depthwise_conv2d_opencl OPENCL basic SRCS depthwise_conv2d_compute.cc DEPS ${cl_kernel_deps}) add_kernel(conv_opencl OPENCL basic SRCS conv_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(layout_opencl OPENCL basic SRCS layout_compute.cc DEPS ${cl_kernel_deps}) lite_cc_test(test_elementwise_add_opencl SRCS elementwise_add_compute_test.cc DEPS elementwise_add_opencl fusion_elementwise_add_activation_opencl op_registry program context @@ -28,17 +29,19 @@ lite_cc_test(test_fc_opencl SRCS fc_compute_test.cc DEPS fc_opencl op_registry program context ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) -lite_cc_test(test_mul_opencl SRCS mul_compute_test.cc - DEPS mul_opencl op_registry program context - ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) +# TODO(ysh329): comment for buffer-impl mul +#lite_cc_test(test_mul_opencl SRCS mul_compute_test.cc +# DEPS mul_opencl op_registry program context +# ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) lite_cc_test(test_io_copy_compute_opencl SRCS io_copy_compute_test.cc DEPS io_copy_compute_opencl op_registry program context ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) -lite_cc_test(test_relu_opencl SRCS relu_compute_test.cc - DEPS relu_opencl op_registry program context - ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) +#TODO(ysh329): comment buffer-impl relu +#lite_cc_test(test_relu_opencl SRCS relu_compute_test.cc +# DEPS relu_opencl op_registry program context +# ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) lite_cc_test(test_depthwise_conv2d_opencl SRCS depthwise_conv2d_compute_test.cc DEPS depthwise_conv2d_opencl op_registry program context @@ -47,3 +50,7 @@ lite_cc_test(test_depthwise_conv2d_opencl SRCS depthwise_conv2d_compute_test.cc lite_cc_test(test_conv_opencl SRCS conv_compute_test.cc DEPS conv_opencl op_registry program context ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) + +lite_cc_test(test_layout_opencl SRCS layout_compute_test.cc + DEPS layout_opencl op_registry program context + ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) diff --git a/lite/kernels/opencl/image_helper.h b/lite/kernels/opencl/image_helper.h new file mode 100644 index 0000000000..d164f1ef77 --- /dev/null +++ b/lite/kernels/opencl/image_helper.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "lite/core/tensor.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +static std::map InitImageDimInfoWith( + const DDim& tensor_dim) { + size_t new_dims[] = {1, 1, 1, 1}; + for (size_t j = 0; j < tensor_dim.size(); ++j) { + new_dims[4 - tensor_dim.size() + j] = tensor_dim[j]; + } + size_t N, C, H, W; + N = new_dims[0]; + C = new_dims[1]; + H = new_dims[2]; + W = new_dims[3]; + size_t width = W * ((C + 3) / 4); + size_t height = H * N; + return std::map({{"width", width}, {"height", height}}); +} + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/opencl/io_copy_compute.cc b/lite/kernels/opencl/io_copy_compute.cc index 1d43f7d97e..bed0ba2553 100644 --- a/lite/kernels/opencl/io_copy_compute.cc +++ b/lite/kernels/opencl/io_copy_compute.cc @@ -35,14 +35,23 @@ void CopyToHostSync(void* target, const void* source, size_t size) { * This kernel copies a tensor from host to OpenCL space. */ class IoCopyHostToOpenCLCompute - : public KernelLite { + : public KernelLite { public: void Run() override { auto& param = Param(); CHECK(param.x->target() == TARGET(kHost) || param.x->target() == TARGET(kARM)); auto mem_size = param.x->memory_size(); + VLOG(4) << "copy size " << mem_size; + VLOG(4) << "param.x->dims().size():" << param.x->dims().size(); + VLOG(4) << "param.x->dims():" << param.x->dims()[0] << " " + << param.x->dims()[1] << " " << param.x->dims()[2] << " " + << param.x->dims()[3]; + VLOG(4) << "param.y->dims().size():" << param.y->dims().size(); + VLOG(4) << "param.y->dims():" << param.y->dims()[0] << " " + << param.y->dims()[1] << " " << param.y->dims()[2] << " " + << param.y->dims()[3]; auto* data = param.y->mutable_data(TARGET(kOpenCL), mem_size); CopyFromHostSync(data, param.x->raw_data(), mem_size); } @@ -74,17 +83,28 @@ class IoCopyHostToOpenCLCompute * This kernel copies a tensor from OpenCL to host space. */ class IoCopykOpenCLToHostCompute - : public KernelLite { + : public KernelLite { public: void Run() override { auto& param = Param(); CHECK(param.x->target() == TARGET(kOpenCL)); auto mem_size = param.x->memory_size(); VLOG(4) << "copy size " << mem_size; + VLOG(4) << "param.x->dims().size():" << param.x->dims().size(); + VLOG(4) << "param.x->dims():" << param.x->dims()[0] << " " + << param.x->dims()[1] << " " << param.x->dims()[2] << " " + << param.x->dims()[3]; + VLOG(4) << "param.y->dims().size():" << param.y->dims().size(); + VLOG(4) << "param.y->dims():" << param.y->dims()[0] << " " + << param.y->dims()[1] << " " << param.y->dims()[2] << " " + << param.y->dims()[3]; auto* data = param.y->mutable_data(TARGET(kHost), mem_size); auto& context = ctx_->As(); auto* wait_list = context.cl_wait_list(); auto* x_ptr = param.x->data(); + + /* TODO(ysh329): io_copy(device->host) jammed if emplace to `cl_wait_list` + in kernel and enable wait_list auto it = wait_list->find(x_ptr); if (it != wait_list->end()) { VLOG(4) << "--- Find the sync event for the target cl tensor. ---"; @@ -93,6 +113,8 @@ class IoCopykOpenCLToHostCompute } else { LOG(FATAL) << "Could not find the sync event for the target cl tensor."; } + */ + CopyToHostSync(data, param.x->raw_data(), mem_size); } @@ -106,40 +128,64 @@ class IoCopykOpenCLToHostCompute REGISTER_LITE_KERNEL(io_copy, kOpenCL, - kAny, - kAny, + kFloat, + kNCHW, paddle::lite::kernels::opencl::IoCopyHostToOpenCLCompute, host_to_device) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL))}) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) .Finalize(); REGISTER_LITE_KERNEL(io_copy, kOpenCL, - kAny, - kAny, + kFloat, + kNCHW, paddle::lite::kernels::opencl::IoCopykOpenCLToHostCompute, device_to_host) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kOpenCL))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) .Finalize(); REGISTER_LITE_KERNEL(io_copy_once, kOpenCL, - kAny, - kAny, + kFloat, + kNCHW, paddle::lite::kernels::opencl::IoCopyHostToOpenCLCompute, host_to_device) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL))}) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) .Finalize(); REGISTER_LITE_KERNEL(io_copy_once, kOpenCL, - kAny, - kAny, + kFloat, + kNCHW, paddle::lite::kernels::opencl::IoCopykOpenCLToHostCompute, device_to_host) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kOpenCL))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) .Finalize(); diff --git a/lite/kernels/opencl/io_copy_compute_test.cc b/lite/kernels/opencl/io_copy_compute_test.cc index 320e257d39..ecbcf5bbf0 100644 --- a/lite/kernels/opencl/io_copy_compute_test.cc +++ b/lite/kernels/opencl/io_copy_compute_test.cc @@ -79,5 +79,5 @@ TEST(io_copy, compute) { } // namespace lite } // namespace paddle -USE_LITE_KERNEL(io_copy, kOpenCL, kAny, kAny, host_to_device); -USE_LITE_KERNEL(io_copy, kOpenCL, kAny, kAny, device_to_host); +USE_LITE_KERNEL(io_copy, kOpenCL, kFloat, kNCHW, host_to_device); +USE_LITE_KERNEL(io_copy, kOpenCL, kFloat, kNCHW, device_to_host); diff --git a/lite/kernels/opencl/layout_compute.cc b/lite/kernels/opencl/layout_compute.cc new file mode 100644 index 0000000000..c4cf7efe5c --- /dev/null +++ b/lite/kernels/opencl/layout_compute.cc @@ -0,0 +1,295 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/api/paddle_place.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/target_wrapper.h" +#include "lite/core/type_system.h" +#include "lite/kernels/opencl/image_helper.h" +#include "lite/operators/op_params.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +class LayoutComputeBufferChwToImage2DHwc + : public KernelLite { + public: + using param_t = operators::LayoutParam; + + void PrepareForRun() override { + auto& context = ctx_->As(); + context.cl_context()->AddKernel( + kernel_func_name_, "buffer/layout_kernel.cl", build_options_); + } + + void Run() override { + auto& param = Param(); + auto* x_data = param.x->data(); + auto x_dims = param.x->dims(); + auto image_shape = InitImageDimInfoWith(x_dims); + auto* y_data = param.y->mutable_data( + image_shape["width"], image_shape["height"]); + auto y_dims = param.y->dims(); + + // out info + std::vector new_dims = {1, 1, 1, 1}; + for (int tidx = 0; tidx < x_dims.size(); ++tidx) { + new_dims[4 - x_dims.size() + tidx] = x_dims[tidx]; + } + const int out_C = new_dims[1]; + const int out_H = new_dims[2]; + const int out_W = new_dims[3]; + const int Stride2 = out_C * out_H * out_W; + const int Stride1 = out_H * out_W; + const int Stride0 = out_W; + + VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " " + << x_dims[1] << " " << x_dims[2] << " " << x_dims[3]; + VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " " + << y_dims[1] << " " << y_dims[2] << " " << y_dims[3]; + VLOG(4) << "new_dims[" << new_dims.size() << "D]:" << new_dims[0] << " " + << new_dims[1] << " " << new_dims[2] << " " << new_dims[3]; + VLOG(4) << "out_C:" << out_C; + VLOG(4) << "out_H:" << out_H; + VLOG(4) << "out_W:" << out_W; + VLOG(4) << "Stride2:" << Stride2; + VLOG(4) << "Stride1:" << Stride1; + VLOG(4) << "Stride0:" << Stride0; + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + + int arg_idx = 0; + cl_int status = kernel.setArg(arg_idx, *x_data); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *y_data); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(out_H)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(out_W)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(out_C)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(Stride0)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(Stride1)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(Stride2)); + CL_CHECK_FATAL(status); + + VLOG(4) << "gws:[3D]" << ((new_dims[1] + 3) / 4) << " " << new_dims[3] + << " " << (new_dims[0] * new_dims[2]); + auto global_work_size = + cl::NDRange{static_cast((new_dims[1] + 3) / 4), + static_cast(new_dims[3]), + static_cast(new_dims[0] * new_dims[2])}; + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + // TODO(ysh329): io_copy(device->host) jammed if emplace to `cl_wait_list` + // context.cl_wait_list()->emplace(y_data, event_); + context.cl_context()->GetCommandQueue().finish(); + } + + std::string doc() const override { + return "Trans Layout from cl::Buffer(NCHW) to cl::Image2D(RGBA)"; + } + + private: + std::string kernel_func_name_{"buffer_to_image2d"}; + std::string build_options_{"-DCL_DTYPE=float"}; + std::shared_ptr event_{new cl::Event}; +}; + +class LayoutComputeImage2DHwcToBufferChw + : public KernelLite { + public: + using param_t = operators::LayoutParam; + + void PrepareForRun() override { + auto& context = ctx_->As(); + context.cl_context()->AddKernel( + kernel_func_name_, "buffer/layout_kernel.cl", build_options_); + } + + void Run() override { + auto& param = Param(); + auto* y_data = param.y->mutable_data(TARGET(kOpenCL)); + auto y_dims = param.y->dims(); + auto* x_data = param.x->data(); + auto x_dims = param.x->dims(); + + std::vector new_dims = {1, 1, 1, 1}; + for (int j = 0; j < x_dims.size(); ++j) { + new_dims[4 - x_dims.size() + j] = x_dims[j]; + } + + VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " " + << x_dims[1] << " " << x_dims[2] << " " << x_dims[3]; + VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " " + << y_dims[1] << " " << y_dims[2] << " " << y_dims[3]; + VLOG(4) << "new_dims[" << new_dims.size() << "D]:" << new_dims[0] << " " + << new_dims[1] << " " << new_dims[2] << " " << new_dims[3]; + + size_t C = new_dims[1]; + size_t in_height = new_dims[2]; + size_t in_width = new_dims[3]; + int size_ch = in_height * in_width; + int size_block = size_ch * 4; + int size_batch = size_ch * C; + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + + int arg_idx = 0; + cl_int status = kernel.setArg(arg_idx, *x_data); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(in_width)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(in_height)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *y_data); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(size_ch)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(size_ch)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(size_batch)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(C)); + CL_CHECK_FATAL(status); + VLOG(4) << "gws:[3D]" << ((new_dims[1] + 3) / 4) << " " << new_dims[3] + << " " << (new_dims[0] * new_dims[2]); + auto global_work_size = + cl::NDRange{static_cast((new_dims[1] + 3) / 4), + static_cast(new_dims[3]), + static_cast(new_dims[0] * new_dims[2])}; + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + // TODO(ysh329): io_copy(device->host) jammed if emplace to `cl_wait_list` + // context.cl_wait_list()->emplace(y_data, event_); + context.cl_context()->GetCommandQueue().finish(); + } + + std::string doc() const override { + return "Trans Layout from cl::Image2D(RGBA) to cl::Buffer(NCHW)"; + } + + private: + std::string kernel_func_name_{"image2d_to_buffer"}; + std::string build_options_{"-DCL_DTYPE=float"}; + std::shared_ptr event_{new cl::Event}; +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +// BufferChwToImage2DHwc +// [chw] -> [hwc] +REGISTER_LITE_KERNEL( + layout, + kOpenCL, + kFloat, + kNHWC, + paddle::lite::kernels::opencl::LayoutComputeBufferChwToImage2DHwc, + buffer_chw_to_image2d_hwc_opencl_fp32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); + +// [chw] -> [hwc] +REGISTER_LITE_KERNEL( + layout_once, + kOpenCL, + kFloat, + kNHWC, + paddle::lite::kernels::opencl::LayoutComputeBufferChwToImage2DHwc, + buffer_chw_to_image2d_hwc_opencl_fp32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); + +// Image2DHwcBufferChw +// [hwc] -> [chw] +REGISTER_LITE_KERNEL( + layout, + kOpenCL, + kFloat, + kNCHW, + paddle::lite::kernels::opencl::LayoutComputeImage2DHwcToBufferChw, + image2d_hwc_to_buffer_chw_opencl_fp32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); + +// [hwc] -> [chw] +REGISTER_LITE_KERNEL( + layout_once, + kOpenCL, + kFloat, + kNCHW, + paddle::lite::kernels::opencl::LayoutComputeImage2DHwcToBufferChw, + image2d_hwc_to_buffer_chw_opencl_fp32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/opencl/layout_compute_test.cc b/lite/kernels/opencl/layout_compute_test.cc new file mode 100644 index 0000000000..0d86aba0e8 --- /dev/null +++ b/lite/kernels/opencl/layout_compute_test.cc @@ -0,0 +1,154 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/kernels/opencl/image_helper.h" + +namespace paddle { +namespace lite { + +// #define LOOP_TEST +// #define PRINT_RESULT +TEST(layout, compute) { + LOG(INFO) << "main steps of test: host -> layout(buf2img) -> layout(img2buf) " + "-> device"; + +#ifdef LOOP_TEST + for (int n = 1; n <= 100; n += 21) { + for (auto c : {1, 3}) { + for (int h = 1; h <= 100; h += 13) { + for (int w = 1; w <= 100; w += 17) { +#else + const int n = 1; + const int c = 1; + const int h = 1; + const int w = 100; +#endif // LOOP_TEST + + LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " << c << " " + << h << " " << w << " ========"; + // set layout kernels + auto buf_to_img_kernels = KernelRegistry::Global().Create( + "layout", TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNHWC)); + auto img_to_buf_kernels = KernelRegistry::Global().Create( + "layout", TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)); + ASSERT_FALSE(buf_to_img_kernels.empty()); + ASSERT_FALSE(buf_to_img_kernels.empty()); + + auto buf_to_img_kernel = std::move(buf_to_img_kernels.front()); + auto img_to_buf_kernel = std::move(img_to_buf_kernels.front()); + LOG(INFO) << "get 1st kernel: " << buf_to_img_kernel->doc(); + LOG(INFO) << "get 2nd kernel: " << img_to_buf_kernel->doc(); + + // set tensors about op param + LOG(INFO) << "set tensors about op param"; + lite::Tensor x, y_image, y; + operators::LayoutParam BufferToImageParam; + operators::LayoutParam ImageToBufferParam; + BufferToImageParam.x = &x; + BufferToImageParam.y = &y_image; + ImageToBufferParam.x = &y_image; + ImageToBufferParam.y = &y; + + const DDim x_dim = DDim(std::vector{n, c, h, w}); + x.Resize(x_dim); + y_image.Resize(x_dim); // useless for image2D + y.Resize(x_dim); + + // initialize tensors + LOG(INFO) << "initialize tensors"; + auto* x_data = x.mutable_data(TARGET(kOpenCL)); + auto* y_data = y.mutable_data(TARGET(kOpenCL)); + auto image_shape = + paddle::lite::kernels::opencl::InitImageDimInfoWith(x_dim); + auto* y_image_data = y_image.mutable_data( + image_shape["width"], image_shape["height"]); + auto* mapped_x = static_cast(TargetWrapperCL::Map( + x_data, 0, sizeof(float) * x_dim.production())); + auto* mapped_y = static_cast(TargetWrapperCL::Map( + y_data, 0, sizeof(float) * x_dim.production())); + for (int i = 0; i < x_dim.production(); ++i) { + mapped_x[i] = static_cast(i); + mapped_y[i] = static_cast(0); + } + + // set context and kernel args + LOG(INFO) << "set context and kernel args"; + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + buf_to_img_kernel->SetParam(BufferToImageParam); + std::unique_ptr buf_to_img_context(new KernelContext); + context->As().CopySharedTo( + &(buf_to_img_context->As())); + buf_to_img_kernel->SetContext(std::move(buf_to_img_context)); + + img_to_buf_kernel->SetParam(ImageToBufferParam); + std::unique_ptr img_to_buf_context(new KernelContext); + context->As().CopySharedTo( + &(img_to_buf_context->As())); + img_to_buf_kernel->SetContext(std::move(img_to_buf_context)); + + // run kernels + LOG(INFO) << "run kernel: buf_to_img_kernel"; + buf_to_img_kernel->Launch(); + LOG(INFO) << "run kernel: img_to_buf_kernel"; + img_to_buf_kernel->Launch(); + +// result +#ifdef PRINT_RESULT + LOG(INFO) << "---- print result ----"; + for (int eidx = 0; i < x_dim.production(); ++eidx) { + std::cout << mapped_x[eidx] << " -> " << mapped_y[eidx] + << std::endl; + } +#endif // PRINT_RESULT + + // check result: compare input and output + for (int eidx = 0; eidx < x_dim.production(); eidx++) { + EXPECT_NEAR(mapped_x[eidx], mapped_y[eidx], 1e-6); + if (abs(mapped_x[eidx] - mapped_y[eidx]) > 1e-6) { + LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx + << " / " << x_dim.production() << ", mapped_x[" << eidx + << "]:" << mapped_x[eidx] << ", mapped_y[" << eidx + << "]:" << mapped_y[eidx]; + break; + } + } + + // free + LOG(INFO) << "free: unmap x, y"; + TargetWrapperCL::Unmap(x_data, mapped_x); + TargetWrapperCL::Unmap(y_data, mapped_y); +#ifdef LOOP_TEST + } // w + } // h + } // c + } // n +#else +// nothing to do. +#endif +} + +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL( + layout, kOpenCL, kFloat, kNHWC, buffer_chw_to_image2d_hwc_opencl_fp32); +USE_LITE_KERNEL( + layout, kOpenCL, kFloat, kNCHW, image2d_hwc_to_buffer_chw_opencl_fp32); diff --git a/lite/kernels/opencl/relu_compute.cc b/lite/kernels/opencl/relu_compute.cc index 93d1dec674..ca5e5cfe77 100644 --- a/lite/kernels/opencl/relu_compute.cc +++ b/lite/kernels/opencl/relu_compute.cc @@ -15,6 +15,7 @@ #include "lite/backends/opencl/cl_include.h" #include "lite/core/kernel.h" #include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" #include "lite/operators/op_params.h" #include "lite/utils/replace_stl/stream.h" @@ -75,17 +76,96 @@ class ReluCompute std::shared_ptr event_{new cl::Event}; }; +class ReluComputeFloatImage + : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void PrepareForRun() override { + auto& context = ctx_->As(); + context.cl_context()->AddKernel( + kernel_func_name_, "image/relu_kernel.cl", build_options_); + } + + void Run() override { + auto& param = *param_.get_mutable(); + const auto& x_dims = param.X->dims(); + auto* x_buf = param.X->data(); + auto image_shape = InitImageDimInfoWith(x_dims); + auto* out_buf = param.Out->mutable_data( + image_shape["width"], image_shape["height"]); + const auto& y_dims = param.Out->dims(); // useless: check dim only + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + + int arg_idx = 0; + cl_int status = kernel.setArg(arg_idx, *x_buf); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *out_buf); + CL_CHECK_FATAL(status); + + VLOG(4) << TargetToStr(param.X->target()); + VLOG(4) << TargetToStr(param.Out->target()); + VLOG(4) << "image_shape(w,h):" << image_shape["width"] << " " + << image_shape["height"]; + VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " " + << x_dims[1] << " " << x_dims[2] << " " << x_dims[3]; + VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " " + << y_dims[1] << " " << y_dims[2] << " " << y_dims[3]; + + auto global_work_size = + cl::NDRange{static_cast(image_shape["width"]), + static_cast(image_shape["height"])}; + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + // TODO(ysh329): io_copy(device->host) jammed if emplace to `cl_wait_list` + // context.cl_wait_list()->emplace(out_buf, event_); + context.cl_context()->GetCommandQueue().finish(); + } + + private: + std::string kernel_func_name_{"relu"}; + std::string build_options_{"-DCL_DTYPE=float -DRELU"}; + std::shared_ptr event_{new cl::Event}; +}; + } // namespace opencl } // namespace kernels } // namespace lite } // namespace paddle +// REGISTER_LITE_KERNEL(relu, +// kOpenCL, +// kFloat, +// kNCHW, +// paddle::lite::kernels::opencl::ReluCompute, +// def) +// .BindInput("X", {LiteType::GetTensorTy(TARGET(kOpenCL))}) +// .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL))}) +// .Finalize(); + REGISTER_LITE_KERNEL(relu, kOpenCL, kFloat, - kNCHW, - paddle::lite::kernels::opencl::ReluCompute, - def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kOpenCL))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL))}) + kNHWC, + paddle::lite::kernels::opencl::ReluComputeFloatImage, + image2d) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) .Finalize(); -- GitLab