提交 0f03e23b 编写于 作者: C Chunwei

framework support cl

上级 17ef40f3
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <unordered_set>
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/lite/utils/string.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
......@@ -84,7 +85,8 @@ void GraphVizPass::ApplyImpl(ir::Graph* graph) const {
auto marked_nodes = ConsumeMarkedNodes(graph);
// Create nodes
for (const Node* n : graph->Nodes()) {
std::string node_id = FormatName(n) + "(" + std::to_string(n->id()) + ")";
std::string node_id =
lite::string_format("%s(%d)", FormatName(n).c_str(), n->id());
if (n->IsOp()) {
decltype(op_attrs) attr =
marked_nodes.count(n) ? marked_op_attrs : op_attrs;
......
......@@ -58,9 +58,11 @@ class Dot {
std::vector<Attr> attrs;
Node(const std::string& name, const std::vector<Attr>& attrs)
: name(name),
attrs(attrs),
id_("node_" + std::to_string(dot_node_counter++)) {}
: name(name), attrs(attrs) {
std::stringstream ss;
ss << "node_" << dot_node_counter++;
id_ = ss.str();
}
std::string id() const { return id_; }
......
......@@ -37,7 +37,7 @@ endfunction()
function (lite_deps TARGET)
set(options "")
set(oneValueArgs "")
set(multiValueArgs DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS ARGS)
set(multiValueArgs DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS CL_DEPS ARGS)
cmake_parse_arguments(lite_deps "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(deps ${lite_deps_DEPS})
......@@ -78,6 +78,12 @@ function (lite_deps TARGET)
endforeach(var)
endif()
if (LITE_WITH_OPENCL)
foreach(var ${lite_deps_CL_DEPS})
set(deps ${deps} ${var})
endforeach(var)
endif()
set(${TARGET} ${deps} PARENT_SCOPE)
endfunction()
......
......@@ -52,22 +52,23 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING)
set(lite_model_test_DEPS cxx_api_lite mir_passes ${ops_lite} ${host_kernels} ${arm_kernels})
lite_cc_test(test_mobilenetv1_lite SRCS mobilenetv1_test.cc
DEPS ${lite_model_test_DEPS}
DEPS ${lite_model_test_DEPS}
CL_DEPS ${opencl_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 SERIAL)
add_dependencies(test_mobilenetv1_lite extern_lite_download_mobilenet_v1_tar_gz)
lite_cc_test(test_mobilenetv2_lite SRCS mobilenetv2_test.cc
DEPS ${lite_model_test_DEPS}
DEPS ${lite_model_test_DEPS}
ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v2 SERIAL)
add_dependencies(test_mobilenetv2_lite extern_lite_download_mobilenet_v2_tar_gz)
lite_cc_test(test_resnet50_lite SRCS resnet50_test.cc
DEPS ${lite_model_test_DEPS}
DEPS ${lite_model_test_DEPS}
ARGS --model_dir=${LITE_MODEL_DIR}/resnet50 SERIAL)
add_dependencies(test_resnet50_lite extern_lite_download_resnet50_tar_gz)
lite_cc_test(test_inceptionv4_lite SRCS inceptionv4_test.cc
DEPS ${lite_model_test_DEPS}
DEPS ${lite_model_test_DEPS}
ARGS --model_dir=${LITE_MODEL_DIR}/inception_v4 SERIAL)
add_dependencies(test_inceptionv4_lite extern_lite_download_inception_v4_tar_gz)
endif()
......
......@@ -23,10 +23,7 @@ namespace paddle {
namespace lite {
void Predictor::SaveModel(const std::string &dir) {
#ifndef LITE_WITH_ARM
MkDirRecur(dir);
#else
#endif
program_->PersistModel(dir, program_desc_);
LOG(INFO) << "Save model to " << dir;
}
......
......@@ -25,16 +25,13 @@
namespace paddle {
namespace lite {
#ifdef LITE_WITH_ARM
TEST(MobileNetV1, test) {
void TestModel(const std::vector<Place>& valid_places,
const Place& preferred_place) {
DeviceInfo::Init();
DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}});
predictor.Build(FLAGS_model_dir, Place{TARGET(kARM), PRECISION(kFloat)},
valid_places);
predictor.Build(FLAGS_model_dir, preferred_place, valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......@@ -70,7 +67,26 @@ TEST(MobileNetV1, test) {
ASSERT_EQ(out->dims()[0], 1);
ASSERT_EQ(out->dims()[1], 1000);
}
#endif
TEST(MobileNetV1, test_arm) {
std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)},
// Place{TARGET(kOpenCL), PRECISION(kFloat)},
});
TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)}));
}
TEST(MobileNetV1, test_opencl) {
std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)},
Place{TARGET(kOpenCL), PRECISION(kFloat)},
});
TestModel(valid_places, Place({TARGET(kOpenCL), PRECISION(kFloat)}));
}
} // namespace lite
} // namespace paddle
......@@ -14,6 +14,10 @@
#include "paddle/fluid/lite/core/context.h"
#ifdef LITE_WITH_OPENCL
DEFINE_string(cl_path, "/data/local/tmp/opencl", "The OpenCL kernels path.");
#endif
namespace paddle {
namespace lite {} // namespace lite
} // namespace paddle
......@@ -23,6 +23,11 @@
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device_context.h"
#endif
#ifdef LITE_WITH_OPENCL
#include "paddle/fluid/lite/opencl/cl_context.h"
#include "paddle/fluid/lite/opencl/cl_engine.h"
#include "paddle/fluid/lite/opencl/cl_helper.h"
#endif
#include <map>
#include <memory>
#include <set>
......@@ -34,6 +39,10 @@
#include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/utils/all.h"
#ifdef LITE_WITH_OPENCL
DECLARE_string(cl_path);
#endif
namespace paddle {
namespace lite {
......@@ -44,6 +53,7 @@ using HostContext = Context<TargetType::kHost>;
using X86Context = Context<TargetType::kX86>;
using CUDAContext = Context<TargetType::kCUDA>;
using ARMContext = Context<TargetType::kARM>;
using OpenClContext = Context<TargetType::kOpenCL>;
template <>
class Context<TargetType::kHost> {
......@@ -51,7 +61,7 @@ class Context<TargetType::kHost> {
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {}
void CopyShared(const HostContext* ctx) {}
void CopySharedTo(const HostContext* ctx) {}
std::string name() const { return "HostContext"; }
};
......@@ -69,7 +79,7 @@ class Context<TargetType::kARM> {
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() { DeviceInfo::Init(); }
void CopyShared(const ARMContext* ctx) {}
void CopySharedTo(const ARMContext* ctx) {}
void SetRunMode(PowerMode mode, int threads) {
return DeviceInfo::Global().SetRunMode(mode, threads);
......@@ -109,7 +119,7 @@ class Context<TargetType::kCUDA> {
cublas_fp32_ = std::make_shared<lite::cuda::Blas<float>>();
}
void CopyShared(const CUDAContext* ctx) {
void CopySharedTo(const CUDAContext* ctx) {
CHECK(ctx);
CHECK(cublas_fp32_) << "cublas_fp32 should be set first";
ctx->cublas_fp32_ = cublas_fp32_;
......@@ -175,7 +185,7 @@ class Context<TargetType::kX86> {
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {}
void CopyShared(const X86Context* ctx) {}
void CopySharedTo(const X86Context* ctx) {}
const device_ctx_t* x86_device_context() { return x86_device_context_.get(); }
void SetX86DeviceContext(std::unique_ptr<device_ctx_t>&& ctx) {
......@@ -202,6 +212,40 @@ class Context<TargetType::kX86> {
};
#endif
#ifdef LITE_WITH_OPENCL
template <>
class Context<TargetType::kOpenCL> {
mutable std::shared_ptr<CLContext> cl_context_;
mutable std::shared_ptr<CLHelper> cl_helper_;
public:
CLContext* cl_context() { return cl_context_.get(); }
CLHelper* cl_helper() { return cl_helper_.get(); }
void InitOnce() {
// Init cl engine.
CHECK(CLEngine::Global()->IsInitSuccess()) << "OpenCL engine init failed";
CLEngine::Global()->set_cl_path(FLAGS_cl_path);
cl_context_ = std::make_shared<CLContext>();
cl_helper_ = std::make_shared<CLHelper>();
cl_helper_->set_context(cl_context_.get());
PrepareKernels();
}
void CopySharedTo(const OpenClContext* ctx) {
ctx->cl_context_ = cl_context_;
}
private:
void PrepareKernels() {
cl_helper_->AddKernel("elementwise_add", "elementwise_add_kernel.cl");
cl_helper_->AddKernel("pool_max", "pool_kernel.cl");
}
};
#endif
// Context for running a kernel.
// Holds the necessary resource and information.
class KernelContext {
......@@ -230,26 +274,32 @@ class ContextScheduler {
std::unique_ptr<KernelContext> ctx(new KernelContext);
switch (target) {
case TARGET(kHost):
kernel_contexts_[TargetType::kHost].As<HostContext>().CopyShared(
kernel_contexts_[TargetType::kHost].As<HostContext>().CopySharedTo(
&ctx->As<HostContext>());
break;
#ifdef LITE_WITH_X86
case TARGET(kX86):
kernel_contexts_[TargetType::kX86].As<X86Context>().CopyShared(
kernel_contexts_[TargetType::kX86].As<X86Context>().CopySharedTo(
&ctx->As<X86Context>());
break;
#endif
#ifdef LITE_WITH_CUDA
case TARGET(kCUDA):
kernel_contexts_[TargetType::kCUDA].As<CUDAContext>().CopyShared(
kernel_contexts_[TargetType::kCUDA].As<CUDAContext>().CopySharedTo(
&ctx->As<CUDAContext>());
break;
#endif
#ifdef LITE_WITH_ARM
case TARGET(kARM):
kernel_contexts_[TargetType::kARM].As<ARMContext>().CopyShared(
kernel_contexts_[TargetType::kARM].As<ARMContext>().CopySharedTo(
&ctx->As<ARMContext>());
break;
#endif
#ifdef LITE_WITH_OPENCL
case TARGET(kOpenCL):
kernel_contexts_[TargetType::kOpenCL].As<OpenClContext>().CopySharedTo(
&ctx->As<OpenClContext>());
break;
#endif
default:
LOG(FATAL) << "unsupported target " << TargetToStr(target);
......@@ -273,6 +323,9 @@ class ContextScheduler {
#endif
#ifdef LITE_WITH_ARM
InitContext<TargetType::kARM, ARMContext>();
#endif
#ifdef LITE_WITH_OPENCL
InitContext<TargetType::kOpenCL, OpenClContext>();
#endif
}
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/utils/string.h"
namespace paddle {
namespace lite {
......@@ -57,24 +58,24 @@ void QuantDequantOpFuser::BuildPattern() {
->AsIntermediate();
std::vector<PMNode*> nodes;
for (int i = 0; i < times_; i++) {
nodes.push_back(VarNode("quantized_op_weight" + std::to_string(i))
nodes.push_back(VarNode(string_format("quantized_op_weight%d", i))
->assert_is_op_input(op_type_, weight_name)
->AsInput());
nodes.push_back(OpNode("quantized_op" + std::to_string(i), op_type_)
nodes.push_back(OpNode(string_format("quantized_op%d", i), op_type_)
->assert_is_op(op_type_)
->AsIntermediate());
nodes.push_back(VarNode("quantized_op_out" + std::to_string(i))
nodes.push_back(VarNode(string_format("quantized_op_out%d", i))
->assert_is_op_output(op_type_)
->assert_is_op_input("fake_dequantize_max_abs", "X")
->AsIntermediate());
nodes.push_back(
OpNode("dequant_op" + std::to_string(i), "fake_dequantize_max_abs")
OpNode(string_format("dequant_op%d", i), "fake_dequantize_max_abs")
->assert_is_op("fake_dequantize_max_abs")
->AsIntermediate());
nodes.push_back(VarNode("dequant_op_out" + std::to_string(i))
nodes.push_back(VarNode(string_format("dequant_op_out%d", i))
->assert_is_op_output("fake_dequantize_max_abs", "Out")
->AsOutput());
}
......@@ -108,11 +109,11 @@ void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
std::vector<Node*> nodes;
for (int i = 0; i < times_; i++) {
nodes.push_back(matched.at("quantized_op_weight" + std::to_string(i)));
nodes.push_back(matched.at("quantized_op" + std::to_string(i)));
nodes.push_back(matched.at("quantized_op_out" + std::to_string(i)));
nodes.push_back(matched.at("dequant_op" + std::to_string(i)));
nodes.push_back(matched.at("dequant_op_out" + std::to_string(i)));
nodes.push_back(matched.at(string_format("quantized_op_weight%d", i)));
nodes.push_back(matched.at(string_format("quantized_op%d", i)));
nodes.push_back(matched.at(string_format("quantized_op_out%d", i)));
nodes.push_back(matched.at(string_format("dequant_op%d", i)));
nodes.push_back(matched.at(string_format("dequant_op_out%d", i)));
}
int bit_length = quant_op->stmt()->op_info()->GetAttr<int>("bit_length");
auto* scope = quant_op->stmt()->op()->scope();
......
......@@ -17,6 +17,7 @@
#include <set>
#include <string>
#include "paddle/fluid/lite/core/mir/pass_registry.h"
#include "paddle/fluid/lite/utils/string.h"
namespace paddle {
namespace lite {
......@@ -39,7 +40,7 @@ std::string Visualize(mir::SSAGraph* graph) {
if (node.IsArg()) {
key = node.AsArg().name;
} else {
key = node.AsStmt().op_type() + std::to_string(id++);
key = string_format("%s%d", node.AsStmt().op_type().c_str(), id++);
}
if (node.IsStmt()) {
......
......@@ -325,7 +325,7 @@ std::string PMPattern::DotString() const {
// Create Nodes
std::unordered_map<PMNode *, std::string> node2dot;
for (const auto &node : nodes()) {
std::string node_id = "Node" + std::to_string(id++);
std::string node_id = string_format("Node%d", id++);
dot.AddNode(node_id, {}, node->name());
node2dot[node.get()] = node_id;
}
......
......@@ -30,6 +30,7 @@
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
#include "paddle/fluid/lite/utils/string.h"
namespace paddle {
namespace lite {
......@@ -228,7 +229,7 @@ class PMPattern {
FRIEND_TEST(PMPattern, NewNode);
#endif
static std::string NewID() { return "pmnode-" + std::to_string(id_++); }
static std::string NewID() { return string_format("pmnode-%d", id_++); }
std::vector<std::unique_ptr<PMNode>> nodes_;
std::vector<edge_t> edges_;
......
......@@ -20,6 +20,7 @@
#include <vector>
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
#include "paddle/fluid/lite/utils/string.h"
namespace paddle {
namespace lite {
......@@ -80,7 +81,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
CHECK(in->IsArg());
auto node_id = [&] { return graph->nodes().size(); };
auto io_copy_output_name =
in->AsArg().name + "/trans/" + std::to_string(node_id());
string_format("%s/trans/%d", in->AsArg().name.c_str(), node_id());
auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name);
auto* io_copy_inst = graph->NewInstructNode();
......
......@@ -30,6 +30,8 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
auto pick_kernel = [&](const Place &place) {
auto ks = KernelRegistry::Global().Create(op_type_, place.target,
place.precision, place.layout);
VLOG(5) << "pick kernel for " << op_info()->Type() << " " << place
<< " get " << ks.size() << " kernels";
for (auto &&it : ks) {
AttachKernel(it.get());
kernels.emplace_back(std::move(it));
......
......@@ -62,6 +62,9 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
case TARGET(kARM): {
CREATE_KERNEL(kARM);
} break;
case TARGET(kOpenCL): {
CREATE_KERNEL(kOpenCL);
} break;
default:
CHECK(false) << "not supported kernel target " << TargetToStr(target);
}
......@@ -99,6 +102,10 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kARM, kInt8, kNCHW);
INIT_FOR(kARM, kAny, kNCHW);
INIT_FOR(kARM, kAny, kAny);
INIT_FOR(kOpenCL, kFloat, kNCHW);
INIT_FOR(kOpenCL, kAny, kNCHW);
INIT_FOR(kOpenCL, kAny, kAny);
#undef INIT_FOR
}
......
......@@ -82,6 +82,10 @@ class KernelRegistry final {
KernelRegistryForTarget<TARGET(kARM), PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kARM), PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kOpenCL), PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kOpenCL), PRECISION(kInt8),
DATALAYOUT(kNCHW)> * //
>;
......
......@@ -50,13 +50,22 @@ class Optimizer {
if (passes.empty()) {
RunPasses(std::vector<std::string>{{
"lite_quant_dequant_fuse_pass", //
"lite_conv_bn_fuse_pass", //
"lite_quant_dequant_fuse_pass", //
"lite_conv_bn_fuse_pass", //
// This pass is disabled to force some opencl kernels selected for final
// running, otherwise, they will be fused to ARM fusion kernels, and the OpenCL
// devices will be discarded.
// TODO(Superjomn) Refine the fusion related design to select fusion kernels for
// devices automatically.
#ifndef LITE_WITH_OPENCL
"lite_conv_elementwise_add_activation_fuse_pass", //
"lite_fc_fuse_pass", //
"identity_scale_eliminate_pass", //
#endif
"lite_fc_fuse_pass", //
"identity_scale_eliminate_pass", //
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#ifndef LITE_WITH_OPENCL
"lite_elementwise_add_activation_fuse_pass", //
#endif
#endif
"static_kernel_pick_pass", //
"variable_place_inference_pass", //
......
......@@ -140,7 +140,8 @@ class RuntimeProgram {
void Run() {
for (auto& inst : instructions_) {
VLOG(3) << ">> Running kernel: " << inst.op()->op_info()->Repr();
VLOG(4) << ">> Running kernel: " << inst.op()->op_info()->Repr()
<< " on Target " << TargetToStr(inst.kernel()->target());
inst.Run();
}
}
......
......@@ -31,6 +31,7 @@ enum class TargetType : int {
kX86,
kCUDA,
kARM,
kOpenCL,
kAny, // any target
NUM, // number of fields.
};
......@@ -69,8 +70,8 @@ static size_t PrecisionTypeLength(PrecisionType type) {
#define DATALAYOUT(item__) paddle::lite::DataLayoutType::item__
static const std::string& TargetToStr(TargetType target) {
static const std::string target2string[] = {"unk", "host", "x86",
"cuda", "arm", "any"};
static const std::string target2string[] = {"unk", "host", "x86", "cuda",
"arm", "opencl", "any"};
auto x = static_cast<int>(target);
CHECK_LT(x, static_cast<int>(TARGET(NUM)));
return target2string[x];
......@@ -92,8 +93,8 @@ static const std::string& DataLayoutToStr(DataLayoutType layout) {
}
static const std::string& TargetRepr(TargetType target) {
static const std::string target2string[] = {"kUnk", "kHost", "kX86", "kCUDA",
"kAny"};
static const std::string target2string[] = {
"kUnk", "kHost", "kX86", "kCUDA", "kARM", "kOpenCL", "kAny"};
auto x = static_cast<int>(target);
CHECK_LT(x, static_cast<int>(TARGET(NUM)));
return target2string[x];
......
......@@ -4,5 +4,5 @@ add_subdirectory(host)
add_subdirectory(arm)
add_subdirectory(cuda)
add_subdirectory(x86)
add_subdirectory(opencl)
if (NOT LITE_WITH_OPENCL)
return ()
endif()
set(cl_kernel_deps op_params_lite cl_caller cl_engine cl_context cl_wrapper)
cc_library(elementwise_add_opencl SRCS elementwise_add_compute.cc DEPS ${cl_kernel_deps})
lite_cc_test(test_elementwise_add_opencl SRCS elementwise_add_compute_test.cc DEPS elementwise_add_opencl
op_registry_lite program_lite
context_lite
)
set(opencl_kernels
elementwise_add_opencl
CACHE INTERNAL "")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/operators/op_params.h"
// NOTE ugly here, hide these.
#include "paddle/fluid/lite/opencl/cl_caller.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {
class ElementwiseAddCompute
: public KernelLite<TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)> {
public:
using param_t = operators::ElementwiseParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<OpenClContext>();
CHECK(context.cl_context());
elementwise_add(
context.cl_context(), static_cast<const float*>(param.X->raw_data()),
param.X->dims(), static_cast<const float*>(param.Y->raw_data()),
param.Y->dims(), param.Out->mutable_data<float>(), param.Out->dims());
}
};
} // namespace opencl
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(elementwise_add, kOpenCL, kFloat, kNCHW,
paddle::lite::kernels::opencl::ElementwiseAddCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
TEST(elementwise_add, init) {
LOG(INFO) << "to get kernel ...";
auto kernels = KernelRegistry::Global().Create(
"elementwise_add", TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW));
ASSERT_FALSE(kernels.empty());
auto kernel = std::move(kernels.front());
LOG(INFO) << "get kernel";
lite::Tensor X, Y, Out;
operators::ElementwiseParam param;
param.X = &X;
param.Y = &Y;
param.Out = &Out;
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenClContext>().InitOnce();
kernel->SetParam(param);
kernel->SetContext(std::move(context));
X.Resize({1, 10});
Y.Resize({1, 10});
Out.Resize({1, 10});
auto* x_data = X.mutable_data<float>();
auto* y_data = Y.mutable_data<float>();
auto* out_data = Out.mutable_data<float>();
for (int i = 0; i < 10; i++) {
x_data[i] = 1.1 * i;
y_data[i] = 2.3 * i;
}
kernel->Launch();
for (int i = 0; i < 10; i++) {
EXPECT_NEAR(out_data[i], 3.4 * i, 1e-1);
}
}
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(elementwise_add, kOpenCL, kFloat, kNCHW, def);
......@@ -61,3 +61,6 @@ USE_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, def);
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device);
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host);
#endif
#ifdef LITE_WITH_OPENCL
USE_LITE_KERNEL(elementwise_add, kOpenCL, kFloat, kNCHW, def);
#endif
......@@ -2,18 +2,16 @@ if (NOT LITE_WITH_OPENCL)
return()
endif()
if (WITH_LITE AND LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
cc_library(cl_wrapper SRCS cl_wrapper.cc)
cc_library(cl_tool SRCS cl_tool.cc)
target_compile_options(cl_tool BEFORE PUBLIC -Wno-ignored-qualifiers)
cc_library(cl_half SRCS cl_half.cc)
target_compile_options(cl_half BEFORE PUBLIC -fno-strict-aliasing)
cc_library(cl_engine SRCS cl_engine.cc DEPS cl_tool)
cc_library(cl_context SRCS cl_context.cc DEPS cl_engine)
cc_library(cl_helper SRCS cl_helper.cc DEPS cl_context)
cc_library(cl_image_converter SRCS cl_image_converter.cc DEPS cl_half lite_tensor)
cc_library(cl_image SRCS cl_image.cc DEPS cl_half lite_tensor cl_image_converter cl_engine)
cc_library(cl_caller SRCS cl_caller.cc DEPS cl_helper cl_image)
lite_cc_test(test_cl_runtime SRCS cl_test.cc DEPS cl_helper cl_image cl_caller cl_wrapper)
add_dependencies(cl_tool opencl_clhpp)
endif()
cc_library(cl_wrapper SRCS cl_wrapper.cc)
cc_library(cl_tool SRCS cl_tool.cc)
target_compile_options(cl_tool BEFORE PUBLIC -Wno-ignored-qualifiers)
cc_library(cl_half SRCS cl_half.cc)
target_compile_options(cl_half BEFORE PUBLIC -fno-strict-aliasing)
cc_library(cl_engine SRCS cl_engine.cc DEPS cl_tool)
cc_library(cl_context SRCS cl_context.cc DEPS cl_engine)
cc_library(cl_helper SRCS cl_helper.cc DEPS cl_context)
cc_library(cl_image_converter SRCS cl_image_converter.cc DEPS cl_half lite_tensor)
cc_library(cl_image SRCS cl_image.cc DEPS cl_half lite_tensor cl_image_converter cl_engine)
cc_library(cl_caller SRCS cl_caller.cc DEPS cl_helper cl_image)
lite_cc_test(test_cl_runtime SRCS cl_test.cc DEPS cl_helper cl_image cl_caller cl_wrapper)
add_dependencies(cl_tool opencl_clhpp)
......@@ -49,12 +49,12 @@ bool InitOpenCLEngine(std::string cl_path) {
return engine->IsInitSuccess();
}
void elementwise_add(CLContext* context, float* in, const DDim& in_dim,
float* bias, const DDim& bias_dim, float* out,
void elementwise_add(CLContext* context, const float* in, const DDim& in_dim,
const float* bias, const DDim& bias_dim, float* out,
const DDim& out_dim) {
CLHelper helper(context);
helper.AddKernel("elementwise_add", "elementwise_add_kernel.cl");
auto kernel = helper.KernelAt(0);
auto kernel = helper.GetKernel(0);
CLImage in_image;
in_image.set_tensor_data(in, in_dim);
in_image.InitNormalCLImage(helper.OpenCLContext());
......
......@@ -22,8 +22,13 @@ namespace paddle {
namespace lite {
bool InitOpenCLEngine(std::string cl_path);
void elementwise_add(CLContext* context, float* in, const DDim& in_dim,
float* bias, const DDim& bias_dim, float* out,
/// An elementwise_add method to embed OpenCL logic inside, it is used as a
/// black box so that the framework can remain simple.
/// NOTE Currently, these methods are quite expensive, we will optimize them
/// latter.
void elementwise_add(CLContext* context, const float* in, const DDim& in_dim,
const float* bias, const DDim& bias_dim, float* out,
const DDim& out_dim);
} // namespace lite
......
......@@ -29,17 +29,18 @@ void CLHelper::AddKernel(const std::string &kernel_name,
CHECK(context_ != nullptr) << "Please use set_context first!";
VLOG(3) << " --- begin to add kernel ---";
auto kernel = context_->GetKernel(kernel_name, file_name, options);
kernels.emplace_back(std::move(kernel));
kernels_.emplace_back(std::move(kernel));
kernel_offset_[kernel_name] = kernels_.size() - 1;
VLOG(3) << " --- end to add kernel --- ";
}
cl::Kernel &CLHelper::KernelAt(const int index) {
VLOG(3) << " --- kernel count: " << kernels.size() << " --- ";
CHECK(static_cast<size_t>(index) < kernels.size())
cl::Kernel &CLHelper::GetKernel(const int index) {
VLOG(3) << " --- kernel count: " << kernels_.size() << " --- ";
CHECK(static_cast<size_t>(index) < kernels_.size())
<< "The index must be less than the size of kernels.";
CHECK(kernels[index] != nullptr)
CHECK(kernels_[index] != nullptr)
<< "The target kernel pointer cannot be null.";
return *(kernels[index]);
return *(kernels_[index]);
}
cl::CommandQueue &CLHelper::OpenCLCommandQueue() {
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
......@@ -35,7 +36,12 @@ class CLHelper {
void AddKernel(const std::string &kernel_name, const std::string &file_name,
const std::string &options = "");
cl::Kernel &KernelAt(const int index);
cl::Kernel &GetKernel(const int index);
cl::Kernel &GetKernel(const std::string &name) {
auto it = kernel_offset_.find(name);
CHECK(it != kernel_offset_.end());
return GetKernel(it->second);
}
cl::CommandQueue &OpenCLCommandQueue();
......@@ -45,7 +51,8 @@ class CLHelper {
private:
CLContext *context_{nullptr};
std::vector<std::unique_ptr<cl::Kernel>> kernels;
std::map<std::string, int> kernel_offset_;
std::vector<std::unique_ptr<cl::Kernel>> kernels_;
};
} // namespace lite
......
......@@ -53,7 +53,7 @@ std::ostream& operator<<(std::ostream& os, const CLImage& cl_image) {
return os;
}
void CLImage::set_tensor_data(float* tensor_data, const DDim& dim) {
void CLImage::set_tensor_data(const float* tensor_data, const DDim& dim) {
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
auto numel = dim.product();
#else
......
......@@ -33,7 +33,7 @@ class CLImage {
/*
* Will not hold input tensor data, memcpy in this method.
* */
void set_tensor_data(float* tensor_data, const DDim& dim);
void set_tensor_data(const float* tensor_data, const DDim& dim);
bool IsInit() { return initialized_; }
/*
......
......@@ -65,7 +65,7 @@ TEST(cl_test, kernel_test) {
helper->AddKernel("elementwise_add", "elementwise_add_kernel.cl");
helper->AddKernel("pool_max", "pool_kernel.cl");
helper->AddKernel("elementwise_add", "elementwise_add_kernel.cl");
auto kernel = helper->KernelAt(2);
auto kernel = helper->GetKernel(2);
std::unique_ptr<float[]> in_data(new float[1024 * 512]);
for (int i = 0; i < 1024 * 512; i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册