提交 c94b71c5 编写于 作者: Y Yuan Shuai 提交者: GitHub

[LITE][OPENCL] Fix OpenCL kernel of exp, tanh; Fix layout pass for opencl backend. (#3212)

* Fix OpenCL kernel of exp, tanh,

* Fix layout pass for opencl backend,

* Add how to debug cl kernel. test=develop.
上级 55cb82f2
...@@ -65,9 +65,11 @@ rm ./lite/api/paddle_use_ops.h ...@@ -65,9 +65,11 @@ rm ./lite/api/paddle_use_ops.h
--arm_os=android \ --arm_os=android \
--arm_abi=armv8 \ --arm_abi=armv8 \
--arm_lang=gcc \ --arm_lang=gcc \
build_test_arm_opencl build_opencl
``` ```
注:如果要调试cl kernel,假设已经完成上述脚本编译(已生成cmake文件)。调试只需要修改`./lite/backends/opencl/cl_kernel/`下对应的kernel文件,保存后在项目根目录执行`python ./lite/tools/cmake_tools/gen_opencl_code.py ./lite/backends/opencl/cl_kernel ./lite/backends/opencl/opencl_kernels_source.cc`,该命令会自动将修改后,再切到build目录下执行`make publish_inference`或者你要编译的单测的可执行文件名,cl kernel文件的内容会随着编译自动打包到产物包如 .so 中或者对应单测可执行文件中。
### 编译产物说明 ### 编译产物说明
编译产物位于`build.lite.android.armv8.gcc.opencl`下的`inference_lite_lib.android.armv8.opencl`文件夹内,这里仅罗列关键产物: 编译产物位于`build.lite.android.armv8.gcc.opencl`下的`inference_lite_lib.android.armv8.opencl`文件夹内,这里仅罗列关键产物:
......
...@@ -145,11 +145,12 @@ class StaticKernelPickPass : public mir::StmtPass { ...@@ -145,11 +145,12 @@ class StaticKernelPickPass : public mir::StmtPass {
} }
VLOG(4) << "[score(final)]:" << final_score; VLOG(4) << "[score(final)]:" << final_score;
VLOG(4) << "-------- pick summary --------"; VLOG(2) << "-------- pick summary for " << instruct.op_type()
VLOG(4) << " ===> winner_place():" << PrecisionToStr(winner_place.precision) << " --------";
VLOG(2) << " ===> winner_place():" << PrecisionToStr(winner_place.precision)
<< " " << DataLayoutToStr(winner_place.layout) << " " << " " << DataLayoutToStr(winner_place.layout) << " "
<< TargetToStr(winner_place.target); << TargetToStr(winner_place.target);
VLOG(4) << " ===> kernel.place():" VLOG(2) << " ===> kernel.place():"
<< PrecisionToStr(kernel.place().precision) << " " << PrecisionToStr(kernel.place().precision) << " "
<< DataLayoutToStr(kernel.place().layout) << " " << DataLayoutToStr(kernel.place().layout) << " "
<< TargetToStr(kernel.place().target); << TargetToStr(kernel.place().target);
......
...@@ -41,8 +41,9 @@ void TypeLayoutTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -41,8 +41,9 @@ void TypeLayoutTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
VLOG(4) << "!node->IsStmt():" << !node->IsStmt(); VLOG(4) << "!node->IsStmt():" << !node->IsStmt();
if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue; if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue;
auto inlinks = node->inlinks; auto inlinks = node->inlinks;
VLOG(4) << "node->AsStmt().desc:" << node->AsStmt().desc VLOG(4) << "============== node->AsStmt().op_type():"
<< " inlinks.size():" << inlinks.size(); << node->AsStmt().op_type() << " inlinks.size():" << inlinks.size()
<< " ================";
for (auto* in : inlinks) { for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in); ComplementInputs(graph.get(), node, in);
} }
...@@ -68,13 +69,25 @@ void TypeLayoutTransformPass::ComplementInputs(SSAGraph* graph, ...@@ -68,13 +69,25 @@ void TypeLayoutTransformPass::ComplementInputs(SSAGraph* graph,
CHECK(inst.op_info()->GetInputArgname(in_arg_name, &inst_in_tensor_name)); CHECK(inst.op_info()->GetInputArgname(in_arg_name, &inst_in_tensor_name));
auto decl_arg_type = auto decl_arg_type =
inst.picked_kernel().GetInputDeclType(inst_in_tensor_name); inst.picked_kernel().GetInputDeclType(inst_in_tensor_name);
CHECK(in->AsArg().type); CHECK(in->AsArg().type);
VLOG(5) << "\n inst_in_tensor_name:" << inst_in_tensor_name VLOG(3) << "\n inst_in_tensor_name:" << inst_in_tensor_name
<< "\n in->AsArg().name:" << in->AsArg().name << "\n in->AsArg().name:" << in->AsArg().name
<< "\n *in->AsArg().type:" << *in->AsArg().type << "\n *in->AsArg().type:" << *in->AsArg().type
<< "\n *decl_arg_type:" << *decl_arg_type << "\n *decl_arg_type:" << *decl_arg_type
<< "\n inst.op()->DebugString():" << inst.op()->DebugString(); << "\n inst.op()->DebugString():" << inst.op()->DebugString();
// TODO(ysh329): conflict if tensor with kARM target but kImageDefault(OpenCL
// layout).
// not a good judge, but don't find the source of this issue from
// static_pick_kernel_pass
// to this pass.
auto* in_arg_type = const_cast<Type*>(in->AsArg().type);
if (in_arg_type->target() == TARGET(kARM) &&
in_arg_type->layout() == DATALAYOUT(kImageDefault)) {
return;
}
if (!DataLayoutCompatible(*in->AsArg().type, *decl_arg_type)) { if (!DataLayoutCompatible(*in->AsArg().type, *decl_arg_type)) {
VLOG(4) << "found Layout unmatched tensor: " << in->AsArg().name VLOG(4) << "found Layout unmatched tensor: " << in->AsArg().name
<< " for kernel " << inst.op()->DebugString() << " " << " for kernel " << inst.op()->DebugString() << " "
......
...@@ -177,7 +177,7 @@ REGISTER_LITE_KERNEL( ...@@ -177,7 +177,7 @@ REGISTER_LITE_KERNEL(
// exp // exp
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
exp_act, exp,
kOpenCL, kOpenCL,
kFP16, kFP16,
kImageDefault, kImageDefault,
...@@ -195,7 +195,7 @@ REGISTER_LITE_KERNEL( ...@@ -195,7 +195,7 @@ REGISTER_LITE_KERNEL(
// tanh // tanh
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
tanh_act, tanh,
kOpenCL, kOpenCL,
kFP16, kFP16,
kImageDefault, kImageDefault,
......
...@@ -109,13 +109,13 @@ TEST(act_image2d_fp16, compute) { ...@@ -109,13 +109,13 @@ TEST(act_image2d_fp16, compute) {
func_name = "sigmoid"; func_name = "sigmoid";
break; break;
case 6: // tanh case 6: // tanh
func_name = "tanh_act"; func_name = "tanh";
break; break;
case 7: // tanh case 7: // tanh
func_name = "swish"; func_name = "swish";
break; break;
case 8: // tanh case 8: // tanh
func_name = "exp_act"; func_name = "exp";
break; break;
} }
LOG(INFO) << "func_name: " << func_name; LOG(INFO) << "func_name: " << func_name;
...@@ -307,7 +307,7 @@ USE_LITE_KERNEL(layout, kOpenCL, kAny, kImageDefault, NCHW_to_ImageDefault); ...@@ -307,7 +307,7 @@ USE_LITE_KERNEL(layout, kOpenCL, kAny, kImageDefault, NCHW_to_ImageDefault);
USE_LITE_KERNEL(layout, kOpenCL, kAny, kNCHW, ImageDefault_to_NCHW); USE_LITE_KERNEL(layout, kOpenCL, kAny, kNCHW, ImageDefault_to_NCHW);
// exp // exp
USE_LITE_KERNEL(exp_act, kOpenCL, kFP16, kImageDefault, ImageDefault); USE_LITE_KERNEL(exp, kOpenCL, kFP16, kImageDefault, ImageDefault);
// swish // swish
USE_LITE_KERNEL(swish, kOpenCL, kFP16, kImageDefault, ImageDefault); USE_LITE_KERNEL(swish, kOpenCL, kFP16, kImageDefault, ImageDefault);
...@@ -316,7 +316,7 @@ USE_LITE_KERNEL(swish, kOpenCL, kFP16, kImageDefault, ImageDefault); ...@@ -316,7 +316,7 @@ USE_LITE_KERNEL(swish, kOpenCL, kFP16, kImageDefault, ImageDefault);
USE_LITE_KERNEL(leaky_relu, kOpenCL, kFP16, kImageDefault, ImageDefault); USE_LITE_KERNEL(leaky_relu, kOpenCL, kFP16, kImageDefault, ImageDefault);
// tanh act // tanh act
USE_LITE_KERNEL(tanh_act, kOpenCL, kFP16, kImageDefault, ImageDefault); USE_LITE_KERNEL(tanh, kOpenCL, kFP16, kImageDefault, ImageDefault);
// relu image2d fp16 // relu image2d fp16
USE_LITE_KERNEL(relu, kOpenCL, kFP16, kImageDefault, ImageDefault); USE_LITE_KERNEL(relu, kOpenCL, kFP16, kImageDefault, ImageDefault);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册