未验证 提交 273b1fa2 编写于 作者: Y ysh329 提交者: GitHub

[cherry-pick][OPENCL] remove conv redundant's for opencl kernel. test… (#3938)

* [cherry-pick][OPENCL] remove conv redundant's for opencl kernel. test=develop
Co-authored-by: Nxiebaiyuan <xiebaiyuan@qq.com>
上级 594175af
......@@ -27,6 +27,28 @@
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "argmax"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto input_dims = param_.X->dims();
auto output_dims = param_.Out->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
ch->remark = "axis" + std::to_string(param_.Axis);
auto axis = param_.Axis;
if (axis < 0) {
axis += input_dims.size();
}
int max_num = 1;
for (int64_t i = axis + 1; i < input_dims.size(); i++)
max_num *= input_dims[i];
float gops = 1.0f;
for (int i = 1; i <= max_num; i++) gops *= i;
ch->macs = gops * output_dims.production();
}
#endif
private:
mutable ArgmaxParam param_;
};
......@@ -85,6 +107,13 @@
using param_t = operators::ArgmaxParam;
void Run() override;
virtual ~ArgmaxCompute() = default;
#ifdef LITE_WITH_PROFILE
virtual void SetProfileRuntimeKernelInfo(
paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
std::string kernel_func_name_{"NotImplForArgmax"};
#endif
};
```
- 在paddlelite/lite/kernels/arm/目录下新建argmax_compute.cc文件,主要实现Run函数。`Run()`函数调用paddlelite/lite/bachends/arm/math/argmax.h中的`argmax_func()`函数,根据输入计算输出。最后在argmax_compute.cc文件中,我们绑定argmax的输入输出(为tensor的输入参数都需要绑定),代码如下:
......@@ -95,6 +124,9 @@
lite::Tensor* output = param.Out;
int axis = param.Axis;
lite::arm::math::argmax_func(input, axis, output);
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "argmax_func";
#endif
return;
}
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <sstream>
#include <string>
#include <vector>
......@@ -25,6 +24,7 @@
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/basic_profiler.h"
#endif // LITE_WITH_PROFILE
#include <gflags/gflags.h>
using paddle::lite::profile::Timer;
......@@ -34,6 +34,10 @@ DEFINE_string(input_shape,
DEFINE_bool(use_optimize_nb,
false,
"optimized & naive buffer model for mobile devices");
DEFINE_string(backend,
"arm_cpu",
"choose backend for valid_places: arm_cpu | opencl. Compile "
"OpenCL version if you choose opencl");
DEFINE_string(arg_name, "", "the arg name");
namespace paddle {
......@@ -49,9 +53,19 @@ void OutputOptModel(const std::string& load_model_dir,
Place{TARGET(kX86), PRECISION(kInt64)},
Place{TARGET(kHost), PRECISION(kFloat)}});
#else
if (FLAGS_backend == "opencl") {
config.set_valid_places({
Place{TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kImageDefault)},
Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)},
Place{TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kImageDefault)},
Place{TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNCHW)},
TARGET(kARM), // enable kARM CPU kernel when no opencl kernel
});
} else { // arm_cpu
config.set_valid_places({
Place{TARGET(kARM), PRECISION(kFloat)},
});
}
#endif
auto predictor = lite_api::CreatePaddlePredictor(config);
......@@ -117,16 +131,40 @@ void Run(const std::vector<std::vector<int64_t>>& input_shapes,
<< ", min time: " << ti.LapTimes().Min() << " ms"
<< ", max time: " << ti.LapTimes().Max() << " ms.";
auto output = predictor->GetOutput(0);
auto out = output->data<float>();
LOG(INFO) << "out " << out[0];
LOG(INFO) << "out " << out[1];
auto output_shape = output->shape();
int output_num = 1;
for (int i = 0; i < output_shape.size(); ++i) {
output_num *= output_shape[i];
// output summary
size_t output_tensor_num = predictor->GetOutputNames().size();
LOG(INFO) << "output tensor num:" << output_tensor_num;
for (size_t tidx = 0; tidx < output_tensor_num; ++tidx) {
auto output_tensor = predictor->GetOutput(tidx);
LOG(INFO) << "============= output tensor " << tidx << " =============";
auto tensor_shape = output_tensor->shape();
std::string tensor_shape_str{""};
int output_tensor_numel = 1;
for (int i = 0; i < tensor_shape.size(); ++i) {
output_tensor_numel *= tensor_shape[i];
tensor_shape_str += std::to_string(tensor_shape[i]);
tensor_shape_str += (i < tensor_shape.size() - 1) ? "x" : "";
}
auto out_data = output_tensor->data<float>();
auto out_mean =
paddle::lite::compute_mean<float>(out_data, output_tensor_numel);
auto out_std_dev = paddle::lite::compute_standard_deviation<float>(
out_data, output_tensor_numel, true, out_mean);
VLOG(0) << "output tensor " << tidx << " dims:" << tensor_shape_str;
VLOG(0) << "output tensor " << tidx
<< " elements num:" << output_tensor_numel;
VLOG(0) << "output tensor " << tidx
<< " standard deviation:" << out_std_dev;
VLOG(0) << "output tensor " << tidx << " mean value:" << out_mean << "\n";
// print result
for (int i = 0; i < output_tensor_numel; ++i) {
VLOG(2) << "output_tensor->data<float>()[" << i
<< "]:" << output_tensor->data<float>()[i];
}
}
LOG(INFO) << "output_num: " << output_num;
// please turn off memory_optimize_pass to use this feature.
if (FLAGS_arg_name != "") {
......@@ -162,6 +200,7 @@ int main(int argc, char** argv) {
<< "--model_dir /path/to/your/model";
exit(0);
}
std::string save_optimized_model_dir = "";
if (FLAGS_use_optimize_nb) {
save_optimized_model_dir = FLAGS_model_dir;
......
......@@ -26,6 +26,7 @@ USE_MIR_PASS(argument_type_display_pass);
USE_MIR_PASS(runtime_context_assign_pass);
USE_MIR_PASS(graph_visualize_pass);
USE_MIR_PASS(remove_tf_redundant_ops_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(lite_shuffle_channel_fuse_pass);
......
......@@ -1254,6 +1254,19 @@ void elementwise_max_relu_broadcast<float>(const float* dinx,
}
}
template <>
void elementwise_div<int64_t>(const int64_t* dinx,
const int64_t* diny,
int64_t* dout,
int num) {
for (int i = 0; i < num; i++) {
*dout = *dinx / *diny;
dout++;
dinx++;
diny++;
}
}
template <>
void elementwise_div<float>(const float* dinx,
const float* diny,
......@@ -1306,6 +1319,28 @@ void elementwise_div<float>(const float* dinx,
}
}
template <>
void elementwise_div_broadcast<int64_t>(const int64_t* dinx,
const int64_t* diny,
int64_t* dout,
int batch,
int channels,
int num) {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const int64_t* din_ptr = dinx + offset;
const int64_t diny_data = diny[j];
int64_t* dout_ptr = dout + offset;
for (int p = 0; p < num; p++) {
*dout_ptr = *din_ptr / diny_data;
dout_ptr++;
din_ptr++;
}
}
}
}
template <>
void elementwise_div_broadcast<float>(const float* dinx,
const float* diny,
......
......@@ -119,19 +119,13 @@ cl::NDRange CLContext::DefaultWorkSize(const CLImage &image) {
}
}
cl::NDRange CLContext::LocalWorkSizeTurn(cl::NDRange global_work_size,
cl::NDRange CLContext::LocalWorkSizeTune(cl::NDRange global_work_size,
size_t max_work_size,
int divisor) {
int preferred_lws = 0;
#if 1
auto gws0 = global_work_size[0];
auto gws1 = global_work_size[1];
auto gws2 = global_work_size[2];
#else
auto gws2 = global_work_size[0];
auto gws1 = global_work_size[1];
auto gws0 = global_work_size[2];
#endif
if (divisor > 1) {
max_work_size /= divisor;
}
......@@ -147,15 +141,40 @@ cl::NDRange CLContext::LocalWorkSizeTurn(cl::NDRange global_work_size,
while (gws0 * gws1 * gws2 > max_work_size && max_work_size > 0) {
gws0 = gws0 % 2 == 0 ? gws0 / 2 : 1;
}
#if 1
return cl::NDRange{static_cast<size_t>(gws0),
static_cast<size_t>(gws1),
static_cast<size_t>(gws2)};
#else
}
cl::NDRange CLContext::LocalWorkSizeTuneReverse(cl::NDRange global_work_size,
size_t max_work_size,
int divisor) {
int preferred_lws = 0;
auto gws2 = global_work_size[0];
auto gws1 = global_work_size[1];
auto gws0 = global_work_size[2];
if (divisor > 1) {
max_work_size /= divisor;
}
if (preferred_lws > 0 && preferred_lws <= max_work_size) {
max_work_size = preferred_lws;
}
while (gws1 > max_work_size && max_work_size > 0) {
gws1 = gws1 % 2 == 0 ? gws1 / 2 : 1;
}
while (gws2 * gws1 > max_work_size && max_work_size > 0) {
gws2 = gws2 % 2 == 0 ? gws2 / 2 : 1;
}
while (gws0 * gws1 * gws2 > max_work_size && max_work_size > 0) {
gws0 = gws0 % 2 == 0 ? gws0 / 2 : 1;
}
return cl::NDRange{static_cast<size_t>(gws2),
static_cast<size_t>(gws1),
static_cast<size_t>(gws0)};
#endif
}
bool CLContext::IsArmMali() {
return CLRuntime::Global()->GetGpuType() == GpuType::ARM_MALI;
}
cl::NDRange CLContext::LocalWorkSize(cl::NDRange global_work_size,
......
......@@ -63,11 +63,14 @@ class CLContext {
cl::NDRange LocalWorkSize(cl::NDRange global_work_size, size_t max_work_size);
cl::NDRange LocalWorkSizeTurn(cl::NDRange global_work_size,
cl::NDRange LocalWorkSizeTune(cl::NDRange global_work_size,
size_t max_work_size,
int divitor = 2);
// cl::NDRange LocalWorkSizeConv1x1(cl::NDRange global_work_size,
// size_t max_work_size);
cl::NDRange LocalWorkSizeTuneReverse(cl::NDRange global_work_size,
size_t max_work_size,
int divitor = 2);
bool IsArmMali();
private:
std::unordered_map<std::string, std::unique_ptr<cl::Program>> programs_;
......
......@@ -6,9 +6,7 @@ __kernel void conv2d_1x1_opt(
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
......@@ -284,9 +282,7 @@ __kernel void conv2d_1x1_simple(
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
......
......@@ -19,9 +19,7 @@ __kernel void conv2d_3x3(__private const int global_size_dim0,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
......
......@@ -19,9 +19,7 @@ __kernel void conv2d_3x3_opt(__private const int item_ch,
__private const int item_h,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
......@@ -264,9 +262,7 @@ __kernel void conv2d_3x3_multi_batch(__private const int item_ch,
__private const int item_h,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
......
......@@ -5,9 +5,7 @@ __kernel void conv2d_5x5(__private const int global_size_dim0,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
......
......@@ -20,9 +20,7 @@ __kernel void conv2d_5x5_opt(__private const int item_ch,
__private const int item_h,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
......@@ -268,9 +266,7 @@ __kernel void conv2d_5x5_multi_batch(__private const int item_ch,
__private const int item_h,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
......
......@@ -5,9 +5,7 @@ __kernel void conv2d_7x7(__private const int global_size_dim0,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
......
......@@ -20,9 +20,7 @@ __kernel void conv2d_7x7_opt(__private const int item_ch,
__private const int item_h,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
......@@ -268,9 +266,7 @@ __kernel void conv2d_7x7_multi_batch(__private const int item_ch,
__private const int item_h,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
......
......@@ -19,9 +19,7 @@ __kernel void depth_conv2d(__private const int global_size_dim0,
__private const int global_size_dim2,
__read_only image2d_t input,
__read_only image2d_t filter,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
......
......@@ -20,9 +20,7 @@ __kernel void depth_conv2d_3x3(
__private const int global_size_dim2,
__read_only image2d_t input,
__read_only image2d_t filter,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
......@@ -249,9 +247,7 @@ __kernel void depth_conv2d_3x3s1(__private const int ou_ch_blk,
__private const int ou_nh,
__read_only image2d_t input,
__read_only image2d_t filter,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
......
#include <cl_common.h>
__kernel void expend_c1(__private const int OUT_C,
__private const int OUT_W,
__private const int OUT_NH,
__private const int IN_C,
__private const int IN_W,
__private const int IN_NH,
__private const int input_width, /* of one block */
__private const int input_height, /* of one block */
__private const int output_width,
__private const int output_height,
__read_only image2d_t input,
__write_only image2d_t output,
__private const int n_times,
__private const int c_times,
__private const int h_times,
__private const int w_times) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
if (out_c >= OUT_C || out_w >= OUT_W || out_nh >= OUT_NH) {
return;
}
const int out_n = out_nh / output_height;
const int out_h = out_nh % output_height;
const int in_c = 0;
const int in_w = out_w / w_times;
const int in_h = out_h / h_times;
const int in_n = out_n / n_times;
const int in_nh = in_n * input_height + in_h;
int2 output_pos = (int2)(out_c * OUT_W + out_w, out_nh);
int2 input_pos = (int2)(in_w, in_nh);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, input_pos);
in.y = 0;
in.z = 0;
in.w = 0;
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, in);
}
__kernel void expend_c2(__private const int OUT_C,
__private const int OUT_W,
__private const int OUT_NH,
__private const int IN_C,
__private const int IN_W,
__private const int IN_NH,
__private const int input_width, /* of one block */
__private const int input_height, /* of one block */
__private const int output_width,
__private const int output_height,
__read_only image2d_t input,
__write_only image2d_t output,
__private const int n_times,
__private const int c_times,
__private const int h_times,
__private const int w_times) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
if (out_c >= OUT_C || out_w >= OUT_W || out_nh >= OUT_NH) {
return;
}
const int out_n = out_nh / output_height;
const int out_h = out_nh % output_height;
const int in_c = 0;
const int in_w = out_w / w_times;
const int in_h = out_h / h_times;
const int in_n = out_n / n_times;
const int in_nh = in_n * input_height + in_h;
int2 output_pos = (int2)(out_c * OUT_W + out_w, out_nh);
int2 input_pos = (int2)(in_w, in_nh);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, input_pos);
in.z = 0;
in.w = 0;
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, in);
}
__kernel void expend_c3(__private const int OUT_C,
__private const int OUT_W,
__private const int OUT_NH,
__private const int IN_C,
__private const int IN_W,
__private const int IN_NH,
__private const int input_width, /* of one block */
__private const int input_height, /* of one block */
__private const int output_width,
__private const int output_height,
__read_only image2d_t input,
__write_only image2d_t output,
__private const int n_times,
__private const int c_times,
__private const int h_times,
__private const int w_times) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
if (out_c >= OUT_C || out_w >= OUT_W || out_nh >= OUT_NH) {
return;
}
const int out_n = out_nh / output_height;
const int out_h = out_nh % output_height;
const int in_c = 0;
const int in_w = out_w / w_times;
const int in_h = out_h / h_times;
const int in_n = out_n / n_times;
const int in_nh = in_n * input_height + in_h;
int2 output_pos = (int2)(out_c * OUT_W + out_w, out_nh);
int2 input_pos = (int2)(in_w, in_nh);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, input_pos);
in.w = 0;
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, in);
}
__kernel void expend_c4(__private const int OUT_C,
__private const int OUT_W,
__private const int OUT_NH,
__private const int IN_C,
__private const int IN_W,
__private const int IN_NH,
__private const int input_width, /* of one block */
__private const int input_height, /* of one block */
__private const int output_width,
__private const int output_height,
__read_only image2d_t input,
__write_only image2d_t output,
__private const int n_times,
__private const int c_times,
__private const int h_times,
__private const int w_times) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
if (out_c >= OUT_C || out_w >= OUT_W || out_nh >= OUT_NH) {
return;
}
const int out_n = out_nh / output_height;
const int out_h = out_nh % output_height;
const int in_c = 0;
const int in_w = out_w / w_times;
const int in_h = out_h / h_times;
const int in_n = out_n / n_times;
const int in_nh = in_n * input_height + in_h;
int2 output_pos = (int2)(out_c * OUT_W + out_w, out_nh);
int2 input_pos = (int2)(in_w, in_nh);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, input_pos);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, in);
}
__kernel void expend_cn(__private const int OUT_C,
__private const int OUT_W,
__private const int OUT_NH,
__private const int IN_C,
__private const int IN_W,
__private const int IN_NH,
__private const int input_width, /* of one block */
__private const int input_height, /* of one block */
__private const int output_width,
__private const int output_height,
__read_only image2d_t input,
__write_only image2d_t output,
__private const int n_times,
__private const int c_times,
__private const int h_times,
__private const int w_times) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
if (out_c >= OUT_C || out_w >= OUT_W || out_nh >= OUT_NH) {
return;
}
const int out_n = out_nh / output_height;
const int out_h = out_nh % output_height;
const int in_c = out_c;
const int in_w = out_w / w_times;
const int in_h = out_h / h_times;
const int in_n = out_n / n_times;
const int in_nh = in_n * input_height + in_h;
int2 output_pos = (int2)(out_c * OUT_W + out_w, out_nh);
int2 input_pos = (int2)(in_c * IN_W + in_w, in_nh);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, input_pos);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, in);
}
\ No newline at end of file
......@@ -63,7 +63,10 @@ __kernel void grid_sampler(__read_only image2d_t input,
if (x0 + 1 < 0 || x0 + 1 > out_width - 1 || y0 + 1 < 0 || y0 + 1 > out_height - 1){
input3 = (CL_DTYPE4)(0.0);
}
CL_DTYPE4 out_val = input0 * xe * ye + input1 * xs * ye + input2 * xe * ys + input3 * xs * ys;
CL_DTYPE4 out_val = input0 * (CL_DTYPE4)(xe) * (CL_DTYPE4)(ye) +
input1 * (CL_DTYPE4)(xs) * (CL_DTYPE4)(ye) +
input2 * (CL_DTYPE4)(xe) * (CL_DTYPE4)(ys) +
input3 * (CL_DTYPE4)(xs) * (CL_DTYPE4)(ys);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, outpoints, out_val);
// y
......@@ -97,7 +100,10 @@ __kernel void grid_sampler(__read_only image2d_t input,
input3 = (CL_DTYPE4)(0.0);
}
out_val = input0 * xe * ye + input1 * xs * ye + input2 * xe * ys + input3 * xs * ys;
out_val = input0 * (CL_DTYPE4)(xe) * (CL_DTYPE4)(ye) +
input1 * (CL_DTYPE4)(xs) * (CL_DTYPE4)(ye) +
input2 * (CL_DTYPE4)(xe) * (CL_DTYPE4)(ys) +
input3 * (CL_DTYPE4)(xs) * (CL_DTYPE4)(ys);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(outpoints.x, outpoints.y + 1), out_val);
// z
......@@ -130,7 +136,10 @@ __kernel void grid_sampler(__read_only image2d_t input,
if (x0 + 1 < 0 || x0 + 1 > out_width - 1 || y0 + 1 < 0 || y0 + 1 > out_height - 1){
input3 = (CL_DTYPE4)(0.0);
}
out_val = input0 * xe * ye + input1 * xs * ye + input2 * xe * ys + input3 * xs * ys;
out_val = input0 * (CL_DTYPE4)(xe) * (CL_DTYPE4)(ye) +
input1 * (CL_DTYPE4)(xs) * (CL_DTYPE4)(ye) +
input2 * (CL_DTYPE4)(xe) * (CL_DTYPE4)(ys) +
input3 * (CL_DTYPE4)(xs) * (CL_DTYPE4)(ys);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(outpoints.x, outpoints.y + 2), out_val);
// w
......@@ -163,6 +172,9 @@ __kernel void grid_sampler(__read_only image2d_t input,
if (x0 + 1 < 0 || x0 + 1 > out_width - 1 || y0 + 1 < 0 || y0 + 1 > out_height - 1){
input3 = (CL_DTYPE4)(0.0);
}
out_val = input0 * xe * ye + input1 * xs * ye + input2 * xe * ys + input3 * xs * ys;
out_val = input0 * (CL_DTYPE4)(xe) * (CL_DTYPE4)(ye) +
input1 * (CL_DTYPE4)(xs) * (CL_DTYPE4)(ye) +
input2 * (CL_DTYPE4)(xe) * (CL_DTYPE4)(ys) +
input3 * (CL_DTYPE4)(xs) * (CL_DTYPE4)(ys);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(outpoints.x, outpoints.y + 3), out_val);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cl_common.h>
__kernel void pixel_shuffle(__read_only image2d_t input_image,
__write_only image2d_t output_image,
__private const int in_N,
__private const int in_C,
__private const int in_H,
__private const int in_W,
__private const int out_N,
__private const int out_C,
__private const int out_H,
__private const int out_W,
__private const int upscale_factor) {
const int out_c4 = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
int out_h = out_nh % out_H;
int out_n = out_nh / out_H;
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int in_h = out_h / upscale_factor;
int in_w = out_w / upscale_factor;
int in_nh = out_n * in_H + in_h;
CL_DTYPE4 res;
int out_c;
int in_c;
CL_DTYPE4 in;
int2 in_pos;
out_c = out_c4 * 4 + 0;
in_c = out_c * upscale_factor * upscale_factor +
(out_h % upscale_factor) * upscale_factor + (out_w % upscale_factor);
in_pos.x = (in_c / 4) * in_W + in_w;
in_pos.y = in_nh;
in = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, in_pos);
if (in_c % 4 == 0) {
res.x = in.x;
} else if (in_c % 4 == 1) {
res.x = in.y;
} else if (in_c % 4 == 2) {
res.x = in.z;
} else if (in_c % 4 == 3) {
res.x = in.w;
}
out_c = out_c4 * 4 + 1;
in_c = out_c * upscale_factor * upscale_factor +
(out_h % upscale_factor) * upscale_factor + (out_w % upscale_factor);
in_pos.x = (in_c / 4) * in_W + in_w;
in_pos.y = in_nh;
in = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, in_pos);
if (in_c % 4 == 0) {
res.y = in.x;
} else if (in_c % 4 == 1) {
res.y = in.y;
} else if (in_c % 4 == 2) {
res.y = in.z;
} else if (in_c % 4 == 3) {
res.y = in.w;
}
out_c = out_c4 * 4 + 2;
in_c = out_c * upscale_factor * upscale_factor +
(out_h % upscale_factor) * upscale_factor + (out_w % upscale_factor);
in_pos.x = (in_c / 4) * in_W + in_w;
in_pos.y = in_nh;
in = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, in_pos);
if (in_c % 4 == 0) {
res.z = in.x;
} else if (in_c % 4 == 1) {
res.z = in.y;
} else if (in_c % 4 == 2) {
res.z = in.z;
} else if (in_c % 4 == 3) {
res.z = in.w;
}
out_c = out_c4 * 4 + 3;
in_c = out_c * upscale_factor * upscale_factor +
(out_h % upscale_factor) * upscale_factor + (out_w % upscale_factor);
in_pos.x = (in_c / 4) * in_W + in_w;
in_pos.y = in_nh;
in = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, in_pos);
if (in_c % 4 == 0) {
res.w = in.x;
} else if (in_c % 4 == 1) {
res.w = in.y;
} else if (in_c % 4 == 2) {
res.w = in.z;
} else if (in_c % 4 == 3) {
res.w = in.w;
}
int2 out_pos;
out_pos.x = out_c4 * out_W + out_w;
out_pos.y = out_nh;
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, out_pos, res);
}
......@@ -191,6 +191,9 @@ bool CLRuntime::InitializeDevice() {
}
return t_str;
};
const std::string device_version = device_->getInfo<CL_DEVICE_VERSION>();
LOG(INFO) << "device_version:" << device_version;
LOG(INFO) << "device_type:" << device_type_to_str(device_type);
device_info_["CL_DEVICE_TYPE"] = device_type;
......@@ -317,6 +320,8 @@ std::map<std::string, size_t>& CLRuntime::GetDeviceInfo() {
return device_info_;
}
GpuType& CLRuntime::GetGpuType() { return gpu_type_; }
void CLRuntime::GetAdrenoContextProperties(
std::vector<cl_context_properties>* properties,
GPUPerfMode gpu_perf_mode,
......@@ -365,5 +370,26 @@ void CLRuntime::GetAdrenoContextProperties(
properties->push_back(0);
}
double CLRuntime::GetCommandTime(const cl::Event& event) {
command_queue().finish();
auto start_nanos = event.getProfilingInfo<CL_PROFILING_COMMAND_START>();
auto stop_nanos = event.getProfilingInfo<CL_PROFILING_COMMAND_END>();
return (stop_nanos - start_nanos) / 1000000.0;
}
double CLRuntime::GetQueuedTime(const cl::Event& event) {
command_queue().finish();
return (event.getProfilingInfo<CL_PROFILING_COMMAND_START>() -
event.getProfilingInfo<CL_PROFILING_COMMAND_QUEUED>()) /
1000000.0;
}
double CLRuntime::GetSubmitTime(const cl::Event& event) {
command_queue().finish();
return (event.getProfilingInfo<CL_PROFILING_COMMAND_START>() -
event.getProfilingInfo<CL_PROFILING_COMMAND_SUBMIT>()) /
1000000.0;
}
} // namespace lite
} // namespace paddle
......@@ -93,6 +93,14 @@ class CLRuntime {
std::map<std::string, size_t>& GetDeviceInfo();
GpuType& GetGpuType();
double GetCommandTime(const cl::Event& event);
double GetQueuedTime(const cl::Event& event);
double GetSubmitTime(const cl::Event& event);
private:
CLRuntime() { Init(); }
......
......@@ -45,5 +45,18 @@ const char* opencl_error_to_str(cl_int error);
#else
#define CL_CHECK_FATAL(err_code__)
#endif
#ifdef LITE_WITH_PROFILE
#define EnqueueNDRangeKernel( \
context, kernel, gws_offset, gws, lws, event_wait_list, event) \
context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( \
kernel, gws_offset, gws, lws, event_wait_list, &event)
#else
#define EnqueueNDRangeKernel( \
context, kernel, gws_offset, gws, lws, event_wait_list, event) \
context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( \
kernel, gws_offset, gws, lws, event_wait_list, nullptr)
#endif
} // namespace lite
} // namespace paddle
......@@ -62,6 +62,18 @@ class KernelBase {
profiler_ = profiler;
profile_id_ = id;
}
virtual void SetProfileRuntimeKernelInfo(
paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = std::string("NotImpl");
#ifdef LITE_WITH_ARM
ch->cl_event = event_;
#endif
}
virtual void SetIsKernelTest(bool is_kernel_test) {
is_kernel_test_ = is_kernel_test;
}
#endif
void Launch() {
......@@ -86,14 +98,24 @@ class KernelBase {
#if defined(LITE_WITH_MLU)
WorkSpace::Global_MLU().AllocReset();
#endif
#ifdef LITE_WITH_PROFILE
if (!is_kernel_test_) {
profiler_->StopTiming(profile::Type::kCreate, profile_id_, ctx_.get());
profiler_->StartTiming(profile::Type::kDispatch, profile_id_, ctx_.get());
}
Run();
#ifdef LITE_WITH_OPENCL
CLRuntime::Global()->command_queue().finish();
#endif
if (is_first_epoch_for_profiler_ && (!is_kernel_test_)) {
SetProfileRuntimeKernelInfo(profiler_->GetOpCharacter(profile_id_));
is_first_epoch_for_profiler_ = false;
}
if (!is_kernel_test_) {
profiler_->StopTiming(profile::Type::kDispatch, profile_id_, ctx_.get());
}
#else
Run();
#endif
......@@ -185,6 +207,11 @@ class KernelBase {
#ifdef LITE_WITH_PROFILE
profile::Profiler* profiler_{nullptr};
int profile_id_{-1};
bool is_first_epoch_for_profiler_{true};
bool is_kernel_test_{true};
#ifdef LITE_WITH_OPENCL
cl::Event event_;
#endif
#endif
};
......
......@@ -28,6 +28,7 @@ lite_cc_library(mir_passes
elimination/identity_scale_eliminate_pass.cc
elimination/identity_dropout_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc
elimination/remove_tf_redundant_ops_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
type_target_cast_pass.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/elimination/remove_tf_redundant_ops_pass.h"
#include <unordered_set>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pattern_matcher.h"
#include "lite/model_parser/cpp/var_desc.h"
namespace paddle {
namespace lite {
namespace mir {
void RemoveTFRedundantOpsPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
RemoveSqueeze2Reshape2Pattern(graph);
RemoveReshape2Pattern(graph);
}
void RemoveTFRedundantOpsPass::RemoveReshape2Pattern(
const std::unique_ptr<SSAGraph>& graph) {
bool found = false;
Node* softmax_node{nullptr};
Node* reshape2_node{nullptr};
std::string reshape2_out_arg_name;
Node* fetch_node{nullptr};
std::string fetch_in_arg_name;
DDim softmax_out_dims;
DDim reshape2_out_dims;
for (auto& op_node : graph->StmtTopologicalOrder()) {
if (op_node->AsStmt().picked_kernel().op_type() == "softmax") {
softmax_node = op_node;
} else if (op_node->AsStmt().picked_kernel().op_type() == "reshape2") {
reshape2_node = op_node;
} else if (op_node->AsStmt().picked_kernel().op_type() == "fetch") {
fetch_node = op_node;
fetch_in_arg_name = fetch_node->inlinks.front()->AsArg().name;
}
}
if (softmax_node == nullptr || reshape2_node == nullptr) {
return;
}
// Get out tensor dims of softmax, reshape2
auto* scope = softmax_node->AsStmt().op()->scope();
auto softmax_out_arg_name = softmax_node->outlinks.front()->AsArg().name;
auto softmax_out_tensor =
scope->FindVar(softmax_out_arg_name)->Get<lite::Tensor>();
softmax_out_dims = softmax_out_tensor.dims();
for (auto out_node : reshape2_node->outlinks) {
if (out_node->IsArg() && out_node->outlinks.size() != 0) {
reshape2_out_arg_name = reshape2_node->outlinks.front()->AsArg().name;
auto reshape2_out_tensor =
scope->FindVar(reshape2_out_arg_name)->Get<lite::Tensor>();
reshape2_out_dims = reshape2_out_tensor.dims();
}
}
VLOG(3) << "reshape2_out_dims:" << reshape2_out_dims;
VLOG(3) << "softmax_out_dims:" << softmax_out_dims;
VLOG(3) << "found:" << found;
if (softmax_out_dims == reshape2_out_dims &&
softmax_node->outlinks.front() == reshape2_node->inlinks.front() &&
reshape2_out_arg_name == fetch_in_arg_name) {
found = true;
}
if (found) {
// link out_arg to op
IR_NODE_LINK_TO(softmax_node->outlinks.front(), fetch_node);
// collect nodes to safe remove
std::unordered_set<const Node*> nodes_to_remove;
auto remove_inst_node_and_out_args_node = [&](Node* n) {
nodes_to_remove.insert(n);
for (auto& out : n->outlinks) {
nodes_to_remove.insert(out);
}
};
remove_inst_node_and_out_args_node(reshape2_node);
GraphSafeRemoveNodes(graph.get(), nodes_to_remove);
auto fetch_op_desc = fetch_node->AsStmt().mutable_op_info();
fetch_op_desc->SetInput("X",
{softmax_node->outlinks.front()->AsArg().name});
}
VLOG(5) << "\n" << Visualize(graph.get());
}
void RemoveTFRedundantOpsPass::RemoveSqueeze2Reshape2Pattern(
const std::unique_ptr<SSAGraph>& graph) {
VLOG(5) << Visualize(graph.get());
bool found = false;
// find out_arg->squeeze2
// find out_arg_dims of out_arg
Node* out_arg_node{nullptr};
DDim out_arg_dims;
Node* squeeze2_node{nullptr};
// find squeeze2->reshape2
// find output dims of squeeze2 and reshape2 nodes
DDim squeeze2_out_dims;
Node* reshape2_node{nullptr};
Node* reshape2_out_node{nullptr};
DDim reshape2_out_dims;
// find next inst node of reshape2
Node* next_inst_node_of_reshape2_out{nullptr};
for (auto& node : graph->StmtTopologicalOrder()) {
if (node->AsStmt().picked_kernel().op_type() != "squeeze2") continue;
auto* scope = node->AsStmt().op()->scope();
// find inlinks of squeeze2: out_arg_node
squeeze2_node = node;
auto squeeze2_inlinks = squeeze2_node->inlinks;
VLOG(5) << "squeeze2_inlinks.size():" << squeeze2_inlinks.size();
for (auto& in_link : squeeze2_inlinks) {
if (in_link->IsArg() && squeeze2_inlinks.size() == 1) {
out_arg_node = in_link;
auto* var = scope->FindVar(out_arg_node->AsArg().name);
out_arg_dims = var->Get<lite::Tensor>().dims();
VLOG(5) << "arg name:" << out_arg_node->AsArg().name
<< " dims:" << out_arg_dims;
} else {
// found mutli-input links
continue;
}
}
// find squeeze2->reshape2 pattern
// and output dims of squeeze2, reshape2 nodes
auto squeeze2_outlinks = squeeze2_node->outlinks;
for (auto& squeeze2_out_link : squeeze2_outlinks) {
if (squeeze2_out_link->IsArg() &&
squeeze2_out_link->outlinks.size() != 0) {
auto* squeeze2_out_var =
scope->FindVar(squeeze2_out_link->AsArg().name);
squeeze2_out_dims = squeeze2_out_var->Get<lite::Tensor>().dims();
VLOG(5) << "squeeze2_out_arg.name:" << squeeze2_out_link->AsArg().name
<< " squeeze2_out_dims:" << squeeze2_out_dims
<< " squeeze2_out_link->outlinks.size():"
<< squeeze2_out_link->outlinks.size();
for (auto& out2_link : squeeze2_out_link->outlinks) {
if (out2_link->IsStmt() &&
out2_link->AsStmt().picked_kernel().op_type() == "reshape2") {
reshape2_node = out2_link;
for (auto& reshape2_out_link : reshape2_node->outlinks) {
if (reshape2_out_link->IsArg() &&
reshape2_out_link->outlinks.size() != 0) {
reshape2_out_node = reshape2_out_link;
auto* reshape2_out_var =
scope->FindVar(reshape2_out_link->AsArg().name);
reshape2_out_dims =
reshape2_out_var->Get<lite::Tensor>().dims();
VLOG(5) << "reshape2_out_node:" << reshape2_out_node
<< " reshape2_out_name:"
<< reshape2_out_link->AsArg().name
<< " reshape2_out_dims:" << reshape2_out_dims;
}
}
}
}
}
}
// find next inst node of reshape2
VLOG(5) << "reshape2_out_node->outlinks.size():"
<< reshape2_out_node->outlinks.size()
<< " reshape2_out_node->IsStmt():" << reshape2_out_node->IsStmt();
VLOG(5) << "reshape2_out_node->AsArg().name:"
<< reshape2_out_node->AsArg().name;
if (reshape2_out_node != nullptr &&
reshape2_out_node->outlinks.size() == 1 &&
reshape2_out_node->outlinks.front()->IsStmt()) {
next_inst_node_of_reshape2_out = reshape2_out_node->outlinks.front();
found = true;
break;
VLOG(5)
<< "next_inst_node_of_reshape2_out->picked_kernel().op_type():"
<< next_inst_node_of_reshape2_out->AsStmt().picked_kernel().op_type();
}
VLOG(5) << "==============================";
VLOG(5) << "out_arg_dims:" << out_arg_dims;
VLOG(5) << "squeeze2_out_dims:" << squeeze2_out_dims;
VLOG(5) << "reshape2_out_dims:" << reshape2_out_dims;
VLOG(5) << "==============================";
}
// replace pattern
if (found && out_arg_dims[1] == squeeze2_out_dims[1] &&
out_arg_dims[1] == reshape2_out_dims[1] && out_arg_dims[1] == 1001 &&
out_arg_dims[2] == out_arg_dims[3] && out_arg_dims[2] == 1 &&
next_inst_node_of_reshape2_out->AsStmt().picked_kernel().op_type() ==
"softmax") {
// link out_arg to op
IR_NODE_LINK_TO(out_arg_node, next_inst_node_of_reshape2_out);
// collect nodes to safe remove
std::unordered_set<const Node*> nodes_to_remove;
auto remove_inst_node_and_out_args_node = [&](Node* n) {
nodes_to_remove.insert(n);
for (auto& out : n->outlinks) {
nodes_to_remove.insert(out);
}
};
remove_inst_node_and_out_args_node(squeeze2_node);
remove_inst_node_and_out_args_node(reshape2_node);
GraphSafeRemoveNodes(graph.get(), nodes_to_remove);
auto next_inst_op_desc =
next_inst_node_of_reshape2_out->AsStmt().mutable_op_info();
next_inst_op_desc->SetInput("X", {out_arg_node->AsArg().name});
VLOG(5) << Visualize(graph.get());
}
VLOG(5) << "replace pattern fininshed";
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(remove_tf_redundant_ops_pass,
paddle::lite::mir::RemoveTFRedundantOpsPass)
.BindTargets({TARGET(kOpenCL), TARGET(kARM)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <limits>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lite/core/mir/pass.h"
#include "lite/core/tensor.h"
#include "lite/core/types.h"
namespace paddle {
namespace lite {
namespace mir {
/*
* mir::RemoveTFRedundantOpsPass remove reshape2->squeeze2 pattern
* and last reshape2 op for tensorflow mobilenetv1/v2.
*/
class RemoveTFRedundantOpsPass : public mir::StmtPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
void RemoveReshape2Pattern(const std::unique_ptr<SSAGraph>& graph);
void RemoveSqueeze2Reshape2Pattern(const std::unique_ptr<SSAGraph>& graph);
};
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -20,6 +20,7 @@
#include <memory>
#include <numeric>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
......
......@@ -170,6 +170,9 @@ class VariablePlaceInferencePass : public DebugPass {
// If is quantization, infer the Int8 type.
if (type->precision() == PRECISION(kInt8)) {
x_out->AsArg().type = type;
} else if (type->precision() == PRECISION(kFP16) &&
type->target() != TARGET(kOpenCL)) {
x_out->AsArg().type = type;
} else {
PrecisionType tmp_ptype = x_out->AsArg().type->precision();
x_out->AsArg().type = LiteType::GetTensorTy(
......
......@@ -73,6 +73,9 @@ class OpLite : public Registry {
// Indicate whether the Op runs only once or not
virtual bool run_once() const { return false; }
std::string Type() { return op_type_; }
#ifdef LITE_WITH_PROFILE
virtual void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {}
#endif
// Link the external execution environ to internal context.
bool Attach(const cpp::OpDesc &opdesc, lite::Scope *scope);
......
......@@ -93,6 +93,7 @@ class Optimizer {
"apu_subgraph_pass",
"rknpu_subgraph_pass",
"static_kernel_pick_pass", // pick original kernel from graph
"remove_tf_redundant_ops_pass",
"variable_place_inference_pass", // inference arg/var's
// info(target/precision/layout/device)
// using kernel info
......@@ -158,6 +159,55 @@ class Optimizer {
const lite::Scope* exec_scope() const { return exec_scope_; }
// Set shape(dims) infos of var descs to scope var.
// developer can write pass using input / output tensor dims of op.
//
// Example: If you have node `Node* softmax_node`,
// you can get dims of output tensor in passes:
//
// auto* scope = softmax_node->AsStmt().op()->scope();
// auto softmax_out_arg_name =
// softmax_node->outlinks.front()->AsArg().name;
// auto softmax_out_tensor =
// scope->FindVar(softmax_out_arg_name)->Get<lite::Tensor>();
// softmax_out_dims = softmax_out_tensor.dims();
void SetVarDescShapeToScopeVar() {
auto dims_to_str_func = [](std::vector<int64_t> shape) -> std::string {
std::string str_res;
for (size_t i = 0; i < shape.size(); ++i) {
str_res += std::to_string(shape[i]);
if (i != shape.size() - 1) {
str_res += "x";
}
}
return str_res;
};
auto* program_desc = program_->program_desc();
VLOG(5) << "program_desc->BlocksSize():" << program_desc->BlocksSize();
auto blocks_desc = program_desc->GetBlocks();
for (size_t bidx = 0; bidx < blocks_desc.size(); ++bidx) {
auto block_desc = blocks_desc[bidx];
auto vars_desc = block_desc.GetVars();
for (size_t vidx = 0; vidx < vars_desc.size(); ++vidx) {
auto var_desc = vars_desc[vidx];
VLOG(5) << var_desc.Name() << " "
<< dims_to_str_func(var_desc.GetShape());
if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue;
auto* var = program_->exec_scope()->FindVar(var_desc.Name());
auto tensor = var->GetMutable<lite::Tensor>();
if (tensor->dims().size() == 0 && var_desc.GetShape().size() != 0) {
VLOG(5) << "var_desc.Name():" << var_desc.Name()
<< " shape:" << dims_to_str_func(var_desc.GetShape());
tensor->Resize(var_desc.GetShape());
}
VLOG(5) << "var_desc.Name():" << var_desc.Name()
<< " shape:" << dims_to_str_func(var_desc.GetShape())
<< " tensor:" << tensor->dims();
}
}
}
// Generate a new program based on the mir graph.
std::unique_ptr<RuntimeProgram> GenRuntimeProgram() {
auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
......@@ -198,6 +248,7 @@ class Optimizer {
// Specify the passes and run them.
void RunPasses(const std::vector<std::string>& passes) {
SetVarDescShapeToScopeVar();
for (auto& x : passes) {
LOG(INFO) << "== Running pass: " << x;
mir::Pass* pass = mir::PassManager::Global().LookUp(x);
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/core/profile/profiler.h"
#include <iomanip>
#include <map>
#include <string>
#include <utility>
......@@ -23,10 +24,11 @@ namespace profile {
namespace {
auto op_comp = [](const OpCharacter& c1, const OpCharacter& c2) {
return (c1.target < c2.target) || (c1.op_type < c2.op_type) ||
(c1.kernel_name < c2.kernel_name) || (c1.remark < c2.remark);
// compare for unique key of map
return (c1.kernel_name + c1.kernel_func_name <
c2.kernel_name + c2.kernel_func_name);
};
}
} // namespace
std::map<Type, std::string> TypeStr{
{Type::kUnk, "Unknown"},
......@@ -64,22 +66,62 @@ int Profiler::NewTimer(const OpCharacter& ch) {
return units_.size() - 1;
}
OpCharacter* Profiler::GetOpCharacter(const size_t index) {
CHECK_LT(index, units_.size())
<< "The timer index in the profiler is out of range.";
return &units_[index].Character();
}
void Profiler::StartTiming(Type type, const int index, KernelContext* ctx) {
CHECK_LT(index, units_.size())
<< "The timer index in the profiler is out of range.";
units_[index].Timer(type)->Start(ctx);
}
float Profiler::StopTiming(Type type, const int index, KernelContext* ctx) {
void Profiler::StopTiming(Type type, const int index, KernelContext* ctx) {
CHECK_LT(index, units_.size())
<< "The timer index in the profiler is out of range.";
return units_[index].Timer(type)->Stop(ctx);
units_[index].Timer(type)->Stop(ctx);
#ifdef LITE_WITH_OPENCL
units_[index].Timer(type)->CLStop(units_[index].character.op_type,
units_[index].character.io_duration,
units_[index].character.cl_event);
#endif
}
int Profiler::GetKernelFuncCalledTimes(const std::string& op_type,
const std::string& kernel_attr,
const std::string& kernel_func_name) {
int count = 0;
for (size_t i = 0; i < units_.size(); ++i) {
if ((units_[i].character.kernel_func_name == kernel_func_name) &&
(units_[i].character.kernel_attr == kernel_attr) &&
(units_[i].character.op_type == op_type)) {
++count;
}
}
return count;
}
float Profiler::GetKernelFuncSummaryGOPs(const std::string& op_type,
const std::string& kernel_attr,
const std::string& kernel_func_name) {
float GOPs = 0;
for (size_t i = 0; i < units_.size(); ++i) {
if ((units_[i].character.kernel_func_name == kernel_func_name) &&
(units_[i].character.kernel_attr == kernel_attr) &&
(units_[i].character.op_type == op_type)) {
GOPs += units_[i].character.macs;
}
}
return GOPs * 1e-9f;
}
std::string Profiler::Summary(Type type, bool concise, size_t w) {
using std::setw;
using std::left;
using std::fixed;
using std::setprecision;
STL::stringstream ss;
std::string title;
// Title.
......@@ -94,14 +136,41 @@ std::string Profiler::Summary(Type type, bool concise, size_t w) {
<< " Profiler Summary: " << name_ << ", Exclude " << w
<< " warm-ups =====" << std::endl;
}
ss << setw(25) << left << "Operator Type"
<< " " << setw(40) << left << "Kernel Name"
<< " " << setw(12) << left << "Remark"
<< " " << setw(12) << left << "Avg (ms)"
<< " " << setw(12) << left << "Min (ms)"
<< " " << setw(12) << left << "Max (ms)"
<< " " << setw(12) << left << "Last (ms)"
<< " " << setw(12) << left << "Percent (%)" << std::endl;
ss << setw(20) << left << "OperatorType"
<< " " << setw(30) << left << "KerneAttr(Place)"
<< " " << setw(24) << left << "KernelFuncName";
if (!concise) {
ss << " " << setw(26) << left << "Remark"
<< " " << setw(15) << left << "InDim"
<< " " << setw(15) << left << "FilterDim"
<< " " << setw(15) << left << "OutDim";
}
ss << " " << setw(7) << left << "Avg(ms)"
<< " " << setw(7) << left << "Min(ms)"
<< " " << setw(7) << left << "Max(ms)";
if (!concise) {
ss << " " << setw(7) << left << "Last(ms)";
}
ss << " " << setw(7) << left << "Avg(%)"
<< " " << setw(7) << left << "GOPs";
if (!concise) {
ss << " " << setw(7) << left << "GOPS";
}
if (concise) {
ss << " " << setw(11) << left << "CalledTimes";
}
#ifdef LITE_WITH_OPENCL
ss << " " << setw(9) << left << "clAvg(ms)"
<< " " << setw(9) << left << "clMin(ms)"
<< " " << setw(9) << left << "clMax(ms)"
<< " " << setw(9) << left << "clAvg(%)";
if (!concise) {
ss << " " << setw(12) << left << "GlobalWorkSize"
<< " " << setw(12) << left << "LocalWorkSize";
}
#endif
ss << std::endl;
// Profile information.
if (concise) {
std::map<OpCharacter, TimeInfo, decltype(op_comp)> summary(op_comp);
......@@ -111,32 +180,75 @@ std::string Profiler::Summary(Type type, bool concise, size_t w) {
ch->second.avg += unit.Timer(type)->LapTimes().Avg(w);
ch->second.min += unit.Timer(type)->LapTimes().Min(w);
ch->second.max += unit.Timer(type)->LapTimes().Max(w);
#ifdef LITE_WITH_OPENCL
ch->second.cl_avg += unit.Timer(type)->CLLapTimes().Avg(w);
ch->second.cl_min += unit.Timer(type)->CLLapTimes().Min(w);
ch->second.cl_max += unit.Timer(type)->CLLapTimes().Max(w);
#endif
} else {
TimeInfo info({unit.Timer(type)->LapTimes().Avg(w),
unit.Timer(type)->LapTimes().Min(w),
unit.Timer(type)->LapTimes().Max(w)});
TimeInfo info;
info.avg = unit.Timer(type)->LapTimes().Avg(w);
info.min = unit.Timer(type)->LapTimes().Min(w);
info.max = unit.Timer(type)->LapTimes().Max(w);
#ifdef LITE_WITH_OPENCL
info.cl_avg = unit.Timer(type)->CLLapTimes().Avg(w);
info.cl_min = unit.Timer(type)->CLLapTimes().Min(w);
info.cl_max = unit.Timer(type)->CLLapTimes().Max(w);
#endif
summary.insert({unit.Character(), info});
}
}
// compute total time
float total = 0.0;
for (const auto& item : summary) {
total += item.second.avg;
}
#ifdef LITE_WITH_OPENCL
float cl_total = 0.0;
for (const auto& item : summary) {
cl_total += item.second.cl_avg;
}
#endif
for (const auto& item : summary) {
float percent = 0;
if (total > 0) {
percent = 100 * (item.second.avg / total);
}
// clang-format off
ss << setw(25) << left << fixed << item.first.op_type \
<< " " << setw(40) << left << fixed << item.first.kernel_name \
<< " " << setw(12) << left << fixed << item.first.remark \
<< " " << setw(12) << left << fixed << item.second.avg \
<< " " << setw(12) << left << fixed << item.second.min \
<< " " << setw(12) << left << fixed << item.second.max \
<< " " << setw(12) << left << fixed << percent << "%" \
<< " " << std::endl;
ss << setw(20) << left << fixed << item.first.op_type
<< " " << setw(30) << left << fixed << item.first.kernel_attr
<< " " << setw(24) << left << fixed << item.first.kernel_func_name
<< " " << setw(7) << left << fixed << setprecision(3)
<< item.second.avg
<< " " << setw(7) << left << fixed << setprecision(3)
<< item.second.min
<< " " << setw(7) << left << fixed << setprecision(3)
<< item.second.max
<< " " << setprecision(2) << percent << "% "
<< " " << setw(7) << left << fixed << setprecision(3)
<< GetKernelFuncSummaryGOPs(item.first.op_type,
item.first.kernel_attr,
item.first.kernel_func_name)
<< " " << setw(11) << left << fixed
<< GetKernelFuncCalledTimes(item.first.op_type,
item.first.kernel_attr,
item.first.kernel_func_name);
#ifdef LITE_WITH_OPENCL
float cl_percent = 0;
if (cl_total > 0) {
cl_percent = 100 * (item.second.cl_avg / cl_total);
}
ss << " " << setw(9) << left << fixed << setprecision(3)
<< item.second.cl_avg
<< " " << setw(9) << left << fixed << setprecision(3)
<< item.second.cl_min
<< " " << setw(9) << left << fixed << setprecision(3)
<< item.second.cl_max
<< " " << left << fixed << setprecision(2) << cl_percent << "% ";
#endif
ss << std::endl;
// clang-format on
}
} else {
......@@ -145,6 +257,13 @@ std::string Profiler::Summary(Type type, bool concise, size_t w) {
const auto& times = unit.Timer(type)->LapTimes();
total += times.Avg(w);
}
#ifdef LITE_WITH_OPENCL
float cl_total = 0.0;
for (auto& unit : units_) {
const auto& cl_times = unit.Timer(type)->CLLapTimes();
cl_total += cl_times.Avg(w);
}
#endif
for (auto& unit : units_) {
const auto& times = unit.Timer(type)->LapTimes();
float run = times.Avg(w);
......@@ -152,17 +271,46 @@ std::string Profiler::Summary(Type type, bool concise, size_t w) {
if (total > 0) {
percent = 100 * (run / total);
}
#ifdef LITE_WITH_OPENCL
const auto& cl_times = unit.Timer(type)->CLLapTimes();
float cl_run = cl_times.Avg(w);
float cl_percent = 0;
if (cl_total > 0) {
cl_percent = 100 * (cl_run / cl_total);
}
#endif
// clang-format off
ss << setw(25) << left << fixed << unit.Character().op_type \
<< " " << setw(40) << left << fixed << unit.Character().kernel_name \
<< " " << setw(12) << left << fixed << unit.Character().remark \
<< " " << setw(12) << left << fixed << times.Avg(w) \
<< " " << setw(12) << left << fixed << times.Min(w) \
<< " " << setw(12) << left << fixed << times.Max(w) \
<< " " << setw(12) << left << fixed << times.Last(w) \
<< " " << setw(12) << left << fixed << percent << "%" \
<< std::endl;
// clang-format on
ss << setw(20) << left << fixed << unit.Character().op_type
<< " " << setw(30) << left << fixed << unit.Character().kernel_attr
<< " " << setw(24) << left << fixed
<< unit.Character().kernel_func_name
<< " " << setw(26) << left << fixed << unit.Character().remark
<< " " << setw(15) << left << fixed << unit.Character().input_shape
<< " " << setw(15) << left << fixed << unit.Character().filter_shape
<< " " << setw(15) << left << fixed << unit.Character().output_shape
<< " " << setw(7) << left << fixed << setprecision(3) << times.Avg(w)
<< " " << setw(7) << left << fixed << setprecision(3) << times.Min(w)
<< " " << setw(7) << left << fixed << setprecision(3) << times.Max(w)
<< " " << setw(7) << left << fixed << setprecision(3) << times.Last(w)
<< " " << left << setprecision(2) << percent << "% "
<< " " << setw(7) << left << fixed << setprecision(3)
<< 1e-9f * unit.Character().macs
<< " " << setw(7) << left << fixed << setprecision(2)
<< 1e-6f * unit.Character().macs / times.Avg(w);
// clang-format on
#ifdef LITE_WITH_OPENCL
ss << " " << setw(9) << left << fixed << setprecision(3)
<< cl_times.Avg(w) << " " << setw(9) << left << fixed
<< setprecision(3) << cl_times.Min(w) << " " << setw(9) << left
<< fixed << setprecision(3) << cl_times.Max(w) << " " << left
<< setprecision(2) << cl_percent << "% "
<< " " << setw(12) << left << fixed
<< unit.Character().global_work_size << " " << setw(12) << left
<< fixed << unit.Character().local_work_size;
#endif
ss << std::endl;
}
}
return ss.str();
......
......@@ -18,6 +18,10 @@
#include <string>
#include <vector>
#include "lite/core/profile/timer.h"
#include "lite/core/tensor.h"
#ifdef LITE_WITH_OPENCL
#include "lite/backends/opencl/cl_include.h"
#endif
namespace paddle {
namespace lite {
......@@ -35,25 +39,83 @@ struct TimeInfo {
float avg;
float min;
float max;
#ifdef LITE_WITH_OPENCL
float cl_avg;
float cl_min;
float cl_max;
#endif
};
struct OpCharacter {
TargetType target;
void* op_lite{nullptr};
std::string op_type{std::string("N/A")};
std::string kernel_name{std::string("N/A")};
std::string kernel_attr{std::string("N/A")};
std::string kernel_func_name{std::string("N/A")};
std::string remark{std::string("N/A")};
std::string input_shape{"N/A"};
std::string output_shape{"N/A"};
std::string filter_shape{"N/A"};
float macs{0};
float macs_ps{0};
float io_duration{0};
#ifdef LITE_WITH_OPENCL
cl::Event cl_event{};
std::string global_work_size{"N/A"};
std::string local_work_size{"N/A"};
std::string NDRangeToStr(const cl::NDRange& range) {
std::string range_str{""};
const size_t range_size = 3;
for (size_t i = 0; i < range_size /*range.size()*/; ++i) {
LOG(INFO) << "range[" << i << "]:" << std::to_string(range[i]);
range_str += std::to_string(range[i]);
if (i != range_size - 1) {
range_str += ",";
}
}
return range_str;
}
#else
void* cl_event{nullptr};
#endif
std::string DimToStr(const paddle::lite::DDimLite& dim) {
if (!dim.size()) return "NotImpl";
std::string dim_str{""};
for (size_t i = 0; i < dim.size(); ++i) {
dim_str += std::to_string(dim[i]);
if (i != dim.size() - 1) {
dim_str += "x";
}
}
return dim_str;
}
std::string str() {
std::string str{""};
str += kernel_name + "/" + kernel_func_name + "/" + remark + "/" +
input_shape + "/" + filter_shape + "/" + output_shape;
return str;
}
};
class StatisUnit final {
public:
explicit StatisUnit(const OpCharacter& ch);
lite::profile::Timer* Timer(Type type);
const OpCharacter& Character() const { return character; }
OpCharacter& Character() { return character; }
OpCharacter character;
protected:
std::unique_ptr<lite::profile::Timer> create_t;
std::unique_ptr<lite::profile::Timer> dispatch_t;
OpCharacter character;
};
class Profiler final {
......@@ -62,8 +124,15 @@ class Profiler final {
explicit Profiler(const std::string& name) : name_(name) {}
int NewTimer(const OpCharacter& ch);
void StartTiming(Type type, const int index, KernelContext* ctx);
float StopTiming(Type type, const int index, KernelContext* ctx);
void StopTiming(Type type, const int index, KernelContext* ctx);
std::string Summary(Type type, bool concise = true, size_t warm_up = 10);
int GetKernelFuncCalledTimes(const std::string& op_type,
const std::string& kernel_attr,
const std::string& kernel_func_name);
float GetKernelFuncSummaryGOPs(const std::string& op_type,
const std::string& kernel_attr,
const std::string& kernel_func_name);
OpCharacter* GetOpCharacter(const size_t index);
private:
std::string name_{std::string("N/A")};
......
......@@ -15,6 +15,7 @@
#pragma once
#include <algorithm>
#include <chrono> // NOLINT
#include <string>
#include <vector>
#ifdef LITE_WITH_CUDA
#include "lite/backends/cuda/cuda_utils.h"
......@@ -87,6 +88,22 @@ class Timer {
this->laps_t_.Add(elapse_ms);
return elapse_ms;
}
#ifdef LITE_WITH_OPENCL
float CLStop(const std::string& op_type, float io_duration, cl::Event event) {
float cl_kernel_elapse_ms = 0.0;
if (op_type != "io_copy") {
cl_kernel_elapse_ms =
CLRuntime::Global()->CLRuntime::GetCommandTime(event);
} else {
cl_kernel_elapse_ms = io_duration;
}
this->cl_laps_t_.Add(cl_kernel_elapse_ms);
return cl_kernel_elapse_ms;
}
const TimeList<float>& CLLapTimes() const { return cl_laps_t_; }
#endif
virtual void Start(KernelContext* ctx) { return Start(); }
virtual float Stop(KernelContext* ctx) { return Stop(); }
float AvgLapTimeMs() const { return laps_t_.Avg(); }
......@@ -94,6 +111,9 @@ class Timer {
protected:
TimeList<float> laps_t_;
#ifdef LITE_WITH_OPENCL
TimeList<float> cl_laps_t_;
#endif
private:
std::chrono::time_point<std::chrono::system_clock> t_start_, t_stop_;
......
......@@ -169,7 +169,7 @@ void RuntimeProgram::Run() {
#endif // LITE_WITH_PRECISION_PROFILE
}
#ifdef LITE_WITH_PROFILE
LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kDispatch, false, 0);
LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kDispatch, false, 1);
#endif
#ifdef LITE_WITH_PRECISION_PROFILE
LOG(INFO) << "\n" << precision_profiler_summary;
......@@ -299,6 +299,14 @@ void Instruction::Run() {
op_->InferShape();
kernel_->Launch();
has_run_ = true;
#ifdef LITE_WITH_PROFILE
if (first_epoch_for_profiler_) {
kernel_->SetIsKernelTest(false);
SetProfileRuntimeOpInfo(profiler_->GetOpCharacter(profile_id_));
first_epoch_for_profiler_ = false;
}
#endif
}
STL::ostream& operator<<(STL::ostream& os, const Instruction& other) {
......
......@@ -23,6 +23,9 @@
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/model_parser/cpp/program_desc.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
namespace paddle {
namespace lite {
......@@ -64,6 +67,8 @@ struct Program {
lite::Scope* exec_scope() { return exec_scope_; }
lite::Scope* scope() { return scope_.get(); }
cpp::ProgramDesc* program_desc() { return &desc_; }
const std::unordered_map<std::string, PrecisionType>& var_data_type() const {
return var_data_type_;
}
......@@ -125,13 +130,22 @@ struct Instruction {
profiler_ = profiler;
if (op_->Type() != "feed" && op_->Type() != "fetch") {
profile::OpCharacter ch;
ch.op_lite = static_cast<void*>(const_cast<paddle::lite::OpLite*>(op()));
ch.target = kernel()->target();
ch.op_type = op_->Type();
ch.kernel_name = kernel()->name();
ch.kernel_attr = kernel()->name().substr(ch.op_type.size() + 1,
kernel()->name().size());
// append `ch.kernel_func_name` in StopTiming
profile_id_ = profiler->NewTimer(ch);
kernel_->SetProfiler(profiler_, profile_id_);
}
}
void SetProfileRuntimeOpInfo(paddle::lite::profile::OpCharacter* ch) {
auto* op_lite = static_cast<paddle::lite::OpLite*>(ch->op_lite);
op_lite->GetOpRuntimeInfo(ch);
}
#endif
private:
......@@ -144,6 +158,7 @@ struct Instruction {
#ifdef LITE_WITH_PROFILE
profile::Profiler* profiler_;
int profile_id_{-1};
bool first_epoch_for_profiler_{true};
#endif // LITE_WITH_PROFILE
};
......
......@@ -35,6 +35,9 @@ void ArgmaxCompute::Run() {
}
lite::arm::math::argmax_func(input, axis, output);
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "argmax_func";
#endif
return;
}
......
......@@ -16,6 +16,10 @@
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/operators/argmax_op.h"
#ifdef LITE_WITH_PROFILE
#include <string>
#include "lite/core/profile/profiler.h"
#endif
namespace paddle {
namespace lite {
......@@ -29,6 +33,14 @@ class ArgmaxCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
void Run() override;
virtual ~ArgmaxCompute() = default;
#ifdef LITE_WITH_PROFILE
virtual void SetProfileRuntimeKernelInfo(
paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
std::string kernel_func_name_{"NotImplForArgmax"};
#endif
};
} // namespace arm
......
......@@ -33,6 +33,17 @@ void CalibComputeFp32ToInt8<DLType>::Run() {
din, dout, scale.data(), 1, 1, param.input->numel());
}
template <DataLayoutType DLType>
void CalibComputeInt64ToInt32<DLType>::Run() {
auto& param = this->template Param<operators::CalibParam>();
const auto* din = param.input->template data<int64_t>();
std::vector<float> scale = {param.scale};
auto* dout = param.output->template mutable_data<int32_t>();
for (auto i = 0; i < param.input->numel(); ++i) {
dout[i] = din[i];
}
}
template <DataLayoutType DLType>
void CalibComputeInt8ToFp32<DLType>::Run() {
auto& param = this->template Param<operators::CalibParam>();
......@@ -105,6 +116,23 @@ REGISTER_LITE_KERNEL(
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(
calib,
kARM,
kInt64,
kNCHW,
paddle::lite::kernels::arm::CalibComputeInt64ToInt32<DATALAYOUT(kNCHW)>,
int64_to_int32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt64),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt32),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(
calib_once,
kARM,
......@@ -161,3 +189,20 @@ REGISTER_LITE_KERNEL(
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(
calib_once,
kARM,
kInt64,
kNCHW,
paddle::lite::kernels::arm::CalibComputeInt64ToInt32<DATALAYOUT(kNCHW)>,
int64_to_int32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt64),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt32),
DATALAYOUT(kNCHW))})
.Finalize();
......@@ -34,6 +34,19 @@ class CalibComputeFp32ToInt8
private:
};
template <DataLayoutType DLType>
class CalibComputeInt64ToInt32
: public KernelLite<TARGET(kARM), PRECISION(kInt64), DLType> {
public:
using param_t = operators::CalibParam;
void Run() override;
~CalibComputeInt64ToInt32() override{};
private:
};
template <DataLayoutType DLType>
class CalibComputeInt8ToFp32
: public KernelLite<TARGET(kARM), PRECISION(kInt8), DLType> {
......
......@@ -62,8 +62,19 @@ void CastCompute::Run() {
int32_t* out_data = param.Out->mutable_data<int32_t>();
std::transform(
x_data_begin, x_data_end, out_data, TransOp<int64_t, int32_t>);
} else if (param.in_dtype == 0 && param.out_dtype == 5) { // bool->fp32
const bool* x_data_begin = param.X->data<bool>();
const bool* x_data_end = x_data_begin + param.X->numel();
float* out_data = param.Out->mutable_data<float>();
std::transform(x_data_begin, x_data_end, out_data, TransOp<bool, float>);
} else if (param.in_dtype == 3 && param.out_dtype == 5) { // int64->fp32
const int64_t* x_data_begin = param.X->data<int64_t>();
const int64_t* x_data_end = x_data_begin + param.X->numel();
float* out_data = param.Out->mutable_data<float>();
std::transform(x_data_begin, x_data_end, out_data, TransOp<int64_t, float>);
} else {
LOG(FATAL) << "other has not been implemented";
LOG(FATAL) << "other has not been implemented transform with dtype"
<< param.in_dtype << " X, dtype" << param.out_dtype << " Out";
}
}
......
......@@ -15,6 +15,9 @@
#pragma once
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/kernel.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
namespace paddle {
namespace lite {
......@@ -36,6 +39,13 @@ class ConvCompute : public KernelLite<TARGET(kARM), Ptype> {
impl_->Run();
}
#ifdef LITE_WITH_PROFILE
virtual void SetProfileRuntimeKernelInfo(
paddle::lite::profile::OpCharacter* ch) {
impl_->SetProfileRuntimeKernelInfo(ch);
}
#endif
~ConvCompute() {
if (impl_ != nullptr) {
delete impl_;
......
......@@ -50,6 +50,9 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
flag_trans_weights_ = true;
}
impl_ = lite::arm::math::conv_depthwise_3x3_fp32;
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_3x3_fp32";
#endif
} else if (kw == 5) {
// VLOG(5) << "invoke 5x5 dw conv fp32";
auto strides = param.strides;
......@@ -67,6 +70,9 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
w_data_in, w_data, oc, 1, cblock, kh * kw);
flag_trans_weights_ = true;
impl_ = lite::arm::math::conv_depthwise_5x5_fp32;
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_5x5_fp32";
#endif
} else {
LOG(FATAL)
<< "5x5 depthwise conv only support stride == 1 or stride == 2";
......@@ -103,6 +109,9 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
// trans weights
// VLOG(5) << "invoke 3x3 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32;
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_3x3_int8_fp32";
#endif
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
auto wptr = param.filter->data<int8_t>();
......@@ -113,6 +122,9 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
// trans weights
// VLOG(5) << "invoke 5x5 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_fp32;
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_5x5_int8_fp32";
#endif
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
auto wptr = param.filter->data<int8_t>();
......@@ -162,6 +174,9 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
// trans weights
// VLOG(5) << "invoke 3x3 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8;
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_3x3_int8_int8";
#endif
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
auto wptr = param.filter->data<int8_t>();
......@@ -172,6 +187,9 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
// trans weights
// VLOG(5) << "invoke 5x5 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_int8;
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_5x5_int8_int8";
#endif
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
auto wptr = param.filter->data<int8_t>();
......@@ -183,6 +201,14 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
}
}
#ifdef LITE_WITH_PROFILE
template <>
void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::
SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
#endif
template <>
void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
......@@ -225,6 +251,14 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
w_scale_.data());
}
#ifdef LITE_WITH_PROFILE
template <>
void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::
SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
#endif
template <>
void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
......@@ -267,6 +301,14 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
w_scale_.data());
}
#ifdef LITE_WITH_PROFILE
template <>
void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::
SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
#endif
template <>
void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
auto& param = this->Param<param_t>();
......
......@@ -15,6 +15,7 @@
#pragma once
#include <cmath>
#include <string>
#include <vector>
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/core/context.h"
......@@ -48,6 +49,15 @@ class DepthwiseConv : public KernelLite<TARGET(kARM), Ptype> {
virtual void PrepareForRun();
virtual void Run();
#ifdef LITE_WITH_PROFILE
virtual void SetProfileRuntimeKernelInfo(
paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
std::string kernel_func_name_{"NotImplForConvDw"};
#endif
private:
using param_t = operators::ConvParam;
Tensor weights_;
......
......@@ -19,6 +19,14 @@ namespace lite {
namespace kernels {
namespace arm {
#ifdef LITE_WITH_PROFILE
template <>
void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::
SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
#endif
template <>
void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
......@@ -62,6 +70,9 @@ void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
b_data,
param,
&ctx);
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_3x3s1_direct_fp32";
#endif
} else {
lite::arm::math::conv_3x3s2_direct_fp32(i_data,
o_data,
......@@ -76,9 +87,20 @@ void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
b_data,
param,
&ctx);
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_3x3s2_direct_fp32";
#endif
}
}
#ifdef LITE_WITH_PROFILE
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::
SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
#endif
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
......@@ -117,6 +139,9 @@ void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
param,
&ctx,
w_scale_.data());
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_3x3s1_direct_int8";
#endif
} else {
lite::arm::math::conv_3x3s2_direct_int8(i_data,
o_data,
......@@ -132,9 +157,20 @@ void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
param,
&ctx,
w_scale_.data());
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_3x3s2_direct_int8";
#endif
}
}
#ifdef LITE_WITH_PROFILE
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kInt8)>::
SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
#endif
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
auto& param = this->Param<param_t>();
......@@ -173,6 +209,9 @@ void DirectConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
param,
&ctx,
w_scale_.data());
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_3x3s1_direct_int8";
#endif
} else {
lite::arm::math::conv_3x3s2_direct_int8(i_data,
o_data,
......@@ -188,6 +227,9 @@ void DirectConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
param,
&ctx,
w_scale_.data());
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_3x3s2_direct_int8";
#endif
}
}
......
......@@ -15,6 +15,7 @@
#pragma once
#include <cmath>
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/context.h"
......@@ -180,6 +181,15 @@ class DirectConv : public KernelLite<TARGET(kARM), Ptype> {
virtual void Run();
#ifdef LITE_WITH_PROFILE
virtual void SetProfileRuntimeKernelInfo(
paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
std::string kernel_func_name_{"NotImplForConvDirect"};
#endif
/// todo, support inplace weights transform
protected:
Tensor weights_;
......
......@@ -81,6 +81,14 @@ void GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
}
}
#ifdef LITE_WITH_PROFILE
template <>
void GemmLikeConv<PRECISION(kFloat), PRECISION(kFloat)>::
SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
#endif
template <>
void GemmLikeConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
......@@ -111,12 +119,26 @@ void GemmLikeConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
if (flag_1x1gemm_) {
lite::arm::math::conv1x1s1_gemm(
din, dout, bs, oc, oh, ow, ic, ih, iw, weights, bias, param, &ctx);
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv1x1s1_gemm";
#endif
} else {
lite::arm::math::conv_im2col_gemm(
din, dout, bs, oc, oh, ow, ic, ih, iw, weights, bias, param, &ctx);
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_im2col_gemm";
#endif
}
}
#ifdef LITE_WITH_PROFILE
template <>
void GemmLikeConv<PRECISION(kInt8), PRECISION(kFloat)>::
SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
#endif
template <>
void GemmLikeConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
......@@ -159,6 +181,9 @@ void GemmLikeConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
param,
&ctx,
w_scale_.data());
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv1x1s1_gemm_int8";
#endif
} else {
lite::arm::math::conv_im2col_gemm_int8(din,
dout,
......@@ -174,9 +199,20 @@ void GemmLikeConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
param,
&ctx,
w_scale_.data());
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_im2col_gemm_int8";
#endif
}
}
#ifdef LITE_WITH_PROFILE
template <>
void GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>::
SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
#endif
template <>
void GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
auto& param = this->Param<param_t>();
......@@ -219,6 +255,9 @@ void GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
param,
&ctx,
w_scale_.data());
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv1x1s1_gemm_int8";
#endif
} else {
lite::arm::math::conv_im2col_gemm_int8(din,
dout,
......@@ -234,6 +273,9 @@ void GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
param,
&ctx,
w_scale_.data());
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_im2col_gemm_int8";
#endif
}
}
......
......@@ -15,6 +15,7 @@
#pragma once
#include <cmath>
#include <string>
#include <vector>
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/funcs.h"
......@@ -94,6 +95,15 @@ class GemmLikeConv : public KernelLite<TARGET(kARM), Ptype> {
virtual void PrepareForRun();
virtual void Run();
#ifdef LITE_WITH_PROFILE
virtual void SetProfileRuntimeKernelInfo(
paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
std::string kernel_func_name_{"NotImplForConvGemm"};
#endif
/// todo, support inplace weights transform
protected:
using param_t = operators::ConvParam;
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <string>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/kernel.h"
#include "lite/operators/conv_transpose_op.h"
......@@ -33,6 +34,14 @@ class Conv2DTransposeCompute
~Conv2DTransposeCompute() = default;
#ifdef LITE_WITH_PROFILE
virtual void SetProfileRuntimeKernelInfo(
paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
std::string kernel_func_name_{"NotImplForConvTranspose"};
#endif
protected:
int workspace_size_{0};
};
......
......@@ -93,6 +93,14 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
ReInitWhenNeeded();
}
#ifdef LITE_WITH_PROFILE
template <>
void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::
SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
#endif
template <>
void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
......@@ -129,6 +137,9 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
b_data,
param,
&ctx);
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_compute_6x6_3x3";
#endif
} else {
int tile_block = 8;
int block_count =
......@@ -147,6 +158,9 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
b_data,
param,
&ctx);
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_compute_2x2_3x3";
#endif
} else {
lite::arm::math::conv_compute_2x2_3x3_small(i_data,
o_data,
......@@ -161,6 +175,9 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
b_data,
param,
&ctx);
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_compute_2x2_3x3_small";
#endif
}
}
}
......
......@@ -35,6 +35,13 @@ class WinogradConv : public KernelLite<TARGET(kARM), Ptype> {
virtual void PrepareForRun();
virtual void ReInitWhenNeeded();
virtual void Run();
#ifdef LITE_WITH_PROFILE
virtual void SetProfileRuntimeKernelInfo(
paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
std::string kernel_func_name_{"NotImplForConvWino"};
#endif
protected:
using param_t = operators::ConvParam;
......
......@@ -300,11 +300,12 @@ void ElementwiseMaxActivationCompute::Run() {
}
}
void ElementwiseDivCompute::Run() {
auto& param = Param<operators::ElementwiseParam>();
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
float* out_data = param.Out->mutable_data<float>();
template <typename T, PrecisionType PType>
void ElementwiseDivCompute<T, PType>::Run() {
auto& param = this->template Param<operators::ElementwiseParam>();
auto* x_data = param.X->template data<T>();
auto* y_data = param.Y->template data<T>();
auto* out_data = param.Out->template mutable_data<T>();
int axis = param.axis;
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
......@@ -313,10 +314,10 @@ void ElementwiseDivCompute::Run() {
LOG(FATAL) << "elewise div don't support x_dims size < y_dims size";
}
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_div_broadcast(
lite::arm::math::elementwise_div_broadcast<T>(
x_data, y_data, out_data, pre, n, post);
} else {
lite::arm::math::elementwise_div(
lite::arm::math::elementwise_div<T>(
x_data, y_data, out_data, x_dims.production());
}
}
......@@ -465,17 +466,27 @@ REGISTER_LITE_KERNEL(
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_div,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ElementwiseDivCompute,
def)
using elementwise_div_fp32 =
paddle::lite::kernels::arm::ElementwiseDivCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(
elementwise_div, kARM, kFloat, kNCHW, elementwise_div_fp32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
using elementwise_div_int64 =
paddle::lite::kernels::arm::ElementwiseDivCompute<int64_t,
PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(
elementwise_div, kARM, kInt64, kNCHW, elementwise_div_int64, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize();
REGISTER_LITE_KERNEL(
fusion_elementwise_div_activation,
kARM,
......
......@@ -86,8 +86,8 @@ class ElementwiseMaxActivationCompute
virtual ~ElementwiseMaxActivationCompute() = default;
};
class ElementwiseDivCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
template <typename T, PrecisionType PType>
class ElementwiseDivCompute : public KernelLite<TARGET(kARM), PType> {
public:
void Run() override;
......
......@@ -73,7 +73,6 @@ void GatherCompute::Run() {
REGISTER_LITE_KERNEL(
gather, kARM, kAny, kNCHW, paddle::lite::kernels::arm::GatherCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("Index",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Index", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize();
......@@ -34,6 +34,9 @@ add_kernel(instance_norm_opencl OPENCL basic SRCS instance_norm_image_compute.cc
add_kernel(dropout_opencl OPENCL basic SRCS dropout_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(pad2d_opencl OPENCL basic SRCS pad2d_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(box_coder_opencl OPENCL basic SRCS box_coder_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(pixel_shuffle_opencl OPENCL basic SRCS pixel_shuffle_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(expand_opencl OPENCL basic SRCS expand_image_compute.cc DEPS ${cl_kernel_deps})
# extra
# wait to add ...
......@@ -73,6 +76,12 @@ lite_cc_test(test_concat_image_opencl SRCS concat_image_compute_test.cc
lite_cc_test(test_layout_image_opencl SRCS layout_image_compute_test.cc
DEPS layout_opencl op_registry program context)
lite_cc_test(test_pixel_shuffle_image_opencl SRCS pixel_shuffle_image_compute_test.cc
DEPS pixel_shuffle_opencl op_registry program context)
lite_cc_test(test_expand_image_opencl SRCS expand_image_compute_test.cc
DEPS expand_opencl op_registry program context)
lite_cc_test(test_elementwise_add_image_opencl SRCS elementwise_add_image_compute_test.cc
DEPS elementwise_add_opencl fusion_elementwise_add_activation_opencl op_registry program context)
lite_cc_test(test_elementwise_sub_image_opencl SRCS elementwise_sub_image_compute_test.cc
......
......@@ -18,6 +18,10 @@
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -63,16 +67,24 @@ class ReluCompute
auto global_work_size = cl::NDRange{count};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
private:
std::string kernel_func_name_{"relu"};
std::string build_options_{"-DCL_DTYPE_float -DRELU"};
......@@ -120,16 +132,24 @@ class SigmoidCompute
auto global_work_size = cl::NDRange{count};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
private:
std::string kernel_func_name_{"sigmoid"};
std::string build_options_{"-DCL_DTYPE_float -DSIGMOID"};
......
......@@ -19,6 +19,10 @@
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -148,16 +152,24 @@ class ActivationComputeImageDefault
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size_,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
private:
param_t* act_param_{nullptr};
DDim x_img_shape_ = DDim(std::vector<DDim::value_type>(
......
......@@ -23,6 +23,10 @@
#include "lite/operators/op_params.h"
#include "lite/utils/logging.h"
#include "lite/utils/replace_stl/stream.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -142,13 +146,13 @@ class BilinearInterpImageCompute
static_cast<cl::size_type>(default_work_size[1]),
static_cast<cl::size_type>(default_work_size[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
#ifdef LITE_WITH_LOG
VLOG(4) << "global_work_size:[2D]:" << global_work_size[0] << " "
......
......@@ -23,6 +23,10 @@
#include "lite/operators/op_params.h"
#include "lite/utils/logging.h"
#include "lite/utils/replace_stl/stream.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -121,13 +125,13 @@ class BoxCoderComputeImage : public KernelLite<TARGET(kOpenCL),
cl::NDRange{static_cast<cl::size_type>(default_work_size[0]),
static_cast<cl::size_type>(default_work_size[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
#ifdef LITE_WITH_LOG
......@@ -138,6 +142,14 @@ class BoxCoderComputeImage : public KernelLite<TARGET(kOpenCL),
}
std::string doc() { return "Boxcoder using cl::Image, kFP16"; }
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
param_t* boxcoder_param_{nullptr};
std::string kernel_func_name_{};
std::string build_options_{" -DCL_DTYPE_half"};
......
......@@ -18,6 +18,10 @@
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -124,13 +128,13 @@ class ConcatCompute : public KernelLite<TARGET(kOpenCL),
status = kernel.setArg(++arg_idx, total1);
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
} else {
auto start = 0;
......@@ -157,13 +161,13 @@ class ConcatCompute : public KernelLite<TARGET(kOpenCL),
status = kernel.setArg(++arg_idx, total0);
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
start += size;
}
......@@ -172,6 +176,14 @@ class ConcatCompute : public KernelLite<TARGET(kOpenCL),
std::string doc() { return "Concat using cl::Buffer, kFloat"; }
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
int axis_size_ = 1;
int post_size_ = 1;
int pre_size_ = 1;
......
......@@ -19,6 +19,10 @@
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -246,6 +250,14 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
std::string doc() { return "Concat using cl::Image, kFP16"; }
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
int axis_size_ = 1;
int axis_ = 1;
int flag_ = 1;
......
......@@ -23,6 +23,10 @@
#include "lite/core/tensor.h"
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -39,6 +43,14 @@ class ConvCompute
void Run() override;
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_names_[0];
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
private:
void GemmlikeConv2d();
void Conv2d1x1();
......
......@@ -22,94 +22,89 @@
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#undef LITE_WITH_LOG
namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {
void ConvImageCompute::PrepareForRun() {
const auto& param = this->Param<param_t>();
auto x_dims = param.x->dims();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
ReInitWhenNeeded();
auto filter_dims = conv_param_->filter->dims();
filter_tensor_n_ = filter_dims[0];
filter_tensor_c_ = filter_dims[1];
filter_tensor_h_ = filter_dims[2];
filter_tensor_w_ = filter_dims[3];
float* filter_cpu = param.filter->mutable_data<float>();
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
filter_gpu_image_ = std::unique_ptr<Tensor>(new Tensor);
tensor_hold_filter_image_ = std::unique_ptr<Tensor>(new Tensor);
tensor_hold_bias_image_ = std::unique_ptr<Tensor>(new Tensor);
int bs = x_dims[0];
int c_in = x_dims[1];
int h_out = output_dims[2];
int w_out = output_dims[3];
int kernel_h = filter_dims[2]; // oihw
int kernel_w = filter_dims[3];
auto paddings = *param.paddings;
auto dilations = *param.dilations;
int stride_h = param.strides[0];
int stride_w = param.strides[1];
int pad_h = paddings[0];
int pad_w = paddings[2];
int groups = param.groups;
bool relu_fused = param.fuse_relu;
bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1);
bool zero_pad = (pad_h == 0) && (pad_w == 0);
bool pad_equal =
((paddings[0] == paddings[1]) && (paddings[1] == paddings[2]) &&
(paddings[2] == paddings[3]));
bool stride_equal = stride_h == stride_w;
bool dilation_equal = dilations[0] == dilations[1];
VLOG(3) << "Is relu fused? / " << (relu_fused ? "Yes" : "No");
VLOG(3) << "groups:" << groups << " stride_h:" << stride_h
<< " stride_w:" << stride_w << " pad_h:" << pad_h
<< " pad_w:" << pad_w << " kernel_h:" << kernel_h
<< " kernel_h:" << kernel_h;
VLOG(3) << "x_dims:" << x_dims[0] << " " << x_dims[1] << " " << x_dims[2]
<< " " << x_dims[3];
VLOG(3) << "dialtion:" << dilations[0] << " " << dilations[1];
VLOG(3) << "output_dims:" << output_dims[0] << " " << output_dims[1] << " "
<< output_dims[2] << " " << output_dims[3];
VLOG(3) << "filter_dims:" << filter_dims[0] << " " << filter_dims[1] << " "
<< filter_dims[2] << " " << filter_dims[3];
const bool is_mali = context.cl_context()->IsArmMali();
auto paddings = *conv_param_->paddings;
pad_up_ = paddings[0];
pad_down_ = paddings[1];
pad_left_ = paddings[2];
pad_right_ = paddings[3];
auto dilations = *conv_param_->dilations;
dilation_h_ = dilations[0];
dilation_w_ = dilations[1];
stride_h_ = conv_param_->strides[0];
stride_w_ = conv_param_->strides[1];
groups_ = conv_param_->groups;
relu_fused_ = conv_param_->fuse_relu;
has_bias_ = (conv_param_->bias) != nullptr;
offset_ = filter_tensor_h_ / 2 - pad_up_;
bool pad_equal = ((pad_left_ == pad_up_) && (pad_up_ == pad_left_) &&
(pad_left_ == pad_right_));
bool stride_equal = stride_h_ == stride_w_;
bool dilation_equal = dilation_h_ == dilation_w_;
VLOG(3) << "Is arm mali / " << (is_mali ? "Yes" : "No");
VLOG(3) << "Is relu fused? / " << (relu_fused_ ? "Yes" : "No");
VLOG(3) << "groups:" << groups_ << " stride_h_:" << stride_h_
<< " stride_w_:" << stride_w_ << " pad_left_:" << pad_left_
<< " pad_up_:" << pad_up_ << " filter_tensor_h_:" << filter_tensor_h_
<< " filter_tensor_h_:" << filter_tensor_h_;
VLOG(3) << "input_tensor_nchw:" << input_tensor_n_ << " " << input_tensor_c_
<< " " << input_tensor_h_ << " " << input_tensor_w_;
VLOG(3) << "dialtion:" << dilation_h_ << " " << dilation_w_;
VLOG(3) << "output_dims:" << output_tensor_n_ << " " << output_tensor_c_
<< " " << output_tensor_h_ << " " << output_tensor_w_;
VLOG(3) << "filter_dims:" << filter_tensor_n_ << " " << filter_tensor_c_
<< " " << filter_tensor_h_ << " " << filter_tensor_w_;
VLOG(3) << "pad_equal:" << pad_equal;
VLOG(3) << "stride_equal:" << stride_equal;
VLOG(3) << "dilation_equal:" << dilation_equal;
VLOG(3) << "padding :" << paddings[0] << " " << paddings[1] << " "
<< paddings[2] << " " << paddings[3];
VLOG(3) << "padding :" << pad_up_ << " " << pad_down_ << " " << pad_left_
<< " " << pad_right_;
CHECK(pad_equal && stride_equal && dilation_equal);
CHECK_GE(conv_param_->dilations->size(), 2);
CHECK(dilation_h_ == dilation_w_);
CHECK_GE(conv_param_->paddings->size(), 2);
CHECK(pad_left_ == pad_up_);
CHECK_GE(conv_param_->strides.size(), 2);
CHECK(stride_h_ == stride_w_);
if (!is_mali) {
use_tune_ = false;
}
// general gws..
auto out_image_shape = InitImageDimInfoWith(output_dims);
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
default_c_blk_ = default_work_size[0];
default_w_blk_ = default_work_size[1];
default_nh_blk_ = default_work_size[2];
c_blk_ = default_c_blk_;
w_blk_ = default_w_blk_;
nh_blk_ = default_nh_blk_;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
if (kernel_h == 1 && kernel_w == 1) {
// conv2d_1x1
// if (param.x->dims()[1] % 4 == 0) {
// kernel_func_names_.push_back("conv2d_1x1_simple");
// } else {
// kernel_func_names_.push_back("conv2d_1x1_opt");
// }
/*********************************************
* Upload filter, bias to opencl device
*********************************************/
float* filter_cpu = conv_param_->filter->mutable_data<float>();
filter_gpu_image_ = std::unique_ptr<Tensor>(new Tensor);
tensor_hold_filter_image_ = std::unique_ptr<Tensor>(new Tensor);
tensor_hold_bias_image_ = std::unique_ptr<Tensor>(new Tensor);
if (param.x->dims()[1] % 4 == 0) {
if (filter_tensor_h_ == 1 && filter_tensor_h_ == 1) {
if (input_tensor_c_ % 4 == 0) {
kernel_func_names_.push_back("conv2d_1x1_simple");
} else {
kernel_func_names_.push_back("conv2d_1x1_opt");
......@@ -118,89 +113,49 @@ void ConvImageCompute::PrepareForRun() {
CLImageConverterNWBlock converter;
const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims);
// std::vector<half_t> filter_image_v(filter_image_dims[0] *
// filter_image_dims[1] * 4); // 4 :
// RGBA
tensor_hold_filter_image_->Resize(
{1, filter_image_dims[0], filter_image_dims[1], 4});
filter_image_h_ = filter_image_dims[1];
filter_image_w_ = filter_image_dims[0];
tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4});
half_t* filter_image_data =
tensor_hold_filter_image_->mutable_data<half_t>();
converter.NCHWToImage(filter_cpu, filter_image_data, filter_dims);
filter_gpu_image_->mutable_data<half_t, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_data);
filter_image_w_, filter_image_h_, filter_image_data);
impl_ = &ConvImageCompute::Conv2d1x1opt;
{
// calc 1x1 gws
w_blk_ = maptofactor(default_w_blk_, 4);
c_blk_ = default_c_blk_;
nh_blk_ = default_nh_blk_;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
}
#define DEPTH_CONV_USE_SPL
#ifdef DEPTH_CONV_USE_SPL
} else if (filter_dims[1] == 1 && x_dims[1] == output_dims[1] &&
kernel_h == 3 && kernel_w == 3 && groups > 1) {
} else if (filter_tensor_c_ == 1 && input_tensor_c_ == output_tensor_c_ &&
filter_tensor_h_ == 3 && filter_tensor_w_ == 3 && groups_ > 1) {
// depth_conv2d_3x3s1, depth_conv2d_3x3
if (stride_h == 1 && dilations[0] == 1) {
if (stride_h_ == 1 && dilation_h_ == 1) {
kernel_func_names_.push_back("depth_conv2d_3x3s1");
impl_ = &ConvImageCompute::DepthwiseConv2d3x3s1;
{
// depthwise spl gws s1
int c_block = (output_dims[1] + 3) / 4;
int w = output_dims[3];
int nh = output_dims[0] * output_dims[2];
int w_blk_size = 2;
int w_blk = (w + w_blk_size - 1) / w_blk_size;
c_blk_ = c_block;
w_blk_ = w_blk;
nh_blk_ = nh;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
}
} else {
kernel_func_names_.push_back("depth_conv2d_3x3");
impl_ = &ConvImageCompute::DepthwiseConv2d3x3;
{
// depthwise spl gws
int c_block = (output_dims[1] + 3) / 4;
int w = output_dims[3];
int nh = output_dims[0] * output_dims[2];
c_blk_ = c_block;
w_blk_ = w;
nh_blk_ = nh;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
}
}
kernel_func_paths_.push_back("image/depthwise_conv2d_kernel.cl");
CLImageConverterNWBlock converter;
const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims);
tensor_hold_filter_image_->Resize(
{1, filter_image_dims[0], filter_image_dims[1], 4});
filter_image_h_ = filter_image_dims[1];
filter_image_w_ = filter_image_dims[0];
tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4});
half_t* filter_image_data =
tensor_hold_filter_image_->mutable_data<half_t>();
converter.NCHWToImage(filter_cpu, filter_image_data, filter_dims);
filter_gpu_image_->mutable_data<half_t, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_data);
filter_image_w_, filter_image_h_, filter_image_data);
#endif
} else if (filter_dims[1] == 1 && x_dims[1] == output_dims[1]
} else if (filter_tensor_c_ == 1 && input_tensor_c_ == output_tensor_c_
#ifdef DEPTH_CONV_USE_SPL
&&
kernel_h != 3
filter_tensor_h_ != 3
#endif
#undef DEPTH_CONV_USE_SPL
) {
......@@ -210,76 +165,61 @@ void ConvImageCompute::PrepareForRun() {
CLImageConverterNWBlock converter;
const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims);
tensor_hold_filter_image_->Resize(
{1, filter_image_dims[0], filter_image_dims[1], 4});
filter_image_h_ = filter_image_dims[1];
filter_image_w_ = filter_image_dims[0];
tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4});
half_t* filter_image_data =
tensor_hold_filter_image_->mutable_data<half_t>();
converter.NCHWToImage(filter_cpu, filter_image_data, filter_dims);
filter_gpu_image_->mutable_data<half_t, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_data);
filter_image_w_, filter_image_h_, filter_image_data);
impl_ = &ConvImageCompute::DepthwiseConv2d;
} else if (kernel_w == 3 && kernel_h == 3) {
} else if (filter_tensor_h_ == 3 && filter_tensor_w_ == 3) {
// #define CONV3x3OPT_FALL_BACK
#ifndef CONV3x3OPT_FALL_BACK
// conv2d_3x3
kernel_func_names_.push_back(bs > 1 ? "conv2d_3x3_multi_batch"
kernel_func_names_.push_back(input_tensor_n_ > 1 ? "conv2d_3x3_multi_batch"
: "conv2d_3x3_opt");
kernel_func_paths_.push_back("image/conv2d_3x3_opt_kernel.cl");
CLImageConverterFolder converter;
const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims);
tensor_hold_filter_image_->Resize(
{1, filter_image_dims[0], filter_image_dims[1], 4});
filter_image_h_ = filter_image_dims[1];
filter_image_w_ = filter_image_dims[0];
tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4});
half_t* filter_image_data =
tensor_hold_filter_image_->mutable_data<half_t>();
converter.NCHWToImage(filter_cpu, filter_image_data, filter_dims);
filter_gpu_image_->mutable_data<half_t, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_data);
filter_image_w_, filter_image_h_, filter_image_data);
impl_ = &ConvImageCompute::Conv2d3x3opt;
{
int w_blk_size = 5;
int w_blk = (default_w_blk_ + w_blk_size - 1) / w_blk_size;
int h_blk_size = 1;
int h_blk = (default_nh_blk_ + h_blk_size - 1) / h_blk_size;
c_blk_ = default_c_blk_;
w_blk_ = w_blk;
nh_blk_ = h_blk;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
}
#else
kernel_func_names_.push_back("conv2d_3x3");
kernel_func_paths_.push_back("image/conv2d_3x3_kernel.cl");
CLImageConverterFolder converter;
const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims);
tensor_hold_filter_image_->Resize(
{1, filter_image_dims[0], filter_image_dims[1], 4});
filter_image_h_ = filter_image_dims[1];
filter_image_w_ = filter_image_dims[0];
tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4});
half_t* filter_image_data =
tensor_hold_filter_image_->mutable_data<half_t>();
converter.NCHWToImage(filter_cpu, filter_image_data, filter_dims);
filter_gpu_image_->mutable_data<half_t, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_data);
filter_image_w_, filter_image_h_, filter_image_data);
impl_ = &ConvImageCompute::Conv2d3x3;
#endif
#undef CONV3x3OPT_FALL_BACK
} else if (kernel_h == 5 && kernel_w == 5) {
} else if (filter_tensor_h_ == 5 && filter_tensor_w_ == 5) {
#define CONV_5x5_OPT
#ifndef CONV_5x5_OPT
// conv2d_5x5
......@@ -288,55 +228,42 @@ void ConvImageCompute::PrepareForRun() {
CLImageConverterFolder converter;
const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims);
tensor_hold_filter_image_->Resize(
{1, filter_image_dims[0], filter_image_dims[1], 4});
filter_image_h_ = filter_image_dims[1];
filter_image_w_ = filter_image_dims[0];
tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4});
half_t* filter_image_data =
tensor_hold_filter_image_->mutable_data<half_t>();
converter.NCHWToImage(filter_cpu, filter_image_data, filter_dims);
filter_gpu_image_->mutable_data<half_t, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_data);
filter_image_w_, filter_image_h_, filter_image_data);
impl_ = &ConvImageCompute::Conv2d5x5;
#else
// conv2d_5x5_opt
kernel_func_names_.push_back(bs > 1 ? "conv2d_5x5_multi_batch"
kernel_func_names_.push_back(input_tensor_n_ > 1 ? "conv2d_5x5_multi_batch"
: "conv2d_5x5_opt");
kernel_func_paths_.push_back("image/conv2d_5x5_opt_kernel.cl");
CLImageConverterFolder converter;
const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims);
tensor_hold_filter_image_->Resize(
{1, filter_image_dims[0], filter_image_dims[1], 4});
filter_image_h_ = filter_image_dims[1];
filter_image_w_ = filter_image_dims[0];
tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4});
half_t* filter_image_data =
tensor_hold_filter_image_->mutable_data<half_t>();
converter.NCHWToImage(filter_cpu, filter_image_data, filter_dims);
filter_gpu_image_->mutable_data<half_t, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_data);
filter_image_w_, filter_image_h_, filter_image_data);
impl_ = &ConvImageCompute::Conv2d5x5opt;
{
int w_blk_size = 5;
int w_blk = (default_w_blk_ + w_blk_size - 1) / w_blk_size;
int h_blk_size = 1;
int h_blk = (default_nh_blk_ + h_blk_size - 1) / h_blk_size;
c_blk_ = default_c_blk_;
w_blk_ = w_blk;
nh_blk_ = h_blk;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
}
#endif
#undef CONV_5x5_OPT
} else if (kernel_h == 7 && kernel_w == 7) {
} else if (filter_tensor_h_ == 7 && filter_tensor_w_ == 7) {
#define CONV_7x7_OPT
#ifndef CONV_7x7_OPT
// conv2d_7x7
......@@ -345,55 +272,41 @@ void ConvImageCompute::PrepareForRun() {
CLImageConverterFolder converter;
const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims);
tensor_hold_filter_image_->Resize(
{1, filter_image_dims[0], filter_image_dims[1], 4});
filter_image_h_ = filter_image_dims[1];
filter_image_w_ = filter_image_dims[0];
tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4});
half_t* filter_image_data =
tensor_hold_filter_image_->mutable_data<half_t>();
converter.NCHWToImage(filter_cpu, filter_image_data, filter_dims);
filter_gpu_image_->mutable_data<half_t, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_data);
filter_image_w_, filter_image_h_, filter_image_data);
impl_ = &ConvImageCompute::Conv2d7x7;
#else
// conv2d_7x7
kernel_func_names_.push_back(bs > 1 ? "conv2d_7x7_multi_batch"
kernel_func_names_.push_back(input_tensor_n_ > 1 ? "conv2d_7x7_multi_batch"
: "conv2d_7x7_opt");
kernel_func_paths_.push_back("image/conv2d_7x7_opt_kernel.cl");
CLImageConverterFolder converter;
const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims);
tensor_hold_filter_image_->Resize(
{1, filter_image_dims[0], filter_image_dims[1], 4});
filter_image_h_ = filter_image_dims[1];
filter_image_w_ = filter_image_dims[0];
tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4});
half_t* filter_image_data =
tensor_hold_filter_image_->mutable_data<half_t>();
converter.NCHWToImage(filter_cpu, filter_image_data, filter_dims);
filter_gpu_image_->mutable_data<half_t, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_data);
filter_image_w_, filter_image_h_, filter_image_data);
impl_ = &ConvImageCompute::Conv2d7x7opt;
{
int w_blk_size = 5;
int w_blk = (default_w_blk_ + w_blk_size - 1) / w_blk_size;
int h_blk_size = 1;
int h_blk = (default_nh_blk_ + h_blk_size - 1) / h_blk_size;
c_blk_ = default_c_blk_;
w_blk_ = w_blk;
nh_blk_ = h_blk;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
}
#endif
#undef CONV_7x7_OPT
} else {
LOG(FATAL) << "conv image compute not support this condition yet! ";
}
......@@ -403,30 +316,30 @@ void ConvImageCompute::PrepareForRun() {
// build options
std::string build_options_single(" -DCL_DTYPE_half");
// relu options
VLOG(3) << "relu_fused:" << relu_fused
<< " param.activation_param.active_type:"
<< static_cast<int>(param.activation_param.active_type)
<< " param.activation_param.has_active:"
<< param.activation_param.has_active;
if (param.activation_param.has_active) {
if (param.activation_param.active_type ==
lite_api::ActivationType::kRelu) { // Note: judge using `relu_fused`
VLOG(3) << "relu_fused_:" << relu_fused_
<< " conv_param_->activation_param.active_type:"
<< static_cast<int>(conv_param_->activation_param.active_type)
<< " conv_param_->activation_param.has_active:"
<< conv_param_->activation_param.has_active;
if (conv_param_->activation_param.has_active) {
if (conv_param_->activation_param.active_type ==
lite_api::ActivationType::kRelu) { // Note: judge using `relu_fused_`
// also is ok
build_options_single += " -DRELU";
} else if (param.activation_param.active_type ==
} else if (conv_param_->activation_param.active_type ==
lite_api::ActivationType::kRelu6) {
build_options_single += " -DRELU6";
} else {
LOG(FATAL) << "Unsupported activation type:"
<< static_cast<int>(param.activation_param.active_type);
<< static_cast<int>(conv_param_->activation_param.active_type);
}
}
GetGlobalWorkSize();
// bias options
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
if (has_bias) {
has_bias_ && conv_param_->output->dims() == conv_param_->bias->dims();
if (has_bias_) {
bias_gpu_image_ = std::unique_ptr<Tensor>(new Tensor);
build_options_single +=
is_element_wise_bias ? " -DBIASE_ELE" : " -DBIASE_CH";
......@@ -434,21 +347,36 @@ void ConvImageCompute::PrepareForRun() {
// convert cpu buffer bias --> gpu image
CLImageConverterFolder bias_converter;
const DDim& bias_image_dims =
bias_converter.InitImageDimInfoWith(param.bias->dims());
bias_converter.InitImageDimInfoWith(conv_param_->bias->dims());
bias_image_h_ = bias_image_dims[1];
bias_image_w_ = bias_image_dims[0];
tensor_hold_bias_image_->Resize(
{1, bias_image_dims[0], bias_image_dims[1], 4});
half_t* bias_image_data = tensor_hold_bias_image_->mutable_data<half_t>();
float* bias_cpu_data = param.bias->mutable_data<float>();
float* bias_cpu_data = conv_param_->bias->mutable_data<float>();
bias_converter.NCHWToImage(
bias_cpu_data, bias_image_data, param.bias->dims());
bias_cpu_data, bias_image_data, conv_param_->bias->dims());
this->bias_gpu_image_->mutable_data<half_t, cl::Image2D>(
bias_image_dims[0], bias_image_dims[1], bias_image_data);
// convert cpu buffer bias --> gpu image --- end ----
} else {
bias_gpu_image_ = std::unique_ptr<Tensor>(new Tensor);
CLImageConverterFolder bias_converter;
tensor_hold_bias_image_->Resize({1, 1, 1, 4});
half_t* bias_image_data = tensor_hold_bias_image_->mutable_data<half_t>();
this->bias_gpu_image_->mutable_data<half_t, cl::Image2D>(
1, 1, bias_image_data);
}
// define image pointer for filter, bias
input_image_p_ = conv_param_->x->data<half_t, cl::Image2D>();
filter_image_p_ = filter_gpu_image_->data<half_t, cl::Image2D>();
bias_image_p_ = bias_gpu_image_->data<half_t, cl::Image2D>();
output_image_p_ = conv_param_->output->mutable_data<half_t, cl::Image2D>(
output_image_w_, output_image_h_);
build_options_.push_back(build_options_single);
for (size_t i = 0; i < kernel_func_names_.size(); i++) {
......@@ -474,453 +402,378 @@ void ConvImageCompute::PrepareForRun() {
VLOG(4) << "max_work_group_size: " << max_work_group_size;
if (max_work_group_size > 0 && use_lws_) {
double min_turn_time = DBL_MAX;
double min_tune_time = DBL_MAX;
cl::NDRange best_local_work_size = context.cl_context()->LocalWorkSize(
global_work_size_, max_work_group_size);
VLOG(3) << "origin :local_work_size_ : " << best_local_work_size[0] << " "
<< best_local_work_size[1] << " " << best_local_work_size[2];
cl::NDRange last_local_work_size = cl::NDRange{
static_cast<size_t>(0), static_cast<size_t>(0), static_cast<size_t>(0)};
if (use_turn_) {
if (use_tune_) {
for (size_t i = 1; i < 15; i++) {
if (kernel_h == 1 && kernel_w == 1) {
if (filter_tensor_h_ == 1 && filter_tensor_w_ == 1) {
// todo use diff logics
local_work_size_ = context.cl_context()->LocalWorkSizeTurn(
local_work_size_ = context.cl_context()->LocalWorkSizeTune(
global_work_size_, max_work_group_size, i);
} else {
local_work_size_ = context.cl_context()->LocalWorkSizeTurn(
local_work_size_ = context.cl_context()->LocalWorkSizeTune(
global_work_size_, max_work_group_size, i);
}
if (last_local_work_size[0] == local_work_size_[0] &&
last_local_work_size[1] == local_work_size_[1] &&
last_local_work_size[2] == local_work_size_[2]) {
// skiped turned lws
// skiped tuneed lws
continue;
}
auto turn_time = this->Turn(5);
if (min_turn_time > turn_time) {
min_turn_time = turn_time;
auto tune_time = this->Tune(10);
if (min_tune_time > tune_time) {
min_tune_time = tune_time;
best_local_work_size = local_work_size_;
}
last_local_work_size = local_work_size_;
}
// reverse
for (size_t i = 1; i < 15; i++) {
if (filter_tensor_h_ == 1 && filter_tensor_w_ == 1) {
// todo use diff logics
local_work_size_ = context.cl_context()->LocalWorkSizeTuneReverse(
global_work_size_, max_work_group_size, i);
} else {
local_work_size_ = context.cl_context()->LocalWorkSizeTuneReverse(
global_work_size_, max_work_group_size, i);
}
if (last_local_work_size[0] == local_work_size_[0] &&
last_local_work_size[1] == local_work_size_[1] &&
last_local_work_size[2] == local_work_size_[2]) {
// skiped tuneed lws
continue;
}
auto tune_time = this->Tune(10);
if (min_tune_time > tune_time) {
min_tune_time = tune_time;
best_local_work_size = local_work_size_;
}
last_local_work_size = local_work_size_;
}
}
local_work_size_ = best_local_work_size;
VLOG(3) << "chossen :local_work_size_ : " << local_work_size_[0] << " "
<< local_work_size_[1] << " " << local_work_size_[2];
VLOG(4) << "local_work_size_[3D]: {" << local_work_size_[0] << ","
<< local_work_size_[1] << "," << local_work_size_[2] << "}";
}
}
void ConvImageCompute::Conv2d1x1opt(bool is_turn) {
void ConvImageCompute::ReInitWhenNeeded() {
conv_param_ = param_.get_mutable<param_t>();
auto x_dims = conv_param_->x->dims();
#ifdef LITE_WITH_LOG
LOG(INFO) << "is_first_epoch_for_run_:" << is_first_epoch_for_run_
<< ", last_input_dims_:" << last_input_dims_
<< ", x_dims:" << x_dims;
#endif
if (is_first_epoch_for_run_ || last_input_dims_ != x_dims) {
is_first_epoch_for_run_ = false;
last_input_dims_ = x_dims;
input_tensor_n_ = x_dims[0];
input_tensor_c_ = x_dims[1];
input_tensor_h_ = x_dims[2];
input_tensor_w_ = x_dims[3];
auto x_image_shape = InitImageDimInfoWith(x_dims);
input_image_h_ = x_image_shape["height"];
input_image_w_ = x_image_shape["width"];
auto output_dims = conv_param_->output->dims();
output_tensor_n_ = output_dims[0];
output_tensor_c_ = output_dims[1];
output_tensor_h_ = output_dims[2];
output_tensor_w_ = output_dims[3];
auto output_image_shape = InitImageDimInfoWith(output_dims);
output_image_h_ = output_image_shape["height"];
output_image_w_ = output_image_shape["width"];
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto* input_image = param.x->data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_->data<half_t, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
int input_width = input_dims[3];
int input_height = input_dims[2];
int output_width = output_dims[3];
int output_height = output_dims[2];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<half_t, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
int offset = static_cast<int>(param.filter->dims()[2]) / 2 -
static_cast<int>(paddings[0]);
CHECK_GE(conv_param_->x->dims().size(), 4);
CHECK_GE(conv_param_->output->dims().size(), 4);
if (kernel_func_names_.size() > 0 &&
kernel_func_names_[0] == "conv2d_3x3") {
groups_ = conv_param_->groups;
if (filter_tensor_n_ == output_tensor_c_ &&
filter_tensor_c_ == input_tensor_c_) {
groups_ = 1;
} else if (!(filter_tensor_n_ == input_tensor_c_ &&
filter_tensor_c_ == 1)) {
groups_ = input_tensor_c_ / filter_tensor_c_;
}
}
// calc input_c_block
auto input_image_shape = InitImageDimInfoWith(input_dims);
int input_c_block = input_image_shape["width"] / input_dims[3];
int input_c = input_dims[1];
auto dilations = *param.dilations;
// define image pointer for input, output
input_image_p_ = conv_param_->x->data<half_t, cl::Image2D>();
output_image_p_ = conv_param_->output->mutable_data<half_t, cl::Image2D>(
output_image_w_, output_image_h_);
#ifdef LITE_WITH_LOG
// VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << ","
<< global_work_size_[1] << "," << global_work_size_[2] << "}";
#endif
#ifdef LITE_WITH_LOG
VLOG(4) << "============ conv2d_1x1 params ============";
VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
<< input_image_shape["height"];
VLOG(4) << "input_c_block: " << input_c_block;
VLOG(4) << "input_c: " << input_c;
// VLOG(4) << "input_image: " << input_image;
VLOG(4) << "filter_dims: " << filter_dims;
// VLOG(4) << "filter_image: " << filter_image;
VLOG(4) << "output_dims: " << output_dims;
VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", "
<< out_image_shape["height"];
VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1];
VLOG(4) << "has bias: " << has_bias;
VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias;
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "offset: " << offset;
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
// VLOG(4) << "default work size{c_block, w, nh}: "
// << "{" << c_block << ", " << w << ", " << nh << ""
// << "}";
#endif
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
CHECK_GE(input_dims.size(), 4);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
// handle bias use buffer for channel wise , use image for element wise
const cl::Buffer* bias_buf = nullptr;
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
GetGlobalWorkSize();
}
}
void ConvImageCompute::GetGlobalWorkSize() {
if (kernel_func_names_.size() <= 0) return;
// general input_c_block
input_c_block_ = static_cast<int>(input_image_w_ / input_tensor_w_);
// general gws
auto output_dims = conv_param_->output->dims();
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(output_image_w_),
static_cast<int64_t>(output_image_h_)}));
default_c_blk_ = default_work_size[0];
default_w_blk_ = default_work_size[1];
default_nh_blk_ = default_work_size[2];
c_blk_ = default_c_blk_;
w_blk_ = default_w_blk_;
nh_blk_ = default_nh_blk_;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
if (kernel_func_names_[0] == "conv2d_1x1_simple" ||
kernel_func_names_[0] == "conv2d_1x1_opt") {
w_blk_ = maptofactor(default_w_blk_, 4);
c_blk_ = default_c_blk_;
nh_blk_ = default_nh_blk_;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
} else if (kernel_func_names_[0] == "depth_conv2d_3x3s1") {
// depthwise spl gws s1
int c_block = (output_tensor_c_ + 3) / 4;
int w = output_tensor_w_;
int nh = output_tensor_n_ * output_tensor_h_;
int w_blk_size = 2;
int w_blk = (w + w_blk_size - 1) / w_blk_size;
c_blk_ = c_block;
w_blk_ = w_blk;
nh_blk_ = nh;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
} else if (kernel_func_names_[0] == "depth_conv2d_3x3") {
// depthwise spl gws
int c_block = (output_tensor_c_ + 3) / 4;
int w = output_tensor_w_;
int nh = output_tensor_n_ * output_tensor_h_;
c_blk_ = c_block;
w_blk_ = w;
nh_blk_ = nh;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
input_c_block_ = static_cast<const int>((input_tensor_c_ + 3) / 4);
} else if (kernel_func_names_[0] == "conv2d_3x3_multi_batch" ||
kernel_func_names_[0] == "conv2d_3x3_opt") {
int w_blk_size = 5;
int w_blk = (default_w_blk_ + w_blk_size - 1) / w_blk_size;
int h_blk_size = 1;
int h_blk = (default_nh_blk_ + h_blk_size - 1) / h_blk_size;
auto kernel = kernel_;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status);
if (has_bias) {
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
c_blk_ = default_c_blk_;
w_blk_ = w_blk;
nh_blk_ = h_blk;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
} else if (kernel_func_names_[0] == "conv2d_5x5_multi_batch" ||
kernel_func_names_[0] == "conv2d_5x5_opt") {
int w_blk_size = 5;
int w_blk = (default_w_blk_ + w_blk_size - 1) / w_blk_size;
int h_blk_size = 1;
int h_blk = (default_nh_blk_ + h_blk_size - 1) / h_blk_size;
c_blk_ = default_c_blk_;
w_blk_ = w_blk;
nh_blk_ = h_blk;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
} else if (kernel_func_names_[0] == "conv2d_7x7_multi_batch" ||
kernel_func_names_[0] == "conv2d_7x7_opt") {
int w_blk_size = 5;
int w_blk = (default_w_blk_ + w_blk_size - 1) / w_blk_size;
int h_blk_size = 1;
int h_blk = (default_nh_blk_ + h_blk_size - 1) / h_blk_size;
c_blk_ = default_c_blk_;
w_blk_ = w_blk;
nh_blk_ = h_blk;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
}
status = kernel.setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, default_w_blk_);
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
}
void ConvImageCompute::Conv2d1x1opt(bool enable_tune) {
#ifdef LITE_WITH_LOG
PrintConvInfo();
#endif
auto& context = ctx_->As<OpenCLContext>();
status_ = kernel_.setArg(0, c_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(1, w_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(2, nh_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(3, *input_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(4, *filter_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(5, *bias_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(6, *output_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(7, stride_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(8, offset_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(9, input_c_block_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(10, input_tensor_c_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(11, dilation_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(12, input_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(13, input_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(14, output_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(15, output_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(16, default_w_blk_);
CL_CHECK_FATAL(status_);
status_ = EnqueueNDRangeKernel(context,
kernel_,
cl::NullRange,
global_work_size_,
local_work_size_,
nullptr,
nullptr);
CL_CHECK_FATAL(status);
if (is_turn) {
event_);
CL_CHECK_FATAL(status_);
if (enable_tune) {
CLRuntime::Global()->command_queue().finish();
}
}
void ConvImageCompute::Conv2d3x3(bool is_turn) {
auto kernel = kernel_;
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto* input_image = param.x->data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_->data<half_t, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
int input_width = input_dims[3];
int input_height = input_dims[2];
int input_channel = input_dims[1];
int output_width = output_dims[3];
int output_height = output_dims[2];
int output_channel = output_dims[1];
int filter_width = filter_dims[3];
int filter_height = filter_dims[2];
int filter_channel = filter_dims[1];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<half_t, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
int offset = static_cast<int>(param.filter->dims()[2]) / 2 -
static_cast<int>(paddings[0]);
// calc input_c_block
auto input_image_shape = InitImageDimInfoWith(input_dims);
int input_c_block = input_image_shape["width"] / input_dims[3];
int input_c = input_dims[1];
auto dilations = *param.dilations;
// re-calc group
int new_groups{param.groups};
if (filter_dims[0] == output_dims[1] && filter_dims[1] == input_dims[1]) {
new_groups = 1;
} else if (!(filter_dims[0] == input_dims[1] && filter_dims[1] == 1)) {
new_groups = input_channel / filter_channel;
}
/* TODO(ysh329): mobile has no case below
else {
LOG(FATAL) << "Not support conv3x3 case with"
<< " input_dims:" << input_dims << " output_dims:" <<
output_dims
<< " filter_dims:" << filter_dims;
}
*/
// const std::vector<size_t>& default_work_size =
// DefaultWorkSize(output_dims,
// DDim(std::vector<DDim::value_type>{
// static_cast<int64_t>(out_image_shape["width"]),
// static_cast<int64_t>(out_image_shape["height"])}));
// int c_block = default_work_size[0];
// int w = default_work_size[1];
// int nh = default_work_size[2];
// VLOG(4) << "============ conv2d params ============";
// VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
// << input_image_shape["height"];
// VLOG(4) << "input_c_block: " << input_c_block;
// VLOG(4) << "input_c: " << input_c;
// VLOG(4) << "input_image: " << input_image;
// VLOG(4) << "input_dims: " << input_dims;
// VLOG(4) << "filter_dims: " << filter_dims;
// VLOG(4) << "filter_image: " << filter_image;
// VLOG(4) << "output_dims: " << output_dims;
// VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", "
// << out_image_shape["height"];
// VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1];
// VLOG(4) << "has bias: " << has_bias;
// VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias;
// VLOG(4) << "strides: " << strides[0] << "," << strides[1];
// VLOG(4) << "offset: " << offset;
// VLOG(4) << "dilations.size : " << dilations.size();
// VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
// VLOG(4) << "param.groups(groups):" << param.groups;
// VLOG(4) << "new_groups:" << new_groups;
// VLOG(4) << "default work size{c_block, w, nh}: "
// << "{" << c_block << ", " << w << ", " << nh << ""
// << "}";
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
CHECK_GE(input_dims.size(), 4);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
}
void ConvImageCompute::Conv2d3x3(bool enable_tune) {
#ifdef LITE_WITH_LOG
PrintConvInfo();
#endif
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
// STL::stringstream kernel_key;
// kernel_key << kernel_func_names_[0] << build_options_[0];
// auto kernel = context.cl_context()->GetKernel(kernel_key.str());
// VLOG(4) << "kernel_key: " << kernel_key.str();
// VLOG(4) << "kernel ready ... " << kernel_key.str();
// VLOG(4) << "w: " << w;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status);
if (has_bias) {
VLOG(4) << "set bias_image: ";
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_channel);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_channel);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, new_groups);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<int>(input_dims[1]));
CL_CHECK_FATAL(status);
// auto global_work_size =
// cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
// static_cast<size_t>(default_work_size.data()[1]),
// static_cast<size_t>(default_work_size.data()[2])};
// VLOG(4) << "out_image: " << out_image;
// VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
// << global_work_size[1] << "," << global_work_size[2] << "}";
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
status_ = kernel_.setArg(0, c_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(1, w_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(2, nh_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(3, *input_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(4, *filter_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(5, *bias_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(6, *output_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(7, stride_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(8, offset_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(9, input_c_block_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(10, dilation_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(11, input_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(12, input_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(13, output_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(14, output_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(15, output_tensor_c_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(16, filter_tensor_c_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(17, filter_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(18, filter_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(19, groups_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(20, input_tensor_c_);
CL_CHECK_FATAL(status_);
status_ = EnqueueNDRangeKernel(context,
kernel_,
cl::NullRange,
global_work_size_,
cl::NullRange,
nullptr,
nullptr);
CL_CHECK_FATAL(status);
event_);
CL_CHECK_FATAL(status_);
}
void ConvImageCompute::Conv2d3x3opt(bool is_turn) {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto dilations = *param.dilations;
auto* input_image = param.x->data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_->data<half_t, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
int input_width = input_dims[3];
int input_height = input_dims[2];
int input_channel = input_dims[1];
int output_width = output_dims[3];
int output_height = output_dims[2];
int output_channel = output_dims[1];
CHECK_EQ(input_dims[0], output_dims[0]);
int batch = input_dims[0];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<half_t, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
void ConvImageCompute::Conv2d3x3opt(bool enable_tune) {
#ifdef LITE_WITH_LOG
VLOG(4) << "============ conv2d params ============";
// VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
// << input_image_shape["height"];
// VLOG(4) << "input_image: " << input_image;
VLOG(4) << "input_dims: " << input_dims;
VLOG(4) << "filter_dims: " << filter_dims;
// VLOG(4) << "filter_image: " << filter_image;
VLOG(4) << "output_dims: " << output_dims;
VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", "
<< out_image_shape["height"];
VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1];
VLOG(4) << "has bias: " << has_bias;
VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias;
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
PrintConvInfo();
#endif
auto& context = ctx_->As<OpenCLContext>();
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
CHECK_GE(input_dims.size(), 4);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto kernel = kernel_;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status);
if (has_bias) {
#ifdef LITE_WITH_LOG
VLOG(4) << "set bias_image: ";
#endif
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, paddings[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, batch);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_channel);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
status_ = kernel_.setArg(0, c_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(1, w_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(2, nh_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(3, *input_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(4, *filter_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(5, *bias_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(6, *output_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(7, stride_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(8, pad_left_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(9, dilation_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(10, input_tensor_n_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(11, input_tensor_c_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(12, input_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(13, input_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(14, output_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(15, output_tensor_h_);
CL_CHECK_FATAL(status_);
#ifdef LITE_WITH_LOG
// VLOG(4) << "out_image: " << out_image;
......@@ -928,827 +781,406 @@ void ConvImageCompute::Conv2d3x3opt(bool is_turn) {
<< global_work_size_[1] << "," << global_work_size_[2] << "}";
#endif
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
status_ = EnqueueNDRangeKernel(context,
kernel_,
cl::NullRange,
global_work_size_,
local_work_size_,
nullptr,
nullptr);
CL_CHECK_FATAL(status);
if (is_turn) {
event_);
CL_CHECK_FATAL(status_);
if (enable_tune) {
CLRuntime::Global()->command_queue().finish();
}
}
void ConvImageCompute::Conv2d5x5(bool is_turn) {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto* input_image = param.x->data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_->data<half_t, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
int input_width = input_dims[3];
int input_height = input_dims[2];
int output_width = output_dims[3];
int output_height = output_dims[2];
int filter_width = filter_dims[3];
int filter_height = filter_dims[2];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<half_t, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
int offset = static_cast<int>(param.filter->dims()[2]) / 2 -
static_cast<int>(paddings[0]);
// calc input_c_block
auto input_image_shape = InitImageDimInfoWith(input_dims);
int input_c_block = input_image_shape["width"] / input_dims[3];
int input_c = input_dims[1];
auto dilations = *param.dilations;
#ifdef LITE_WITH_LOG
VLOG(4) << "============ conv2d params ============";
VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
<< input_image_shape["height"];
VLOG(4) << "input_c_block: " << input_c_block;
VLOG(4) << "input_c: " << input_c;
// VLOG(4) << "input_image: " << input_image;
VLOG(4) << "input_dims: " << input_dims;
VLOG(4) << "filter_dims: " << filter_dims;
// VLOG(4) << "filter_image: " << filter_image;
VLOG(4) << "output_dims: " << output_dims;
VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", "
<< out_image_shape["height"];
VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1];
VLOG(4) << "has bias: " << has_bias;
VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias;
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "offset: " << offset;
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
#endif
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
CHECK_GE(input_dims.size(), 4);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto kernel = kernel_;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status);
if (has_bias) {
void ConvImageCompute::Conv2d5x5(bool enable_tune) {
#ifdef LITE_WITH_LOG
VLOG(4) << "set bias_image: ";
#endif
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
#ifdef LITE_WITH_LOG
// VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << ","
<< global_work_size_[1] << "," << global_work_size_[2] << "}";
PrintConvInfo();
#endif
auto& context = ctx_->As<OpenCLContext>();
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
status_ = kernel_.setArg(0, c_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(1, w_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(2, nh_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(3, *input_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(4, *filter_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(5, *bias_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(6, *output_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(7, stride_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(8, offset_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(9, input_c_block_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(10, dilation_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(11, input_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(12, input_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(13, output_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(14, output_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = EnqueueNDRangeKernel(context,
kernel_,
cl::NullRange,
global_work_size_,
cl::NullRange,
nullptr,
nullptr);
CL_CHECK_FATAL(status);
if (is_turn) {
event_);
CL_CHECK_FATAL(status_);
if (enable_tune) {
CLRuntime::Global()->command_queue().finish();
}
}
void ConvImageCompute::Conv2d5x5opt(bool is_turn) {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto dilations = *param.dilations;
auto* input_image = param.x->data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_->data<half_t, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
int input_width = input_dims[3];
int input_height = input_dims[2];
int input_channel = input_dims[1];
int output_width = output_dims[3];
int output_height = output_dims[2];
int output_channel = output_dims[1];
CHECK_EQ(input_dims[0], output_dims[0]);
int batch = input_dims[0];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<half_t, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
// default_work_size[2] = h_blk;
void ConvImageCompute::Conv2d5x5opt(bool enable_tune) {
#ifdef LITE_WITH_LOG
VLOG(4) << "============ conv2d params ============";
// VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
// << input_image_shape["height"];
// VLOG(4) << "input_image: " << input_image;
VLOG(4) << "input_dims: " << input_dims;
VLOG(4) << "filter_dims: " << filter_dims;
// VLOG(4) << "filter_image: " << filter_image;
VLOG(4) << "output_dims: " << output_dims;
VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", "
<< out_image_shape["height"];
VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1];
VLOG(4) << "has bias: " << has_bias;
VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias;
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
PrintConvInfo();
#endif
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
CHECK_GE(input_dims.size(), 4);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto kernel = kernel_;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status);
if (has_bias) {
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, paddings[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, batch);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_channel);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
// VLOG(4) << "out_image: " << out_image;
auto& context = ctx_->As<OpenCLContext>();
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
status_ = kernel_.setArg(0, c_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(1, w_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(2, nh_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(3, *input_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(4, *filter_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(5, *bias_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(6, *output_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(7, stride_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(8, pad_left_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(9, dilation_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(10, input_tensor_n_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(11, input_tensor_c_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(12, input_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(13, input_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(14, output_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(15, output_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = EnqueueNDRangeKernel(context,
kernel_,
cl::NullRange,
global_work_size_,
local_work_size_,
nullptr,
nullptr);
CL_CHECK_FATAL(status);
if (is_turn) {
event_);
CL_CHECK_FATAL(status_);
if (enable_tune) {
CLRuntime::Global()->command_queue().finish();
}
}
void ConvImageCompute::Conv2d7x7(bool is_turn) {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto* input_image = param.x->data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_->data<half_t, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
int input_width = input_dims[3];
int input_height = input_dims[2];
int output_width = output_dims[3];
int output_height = output_dims[2];
int filter_width = filter_dims[3];
int filter_height = filter_dims[2];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<half_t, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
int offset = static_cast<int>(param.filter->dims()[2]) / 2 -
static_cast<int>(paddings[0]);
// calc input_c_block
auto input_image_shape = InitImageDimInfoWith(input_dims);
int input_c_block = input_image_shape["width"] / input_dims[3];
int input_c = input_dims[1];
auto dilations = *param.dilations;
void ConvImageCompute::Conv2d7x7(bool enable_tune) {
#ifdef LITE_WITH_LOG
VLOG(4) << "============ conv2d params ============";
VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
<< input_image_shape["height"];
VLOG(4) << "input_c_block: " << input_c_block;
VLOG(4) << "input_c: " << input_c;
// VLOG(4) << "input_image: " << input_image;
VLOG(4) << "input_dims: " << input_dims;
VLOG(4) << "filter_dims: " << filter_dims;
// VLOG(4) << "filter_image: " << filter_image;
VLOG(4) << "output_dims: " << output_dims;
VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", "
<< out_image_shape["height"];
VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1];
VLOG(4) << "has bias: " << has_bias;
VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias;
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "offset: " << offset;
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
#endif
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
CHECK_GE(input_dims.size(), 4);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto kernel = kernel_;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status);
if (has_bias) {
#ifdef LITE_WITH_LOG
VLOG(4) << "set bias_image: ";
#endif
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
#ifdef LITE_WITH_LOG
// VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << ","
<< global_work_size_[1] << "," << global_work_size_[2] << "}";
PrintConvInfo();
#endif
auto& context = ctx_->As<OpenCLContext>();
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
status_ = kernel_.setArg(0, c_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(1, w_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(2, nh_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(3, *input_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(4, *filter_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(5, *bias_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(6, *output_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(7, stride_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(8, offset_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(9, input_c_block_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(9, dilation_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(10, input_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(11, input_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(12, output_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(13, output_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = EnqueueNDRangeKernel(context,
kernel_,
cl::NullRange,
global_work_size_,
cl::NullRange,
nullptr,
nullptr);
CL_CHECK_FATAL(status);
if (is_turn) {
event_);
CL_CHECK_FATAL(status_);
if (enable_tune) {
CLRuntime::Global()->command_queue().finish();
}
}
void ConvImageCompute::Conv2d7x7opt(bool is_turn) {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto dilations = *param.dilations;
auto* input_image = param.x->data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_->data<half_t, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
int input_width = input_dims[3];
int input_height = input_dims[2];
int input_channel = input_dims[1];
int output_width = output_dims[3];
int output_height = output_dims[2];
int output_channel = output_dims[1];
CHECK_EQ(input_dims[0], output_dims[0]);
int batch = input_dims[0];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<half_t, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
void ConvImageCompute::Conv2d7x7opt(bool enable_tune) {
#ifdef LITE_WITH_LOG
VLOG(4) << "============ conv2d 7x7 params ============";
// VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
// << input_image_shape["height"];
// VLOG(4) << "input_image: " << input_image;
VLOG(4) << "input_dims: " << input_dims;
VLOG(4) << "filter_dims: " << filter_dims;
// VLOG(4) << "filter_image: " << filter_image;
VLOG(4) << "output_dims: " << output_dims;
VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", "
<< out_image_shape["height"];
VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1];
VLOG(4) << "has bias: " << has_bias;
VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias;
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
PrintConvInfo();
#endif
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
CHECK_GE(input_dims.size(), 4);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
auto kernel = kernel_;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status);
if (has_bias) {
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, paddings[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, batch);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_channel);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
status_ = kernel_.setArg(0, c_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(1, w_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(2, nh_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(3, *input_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(4, *filter_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(5, *bias_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(6, *output_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(7, stride_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(8, pad_left_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(9, dilation_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(10, input_tensor_n_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(11, input_tensor_c_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(12, input_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(13, input_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(14, output_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(15, output_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = EnqueueNDRangeKernel(context,
kernel_,
cl::NullRange,
global_work_size_,
local_work_size_,
nullptr,
nullptr);
CL_CHECK_FATAL(status);
event_);
CL_CHECK_FATAL(status_);
if (is_turn) {
if (enable_tune) {
CLRuntime::Global()->command_queue().finish();
}
}
void ConvImageCompute::DepthwiseConv2d3x3s1(bool is_turn) {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto x_dims = param.x->dims();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto dilations = *param.dilations;
auto* input_img = param.x->data<half_t, cl::Image2D>();
auto* filter_img = filter_gpu_image_->data<half_t, cl::Image2D>();
const cl::Image2D* bias_img = nullptr;
if (param.bias) {
bias_img = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto image_shape = InitImageDimInfoWith(output_dims);
auto* output_img = param.output->mutable_data<half_t, cl::Image2D>(
image_shape["width"], image_shape["height"]);
auto kernel = kernel_;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_img);
CL_CHECK_FATAL(status);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
void ConvImageCompute::DepthwiseConv2d3x3s1(bool enable_tune) {
#ifdef LITE_WITH_LOG
VLOG(4) << "set bias_image: ";
PrintConvInfo();
#endif
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *output_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(dilations[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[1]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[3]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[2]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[3]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[2]));
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
auto& context = ctx_->As<OpenCLContext>();
status_ = kernel_.setArg(0, c_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(1, w_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(2, nh_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(3, *input_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(4, *filter_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(5, *bias_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(6, *output_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(7, stride_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(8, pad_left_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(9, dilation_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(10, input_tensor_c_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(11, input_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(12, input_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(13, output_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(14, output_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = EnqueueNDRangeKernel(context,
kernel_,
cl::NullRange,
global_work_size_,
local_work_size_,
nullptr,
nullptr);
CL_CHECK_FATAL(status);
event_);
CL_CHECK_FATAL(status_);
if (is_turn) {
if (enable_tune) {
CLRuntime::Global()->command_queue().finish();
}
}
void ConvImageCompute::DepthwiseConv2d3x3(bool is_turn) {
void ConvImageCompute::DepthwiseConv2d3x3(bool enable_tune) {
#ifdef LITE_WITH_LOG
PrintConvInfo();
#endif
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto x_dims = param.x->dims();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto dilations = *param.dilations;
int offset = filter_dims[2] / 2 - paddings[0];
int input_c_block = (x_dims[1] + 3) / 4;
auto* input_img = param.x->data<half_t, cl::Image2D>();
auto* filter_img = filter_gpu_image_->data<half_t, cl::Image2D>();
const cl::Image2D* bias_img = nullptr;
if (param.bias) {
bias_img = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto image_shape = InitImageDimInfoWith(output_dims);
auto* output_img = param.output->mutable_data<half_t, cl::Image2D>(
image_shape["width"], image_shape["height"]);
status_ = kernel_.setArg(0, c_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(1, w_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(2, nh_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(3, *input_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(4, *filter_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(5, *bias_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(6, *output_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(7, stride_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(8, offset_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(9, dilation_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(10, input_c_block_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(11, input_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(12, input_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(13, output_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(14, output_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = EnqueueNDRangeKernel(context,
kernel_,
cl::NullRange,
global_work_size_,
cl::NullRange,
nullptr,
event_);
CL_CHECK_FATAL(status_);
auto kernel = kernel_;
if (enable_tune) {
CLRuntime::Global()->command_queue().finish();
}
}
void ConvImageCompute::DepthwiseConv2d(bool enable_tune) {
#ifdef LITE_WITH_LOG
VLOG(4) << "setArg";
VLOG(4) << "strides = " << strides[0];
VLOG(4) << "offset = " << offset;
VLOG(4) << "dilations = " << dilations[0];
VLOG(4) << "input_c_block = " << input_c_block;
VLOG(4) << "x_dims[3] = " << x_dims[3];
VLOG(4) << "x_dims[2] = " << x_dims[2];
VLOG(4) << "output_dims[3] = " << output_dims[3];
VLOG(4) << "output_dims[2] = " << output_dims[2];
PrintConvInfo();
#endif
auto& context = ctx_->As<OpenCLContext>();
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_img);
CL_CHECK_FATAL(status);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
#ifdef LITE_WITH_LOG
VLOG(4) << "set bias_image: ";
#endif
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *output_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(offset));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(dilations[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(input_c_block));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[3]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[2]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[3]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[2]));
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
status_ = kernel_.setArg(0, c_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(1, w_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(2, nh_blk_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(3, *input_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(4, *filter_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(5, *bias_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(6, *output_image_p_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(7, stride_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(8, offset_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(9, input_c_block_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(10, dilation_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(11, input_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(12, input_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(13, output_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(14, output_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(15, filter_tensor_w_);
CL_CHECK_FATAL(status_);
status_ = kernel_.setArg(16, filter_tensor_h_);
CL_CHECK_FATAL(status_);
status_ = EnqueueNDRangeKernel(context,
kernel_,
cl::NullRange,
global_work_size_,
cl::NullRange,
nullptr,
nullptr);
CL_CHECK_FATAL(status);
event_);
CL_CHECK_FATAL(status_);
if (is_turn) {
if (enable_tune) {
CLRuntime::Global()->command_queue().finish();
}
}
void ConvImageCompute::DepthwiseConv2d(bool is_turn) {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto* input_image = param.x->data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_->data<half_t, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
int input_width = input_dims[3];
int input_height = input_dims[2];
int output_width = output_dims[3];
int output_height = output_dims[2];
int filter_width = filter_dims[3];
int filter_height = filter_dims[2];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<half_t, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
int offset = static_cast<int>(param.filter->dims()[2]) / 2 -
static_cast<int>(paddings[0]);
void ConvImageCompute::Run() { (this->*impl_)(false); }
// calc input_c_block
auto input_image_shape = InitImageDimInfoWith(input_dims);
int input_c_block = input_image_shape["width"] / input_dims[3];
int input_c = input_dims[1];
auto dilations = *param.dilations;
void ConvImageCompute::PrintConvInfo() {
const bool is_element_wise_bias =
has_bias_ && conv_param_->output->dims() == conv_param_->bias->dims();
#ifdef LITE_WITH_LOG
VLOG(4) << "============ depthwise conv2d params ============";
VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
<< input_image_shape["height"];
VLOG(4) << "input_c_block: " << input_c_block;
VLOG(4) << "input_c: " << input_c;
// VLOG(4) << "input_image: " << input_image;
VLOG(4) << "filter_dims: " << filter_dims;
VLOG(4) << "input_image_shape: " << input_image_w_ << "," << input_image_h_;
// VLOG(4) << "input_image: " << input_image_p_;
VLOG(4) << "input_dims: " << conv_param_->x->dims();
VLOG(4) << "filter_dims: " << conv_param_->filter->dims();
// VLOG(4) << "filter_image: " << filter_image;
VLOG(4) << "output_dims: " << output_dims;
VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", "
<< out_image_shape["height"];
VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1];
VLOG(4) << "has bias: " << has_bias;
VLOG(4) << "output_dims: " << conv_param_->output->dims();
VLOG(4) << "out_image_shape: " << output_image_w_ << ", " << output_image_h_;
VLOG(4) << "paddings: " << pad_left_ << "," << pad_up_;
VLOG(4) << "has bias: " << has_bias_;
VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias;
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "offset: " << offset;
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
#endif
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
CHECK_GE(input_dims.size(), 4);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
// handle bias use buffer for channel wise , use image for element wise
const cl::Buffer* bias_buf = nullptr;
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto kernel = kernel_;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status);
if (has_bias) {
#ifdef LITE_WITH_LOG
VLOG(4) << "set bias_image: ";
#endif
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_height);
CL_CHECK_FATAL(status);
#ifdef LITE_WITH_LOG
VLOG(4) << "strides: " << stride_h_ << "," << stride_w_;
VLOG(4) << "offset: ";
VLOG(4) << "dilations.size : " << conv_param_->dilations->size();
VLOG(4) << "dilations: " << dilation_h_ << ", " << dilation_w_;
VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << ","
<< global_work_size_[1] << "," << global_work_size_[2] << "}";
#endif
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size_,
cl::NullRange,
nullptr,
nullptr);
CL_CHECK_FATAL(status);
}
void ConvImageCompute::Run() { (this->*impl_)(false); }
double ConvImageCompute::Turn(int times) {
double ConvImageCompute::Tune(int times) {
auto GetCurrentUS = []() -> double {
struct timeval time;
gettimeofday(&time, NULL);
......@@ -1802,3 +1234,4 @@ REGISTER_LITE_KERNEL(depthwise_conv2d,
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.Finalize();
#define LITE_WITH_LOG
......@@ -24,11 +24,16 @@
#include "lite/core/tensor.h"
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {
class ConvImageCompute : public KernelLite<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault)> {
......@@ -38,20 +43,37 @@ class ConvImageCompute : public KernelLite<TARGET(kOpenCL),
void PrepareForRun() override;
void ReInitWhenNeeded() override;
void Run() override;
double Turn(int times = 5);
double Tune(int times = 5);
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_names_[0];
ch->global_work_size = ch->NDRangeToStr(global_work_size_);
ch->local_work_size = ch->NDRangeToStr(local_work_size_);
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
private:
void Conv2d1x1opt(bool is_turn = false);
void Conv2d3x3(bool is_turn = false);
void Conv2d3x3opt(bool is_turn = false);
void Conv2d5x5(bool is_turn = false);
void Conv2d5x5opt(bool is_turn = false);
void Conv2d7x7(bool is_turn = false);
void Conv2d7x7opt(bool is_turn = false);
void DepthwiseConv2d3x3s1(bool is_turn = false);
void DepthwiseConv2d3x3(bool is_turn = false);
void DepthwiseConv2d(bool is_turn = false);
void PrintConvInfo();
void GetGlobalWorkSize();
void Conv2d1x1opt(bool enable_tune = false);
void Conv2d3x3(bool enable_tune = false);
void Conv2d3x3opt(bool enable_tune = false);
void Conv2d5x5(bool enable_tune = false);
void Conv2d5x5opt(bool enable_tune = false);
void Conv2d7x7(bool enable_tune = false);
void Conv2d7x7opt(bool enable_tune = false);
void DepthwiseConv2d3x3s1(bool enable_tune = false);
void DepthwiseConv2d3x3(bool enable_tune = false);
void DepthwiseConv2d(bool enable_tune = false);
param_t* conv_param_{nullptr};
kernel_t impl_;
std::vector<std::string> kernel_func_names_{};
......@@ -65,19 +87,72 @@ class ConvImageCompute : public KernelLite<TARGET(kOpenCL),
std::unique_ptr<Tensor> tensor_hold_bias_image_{nullptr};
cl::NDRange global_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
// opencl kernel args
int c_blk_ = 1;
int w_blk_ = 1;
int nh_blk_ = 1;
const cl::Image2D* input_image_p_{nullptr};
const cl::Image2D* filter_image_p_{nullptr};
const cl::Image2D* bias_image_p_{nullptr};
const cl::Image2D* output_image_p_{nullptr};
int stride_h_{-1};
int stride_w_{-1};
int dilation_h_{-1};
int dilation_w_{-1};
int pad_up_{-1};
int pad_down_{-1};
int pad_left_{-1};
int pad_right_{-1};
int offset_{-1};
int groups_{-1};
bool relu_fused_{false};
bool has_bias_{false};
int input_tensor_n_{-1};
int input_tensor_c_{-1};
int input_tensor_h_{-1};
int input_tensor_w_{-1};
int input_image_h_{-1};
int input_image_w_{-1};
int input_c_block_{-1};
int output_tensor_n_{-1};
int output_tensor_c_{-1};
int output_tensor_h_{-1};
int output_tensor_w_{-1};
int output_image_h_{-1};
int output_image_w_{-1};
int filter_tensor_n_{-1};
int filter_tensor_c_{-1};
int filter_tensor_h_{-1};
int filter_tensor_w_{-1};
int filter_image_h_{-1};
int filter_image_w_{-1};
int bias_image_h_{-1};
int bias_image_w_{-1};
int default_c_blk_ = 1;
int default_w_blk_ = 1;
int default_nh_blk_ = 1;
// =================
DDim last_input_dims_{};
bool is_first_epoch_for_run_{true};
cl::Kernel kernel_;
cl_int status_;
cl::NDRange local_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
bool use_lws_{true};
bool use_turn_{false};
bool use_tune_{false};
};
} // namespace opencl
......
......@@ -20,6 +20,10 @@
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -119,6 +123,14 @@ class DepthwiseConv2dCompute
CL_CHECK_FATAL(status);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
private:
std::string kernel_func_name_{"depthwise_conv2d"};
std::string build_options_{"-DCL_DTYPE_float"};
......
......@@ -21,6 +21,10 @@
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#include "lite/utils/string.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -89,16 +93,24 @@ class DropoutComputeImage2D : public KernelLite<TARGET(kOpenCL),
static_cast<cl::size_type>(default_work_size.data()[1]),
static_cast<cl::size_type>(default_work_size.data()[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
private:
std::string kernel_func_name_{"dropout"};
std::string build_options_{"-DCL_DTYPE_half"};
......
......@@ -19,6 +19,10 @@
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/cp_logging.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -38,6 +42,14 @@ class ElementwiseAddCompute
return "ElementwiseAdd using cl::Buffer, kFloat";
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
protected:
void UpdateParams();
......
......@@ -18,6 +18,8 @@
#include "lite/core/op_registry.h"
#include "lite/utils/replace_stl/stream.h"
#undef LITE_WITH_LOG
namespace paddle {
namespace lite {
namespace kernels {
......@@ -154,13 +156,13 @@ void ElementwiseAddImageCompute::Run() {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size_,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
......@@ -196,3 +198,5 @@ REGISTER_LITE_KERNEL(elementwise_add,
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.Finalize();
#define LITE_WITH_LOG
......@@ -21,6 +21,10 @@
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/cp_logging.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -42,6 +46,14 @@ class ElementwiseAddImageCompute
void Run() override;
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
std::string doc() const override {
return "ElementwiseAdd using cl::Image2D, kFP16";
}
......
......@@ -153,13 +153,13 @@ void ElementwiseMulFloatImageCompute::Run() {
auto global_work_size = cl::NDRange{static_cast<cl::size_type>(x_img_width),
static_cast<cl::size_type>(x_img_height)};
auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel
auto status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
std::string time_stamp_{GetTimeStamp()};
......
......@@ -23,6 +23,10 @@
#include "lite/operators/op_params.h"
#include "lite/utils/logging.h"
#include "lite/utils/replace_stl/stream.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -186,13 +190,13 @@ class ElementwiseMulImageCompute
cl::NDRange{static_cast<cl::size_type>(x_img_width),
static_cast<cl::size_type>(x_img_height)};
auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
auto status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
#ifdef LITE_WITH_LOG
VLOG(4) << "global_work_size:[2D]:" << x_img_width << " " << x_img_height;
......
......@@ -138,8 +138,13 @@ void ElementwiseSubImageCompute::Run() {
VLOG(4) << "global_work_size:[2D]:" << x_img_width << " " << x_img_height;
#endif
auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, cl::NullRange, global_work_size, cl::NullRange, nullptr, nullptr);
auto status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_);
CL_CHECK_FATAL(status);
}
......
......@@ -20,6 +20,10 @@
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/cp_logging.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -41,6 +45,14 @@ class ElementwiseSubImageCompute
return "ElementwiseSub using cl::Image2D, kFP16";
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
protected:
param_t* ele_param_{nullptr};
std::string kernel_func_name_{"elementwise_sub"};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "lite/backends/opencl/cl_half.h"
#include "lite/backends/opencl/cl_include.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#include "lite/utils/string.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {
class ExpandComputeImage2D : public KernelLite<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault)> {
public:
using param_t = operators::ExpandParam;
std::string doc() const override { return "expand using cl::Image2D, kFP16"; }
void PrepareForRun() override {
expand_param_ = param_.get_mutable<param_t>();
auto expand_times = expand_param_->expand_times;
auto in_dims = expand_param_->X->dims();
CHECK(in_dims.size() == 4) << "expand image now only support indims size 4";
CHECK(expand_times.size() == 4)
<< "expand image now only support in_expand_timesdims size 4";
CHECK(expand_times[1] == 1) << "expand image do not support expend c now";
// do not confuse with these cases.it is use to support expend c in future
if (in_dims[1] == 1) {
kernel_func_name_ = "expend_c1";
} else if (in_dims[1] == 2) {
kernel_func_name_ = "expend_c2";
} else if (in_dims[1] == 3) {
kernel_func_name_ = "expend_c3";
} else if (in_dims[1] == 4) {
kernel_func_name_ = "expend_c4";
} else {
kernel_func_name_ = "expend_cn";
}
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_,
"image/expand_kernel.cl",
build_options_,
time_stamp_);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
kernel_ = context.cl_context()->GetKernel(kernel_key.str());
}
void ReInitWhenNeeded() override {
VLOG(1) << "ReInitWhenNeeded: " << kernel_func_name_;
auto x_dims = expand_param_->X->dims();
auto out_dims = expand_param_->Out->dims();
auto expand_times = expand_param_->expand_times;
VLOG(1) << "x_dims: " << x_dims;
VLOG(1) << "out_dims: " << out_dims;
VLOG(1) << "expand_times: " << expand_times[0] << " " << expand_times[1]
<< " " << expand_times[2] << " " << expand_times[3];
if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) ||
first_epoch_for_reinit_) {
last_x_dims_ = x_dims;
first_epoch_for_reinit_ = false;
// compute image shape
paddle::lite::CLImageConverterDefault default_convertor;
out_img_shape_ = default_convertor.InitImageDimInfoWith(out_dims);
VLOG(1) << "out_img_shape_: " << out_img_shape_[0] << " "
<< out_img_shape_[1];
// compute global work size
auto image_width = out_dims[3] * ((out_dims[1] + 3) / 4);
size_t work_size_0 = image_width / out_dims[3];
size_t work_size_1 = out_dims[3];
size_t work_size_2 = out_dims[0] * out_dims[2];
global_work_size_ = cl::NDRange{work_size_0, work_size_1, work_size_2};
VLOG(1) << "global_work_size_: " << global_work_size_[0] << " "
<< global_work_size_[1] << " " << global_work_size_[2];
}
}
void Run() override {
auto* x_img = expand_param_->X->data<half_t, cl::Image2D>();
auto* out_img = expand_param_->Out->mutable_data<half_t, cl::Image2D>(
out_img_shape_[0], out_img_shape_[1]);
auto expand_times = expand_param_->expand_times;
auto x_dims = expand_param_->X->dims();
int in_n = x_dims[0];
int in_c = x_dims[1];
int in_h = x_dims[2];
int in_w = x_dims[3];
auto out_dims = expand_param_->Out->dims();
int out_n = out_dims[0];
int out_c = out_dims[1];
int out_h = out_dims[2];
int out_w = out_dims[3];
auto out_image_width = out_dims[3] * ((out_dims[1] + 3) / 4);
int out_c_block = out_image_width / out_dims[3];
int out_nh = out_dims[0] * out_dims[2];
auto in_image_width = x_dims[3] * ((x_dims[1] + 3) / 4);
int in_c_block = in_image_width / x_dims[3];
int in_nh = x_dims[0] * x_dims[2];
int expand_times_n = expand_times[0];
int expand_times_c = expand_times[1];
int expand_times_h = expand_times[2];
int expand_times_w = expand_times[3];
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
auto kernel = kernel_;
cl_int status;
status = kernel.setArg(0, out_c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(1, out_w);
CL_CHECK_FATAL(status);
status = kernel.setArg(2, out_nh);
CL_CHECK_FATAL(status);
status = kernel.setArg(3, in_c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(4, in_w);
CL_CHECK_FATAL(status);
status = kernel.setArg(5, in_nh);
CL_CHECK_FATAL(status);
status = kernel.setArg(6, in_w);
CL_CHECK_FATAL(status);
status = kernel.setArg(7, in_h);
CL_CHECK_FATAL(status);
status = kernel.setArg(8, out_w);
CL_CHECK_FATAL(status);
status = kernel.setArg(9, out_h);
CL_CHECK_FATAL(status);
status = kernel.setArg(10, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(11, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(12, expand_times_n);
CL_CHECK_FATAL(status);
status = kernel.setArg(13, expand_times_c);
CL_CHECK_FATAL(status);
status = kernel.setArg(14, expand_times_h);
CL_CHECK_FATAL(status);
status = kernel.setArg(15, expand_times_w);
CL_CHECK_FATAL(status);
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size_,
cl::NullRange,
nullptr,
event_);
CL_CHECK_FATAL(status);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
private:
std::string kernel_func_name_{};
std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()};
param_t* expand_param_{nullptr};
cl::Kernel kernel_;
bool first_epoch_for_reinit_{true};
DDim last_x_dims_;
DDim out_img_shape_ = DDim(std::vector<DDim::value_type>(
{static_cast<DDim::value_type>(1), static_cast<DDim::value_type>(1)}));
cl::NDRange global_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
};
} // namespace opencl
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(expand,
kOpenCL,
kFP16,
kImageDefault,
paddle::lite::kernels::opencl::ExpandComputeImage2D,
image2d)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <random>
#include <gtest/gtest.h>
#include "lite/backends/opencl/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/kernels/opencl/test_helper.h"
#define FP16_MAX_DIFF (5e-1)
namespace paddle {
namespace lite {
TEST(expand_hw_image2d, compute) {
LOG(INFO) << "create kernel ...";
auto kernels = KernelRegistry::Global().Create(
"expand", TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kImageDefault));
ASSERT_FALSE(kernels.empty());
const int INPUT_N = 1;
const int INPUT_C = 1;
const int INPUT_H = 2;
const int INPUT_W = 3;
const int EXPAND_N = 1;
const int EXPAND_C = 1;
const int EXPAND_H = 2;
const int EXPAND_W = 3;
auto kernel = std::move(kernels.front());
LOG(INFO) << "prepare to test kernel ====> " << kernel->doc();
lite::Tensor x, out;
operators::ExpandParam param;
param.X = &x;
param.Out = &out;
param.expand_times = {EXPAND_N, EXPAND_C, EXPAND_H, EXPAND_W};
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
kernel->SetParam(param);
std::unique_ptr<KernelContext> pixel_shuffle_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(pixel_shuffle_context->As<OpenCLContext>()));
kernel->SetContext(std::move(pixel_shuffle_context));
const DDim in_dim =
DDim(std::vector<DDim::value_type>{INPUT_N, INPUT_C, INPUT_H, INPUT_W});
const DDim out_dim = DDim(std::vector<DDim::value_type>{INPUT_N * EXPAND_N,
INPUT_C * EXPAND_C,
INPUT_H * EXPAND_H,
INPUT_W * EXPAND_W});
LOG(INFO) << "in_dim: " << in_dim;
LOG(INFO) << "expand_times: " << EXPAND_N << EXPAND_C << EXPAND_H << EXPAND_W;
LOG(INFO) << "out_dim: " << out_dim;
x.Resize(in_dim);
out.Resize(out_dim);
std::default_random_engine engine;
std::uniform_real_distribution<float> dist(-2, 2);
std::vector<float> input_v(INPUT_N * INPUT_C * INPUT_H * INPUT_W);
int index = 0;
for (auto& i : input_v) {
i = index++;
}
VLOG(1) << "input_v ..... ";
for (size_t i = 0; i < input_v.size(); i++) {
VLOG(10) << input_v[i];
}
LOG(INFO) << "prepare input";
CLImageConverterDefault* default_converter = new CLImageConverterDefault();
DDim x_image_shape = default_converter->InitImageDimInfoWith(in_dim);
LOG(INFO) << "x_image_shape = " << x_image_shape[0] << " "
<< x_image_shape[1];
std::vector<half_t> x_image_data(x_image_shape.production() * 4); // 4 : RGBA
default_converter->NCHWToImage(input_v.data(), x_image_data.data(), in_dim);
auto* x_image = x.mutable_data<half_t, cl::Image2D>(
x_image_shape[0], x_image_shape[1], x_image_data.data());
VLOG(1) << "x_image_data ..... ";
for (size_t i = 0; i < x_image_data.size(); i++) {
VLOG(10) << Half2Float(x_image_data[i]);
}
DDim out_image_shape = default_converter->InitImageDimInfoWith(out_dim);
LOG(INFO) << "out_image_shape = " << out_image_shape[0] << " "
<< out_image_shape[1];
auto* out_image = out.mutable_data<half_t, cl::Image2D>(out_image_shape[0],
out_image_shape[1]);
kernel->Launch();
CLRuntime::Global()->command_queue().finish();
std::vector<float> out_data_v{0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0,
1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4,
5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5};
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
half_t* out_image_data = new half_t[out_image_shape.production() * 4];
TargetWrapperCL::ImgcpySync(out_image_data,
out_image,
out_image_shape[0],
out_image_shape[1],
cl_image2d_row_pitch,
cl_image2d_slice_pitch,
IoDirection::DtoH);
VLOG(1) << "out_image_data ..... ";
for (size_t i = 0; i < out_image_shape.production() * 4; i++) {
VLOG(10) << Half2Float(out_image_data[i]);
}
float* out_data = new float[out_image_shape.production() * 4];
default_converter->ImageToNCHW(
out_image_data, out_data, out_image_shape, out_dim);
VLOG(1) << "out_data ..... ";
for (int i = 0; i < out_dim.production(); i++) {
VLOG(10) << out_data[i];
}
for (int i = 0; i < out_dim.production(); i++) {
auto abs_diff = abs(out_data[i] - out_data_v[i]);
auto relative_diff = COMPUTE_RELATIVE_DIFF(out_data[i], out_data_v[i]);
EXPECT_EQ((relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF),
true);
if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) {
LOG(ERROR) << "error idx:" << i << " out_data[" << i
<< "]:" << out_data[i] << " "
"out_ref["
<< i << "]:" << out_data_v[i] << " abs_diff:" << abs_diff
<< " relative_diff:" << relative_diff
<< " FP16_MAX_DIFF:" << FP16_MAX_DIFF;
}
}
}
TEST(expand_c2hw_image2d, compute) {
LOG(INFO) << "create kernel ...";
auto kernels = KernelRegistry::Global().Create(
"expand", TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kImageDefault));
ASSERT_FALSE(kernels.empty());
const int INPUT_N = 1;
const int INPUT_C = 2;
const int INPUT_H = 2;
const int INPUT_W = 3;
const int EXPAND_N = 1;
const int EXPAND_C = 1;
const int EXPAND_H = 2;
const int EXPAND_W = 1;
auto kernel = std::move(kernels.front());
LOG(INFO) << "prepare to test kernel ====> " << kernel->doc();
lite::Tensor x, out;
operators::ExpandParam param;
param.X = &x;
param.Out = &out;
param.expand_times = {EXPAND_N, EXPAND_C, EXPAND_H, EXPAND_W};
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
kernel->SetParam(param);
std::unique_ptr<KernelContext> pixel_shuffle_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(pixel_shuffle_context->As<OpenCLContext>()));
kernel->SetContext(std::move(pixel_shuffle_context));
const DDim in_dim =
DDim(std::vector<DDim::value_type>{INPUT_N, INPUT_C, INPUT_H, INPUT_W});
const DDim out_dim = DDim(std::vector<DDim::value_type>{INPUT_N * EXPAND_N,
INPUT_C * EXPAND_C,
INPUT_H * EXPAND_H,
INPUT_W * EXPAND_W});
LOG(INFO) << "in_dim: " << in_dim;
LOG(INFO) << "expand_times: " << EXPAND_N << EXPAND_C << EXPAND_H << EXPAND_W;
LOG(INFO) << "out_dim: " << out_dim;
x.Resize(in_dim);
out.Resize(out_dim);
std::default_random_engine engine;
std::uniform_real_distribution<float> dist(-2, 2);
std::vector<float> input_v(INPUT_N * INPUT_C * INPUT_H * INPUT_W);
int index = 0;
for (auto& i : input_v) {
i = index++;
}
VLOG(1) << "input_v ..... ";
for (size_t i = 0; i < input_v.size(); i++) {
VLOG(10) << input_v[i];
}
LOG(INFO) << "prepare input";
CLImageConverterDefault* default_converter = new CLImageConverterDefault();
DDim x_image_shape = default_converter->InitImageDimInfoWith(in_dim);
LOG(INFO) << "x_image_shape = " << x_image_shape[0] << " "
<< x_image_shape[1];
std::vector<half_t> x_image_data(x_image_shape.production() * 4); // 4 : RGBA
default_converter->NCHWToImage(input_v.data(), x_image_data.data(), in_dim);
auto* x_image = x.mutable_data<half_t, cl::Image2D>(
x_image_shape[0], x_image_shape[1], x_image_data.data());
VLOG(1) << "x_image_data ..... ";
for (size_t i = 0; i < x_image_data.size(); i++) {
VLOG(10) << Half2Float(x_image_data[i]);
}
DDim out_image_shape = default_converter->InitImageDimInfoWith(out_dim);
LOG(INFO) << "out_image_shape = " << out_image_shape[0] << " "
<< out_image_shape[1];
auto* out_image = out.mutable_data<half_t, cl::Image2D>(out_image_shape[0],
out_image_shape[1]);
kernel->Launch();
CLRuntime::Global()->command_queue().finish();
std::vector<float> out_data_v{0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5,
6, 7, 8, 6, 7, 8, 9, 10, 11, 9, 10, 11};
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
half_t* out_image_data = new half_t[out_image_shape.production() * 4];
TargetWrapperCL::ImgcpySync(out_image_data,
out_image,
out_image_shape[0],
out_image_shape[1],
cl_image2d_row_pitch,
cl_image2d_slice_pitch,
IoDirection::DtoH);
VLOG(1) << "out_image_data ..... ";
for (size_t i = 0; i < out_image_shape.production() * 4; i++) {
VLOG(10) << Half2Float(out_image_data[i]);
}
float* out_data = new float[out_image_shape.production() * 4];
default_converter->ImageToNCHW(
out_image_data, out_data, out_image_shape, out_dim);
VLOG(1) << "out_data ..... ";
for (int i = 0; i < out_dim.production(); i++) {
VLOG(10) << out_data[i];
}
for (int i = 0; i < out_dim.production(); i++) {
auto abs_diff = abs(out_data[i] - out_data_v[i]);
auto relative_diff = COMPUTE_RELATIVE_DIFF(out_data[i], out_data_v[i]);
EXPECT_EQ((relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF),
true);
if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) {
LOG(ERROR) << "error idx:" << i << " out_data[" << i
<< "]:" << out_data[i] << " "
"out_ref["
<< i << "]:" << out_data_v[i] << " abs_diff:" << abs_diff
<< " relative_diff:" << relative_diff
<< " FP16_MAX_DIFF:" << FP16_MAX_DIFF;
}
}
}
TEST(expand_c3hw_image2d, compute) {
LOG(INFO) << "create kernel ...";
auto kernels = KernelRegistry::Global().Create(
"expand", TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kImageDefault));
ASSERT_FALSE(kernels.empty());
const int INPUT_N = 1;
const int INPUT_C = 3;
const int INPUT_H = 2;
const int INPUT_W = 3;
const int EXPAND_N = 1;
const int EXPAND_C = 1;
const int EXPAND_H = 2;
const int EXPAND_W = 1;
auto kernel = std::move(kernels.front());
LOG(INFO) << "prepare to test kernel ====> " << kernel->doc();
lite::Tensor x, out;
operators::ExpandParam param;
param.X = &x;
param.Out = &out;
param.expand_times = {EXPAND_N, EXPAND_C, EXPAND_H, EXPAND_W};
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
kernel->SetParam(param);
std::unique_ptr<KernelContext> pixel_shuffle_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(pixel_shuffle_context->As<OpenCLContext>()));
kernel->SetContext(std::move(pixel_shuffle_context));
const DDim in_dim =
DDim(std::vector<DDim::value_type>{INPUT_N, INPUT_C, INPUT_H, INPUT_W});
const DDim out_dim = DDim(std::vector<DDim::value_type>{INPUT_N * EXPAND_N,
INPUT_C * EXPAND_C,
INPUT_H * EXPAND_H,
INPUT_W * EXPAND_W});
LOG(INFO) << "in_dim: " << in_dim;
LOG(INFO) << "expand_times: " << EXPAND_N << EXPAND_C << EXPAND_H << EXPAND_W;
LOG(INFO) << "out_dim: " << out_dim;
x.Resize(in_dim);
out.Resize(out_dim);
std::default_random_engine engine;
std::uniform_real_distribution<float> dist(-2, 2);
std::vector<float> input_v(INPUT_N * INPUT_C * INPUT_H * INPUT_W);
int index = 0;
for (auto& i : input_v) {
i = index++;
}
VLOG(1) << "input_v ..... ";
for (size_t i = 0; i < input_v.size(); i++) {
VLOG(10) << input_v[i];
}
LOG(INFO) << "prepare input";
CLImageConverterDefault* default_converter = new CLImageConverterDefault();
DDim x_image_shape = default_converter->InitImageDimInfoWith(in_dim);
LOG(INFO) << "x_image_shape = " << x_image_shape[0] << " "
<< x_image_shape[1];
std::vector<half_t> x_image_data(x_image_shape.production() * 4); // 4 : RGBA
default_converter->NCHWToImage(input_v.data(), x_image_data.data(), in_dim);
auto* x_image = x.mutable_data<half_t, cl::Image2D>(
x_image_shape[0], x_image_shape[1], x_image_data.data());
VLOG(1) << "x_image_data ..... ";
for (size_t i = 0; i < x_image_data.size(); i++) {
VLOG(10) << Half2Float(x_image_data[i]);
}
DDim out_image_shape = default_converter->InitImageDimInfoWith(out_dim);
LOG(INFO) << "out_image_shape = " << out_image_shape[0] << " "
<< out_image_shape[1];
auto* out_image = out.mutable_data<half_t, cl::Image2D>(out_image_shape[0],
out_image_shape[1]);
kernel->Launch();
CLRuntime::Global()->command_queue().finish();
std::vector<float> out_data_v{0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5,
6, 7, 8, 6, 7, 8, 9, 10, 11, 9, 10, 11,
12, 13, 14, 12, 13, 14, 15, 16, 17, 15, 16, 17};
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
half_t* out_image_data = new half_t[out_image_shape.production() * 4];
TargetWrapperCL::ImgcpySync(out_image_data,
out_image,
out_image_shape[0],
out_image_shape[1],
cl_image2d_row_pitch,
cl_image2d_slice_pitch,
IoDirection::DtoH);
VLOG(1) << "out_image_data ..... ";
for (size_t i = 0; i < out_image_shape.production() * 4; i++) {
VLOG(10) << Half2Float(out_image_data[i]);
}
float* out_data = new float[out_image_shape.production() * 4];
default_converter->ImageToNCHW(
out_image_data, out_data, out_image_shape, out_dim);
VLOG(1) << "out_data ..... ";
for (int i = 0; i < out_dim.production(); i++) {
VLOG(10) << out_data[i];
}
for (int i = 0; i < out_dim.production(); i++) {
auto abs_diff = abs(out_data[i] - out_data_v[i]);
auto relative_diff = COMPUTE_RELATIVE_DIFF(out_data[i], out_data_v[i]);
EXPECT_EQ((relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF),
true);
if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) {
LOG(ERROR) << "error idx:" << i << " out_data[" << i
<< "]:" << out_data[i] << " "
"out_ref["
<< i << "]:" << out_data_v[i] << " abs_diff:" << abs_diff
<< " relative_diff:" << relative_diff
<< " FP16_MAX_DIFF:" << FP16_MAX_DIFF;
}
}
}
TEST(expand_c4hw_image2d, compute) {
LOG(INFO) << "create kernel ...";
auto kernels = KernelRegistry::Global().Create(
"expand", TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kImageDefault));
ASSERT_FALSE(kernels.empty());
const int INPUT_N = 1;
const int INPUT_C = 4;
const int INPUT_H = 2;
const int INPUT_W = 1;
const int EXPAND_N = 1;
const int EXPAND_C = 1;
const int EXPAND_H = 2;
const int EXPAND_W = 1;
auto kernel = std::move(kernels.front());
LOG(INFO) << "prepare to test kernel ====> " << kernel->doc();
lite::Tensor x, out;
operators::ExpandParam param;
param.X = &x;
param.Out = &out;
param.expand_times = {EXPAND_N, EXPAND_C, EXPAND_H, EXPAND_W};
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
kernel->SetParam(param);
std::unique_ptr<KernelContext> pixel_shuffle_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(pixel_shuffle_context->As<OpenCLContext>()));
kernel->SetContext(std::move(pixel_shuffle_context));
const DDim in_dim =
DDim(std::vector<DDim::value_type>{INPUT_N, INPUT_C, INPUT_H, INPUT_W});
const DDim out_dim = DDim(std::vector<DDim::value_type>{INPUT_N * EXPAND_N,
INPUT_C * EXPAND_C,
INPUT_H * EXPAND_H,
INPUT_W * EXPAND_W});
LOG(INFO) << "in_dim: " << in_dim;
LOG(INFO) << "expand_times: " << EXPAND_N << EXPAND_C << EXPAND_H << EXPAND_W;
LOG(INFO) << "out_dim: " << out_dim;
x.Resize(in_dim);
out.Resize(out_dim);
std::default_random_engine engine;
std::uniform_real_distribution<float> dist(-2, 2);
std::vector<float> input_v(INPUT_N * INPUT_C * INPUT_H * INPUT_W);
int index = 0;
for (auto& i : input_v) {
i = index++;
}
VLOG(1) << "input_v ..... ";
for (size_t i = 0; i < input_v.size(); i++) {
VLOG(10) << input_v[i];
}
LOG(INFO) << "prepare input";
CLImageConverterDefault* default_converter = new CLImageConverterDefault();
DDim x_image_shape = default_converter->InitImageDimInfoWith(in_dim);
LOG(INFO) << "x_image_shape = " << x_image_shape[0] << " "
<< x_image_shape[1];
std::vector<half_t> x_image_data(x_image_shape.production() * 4); // 4 : RGBA
default_converter->NCHWToImage(input_v.data(), x_image_data.data(), in_dim);
auto* x_image = x.mutable_data<half_t, cl::Image2D>(
x_image_shape[0], x_image_shape[1], x_image_data.data());
VLOG(1) << "x_image_data ..... ";
for (size_t i = 0; i < x_image_data.size(); i++) {
VLOG(10) << Half2Float(x_image_data[i]);
}
DDim out_image_shape = default_converter->InitImageDimInfoWith(out_dim);
LOG(INFO) << "out_image_shape = " << out_image_shape[0] << " "
<< out_image_shape[1];
auto* out_image = out.mutable_data<half_t, cl::Image2D>(out_image_shape[0],
out_image_shape[1]);
kernel->Launch();
CLRuntime::Global()->command_queue().finish();
std::vector<float> out_data_v{0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7};
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
half_t* out_image_data = new half_t[out_image_shape.production() * 4];
TargetWrapperCL::ImgcpySync(out_image_data,
out_image,
out_image_shape[0],
out_image_shape[1],
cl_image2d_row_pitch,
cl_image2d_slice_pitch,
IoDirection::DtoH);
VLOG(1) << "out_image_data ..... ";
for (size_t i = 0; i < out_image_shape.production() * 4; i++) {
VLOG(10) << Half2Float(out_image_data[i]);
}
float* out_data = new float[out_image_shape.production() * 4];
default_converter->ImageToNCHW(
out_image_data, out_data, out_image_shape, out_dim);
VLOG(1) << "out_data ..... ";
for (int i = 0; i < out_dim.production(); i++) {
VLOG(10) << out_data[i];
}
for (int i = 0; i < out_dim.production(); i++) {
auto abs_diff = abs(out_data[i] - out_data_v[i]);
auto relative_diff = COMPUTE_RELATIVE_DIFF(out_data[i], out_data_v[i]);
EXPECT_EQ((relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF),
true);
if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) {
LOG(ERROR) << "error idx:" << i << " out_data[" << i
<< "]:" << out_data[i] << " "
"out_ref["
<< i << "]:" << out_data_v[i] << " abs_diff:" << abs_diff
<< " relative_diff:" << relative_diff
<< " FP16_MAX_DIFF:" << FP16_MAX_DIFF;
}
}
}
TEST(expand_n_image2d, compute) {
LOG(INFO) << "create kernel ...";
auto kernels = KernelRegistry::Global().Create(
"expand", TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kImageDefault));
ASSERT_FALSE(kernels.empty());
const int INPUT_N = 1;
const int INPUT_C = 1;
const int INPUT_H = 2;
const int INPUT_W = 3;
const int EXPAND_N = 2;
const int EXPAND_C = 1;
const int EXPAND_H = 2;
const int EXPAND_W = 3;
auto kernel = std::move(kernels.front());
LOG(INFO) << "prepare to test kernel ====> " << kernel->doc();
lite::Tensor x, out;
operators::ExpandParam param;
param.X = &x;
param.Out = &out;
param.expand_times = {EXPAND_N, EXPAND_C, EXPAND_H, EXPAND_W};
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
kernel->SetParam(param);
std::unique_ptr<KernelContext> pixel_shuffle_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(pixel_shuffle_context->As<OpenCLContext>()));
kernel->SetContext(std::move(pixel_shuffle_context));
const DDim in_dim =
DDim(std::vector<DDim::value_type>{INPUT_N, INPUT_C, INPUT_H, INPUT_W});
const DDim out_dim = DDim(std::vector<DDim::value_type>{INPUT_N * EXPAND_N,
INPUT_C * EXPAND_C,
INPUT_H * EXPAND_H,
INPUT_W * EXPAND_W});
LOG(INFO) << "in_dim: " << in_dim;
LOG(INFO) << "expand_times: " << EXPAND_N << EXPAND_C << EXPAND_H << EXPAND_W;
LOG(INFO) << "out_dim: " << out_dim;
x.Resize(in_dim);
out.Resize(out_dim);
std::default_random_engine engine;
std::uniform_real_distribution<float> dist(-2, 2);
std::vector<float> input_v(INPUT_N * INPUT_C * INPUT_H * INPUT_W);
int index = 0;
for (auto& i : input_v) {
i = index++;
}
VLOG(1) << "input_v ..... ";
for (size_t i = 0; i < input_v.size(); i++) {
VLOG(10) << input_v[i];
}
LOG(INFO) << "prepare input";
CLImageConverterDefault* default_converter = new CLImageConverterDefault();
DDim x_image_shape = default_converter->InitImageDimInfoWith(in_dim);
LOG(INFO) << "x_image_shape = " << x_image_shape[0] << " "
<< x_image_shape[1];
std::vector<half_t> x_image_data(x_image_shape.production() * 4); // 4 : RGBA
default_converter->NCHWToImage(input_v.data(), x_image_data.data(), in_dim);
auto* x_image = x.mutable_data<half_t, cl::Image2D>(
x_image_shape[0], x_image_shape[1], x_image_data.data());
VLOG(1) << "x_image_data ..... ";
for (size_t i = 0; i < x_image_data.size(); i++) {
VLOG(10) << Half2Float(x_image_data[i]);
}
DDim out_image_shape = default_converter->InitImageDimInfoWith(out_dim);
LOG(INFO) << "out_image_shape = " << out_image_shape[0] << " "
<< out_image_shape[1];
auto* out_image = out.mutable_data<half_t, cl::Image2D>(out_image_shape[0],
out_image_shape[1]);
kernel->Launch();
CLRuntime::Global()->command_queue().finish();
std::vector<float> out_data_v{
0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4,
5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0,
1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5};
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
half_t* out_image_data = new half_t[out_image_shape.production() * 4];
TargetWrapperCL::ImgcpySync(out_image_data,
out_image,
out_image_shape[0],
out_image_shape[1],
cl_image2d_row_pitch,
cl_image2d_slice_pitch,
IoDirection::DtoH);
VLOG(1) << "out_image_data ..... ";
for (size_t i = 0; i < out_image_shape.production() * 4; i++) {
VLOG(10) << Half2Float(out_image_data[i]);
}
float* out_data = new float[out_image_shape.production() * 4];
default_converter->ImageToNCHW(
out_image_data, out_data, out_image_shape, out_dim);
VLOG(1) << "out_data ..... ";
for (int i = 0; i < out_dim.production(); i++) {
VLOG(10) << out_data[i];
}
for (int i = 0; i < out_dim.production(); i++) {
auto abs_diff = abs(out_data[i] - out_data_v[i]);
auto relative_diff = COMPUTE_RELATIVE_DIFF(out_data[i], out_data_v[i]);
EXPECT_EQ((relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF),
true);
if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) {
LOG(ERROR) << "error idx:" << i << " out_data[" << i
<< "]:" << out_data[i] << " "
"out_ref["
<< i << "]:" << out_data_v[i] << " abs_diff:" << abs_diff
<< " relative_diff:" << relative_diff
<< " FP16_MAX_DIFF:" << FP16_MAX_DIFF;
}
}
}
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(expand, kOpenCL, kFP16, kImageDefault, image2d);
......@@ -20,6 +20,10 @@
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#include "lite/utils/string.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -141,16 +145,24 @@ class FcCompute
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size_,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
private:
int m_, n_, k_;
param_t* fc_param_{nullptr};
......
......@@ -23,6 +23,10 @@
#include "lite/operators/op_params.h"
#include "lite/utils/logging.h"
#include "lite/utils/replace_stl/stream.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -131,16 +135,24 @@ class GridSamplerImageCompute : public KernelLite<TARGET(kOpenCL),
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size_,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
protected:
param_t* grid_param_{nullptr};
bool first_epoch_for_reinit_{true};
......
......@@ -23,6 +23,10 @@
#include "lite/operators/op_params.h"
#include "lite/utils/logging.h"
#include "lite/utils/replace_stl/stream.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -137,13 +141,13 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL),
status = kernel.setArg(7, *out_img);
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
local_work_size,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
......@@ -258,17 +262,25 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL),
status = kernel.setArg(arg_idx++, in_w);
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
local_work_size,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
#endif
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
protected:
param_t* instance_norm_param_{nullptr};
std::string kernel_func_name_{"instance_norm_onnx"};
......
......@@ -16,19 +16,46 @@
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#undef LITE_WITH_LOG
namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
}
// Host to OpenCL memory.
void CopyFromHostSync(void* target, const void* source, size_t size) {
float CopyFromHostSync(void* target, const void* source, size_t size) {
#ifdef LITE_WITH_PROFILE
auto h2d_copy_start = GetCurrentUS();
#endif
TargetWrapperCL::MemcpySync(target, source, size, IoDirection::HtoD);
#ifdef LITE_WITH_PROFILE
auto h2d_duration = (GetCurrentUS() - h2d_copy_start) / 1000.0;
return h2d_duration;
#else
return 0.0;
#endif
}
// Device to Host memory.
void CopyToHostSync(void* target, const void* source, size_t size) {
float CopyToHostSync(void* target, const void* source, size_t size) {
#ifdef LITE_WITH_PROFILE
auto d2h_copy_start = GetCurrentUS();
#endif
CLRuntime::Global()->command_queue().finish();
TargetWrapperCL::MemcpySync(target, source, size, IoDirection::DtoH);
#ifdef LITE_WITH_PROFILE
auto d2h_duration = (GetCurrentUS() - d2h_copy_start) / 1000.0;
return d2h_duration;
#else
return 0.0;
#endif
}
/*
......@@ -37,6 +64,13 @@ void CopyToHostSync(void* target, const void* source, size_t size) {
class IoCopyHostToOpenCLCompute
: public KernelLite<TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = "HostToOpenCL";
ch->io_duration = h2d_duration_;
}
#endif
void Run() override {
auto& param = Param<operators::IoCopyParam>();
CHECK(param.x->target() == TARGET(kHost) ||
......@@ -50,7 +84,7 @@ class IoCopyHostToOpenCLCompute
VLOG(2) << "param.y->dims():" << param.y->dims();
#endif
auto* data = param.y->mutable_data(TARGET(kOpenCL), mem_size);
CopyFromHostSync(data, param.x->raw_data(), mem_size);
h2d_duration_ = CopyFromHostSync(data, param.x->raw_data(), mem_size);
}
std::unique_ptr<type_infer_handler_t> GetTypeInferHandler() override {
......@@ -74,6 +108,8 @@ class IoCopyHostToOpenCLCompute
}
std::string doc() const override { return "Copy IO from HOST to OpenCL"; }
float h2d_duration_{0};
};
/*
......@@ -82,6 +118,13 @@ class IoCopyHostToOpenCLCompute
class IoCopykOpenCLToHostCompute
: public KernelLite<TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = "OpenCLToHost";
ch->io_duration = d2h_duration_;
}
#endif
void Run() override {
auto& param = Param<operators::IoCopyParam>();
CHECK(param.x->target() == TARGET(kOpenCL));
......@@ -109,12 +152,13 @@ class IoCopykOpenCLToHostCompute
#ifdef LITE_WITH_LOG
VLOG(2) << "--- Find the sync event for the target cl tensor. ---";
#endif
CLRuntime::Global()->command_queue().finish();
CopyToHostSync(data, param.x->raw_data(), mem_size);
d2h_duration_ = CopyToHostSync(data, param.x->raw_data(), mem_size);
}
std::string doc() const override { return "Copy IO from OpenCL to HOST"; }
float d2h_duration_{0};
};
} // namespace opencl
......@@ -161,3 +205,5 @@ REGISTER_LITE_KERNEL(io_copy_once,
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kOpenCL))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
#define LITE_WITH_LOG
......@@ -16,6 +16,7 @@
#include <string>
#include "lite/api/paddle_place.h"
#include "lite/backends/opencl/cl_half.h"
#include "lite/backends/opencl/cl_utility.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/target_wrapper.h"
......@@ -24,6 +25,8 @@
#include "lite/operators/op_params.h"
#include "lite/utils/cp_logging.h"
#undef LITE_WITH_LOG
namespace paddle {
namespace lite {
namespace kernels {
......@@ -50,6 +53,14 @@ class LayoutComputeBufferChwToImageDefault
time_stamp_);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
void Run() override {
auto& param = Param<param_t>();
const cl::Buffer* x_data;
......@@ -128,13 +139,13 @@ class LayoutComputeBufferChwToImageDefault
static_cast<cl::size_type>(new_dims[3]),
static_cast<cl::size_type>(new_dims[0] * new_dims[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
......@@ -168,6 +179,14 @@ class LayoutComputeImageDefaultToBufferChw
time_stamp_);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
void Run() override {
auto& param = Param<param_t>();
const cl::Buffer* y_data;
......@@ -237,13 +256,13 @@ class LayoutComputeImageDefaultToBufferChw
static_cast<cl::size_type>(new_dims[3]),
static_cast<cl::size_type>(new_dims[0] * new_dims[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
......@@ -274,6 +293,14 @@ class LayoutComputeBufferChwToImage2DNw
time_stamp_);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
void Run() override {
auto& param = Param<param_t>();
auto* x_data = param.x->data<float, cl::Buffer>();
......@@ -333,13 +360,13 @@ class LayoutComputeBufferChwToImage2DNw
static_cast<cl::size_type>(out_W), // w
static_cast<cl::size_type>(out_C * out_H)}; // ch
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
......@@ -394,3 +421,4 @@ REGISTER_LITE_KERNEL(
PRECISION(kAny),
DATALAYOUT(kNCHW))})
.Finalize();
#define LITE_WITH_LOG
......@@ -23,6 +23,10 @@
#include "lite/operators/op_params.h"
#include "lite/utils/logging.h"
#include "lite/utils/replace_stl/stream.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -128,13 +132,13 @@ class LrnImageCompute : public KernelLite<TARGET(kOpenCL),
static_cast<cl::size_type>(default_work_size[1]),
static_cast<cl::size_type>(default_work_size[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
#ifdef LITE_WITH_LOG
VLOG(4) << "global_work_size:[2D]:" << global_work_size[0] << " "
......@@ -142,6 +146,14 @@ class LrnImageCompute : public KernelLite<TARGET(kOpenCL),
#endif
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
protected:
param_t* lrn_param_{nullptr};
int n_{5};
......
......@@ -20,6 +20,10 @@
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#include "lite/utils/string.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -92,16 +96,24 @@ class MulCompute
auto global_work_size = cl::NDRange{static_cast<size_t>((m_ + 3) / 4),
static_cast<size_t>((n_ + 3) / 4)};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
private:
int m_, n_, k_;
std::string kernel_func_name_{"mat_mul"};
......
......@@ -19,6 +19,10 @@
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -110,16 +114,24 @@ class NearestInterpComputeImageDefault
static_cast<cl::size_type>(default_work_size.data()[1]),
static_cast<cl::size_type>(default_work_size.data()[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
private:
std::string kernel_func_name_{"nearest_interp"};
std::string build_options_{" -DCL_DTYPE_half"};
......
......@@ -23,6 +23,10 @@
#include "lite/operators/op_params.h"
#include "lite/utils/logging.h"
#include "lite/utils/replace_stl/stream.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -142,13 +146,13 @@ class Pad2dCompute : public KernelLite<TARGET(kOpenCL),
static_cast<cl::size_type>(default_work_size[1]),
static_cast<cl::size_type>(default_work_size[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
#ifdef LITE_WITH_LOG
VLOG(4) << "global_work_size:[2D]:" << global_work_size[0] << " "
......@@ -156,6 +160,14 @@ class Pad2dCompute : public KernelLite<TARGET(kOpenCL),
#endif
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
protected:
param_t* pad2d_param_{nullptr};
std::string kernel_func_name_{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "lite/backends/opencl/cl_half.h"
#include "lite/backends/opencl/cl_include.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#include "lite/utils/string.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {
class PixelShuffleComputeImage2D
: public KernelLite<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault)> {
public:
using param_t = operators::PixelShuffleParam;
std::string doc() const override {
return "PixelShuffle using cl::Image2D, kFP16";
}
void PrepareForRun() override {
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_,
"image/pixel_shuffle_kernel.cl",
build_options_,
time_stamp_);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
kernel_ = context.cl_context()->GetKernel(kernel_key.str());
}
void ReInitWhenNeeded() override {
VLOG(1) << "ReInitWhenNeeded: " << kernel_func_name_;
pixel_shuffle_param_ = param_.get_mutable<param_t>();
auto x_dims = pixel_shuffle_param_->x->dims();
auto out_dims = pixel_shuffle_param_->output->dims();
VLOG(1) << "x_dims: " << x_dims;
VLOG(1) << "out_dims: " << out_dims;
VLOG(1) << "upscale_factor: " << pixel_shuffle_param_->upscale_factor;
if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) ||
first_epoch_for_reinit_) {
last_x_dims_ = x_dims;
first_epoch_for_reinit_ = false;
// compute image shape
paddle::lite::CLImageConverterDefault default_convertor;
out_img_shape_ = default_convertor.InitImageDimInfoWith(
pixel_shuffle_param_->output->dims());
VLOG(1) << "out_img_shape_: " << out_img_shape_[0] << " "
<< out_img_shape_[1];
// compute global work size
auto image_width = out_dims[3] * ((out_dims[1] + 3) / 4);
size_t work_size_0 = image_width / out_dims[3];
size_t work_size_1 = out_dims[3];
size_t work_size_2 = out_dims[0] * out_dims[2];
global_work_size_ = cl::NDRange{work_size_0, work_size_1, work_size_2};
VLOG(1) << "global_work_size_: " << global_work_size_[0] << " "
<< global_work_size_[1] << " " << global_work_size_[2];
}
}
void Run() override {
auto* x_img = pixel_shuffle_param_->x->data<half_t, cl::Image2D>();
auto* out_img =
pixel_shuffle_param_->output->mutable_data<half_t, cl::Image2D>(
out_img_shape_[0], out_img_shape_[1]);
auto x_dims = pixel_shuffle_param_->x->dims();
int in_n = x_dims[0];
int in_c = x_dims[1];
int in_h = x_dims[2];
int in_w = x_dims[3];
auto out_dims = pixel_shuffle_param_->output->dims();
int out_n = out_dims[0];
int out_c = out_dims[1];
int out_h = out_dims[2];
int out_w = out_dims[3];
const int upscale_factor = pixel_shuffle_param_->upscale_factor;
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
auto kernel = kernel_;
cl_int status;
status = kernel.setArg(0, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(1, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(2, in_n);
CL_CHECK_FATAL(status);
status = kernel.setArg(3, in_c);
CL_CHECK_FATAL(status);
status = kernel.setArg(4, in_h);
CL_CHECK_FATAL(status);
status = kernel.setArg(5, in_w);
CL_CHECK_FATAL(status);
status = kernel.setArg(6, out_n);
CL_CHECK_FATAL(status);
status = kernel.setArg(7, out_c);
CL_CHECK_FATAL(status);
status = kernel.setArg(8, out_h);
CL_CHECK_FATAL(status);
status = kernel.setArg(9, out_w);
CL_CHECK_FATAL(status);
status = kernel.setArg(10, upscale_factor);
CL_CHECK_FATAL(status);
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size_,
cl::NullRange,
nullptr,
event_);
CL_CHECK_FATAL(status);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
private:
std::string kernel_func_name_{"pixel_shuffle"};
std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()};
param_t* pixel_shuffle_param_{nullptr};
cl::Kernel kernel_;
bool first_epoch_for_reinit_{true};
DDim last_x_dims_;
DDim out_img_shape_ = DDim(std::vector<DDim::value_type>(
{static_cast<DDim::value_type>(1), static_cast<DDim::value_type>(1)}));
cl::NDRange global_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
};
} // namespace opencl
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(pixel_shuffle,
kOpenCL,
kFP16,
kImageDefault,
paddle::lite::kernels::opencl::PixelShuffleComputeImage2D,
image2d)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <random>
#include <gtest/gtest.h>
#include "lite/backends/opencl/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/kernels/opencl/test_helper.h"
#define FP16_MAX_DIFF (5e-1)
namespace paddle {
namespace lite {
TEST(pixel_shuffle_image2d, compute) {
LOG(INFO) << "create kernel ...";
auto kernels = KernelRegistry::Global().Create("pixel_shuffle",
TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault));
ASSERT_FALSE(kernels.empty());
const int INPUT_N = 1;
const int INPUT_C = 4;
const int INPUT_H = 2;
const int INPUT_W = 2;
const int UPSCALE_FACTOR = 2;
auto kernel = std::move(kernels.front());
LOG(INFO) << "prepare to test kernel ====> " << kernel->doc();
lite::Tensor x, out;
operators::PixelShuffleParam param;
param.x = &x;
param.output = &out;
param.upscale_factor = UPSCALE_FACTOR;
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
kernel->SetParam(param);
std::unique_ptr<KernelContext> pixel_shuffle_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(pixel_shuffle_context->As<OpenCLContext>()));
kernel->SetContext(std::move(pixel_shuffle_context));
const DDim in_dim =
DDim(std::vector<DDim::value_type>{INPUT_N, INPUT_C, INPUT_H, INPUT_W});
const DDim out_dim = DDim(
std::vector<DDim::value_type>{INPUT_N,
INPUT_C / UPSCALE_FACTOR / UPSCALE_FACTOR,
INPUT_H * UPSCALE_FACTOR,
INPUT_W * UPSCALE_FACTOR});
LOG(INFO) << "in_dim: " << in_dim;
LOG(INFO) << "UPSCALE_FACTOR: " << UPSCALE_FACTOR;
LOG(INFO) << "out_dim: " << out_dim;
x.Resize(in_dim);
out.Resize(out_dim);
std::default_random_engine engine;
std::uniform_real_distribution<float> dist(-2, 2);
std::vector<float> input_v(INPUT_N * INPUT_C * INPUT_H * INPUT_W);
int index = 0;
for (auto& i : input_v) {
i = index++;
}
VLOG(1) << "input_v ..... ";
for (size_t i = 0; i < input_v.size(); i++) {
VLOG(10) << input_v[i];
}
LOG(INFO) << "prepare input";
CLImageConverterDefault* default_converter = new CLImageConverterDefault();
DDim x_image_shape = default_converter->InitImageDimInfoWith(in_dim);
LOG(INFO) << "x_image_shape = " << x_image_shape[0] << " "
<< x_image_shape[1];
std::vector<half_t> x_image_data(x_image_shape.production() * 4); // 4 : RGBA
default_converter->NCHWToImage(input_v.data(), x_image_data.data(), in_dim);
auto* x_image = x.mutable_data<half_t, cl::Image2D>(
x_image_shape[0], x_image_shape[1], x_image_data.data());
VLOG(1) << "x_image_data ..... ";
for (size_t i = 0; i < x_image_data.size(); i++) {
VLOG(10) << Half2Float(x_image_data[i]);
}
DDim out_image_shape = default_converter->InitImageDimInfoWith(out_dim);
LOG(INFO) << "out_image_shape = " << out_image_shape[0] << " "
<< out_image_shape[1];
auto* out_image = out.mutable_data<half_t, cl::Image2D>(out_image_shape[0],
out_image_shape[1]);
kernel->Launch();
CLRuntime::Global()->command_queue().finish();
std::vector<float> out_data_v{
0, 4, 1, 5, 8, 12, 9, 13, 2, 6, 3, 7, 10, 14, 11, 15};
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
half_t* out_image_data = new half_t[out_image_shape.production() * 4];
TargetWrapperCL::ImgcpySync(out_image_data,
out_image,
out_image_shape[0],
out_image_shape[1],
cl_image2d_row_pitch,
cl_image2d_slice_pitch,
IoDirection::DtoH);
VLOG(1) << "out_image_data ..... ";
for (size_t i = 0; i < out_image_shape.production() * 4; i++) {
VLOG(10) << Half2Float(out_image_data[i]);
}
float* out_data = new float[out_image_shape.production() * 4];
default_converter->ImageToNCHW(
out_image_data, out_data, out_image_shape, out_dim);
VLOG(1) << "out_data ..... ";
for (int i = 0; i < out_dim.production(); i++) {
VLOG(10) << out_data[i];
}
for (int i = 0; i < out_dim.production(); i++) {
auto abs_diff = abs(out_data[i] - out_data_v[i]);
auto relative_diff = COMPUTE_RELATIVE_DIFF(out_data[i], out_data_v[i]);
EXPECT_EQ((relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF),
true);
if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) {
LOG(ERROR) << "error idx:" << i << " out_data[" << i
<< "]:" << out_data[i] << " "
"out_ref["
<< i << "]:" << out_data_v[i] << " abs_diff:" << abs_diff
<< " relative_diff:" << relative_diff
<< " FP16_MAX_DIFF:" << FP16_MAX_DIFF;
}
}
}
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(pixel_shuffle, kOpenCL, kFP16, kImageDefault, image2d);
......@@ -20,6 +20,10 @@
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#include "lite/utils/string.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -106,16 +110,24 @@ class PoolCompute
CL_CHECK_FATAL(status);
auto global_work_size = cl::NDRange(static_cast<size_t>(numel));
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
private:
std::string kernel_func_name_{"pool_"};
std::string build_options_{"-DCL_DTYPE_float"};
......
......@@ -22,6 +22,12 @@
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#include "lite/utils/string.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
#undef LITE_WITH_LOG
namespace paddle {
namespace lite {
......@@ -50,6 +56,14 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
kernel_func_name_, "image/pool_kernel.cl", build_options_, time_stamp_);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
void Run() override {
const auto& param = *param_.get_mutable<param_t>();
const auto& in_dims = param.x->dims();
......@@ -150,13 +164,13 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[0]));
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
......@@ -186,3 +200,4 @@ REGISTER_LITE_KERNEL(pool2d,
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.Finalize();
#define LITE_WITH_LOG
......@@ -20,6 +20,12 @@
#include "lite/operators/op_params.h"
#include "lite/utils/logging.h"
#include "lite/utils/replace_stl/stream.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
#undef LITE_WITH_LOG
namespace paddle {
namespace lite {
......@@ -42,6 +48,14 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL),
time_stamp_);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
void Run() override {
auto& param = *param_.get_mutable<param_t>();
const Tensor* const x = param.x;
......@@ -154,13 +168,13 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL),
static_cast<size_t>(default_work_size.data()[1]),
static_cast<size_t>(default_work_size.data()[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
......@@ -246,3 +260,4 @@ REGISTER_LITE_KERNEL(flatten2,
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.Finalize();
#define LITE_WITH_LOG
......@@ -21,6 +21,10 @@
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#include "lite/utils/string.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -93,16 +97,24 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL),
status = kernel.setArg(3, bias);
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size_,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
private:
std::string kernel_func_name_{"scale"};
std::string build_options_{"-DCL_DTYPE_half"};
......
......@@ -21,6 +21,10 @@
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#include "lite/utils/string.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
......@@ -96,16 +100,24 @@ class SliceComputeImage2D : public KernelLite<TARGET(kOpenCL),
static_cast<cl::size_type>(default_work_size.data()[1]),
static_cast<cl::size_type>(default_work_size.data()[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
nullptr);
event_);
CL_CHECK_FATAL(status);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
private:
std::string kernel_func_name_{"slice"};
std::string build_options_{"-DCL_DTYPE_half"};
......
......@@ -45,6 +45,8 @@ class BlockDesc : public BlockDescAPI {
template <typename T>
T* GetVar(int32_t idx);
std::vector<VarDesc>& GetVars() { return vars_; }
template <typename T>
T* AddVar();
......
......@@ -36,6 +36,8 @@ class ProgramDesc : public ProgramDescAPI {
template <typename T>
T* GetBlock(int32_t idx);
std::vector<BlockDesc>& GetBlocks() { return blocks_; }
template <typename T>
T* AddBlock();
......
......@@ -108,6 +108,7 @@ add_operator(collect_fpn_proposals_op_lite extra SRCS collect_fpn_proposals_op.c
add_operator(distribute_fpn_proposals_op_lite extra SRCS distribute_fpn_proposals_op.cc DEPS ${op_DEPS})
add_operator(crf_decoding_op_lite extra SRCS crf_decoding_op.cc DEPS ${op_DEPS})
add_operator(ctc_align_op_lite extra SRCS ctc_align_op.cc DEPS ${op_DEPS})
add_operator(pixel_shuffle_op extra SRCS pixel_shuffle_op.cc DEPS ${op_DEPS})
# for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
......@@ -15,6 +15,9 @@
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#ifdef LITE_WITH_PROFILE
#include "lite/api/paddle_place.h"
#endif
namespace paddle {
namespace lite {
......@@ -34,6 +37,58 @@ class ActivationOp : public OpLite {
std::string DebugString() const override { return "activation_op"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter* ch) {
auto input_dims = param_.X->dims();
auto output_dims = param_.Out->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
ch->remark = ActivationTypeToStr(param_.active_type);
switch (param_.active_type) {
case lite_api::ActivationType::kRelu:
ch->macs = param_.X->numel();
break;
case lite_api::ActivationType::kRelu6:
ch->macs = param_.X->numel() * 2.0;
break;
case lite_api::ActivationType::kLeakyRelu:
ch->macs = param_.X->numel() * 2.0;
break;
case lite_api::ActivationType::kPRelu:
ch->macs = param_.X->numel() * 2.0;
break;
case lite_api::ActivationType::kSwish:
ch->macs = param_.X->numel() * 4.0;
break;
case lite_api::ActivationType::kSigmoid:
ch->macs = param_.X->numel() * 3.0;
break;
case lite_api::ActivationType::kTanh:
ch->macs = param_.X->numel() * 5.0;
break;
case lite_api::ActivationType::kExp:
ch->macs = param_.X->numel();
break;
case lite_api::ActivationType::kAbs:
ch->macs = param_.X->numel();
break;
case lite_api::ActivationType::kHardSwish:
ch->macs = param_.X->numel() * 5.0;
break;
case lite_api::ActivationType::kReciprocal:
ch->macs = param_.X->numel();
break;
case lite_api::ActivationType::kIndentity:
break;
default:
LOG(FATAL) << "This Type of Activation:"
<< static_cast<int>(param_.active_type)
<< ActivationTypeToStr(param_.active_type)
<< " doesn't support";
}
}
#endif
private:
mutable operators::ActivationParam param_;
};
......
......@@ -39,6 +39,17 @@ class AffineChannelOpLite : public OpLite {
std::string DebugString() const override { return "affine_channel"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto input_dims = param_.X->dims();
auto output_dims = param_.Out->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
ch->remark = param_.data_layout;
ch->macs = param_.X->numel() * 2.0;
}
#endif
private:
mutable AffineChannelParam param_;
};
......
......@@ -39,6 +39,27 @@ class ArgmaxOpLite : public OpLite {
std::string DebugString() const override { return "argmax"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto input_dims = param_.X->dims();
auto output_dims = param_.Out->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
ch->remark = "axis" + std::to_string(param_.Axis);
auto axis = param_.Axis;
if (axis < 0) {
axis += input_dims.size();
}
int max_num = 1;
for (int64_t i = axis + 1; i < input_dims.size(); i++)
max_num *= input_dims[i];
float gops = 1.0f;
for (int i = 1; i <= max_num; i++) gops *= i;
ch->macs = gops * output_dims.production();
}
#endif
private:
mutable ArgmaxParam param_;
};
......
......@@ -37,6 +37,17 @@ class AssignOpLite : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "assign"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto input_dims = param_.X->dims();
auto output_dims = param_.Out->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
// ch->remark = "";
ch->macs = param_.X->numel() * 1.0;
}
#endif
private:
mutable AssignParam param_;
};
......
......@@ -39,6 +39,17 @@ class AssignValueOpLite : public OpLite {
std::string DebugString() const override { return "assign value"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
// auto input_dims = param_.X->dims();
auto output_dims = param_.Out->dims();
// ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
ch->remark = "dtype" + std::to_string(param_.dtype);
ch->macs = param_.Out->numel() * 1.0;
}
#endif
private:
mutable AssignValueParam param_;
};
......
......@@ -39,6 +39,17 @@ class AxpyOpLite : public OpLite {
std::string DebugString() const override { return "axpy"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto input_dims = param_.X->dims();
auto output_dims = param_.Out->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
// ch->remark = "";
ch->macs = param_.X->numel() * 2.0;
}
#endif
private:
mutable AxpyParam param_;
};
......
......@@ -37,6 +37,17 @@ class BatchNormOp : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "batch_norm"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto input_dims = param_.x->dims();
auto output_dims = param_.y->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
// ch->remark = "";
ch->macs = param_.y->numel() * 2.0;
}
#endif
private:
mutable BatchNormParam param_;
};
......
......@@ -39,6 +39,17 @@ class BoxClipOpLite : public OpLite {
std::string DebugString() const override { return "box clip"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto input_dims = param_.Input->dims();
auto output_dims = param_.Output->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
// ch->remark = "";
ch->macs = param_.Output->numel() * 2.0;
}
#endif
private:
mutable BoxClipParam param_;
};
......
......@@ -34,8 +34,21 @@ class BoxCoderOpLite : public OpLite {
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "box_coder"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
// auto input_dims = param_.Input->dims();
// auto output_dims = param_.Output->dims();
// ch->input_shape = ch->DimToStr(input_dims);
// ch->output_shape = ch->DimToStr(output_dims);
ch->remark = "proposals" + std::to_string(param_.proposals->dims()[0]) +
"x" + std::to_string(param_.proposals->dims()[1]);
ch->macs = param_.proposals->dims()[0] * param_.proposals->dims()[1] * 30.f;
}
#endif
private:
mutable BoxCoderParam param_;
};
......
......@@ -50,6 +50,17 @@ class CalibOpLite : public OpLite {
std::string DebugString() const override { return "calib"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto input_dims = param_.input->dims();
auto output_dims = param_.output->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
ch->remark = "scale" + std::to_string(param_.scale);
ch->macs = param_.output->numel() * 1.0f;
}
#endif
private:
mutable CalibParam param_;
};
......
......@@ -38,6 +38,18 @@ class CompareOp : public OpLite {
std::string DebugString() const override { return "binary logical"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto output_dims = param_.Out->dims();
ch->input_shape = "X:" + ch->DimToStr(param_.X->dims()) + "Y:" +
ch->DimToStr(param_.Y->dims());
ch->output_shape = ch->DimToStr(output_dims);
ch->remark = "axis" + std::to_string(param_.axis) + "force_cpu" +
std::to_string(param_.force_cpu);
ch->macs = param_.Out->numel() * 1.0f;
}
#endif
private:
mutable CompareParam param_;
};
......
......@@ -37,6 +37,21 @@ class ConcatOpLite : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "concat"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto output_dims = param_.output->dims();
std::string inputs_shape = "";
for (size_t i = 0; i < param_.x.size(); ++i) {
inputs_shape += ch->DimToStr(param_.x[i]->dims());
if (i != param_.x.size() - 1) inputs_shape += "/";
}
ch->input_shape = inputs_shape;
ch->output_shape = ch->DimToStr(output_dims);
ch->remark = "axis" + std::to_string(param_.axis);
ch->macs = 0.f; // no calc. only io operation
}
#endif
private:
mutable ConcatParam param_;
};
......
......@@ -22,6 +22,9 @@
#include "lite/core/tensor.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
#ifdef LITE_WITH_PROFILE
#include "lite/api/paddle_place.h"
#endif
namespace paddle {
namespace lite {
......@@ -36,6 +39,29 @@ class ConvOpLite : public OpLite {
bool CheckShape() const override;
bool InferShapeImpl() const override;
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter* ch) {
auto filter_dims = param_.filter->dims();
auto input_dims = param_.x->dims();
auto output_dims = param_.output->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
ch->filter_shape = ch->DimToStr(filter_dims);
ch->remark =
std::to_string(filter_dims[2]) + "x" + std::to_string(filter_dims[3]) +
"p" + std::to_string((*param_.paddings)[0]) + "s" +
std::to_string(param_.strides[0]) + "g" +
std::to_string(param_.groups) + "d" +
std::to_string((*param_.dilations)[0]) + (param_.bias ? "Bias" : "") +
ActivationTypeToStr(param_.activation_param.active_type);
// MACs = 2.f * kw * kh * batchsize * out_c * out_h * out_w * in_c / group
// GMACs = 1e-9f * MACs
// GMACPS = 1e-6f * MACs / predict_ms
ch->macs = 2.f * filter_dims[2] * filter_dims[3] *
output_dims.production() * input_dims[1] / param_.groups;
}
#endif
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override {
AttachParam(&param_);
......
......@@ -21,6 +21,9 @@
#include "lite/core/tensor.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
#ifdef LITE_WITH_PROFILE
#include "lite/api/paddle_place.h"
#endif
namespace paddle {
namespace lite {
......@@ -42,6 +45,29 @@ class ConvTransposeOpLite : public OpLite {
std::string DebugString() const override { return "conv_transpose"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto filter_dims = param_.filter->dims();
auto input_dims = param_.x->dims();
auto output_dims = param_.output->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
ch->filter_shape = ch->DimToStr(filter_dims);
ch->remark =
std::to_string(filter_dims[2]) + "x" + std::to_string(filter_dims[3]) +
"p" + std::to_string((*param_.paddings)[0]) + "s" +
std::to_string(param_.strides[0]) + "g" +
std::to_string(param_.groups) + "d" +
std::to_string((*param_.dilations)[0]) + (param_.bias ? "Bias" : "") +
ActivationTypeToStr(param_.activation_param.active_type);
// MACs = 2.f * kw * kh * batchsize * out_c * out_h * out_w * in_c / group
// GMACs = 1e-9f * MACs
// GMACPS = 1e-6f * MACs / predict_ms
ch->macs = 2.f * filter_dims[2] * filter_dims[3] *
output_dims.production() * input_dims[1] / param_.groups;
}
#endif
private:
mutable ConvParam param_;
std::string padding_algorithm_{""};
......
......@@ -35,6 +35,17 @@ class ElementwiseOp : public OpLite {
std::string DebugString() const override { return "elementwise_op"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter* ch) {
auto output_dims = param_.Out->dims();
ch->input_shape = "X" + ch->DimToStr(param_.X->dims()) + "Y" +
ch->DimToStr(param_.Y->dims());
ch->output_shape = ch->DimToStr(output_dims);
ch->remark = "axis" + std::to_string(param_.axis);
ch->macs = 1.0f * param_.Out->numel();
}
#endif
private:
mutable operators::ElementwiseParam param_;
};
......
......@@ -43,6 +43,17 @@ class FcOpLite : public OpLite {
std::string DebugString() const override { return "fc"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto m = param_.input->dims().count(0, param_.in_num_col_dims);
ch->input_shape = ch->DimToStr(param_.input->dims());
ch->filter_shape = ch->DimToStr(param_.w->dims());
ch->output_shape = ch->DimToStr(param_.output->dims());
ch->remark = (param_.bias ? "Bias" : "") + param_.activation_type;
ch->macs = m * param_.w->dims()[0] * param_.w->dims()[1] * 3.0f;
}
#endif
private:
mutable FcParam param_;
};
......
......@@ -38,6 +38,15 @@ class IncrementOp : public OpLite {
std::string DebugString() const override { return "increment"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.X->dims());
ch->output_shape = ch->DimToStr(param_.Out->dims());
ch->remark = "step" + std::to_string(param_.step);
ch->macs = param_.X->numel() * 1.0f;
}
#endif
private:
mutable IncrementParam param_;
};
......
......@@ -36,8 +36,22 @@ class InstanceNormOp : public OpLite {
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "instance_norm"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.x->dims());
ch->output_shape = ch->DimToStr(param_.out->dims());
// ch->remark = "";
auto x_dims = param_.x->dims();
auto nc = x_dims[0] * x_dims[1];
auto hw = x_dims[2] * x_dims[3];
auto nchw = x_dims.production();
ch->macs = 5.f * nchw + 3.f * (nc + hw);
}
#endif
private:
mutable InstanceNormParam param_;
};
......
......@@ -36,8 +36,18 @@ class InterpolateOp : public OpLite {
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "interpolate"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.X->dims());
ch->output_shape = ch->DimToStr(param_.Out->dims());
ch->remark = param_.interp_method;
ch->macs = param_.Out->numel() * 14.f;
}
#endif
private:
mutable InterpolateParam param_;
};
......
......@@ -30,6 +30,16 @@ class IoCopyOp : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto input_dims = param_.x->dims();
auto output_dims = param_.y->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
ch->remark = "type" + std::to_string(param_.process_type);
}
#endif
protected:
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -38,6 +38,15 @@ class LayerNormOp : public OpLite {
std::string DebugString() const override { return "layer_norm"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.X->dims());
ch->output_shape = ch->DimToStr(param_.Y->dims());
ch->remark = "begin_norm_axis" + std::to_string(param_.begin_norm_axis);
ch->macs = param_.Y->numel() * 7.f;
}
#endif
private:
mutable LayerNormParam param_;
};
......
......@@ -30,6 +30,16 @@ class LayoutOp : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto input_dims = param_.x->dims();
auto output_dims = param_.y->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
ch->remark = "type" + std::to_string(param_.process_type);
}
#endif
protected:
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -38,6 +38,16 @@ class BinaryLogicalOp : public OpLite {
std::string DebugString() const override { return "binary logical"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = "X" + ch->DimToStr(param_.X->dims()) + "Y" +
ch->DimToStr(param_.Y->dims());
ch->output_shape = ch->DimToStr(param_.Out->dims());
// ch->remark = "";
ch->macs = param_.Out->numel() * 3.f;
}
#endif
private:
mutable LogicalParam param_;
};
......@@ -57,6 +67,16 @@ class UnaryLogicalOp : public OpLite {
std::string DebugString() const override { return "binary logical"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = "X" + ch->DimToStr(param_.X->dims()) + "Y" +
ch->DimToStr(param_.Y->dims());
ch->output_shape = ch->DimToStr(param_.Out->dims());
// ch->remark = "";
ch->macs = param_.Out->numel() * 3.f;
}
#endif
private:
mutable LogicalParam param_;
};
......
......@@ -33,8 +33,18 @@ class LrnOpLite : public OpLite {
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "lrn"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.X->dims());
ch->output_shape = ch->DimToStr(param_.Out->dims());
ch->remark = "n" + std::to_string(param_.n) + param_.norm_region;
ch->macs = param_.Out->numel() * param_.k * 2.f;
}
#endif
private:
mutable LrnParam param_;
};
......
......@@ -41,6 +41,31 @@ class MatMulOpLite : public OpLite {
std::string DebugString() const override { return "matmul"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.X->dims());
ch->filter_shape = ch->DimToStr(param_.Y->dims());
ch->output_shape = ch->DimToStr(param_.Out->dims());
ch->remark = "alpha" + std::to_string(param_.alpha) + "trans_x" +
std::to_string(param_.transpose_X) + "trans_y" +
std::to_string(param_.transpose_Y);
auto x_dims = param_.X->dims();
auto y_dims = param_.Y->dims();
auto m = x_dims[x_dims.size() - 2];
auto k = x_dims[x_dims.size() - 1];
auto n = y_dims[y_dims.size() - 1];
if (param_.transpose_X) {
m = x_dims[x_dims.size() - 1];
k = x_dims[x_dims.size() - 2];
}
if (param_.transpose_Y) {
n = y_dims[y_dims.size() - 2];
}
ch->macs = 3.f * m * n * k;
}
#endif
private:
mutable MatMulParam param_;
};
......
......@@ -35,6 +35,15 @@ class MeanOp : public OpLite {
std::string DebugString() const override { return "mean"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.X->dims());
ch->output_shape = ch->DimToStr(param_.Out->dims());
// ch->remark = "";
ch->macs = param_.X->numel() * 1.f;
}
#endif
private:
mutable operators::MeanParam param_;
};
......
......@@ -63,6 +63,20 @@ class MulOpLite : public OpLite {
std::string DebugString() const override { return "mul"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.x->dims());
ch->filter_shape = ch->DimToStr(param_.y->dims());
ch->output_shape = ch->DimToStr(param_.output->dims());
// ch->remark = "";
auto x_dims = param_.x->dims();
auto y_dims = param_.y->dims();
auto x_mat_dims = x_dims.Flatten2D(param_.x_num_col_dims);
auto y_mat_dims = y_dims.Flatten2D(param_.y_num_col_dims);
ch->macs = 1.f * x_mat_dims[0] * x_mat_dims[1] * y_mat_dims[1];
}
#endif
private:
mutable MulParam param_;
};
......
......@@ -35,8 +35,18 @@ class NegativeOpLite : public OpLite {
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "negative"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.X->dims());
ch->output_shape = ch->DimToStr(param_.Out->dims());
// ch->remark = "";
ch->macs = 1.f * param_.Out->numel();
}
#endif
private:
mutable NegativeParam param_;
};
......
......@@ -244,6 +244,10 @@ struct ScaleParam : ParamBase {
float scale{1.};
float bias{};
bool bias_after_scale{true};
std::string activation_type{""};
bool fuse_relu{false};
float alpha{6.};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() override {
......@@ -1511,6 +1515,11 @@ struct XPUFcParam : ParamBase {
std::string activation_type{""};
};
struct PixelShuffleParam : ParamBase {
lite::Tensor* x{nullptr};
lite::Tensor* output{nullptr};
int upscale_factor{1};
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/pixel_shuffle_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool PixelShuffleOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
CHECK_OR_FALSE(param_.upscale_factor);
const auto x_dims = param_.x->dims();
const auto upscale_factor = param_.upscale_factor;
CHECK_EQ_OR_FALSE(x_dims[1] % (upscale_factor * upscale_factor), 0);
return true;
}
bool PixelShuffleOpLite::InferShapeImpl() const {
const auto x_dims = param_.x->dims();
const auto upscale_factor = param_.upscale_factor;
auto output_dims = x_dims;
output_dims[0] = x_dims[0];
output_dims[1] = x_dims[1] / (upscale_factor * upscale_factor);
output_dims[2] = x_dims[2] * upscale_factor;
output_dims[3] = x_dims[3] * upscale_factor;
param_.output->Resize(output_dims);
return true;
}
bool PixelShuffleOpLite::AttachImpl(const cpp::OpDesc& opdesc,
lite::Scope* scope) {
auto input = opdesc.Input("X").front();
auto out = opdesc.Output("Out").front();
param_.x = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
if (opdesc.HasAttr("upscale_factor")) {
param_.upscale_factor = opdesc.GetAttr<int>("upscale_factor");
}
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(pixel_shuffle, paddle::lite::operators::PixelShuffleOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace operators {
class PixelShuffleOpLite : public OpLite {
public:
PixelShuffleOpLite() {}
explicit PixelShuffleOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "pixel_shuffle"; }
private:
mutable PixelShuffleParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -92,6 +92,25 @@ class PoolOpLite : public OpLite {
std::string DebugString() const override { return "pool2d"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto input_dims = param_.x->dims();
auto output_dims = param_.output->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
if (param_.global_pooling) {
ch->remark = "global" + param_.pooling_type;
} else {
ch->remark = param_.pooling_type + std::to_string(param_.ksize[0]) + "x" +
std::to_string(param_.ksize[1]) + "s" +
std::to_string(param_.strides[0]) + "p" +
std::to_string((*param_.paddings)[0]);
}
ch->remark += padding_algorithm_;
ch->macs = output_dims.production() * param_.ksize[0] * param_.ksize[1];
}
#endif
private:
mutable PoolParam param_;
std::string padding_algorithm_{""};
......
......@@ -36,8 +36,18 @@ class PowerOp : public OpLite {
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "power"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.X->dims());
ch->output_shape = ch->DimToStr(param_.Out->dims());
// ch->remark = "";
ch->macs = param_.Out->numel() * 3.0f;
}
#endif
private:
mutable PowerParam param_;
};
......
......@@ -32,8 +32,29 @@ class ReduceMaxOp : public OpLite {
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "reduce_max"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.X->dims());
ch->output_shape = ch->DimToStr(param_.Out->dims());
ch->remark = "keep_dim" + std::to_string(param_.keep_dim);
auto dims = param_.dim;
auto in_sum = param_.X->numel();
if (dims.size() == 0 || dims.size() == 1) {
ch->macs = 1.f * in_sum;
} else if (dims.size() == 2) {
ch->macs = 2.f * in_sum;
} else {
LOG(FATAL) << "This dims size of ReduceMaxParm: " << dims.size()
<< " doesn't support";
ch->macs = 0.f;
}
}
#endif
private:
mutable ReduceMaxParam param_;
};
......
......@@ -26,14 +26,41 @@ namespace operators {
class ReduceMeanOp : public OpLite {
public:
ReduceMeanOp() {}
explicit ReduceMeanOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "reduce_mean"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.X->dims());
ch->output_shape = ch->DimToStr(param_.Out->dims());
ch->remark = "keep_dim" + std::to_string(param_.keep_dim);
auto dims = param_.dim;
auto in_sum = param_.X->numel();
if (dims.size() == 0) {
ch->macs = 1.f * in_sum;
} else if (dims.size() == 1) {
ch->macs = 2.f * in_sum;
} else if (dims.size() == 2) {
ch->macs = 4.f * in_sum;
} else {
LOG(FATAL) << "This dims size of ReduceMean: " << dims.size()
<< " doesn't support";
ch->macs = 0.f;
}
}
#endif
private:
mutable ReduceMeanParam param_;
};
......
......@@ -37,6 +37,27 @@ class ReduceProdOpLite : public OpLite {
std::string DebugString() const override { return "reduce_prod"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.x->dims());
ch->output_shape = ch->DimToStr(param_.output->dims());
ch->remark = "keep_dim" + std::to_string(param_.keep_dim) + "reduce_all" +
std::to_string(param_.reduce_all);
auto dims = param_.dim;
auto in_sum = param_.x->numel();
if (dims.size() == 0 || dims.size() == 1) {
ch->macs = 1.f * in_sum;
} else if (dims.size() == 2) {
ch->macs = 2.f * in_sum;
} else {
LOG(FATAL) << "This dims size of ReduceProd: " << dims.size()
<< " doesn't support";
ch->macs = 0.f;
}
}
#endif
private:
mutable ReduceParam param_;
};
......
......@@ -18,6 +18,9 @@
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
#ifdef LITE_WITH_PROFILE
#include "lite/api/paddle_place.h"
#endif
namespace paddle {
namespace lite {
......@@ -35,8 +38,61 @@ class ReluOp : public OpLite {
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "relu"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto input_dims = param_.X->dims();
auto output_dims = param_.Out->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
ch->remark = ActivationTypeToStr(param_.active_type);
switch (param_.active_type) {
case lite_api::ActivationType::kRelu:
ch->macs = param_.X->numel();
break;
case lite_api::ActivationType::kRelu6:
ch->macs = param_.X->numel() * 2.0;
break;
case lite_api::ActivationType::kLeakyRelu:
ch->macs = param_.X->numel() * 2.0;
break;
case lite_api::ActivationType::kPRelu:
ch->macs = param_.X->numel() * 2.0;
break;
case lite_api::ActivationType::kSwish:
ch->macs = param_.X->numel() * 4.0;
break;
case lite_api::ActivationType::kSigmoid:
ch->macs = param_.X->numel() * 3.0;
break;
case lite_api::ActivationType::kTanh:
ch->macs = param_.X->numel() * 5.0;
break;
case lite_api::ActivationType::kExp:
ch->macs = param_.X->numel();
break;
case lite_api::ActivationType::kAbs:
ch->macs = param_.X->numel();
break;
case lite_api::ActivationType::kHardSwish:
ch->macs = param_.X->numel() * 5.0;
break;
case lite_api::ActivationType::kReciprocal:
ch->macs = param_.X->numel();
break;
case lite_api::ActivationType::kIndentity:
break;
default:
LOG(FATAL) << "This Type of Activation:"
<< static_cast<int>(param_.active_type)
<< ActivationTypeToStr(param_.active_type)
<< " doesn't support";
}
}
#endif
private:
mutable ActivationParam param_;
};
......
......@@ -37,6 +37,15 @@ class ReshapeOp : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "reshape"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto input_dims = param_.x->dims();
auto output_dims = param_.output->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
}
#endif
protected:
mutable ReshapeParam param_;
};
......
......@@ -35,8 +35,19 @@ class ScaleOp : public OpLite {
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "scale"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.x->dims());
ch->output_shape = ch->DimToStr(param_.output->dims());
ch->remark =
param_.activation_type + "alpha" + std::to_string(param_.alpha);
ch->macs = param_.x->numel() * 1.f;
}
#endif
private:
mutable ScaleParam param_;
};
......
......@@ -27,17 +27,48 @@ class SearchAlignedMatMulOpLite : public OpLite {
public:
SearchAlignedMatMulOpLite() {}
explicit SearchAlignedMatMulOpLite(const std::string &type) : OpLite(type) {}
explicit SearchAlignedMatMulOpLite(const std::string& type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override;
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;
std::string DebugString() const override { return "search_aligned_mat_mul"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter* ch) {
ch->input_shape = ch->DimToStr(param_.X->dims());
ch->filter_shape = ch->DimToStr(param_.Y->dims());
ch->output_shape = ch->DimToStr(param_.Out->dims());
ch->remark = "alpha" + std::to_string(param_.alpha) + "trans_x" +
std::to_string(param_.transpose_X) + "trans_y" +
std::to_string(param_.transpose_Y);
const auto x_dims = param_.X->dims();
const auto y_dims = param_.Y->dims();
const auto& x_lod = param_.X->lod();
const auto& y_lod = param_.Y->lod();
const auto& x_lod_0 = x_lod[0];
const auto& y_lod_0 = y_lod[0];
int x_inner_size = x_dims[1];
int y_inner_size = y_dims[1];
int x_batch_size = x_lod_0[1];
int y_batch_size = y_lod_0[1];
int M = param_.transpose_X ? x_inner_size : x_batch_size;
int N = param_.transpose_Y ? y_batch_size : y_inner_size;
int X_K = param_.transpose_X ? x_batch_size : x_inner_size;
int Y_K = param_.transpose_Y ? y_inner_size : y_batch_size;
CHECK_EQ(X_K, Y_K) << "K of Input(X) and Input(Y) is not equal";
int K = X_K;
ch->macs = 2.0 * M * N * K;
}
#endif
private:
mutable MatMulParam param_;
};
......
......@@ -35,8 +35,21 @@ class SearchFcOpLite : public OpLite {
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "search_fc"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.X->dims());
ch->filter_shape = ch->DimToStr(param_.W->dims());
ch->output_shape = ch->DimToStr(param_.Out->dims());
ch->remark = "out_size" + std::to_string(param_.out_size);
auto x_dims = param_.X->dims();
auto w_dims = param_.W->dims();
ch->macs = 2.f * x_dims[0] * x_dims[1] * w_dims[0];
}
#endif
private:
mutable SearchFcParam param_;
};
......
......@@ -36,8 +36,21 @@ class SearchSeqFcOpLite : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;
std::string DebugString() const override { return "search_seq_fc"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.x->dims());
ch->filter_shape = ch->DimToStr(param_.w->dims());
ch->output_shape = ch->DimToStr(param_.out->dims());
ch->remark = "out_size" + std::to_string(param_.out_size);
auto x_dims = param_.x->dims();
auto w_dims = param_.w->dims();
ch->macs = 2.f * x_dims[0] * x_dims[1] * w_dims[0];
}
#endif
private:
mutable SearchSeqFcParam param_;
};
......
......@@ -36,8 +36,20 @@ class SearchSeqSoftmaxOp : public OpLite {
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "search_seq_softmax_op"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto input_dims = param_.x->dims();
auto output_dims = param_.output->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
ch->remark = "axis" + std::to_string(param_.axis);
ch->macs = 4.f * param_.x->numel();
}
#endif
private:
mutable SoftmaxParam param_;
};
......
......@@ -37,6 +37,17 @@ class SoftmaxOp : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "softmax"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto input_dims = param_.x->dims();
auto output_dims = param_.output->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
ch->remark = "axis" + std::to_string(param_.axis);
ch->macs = 2.f * input_dims.production() * 3;
}
#endif
private:
mutable SoftmaxParam param_;
};
......
......@@ -37,6 +37,15 @@ class SqueezeOp : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "squeeze"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto input_dims = param_.X->dims();
auto output_dims = param_.Out->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
}
#endif
protected:
mutable SqueezeParam param_;
};
......@@ -54,6 +63,15 @@ class Squeeze2Op : public SqueezeOp {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "squeeze2"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto input_dims = param_.X->dims();
auto output_dims = param_.Out->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
}
#endif
};
} // namespace operators
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "lite/core/context.h"
#include "lite/core/profile/timer.h"
#include "lite/operators/op_params.h"
#include "lite/tests/utils/naive_math_impl.h"
#include "lite/tests/utils/tensor_utils.h"
#ifdef LITE_WITH_ARM
#include "lite/kernels/arm/deformable_conv_compute.h"
#endif // LITE_WITH_ARM
DEFINE_int32(power_mode,
3,
"power mode: "
"0 for POWER_HIGH;"
"1 for POWER_LOW;"
"2 for POWER_FULL;"
"3 for NO_BIND");
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, false, "do all tests");
DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(batch, 1, "batch size");
DEFINE_int32(in_channel, 32, "input channel");
DEFINE_int32(in_height, 112, "input height");
DEFINE_int32(in_width, 112, "input width");
DEFINE_int32(out_channel, 32, "output channel");
DEFINE_int32(group, 1, "group");
DEFINE_int32(kernel_h, 3, "kernel height");
DEFINE_int32(kernel_w, 3, "kernel width");
DEFINE_int32(pad_h, 1, "pad height");
DEFINE_int32(pad_w, 1, "pad width");
DEFINE_int32(stride_h, 1, "stride height");
DEFINE_int32(stride_w, 1, "stride width");
DEFINE_int32(dila_h, 1, "dilation height");
DEFINE_int32(dila_w, 1, "dilation width");
DEFINE_int32(flag_act,
0,
"do activation"); // 0-no act, 1-relu, 2-relu6, 4-leakyrelu
DEFINE_double(leakey_relu_alpha, 1.0, "leakey relu alpha");
DEFINE_bool(flag_bias, true, "with bias");
typedef paddle::lite::DDim DDim;
typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::operators::DeformableConvParam DeformableConvParam;
typedef paddle::lite::operators::ActivationParam ActivationParam;
using paddle::lite::profile::Timer;
DDim compute_out_dim(const DDim& dim_in,
const paddle::lite::operators::ConvParam& param) {
DDim dim_out = dim_in;
auto paddings = *param.paddings;
auto dilations = *param.dilations;
dim_out[1] = param.filter->dims()[0];
auto kernel_h = param.filter->dims()[2];
auto kernel_w = param.filter->dims()[3];
auto h = dim_in[2];
auto w = dim_in[3];
int dila_h = dilations[0];
int dila_w = dilations[1];
int pad_top = paddings[0];
int pad_bottom = paddings[1];
int pad_left = paddings[2];
int pad_right = paddings[3];
int stride_h = param.strides[0];
int stride_w = param.strides[1];
auto kernel_exten = dila_h * (kernel_h - 1) + 1;
auto hout = (h + pad_top + pad_bottom - kernel_exten) / stride_h + 1;
kernel_exten = dila_w * (kernel_w - 1) + 1;
auto wout = (w + pad_left + pad_right - kernel_exten) / stride_w + 1;
dim_out[2] = hout;
dim_out[3] = wout;
return dim_out;
}
#ifdef LITE_WITH_ARM
void test_deformable_conv_fp32(const std::vector<DDim>& input_dims,
const DDim& weight_dim,
int group,
const std::vector<int>& strides,
const std::vector<int>& pads,
const std::vector<int>& dilas,
bool flag_bias,
bool flag_relu,
bool modulated,
const std::vector<int>& thread_num,
const std::vector<int>& power_mode,
const float leakey_relu_scale) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
DeformableConvParam param;
param.x = new Tensor;
param.x->set_precision(PRECISION(kFloat));
param.conv_param.filter = new Tensor;
param.conv_param.filter->Resize(weight_dim);
param.conv_param.filter->set_precision(PRECISION(kFloat));
param.offset = new Tensor;
param.offset->set_precision(PRECISION(kFloat));
param.mask = new Tensor;
param.mask->set_precision(PRECISION(kFloat));
if (flag_bias) {
param.conv_param.bias = new Tensor;
param.conv_param.bias->Resize({weight_dim[0]});
param.conv_param.bias->set_precision(PRECISION(kFloat));
}
param.conv_param.strides = strides;
param.conv_param.paddings = std::make_shared<std::vector<int>>(pads);
param.conv_param.dilations = std::make_shared<std::vector<int>>(dilas);
param.conv_param.groups = group;
param.deformable_groups = group;
param.modulated = modulated;
const float six = 6.f;
int flag_act = flag_relu ? 1 : 0;
if (flag_act > 0) {
ActivationParam act_param;
act_param.has_active = true;
act_param.active_type = (paddle::lite_api::ActivationType)
flag_act; // 1-relu, 2-relu6, 4-leakyrelu
if (flag_act == 1) {
param.conv_param.fuse_relu = true;
} else if (flag_act == 2) {
act_param.Relu_clipped_coef = six;
} else if (flag_act == 4) {
act_param.Leaky_relu_alpha = leakey_relu_scale;
}
param.conv_param.activation_param = act_param;
}
param.output = new Tensor;
param.output->set_precision(PRECISION(kFloat));
paddle::lite::fill_tensor_rand(*param.conv_param.filter, -1.f, 1.f);
// paddle::lite::fill_tensor_const(*param.filter, 1.f);
if (flag_bias) {
paddle::lite::fill_tensor_rand(*param.conv_param.bias, -1.f, 1.f);
// paddle::lite::fill_tensor_const(*param.bias, 1.f);
}
auto wptr = param.conv_param.filter->data<float>();
auto bias_ptr = flag_bias ? param.conv_param.bias->data<float>() : nullptr;
for (auto& cls : power_mode) {
for (auto& th : thread_num) {
paddle::lite::kernels::arm::DeformableConvCompute<PRECISION(kFloat),
PRECISION(kFloat)>
deformableConv;
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), th);
/// set param and context
for (auto& dim_in : input_dims) {
param.x->Resize(dim_in);
DDim out_tmp_dims = compute_out_dim(dim_in, param.conv_param);
if (out_tmp_dims[2] < 1 || out_tmp_dims[3] < 1) {
continue;
}
param.output->Resize(out_tmp_dims);
break;
}
deformableConv.SetParam(param);
deformableConv.SetContext(std::move(ctx1));
/// prepare for run
deformableConv.PrepareForRun();
for (auto& dim_in : input_dims) {
CHECK_EQ(weight_dim[1] * group, dim_in[1])
<< "input channel must equal to weights channel";
DDim dim_out = compute_out_dim(dim_in, param.conv_param);
int num = dim_in[0];
int in_size = dim_in[2] * dim_in[3];
int kernel_size = weight_dim[2] * weight_dim[3];
param.offset->Resize(
{num, 2 * group * kernel_size, dim_in[2], dim_in[3]});
param.mask->Resize({num, group * kernel_size, dim_in[2], dim_in[3]});
paddle::lite::fill_tensor_rand(*param.offset, -1.f, 1.f);
paddle::lite::fill_tensor_rand(*param.mask, -1.f, 1.f);
if (dim_out[2] < 1 || dim_out[3] < 1) {
continue;
}
if (dim_out[2] != dim_in[2] || dim_out[3] != dim_in[3]) {
continue;
}
param.x->Resize(dim_in);
param.output->Resize(dim_out);
paddle::lite::fill_tensor_rand(*param.x, -1.f, 1.f);
// paddle::lite::fill_tensor_const(*param.x, 1.f);
auto din = param.x->data<float>();
Tensor tout_basic;
if (FLAGS_check_result) {
auto offset_data = param.offset->data<float>();
auto mask_data = param.mask->data<float>();
tout_basic.set_precision(PRECISION(kFloat));
tout_basic.Resize(dim_out);
fill_tensor_const(tout_basic, 0.f);
auto dout_basic = tout_basic.mutable_data<float>();
LOG(INFO) << "flag_relu: " << flag_relu;
deformable_conv_basic<float, float>(din,
offset_data,
mask_data,
dout_basic,
dim_in[0],
dim_out[1],
dim_out[2],
dim_out[3],
dim_in[1],
dim_in[2],
dim_in[3],
wptr,
bias_ptr,
group,
weight_dim[3],
weight_dim[2],
strides[1],
strides[0],
dilas[1],
dilas[0],
pads[2],
pads[0],
flag_bias,
flag_relu,
modulated);
}
/// warm up
for (int i = 0; i < FLAGS_warmup; ++i) {
deformableConv.Launch();
}
/// compute
Timer t0;
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.Start();
deformableConv.Launch();
t0.Stop();
}
double gops = 2.0 * dim_out.production() * dim_in[1] * weight_dim[2] *
weight_dim[3] / param.conv_param.groups;
LOG(INFO) << "deformable conv fp32: input shape: " << dim_in
<< ", output shape" << dim_out
<< ",running time, avg: " << t0.LapTimes().Avg()
<< ", min time: " << t0.LapTimes().Min()
<< ", total GOPS: " << 1e-9 * gops
<< " GOPS, avg GOPs: " << 1e-6 * gops / t0.LapTimes().Avg()
<< " GOPs, max GOPs: " << 1e-6 * gops / t0.LapTimes().Min();
if (FLAGS_check_result) {
double max_ratio = 0;
double max_diff = 0;
tensor_cmp_host(tout_basic, *param.output, max_ratio, max_diff);
LOG(INFO) << "compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (std::abs(max_ratio) > 1e-3f) {
if (max_diff > 5e-4f) {
LOG(WARNING) << "weights data";
print_tensor(*param.conv_param.filter);
LOG(WARNING) << "basic result";
print_tensor(tout_basic);
LOG(WARNING) << "lite result";
print_tensor(*param.output);
Tensor tdiff;
tdiff.Resize(tout_basic.dims());
tdiff.set_precision(PRECISION(kFloat));
tensor_diff(tout_basic, *param.output, tdiff);
print_tensor(tdiff);
LOG(FATAL) << "test fp32 deformable conv: input: " << dim_in
<< ", output: " << dim_out
<< ", weight dim: " << weight_dim
<< ", pad: " << pads[0] << ", " << pads[1] << ", "
<< pads[2] << ", " << pads[3]
<< ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false")
<< ", modulated: " << (modulated ? "V2" : "V1")
<< ", threads: " << th << ", power_mode: " << cls
<< " failed!!\n";
}
}
}
LOG(INFO) << "test fp32 deformable conv: input: " << dim_in
<< ", output: " << dim_out << ", weight dim: " << weight_dim
<< ", pad: " << pads[0] << ", " << pads[1] << ", " << pads[2]
<< ", " << pads[3] << ", stride: " << strides[0] << ", "
<< strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false")
<< ", modulated: " << (modulated ? "V2" : "V1")
<< ", threads: " << th << ", power_mode: " << cls
<< " successed!!\n";
}
}
}
delete param.x;
delete param.conv_param.filter;
delete param.offset;
delete param.mask;
delete param.output;
delete param.conv_param.bias;
}
#else
void test_deformable_conv_fp32(const std::vector<DDim>& input_dims,
const DDim& weight_dim,
int group,
const std::vector<int>& strides,
const std::vector<int>& pads,
const std::vector<int>& dilas,
bool flag_bias,
bool flag_relu,
bool modulated,
const std::vector<int>& thread_num,
const std::vector<int>& power_mode,
const float leakey_relu_scale) {}
#endif // LITE_WITH_ARM
#if 1 /// random param conv
TEST(TestDeformableConvRand, test_deformable_conv_rand) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8}) {
for (auto& cout : {1, 5, 16}) {
for (auto& g : {1, 2}) {
for (auto& kw : {1, 2, 3}) {
for (auto& kh : {1, 2, 3}) {
for (auto& stride : {1, 2}) {
for (auto& pad_h : {0, 1, 2}) {
for (auto& pad_w : {0, 1, 2}) {
for (auto& dila : {1, 2}) {
for (auto& modulated : {false, true}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_act : {0, 1}) {
if (cin % g != 0 || cout % g != 0) {
continue;
}
std::vector<DDim> dims;
DDim weights_dim({cout, cin / g, kh, kw});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 16, 19, 32, 64}) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
const float leakey_relu_scale = 8.88;
test_deformable_conv_fp32(
dims,
weights_dim,
g,
{stride, stride},
{pad_h, pad_h, pad_w, pad_w},
{dila, dila},
flag_bias,
flag_act,
modulated,
{1},
{FLAGS_power_mode},
leakey_relu_scale);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
#endif /// random param conv
#if 1 /// custom
TEST(TestDeformableConvCustom, test_deformable_conv_fp32_custom_size) {
CHECK_EQ(FLAGS_in_channel % FLAGS_group, 0)
<< "input channel must be divided by group";
CHECK_EQ(FLAGS_out_channel % FLAGS_group, 0)
<< "num_output must be divided by group";
test_deformable_conv_fp32(
{DDim({FLAGS_batch, FLAGS_in_channel, FLAGS_in_height, FLAGS_in_width})},
DDim({FLAGS_out_channel,
FLAGS_in_channel / FLAGS_group,
FLAGS_kernel_h,
FLAGS_kernel_w}),
FLAGS_group,
{FLAGS_stride_h, FLAGS_stride_w},
{FLAGS_pad_h, FLAGS_pad_h, FLAGS_pad_w, FLAGS_pad_w},
{FLAGS_dila_h, FLAGS_dila_w},
FLAGS_flag_bias,
FLAGS_flag_act,
true,
{FLAGS_threads},
{FLAGS_power_mode},
FLAGS_leakey_relu_alpha);
}
#endif // custom
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册