提交 c2f90b59 编写于 作者: N NazgulLee 提交者: Yanzhan Yang

opencl support pad2d, instancenorm, convtranspose and tanh. test=develop (#1940)

上级 23bebdb9
......@@ -19,6 +19,7 @@ namespace paddle_mobile {
const char *G_OP_TYPE_CONV = "conv2d";
const char *G_OP_TYPE_BATCHNORM = "batch_norm";
const char *G_OP_TYPE_INSTANCENORM = "instance_norm";
const char *G_OP_TYPE_BOX_CODER = "box_coder";
const char *G_OP_TYPE_CONCAT = "concat";
const char *G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add";
......@@ -153,6 +154,7 @@ std::unordered_map<
{G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_POOL2D, {{"X"}, {"Out"}}},
{G_OP_TYPE_BATCHNORM, {{"X"}, {"Y"}}},
{G_OP_TYPE_INSTANCENORM, {{"X"}, {"Out"}}},
{G_OP_TYPE_LRN, {{"X"}, {"Out"}}},
{G_OP_TYPE_CONCAT, {{"X"}, {"Out"}}},
{G_OP_TYPE_SPLIT, {{"X"}, {"Out"}}},
......
......@@ -159,6 +159,7 @@ enum ARMArch {
extern const char *G_OP_TYPE_CONV;
extern const char *G_OP_TYPE_BATCHNORM;
extern const char *G_OP_TYPE_INSTANCENORM;
extern const char *G_OP_TYPE_BOX_CODER;
extern const char *G_OP_TYPE_CONCAT;
extern const char *G_OP_TYPE_ELEMENTWISE_ADD;
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include "CL/cl.h"
#include "common/enforce.h"
......@@ -27,6 +28,36 @@ limitations under the License. */
namespace paddle_mobile {
namespace framework {
class CLLocalWorkSizeInfo {
public:
CLLocalWorkSizeInfo() {
max_work_group_size = 0;
max_work_item_size0 = 0;
max_work_item_size1 = 0;
max_work_item_size2 = 0;
}
CLLocalWorkSizeInfo(size_t total_size, size_t size0, size_t size1,
size_t size2) {
max_work_group_size = total_size;
max_work_item_size0 = size0;
max_work_item_size1 = size1;
max_work_item_size2 = size2;
}
bool isEmpty() {
return max_work_group_size == 0 && max_work_item_size0 == 0 &&
max_work_item_size1 == 0 && max_work_item_size2 == 0;
}
// max total number of work-items in the work-group
size_t max_work_group_size;
// max number of work-items in local_work_size in dim 0
size_t max_work_item_size0;
// max number of work-items in local_work_size in dim 1
size_t max_work_item_size1;
// max number of work-items in local_work_size in dim 2
size_t max_work_item_size2;
};
class CLEngine {
public:
static CLEngine *Instance();
......@@ -66,6 +97,43 @@ class CLEngine {
return command_queue_.get();
}
CLLocalWorkSizeInfo getLocalWorkSizeInfo() {
if (!localWorkSizeInfo_.isEmpty()) {
return localWorkSizeInfo_;
}
cl_int status;
size_t max_work_group_size = 0;
status = clGetDeviceInfo(devices_[0], CL_DEVICE_MAX_WORK_GROUP_SIZE,
sizeof(size_t), &max_work_group_size, NULL);
if (status != CL_SUCCESS) {
return CLLocalWorkSizeInfo(0, 0, 0, 0);
}
cl_uint max_dims_num = 0;
status = clGetDeviceInfo(devices_[0], CL_DEVICE_MAX_WORK_ITEM_DIMENSIONS,
sizeof(cl_uint), &max_dims_num, NULL);
if (status != CL_SUCCESS) {
return CLLocalWorkSizeInfo(0, 0, 0, 0);
}
DLOG << "max_work_item_sizes max_dims_num: " << max_dims_num;
size_t *max_work_item_sizes =
reinterpret_cast<size_t *>(calloc(max_dims_num, sizeof(size_t)));
size_t ret_size = 0;
status = clGetDeviceInfo(devices_[0], CL_DEVICE_MAX_WORK_ITEM_SIZES,
max_dims_num * sizeof(size_t), max_work_item_sizes,
&ret_size);
if (status != CL_SUCCESS || ret_size / sizeof(size_t) < 3) {
return CLLocalWorkSizeInfo(0, 0, 0, 0);
}
DLOG << max_work_item_sizes[0];
DLOG << max_work_item_sizes[1];
DLOG << max_work_item_sizes[2];
localWorkSizeInfo_ =
CLLocalWorkSizeInfo(max_work_group_size, max_work_item_sizes[0],
max_work_item_sizes[1], max_work_item_sizes[2]);
free(max_work_item_sizes);
return localWorkSizeInfo_;
}
std::unique_ptr<_cl_program, CLProgramDeleter> CreateProgramWith(
cl_context context, std::string file_name) {
FILE *file = fopen(file_name.c_str(), "rb");
......@@ -127,7 +195,7 @@ class CLEngine {
CL_CHECK_ERRORS(status);
if (status_ == CL_BUILD_PROGRAM_FAILURE) {
if (status == CL_BUILD_PROGRAM_FAILURE) {
size_t log_size;
clGetProgramBuildInfo(program, CLEngine::Instance()->DeviceID(),
CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
......@@ -158,6 +226,8 @@ class CLEngine {
bool initialized_;
CLLocalWorkSizeInfo localWorkSizeInfo_;
cl_platform_id platform_;
cl_device_id *devices_;
......
......@@ -14,8 +14,10 @@ limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "common/log.h"
......@@ -49,6 +51,10 @@ class CLHelper {
cl_context CLContext() { return scope_->Context(); }
CLLocalWorkSizeInfo LocalWorkSizeInfo() {
return scope_->LocalWorkSizeInfo();
}
std::vector<size_t> DefaultWorkSize(const CLImage &image) {
// n c h w
auto image_dim = image.dims();
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <memory>
#include <vector>
#include "CL/cl.h"
......@@ -168,7 +169,7 @@ class CLImage {
}
void InitWithExitedMem(cl_context context, cl_command_queue command_queue,
DDim need_dims, CLImage &src) {
DDim need_dims, const CLImage &src) {
CLImageConverterNormal *normal_converter = new CLImageConverterNormal();
real_image_dims = normal_converter->InitImageDimInfoWith(src.dims());
......@@ -188,6 +189,15 @@ class CLImage {
DLOG << " end init cl image";
}
void InitConv2dTransposeFilterCLImage(cl_context context,
cl_command_queue command_queue) {
PADDLE_MOBILE_ENFORCE(tensor_data_ != nullptr,
" need call SetTensorData first");
CLImageConverterConv2dTransposeTransWeight *converter =
new CLImageConverterConv2dTransposeTransWeight();
InitCLImage(context, command_queue, converter);
}
/*! The internal of two tensors share the same memory block. */
inline CLImage &ShareHolderWith(const CLImage &src) {
PADDLE_MOBILE_ENFORCE(
......
......@@ -448,5 +448,68 @@ void CLImageConverterWinoTransWeight::ImageToNCHW(half_t *image, float *tensor,
const DDim &image_dim,
const DDim &tensor_dim) {}
const DDim &CLImageConverterConv2dTransposeTransWeight::InitImageDimInfoWith(
const DDim &tensor_dim) {
size_t new_dims[] = {1, 1, 1, 1};
for (int j = 0; j < tensor_dim.size(); ++j) {
new_dims[4 - tensor_dim.size() + j] = tensor_dim[j];
}
size_t N, C, H, W;
C = new_dims[0];
N = new_dims[1];
H = new_dims[2];
W = new_dims[3];
size_t width = W * ((C + 3) / 4);
size_t height = H * N;
return make_ddim({width, height});
}
// it is actually CNHW to Image, because conv2d_transpose's filter is CNHW
void CLImageConverterConv2dTransposeTransWeight::NCHWToImage(
float *nchw, half_t *image, const DDim &tensor_dim) {
size_t new_dims[] = {1, 1, 1, 1};
for (int j = 0; j < tensor_dim.size(); ++j) {
new_dims[4 - tensor_dim.size() + j] = tensor_dim[j];
}
size_t N, C, H, W;
C = new_dims[0];
N = new_dims[1];
H = new_dims[2];
W = new_dims[3];
DDim in_image_dim = InitImageDimInfoWith(tensor_dim);
DLOG << " tensor dim " << tensor_dim;
DLOG << " image dim " << in_image_dim;
size_t width = in_image_dim[0];
size_t height = in_image_dim[1];
int w_block = width / W;
float *p = nchw;
int realC = w_block * 4;
for (int c = 0; c < realC; c++) {
for (int n = 0; n < N; n++) {
for (int h = 0; h < H; h++) {
for (int w = 0; w < W; w++) {
int index = (n * H + h) * width * 4 + (c / 4) * 4 * W + w * 4 + c % 4;
if (c < C) {
image[index] = Float2Half(*p);
p++;
} else {
image[index] = 0;
}
}
}
}
}
}
void CLImageConverterConv2dTransposeTransWeight::ImageToNCHW(
half_t *image, float *tensor, const DDim &image_dim,
const DDim &tensor_dim) {}
} // namespace framework
} // namespace paddle_mobile
......@@ -109,5 +109,13 @@ class CLImageConverterWinoTransWeight : public CLImageConverterBase {
const DDim &tensor_dim);
};
class CLImageConverterConv2dTransposeTransWeight : public CLImageConverterBase {
public:
const DDim &InitImageDimInfoWith(const DDim &tensor_dim);
void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim);
void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim,
const DDim &tensor_dim);
};
} // namespace framework
} // namespace paddle_mobile
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "CL/cl.h"
......@@ -38,6 +39,7 @@ class CLScope {
CLEngine *engine = CLEngine::Instance();
context_ = engine->getContext();
command_queue_ = engine->getClCommandQueue();
localWorkSizeInfo_ = engine->getLocalWorkSizeInfo();
}
cl_command_queue CommandQueue() { return command_queue_; }
......@@ -101,6 +103,8 @@ class CLScope {
return programs_[program_key].get();
}
CLLocalWorkSizeInfo LocalWorkSizeInfo() { return localWorkSizeInfo_; }
private:
cl_int status_;
cl_context context_;
......@@ -108,6 +112,7 @@ class CLScope {
std::unordered_map<std::string,
std::unique_ptr<_cl_program, CLProgramDeleter>>
programs_;
CLLocalWorkSizeInfo localWorkSizeInfo_;
};
} // namespace framework
......
......@@ -70,6 +70,9 @@ LOAD_OP2(fill_constant, CPU, FPGA)
#ifdef BATCHNORM_OP
LOAD_OP2(batch_norm, CPU, GPU_CL);
#endif
#ifdef INSTANCENORM_OP
LOAD_OP1(instance_norm, GPU_CL);
#endif
#ifdef BILINEAR_INTERP_OP
LOAD_OP1(bilinear_interp, CPU);
#endif
......@@ -159,6 +162,9 @@ LOAD_OP2(elementwise_add, CPU, GPU_CL);
#ifdef PRELU_OP
LOAD_OP1(prelu, CPU);
#endif
#ifdef TANH_OP
LOAD_OP2(tanh, CPU, GPU_CL);
#endif
#ifdef FLATTEN_OP
LOAD_OP1(flatten, CPU);
#endif
......
......@@ -80,6 +80,9 @@ REGISTER_OPERATOR_CL(sigmoid, ops::SigmoidOp);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(tanh, ops::TanhOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(tanh, ops::TanhOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(tanh, ops::TanhOp);
#endif
......
......@@ -29,4 +29,8 @@ REGISTER_OPERATOR_CPU(conv2d_transpose, ops::ConvOpTranspose);
REGISTER_OPERATOR_FPGA(conv2d_transpose, ops::ConvOpTranspose);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(conv2d_transpose, ops::ConvOpTranspose);
#endif
#endif
......@@ -47,6 +47,7 @@ class ConvOpTranspose : public framework::OperatorWithKernel<
std::vector<int> strides = this->param_.Strides();
std::vector<int> paddings = this->param_.Paddings();
std::vector<int> dilations = this->param_.Dilations();
std::vector<int> output_size = this->param_.OutputSize();
int groups = this->param_.Groups();
......@@ -73,11 +74,17 @@ class ConvOpTranspose : public framework::OperatorWithKernel<
"be equal to the number of filter's channels.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[1] * groups});
for (size_t i = 0; i < strides.size(); ++i) {
auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1;
output_shape.push_back((in_dims[i + 2] - 1) * strides[i] -
2 * paddings[i] + filter_extent);
if (output_size.size() == 2) {
output_shape.push_back(output_size[0]);
output_shape.push_back(output_size[1]);
} else {
for (size_t i = 0; i < strides.size(); ++i) {
auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1;
output_shape.push_back((in_dims[i + 2] - 1) * strides[i] -
2 * paddings[i] + filter_extent);
}
}
this->param_.Output()->Resize(framework::make_ddim(output_shape));
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef INSTANCENORM_OP
#include "operators/instancenorm_op.h"
#include "framework/op_proto_maker.h"
#include "framework/op_registry.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void InstanceNormOp<Dtype, T>::InferShape() const {
auto x_dims = this->param_.InputX()->dims();
this->param_.Out()->Resize(x_dims);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(instance_norm, ops::InstanceNormOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef INSTANCENORM_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/instancenorm_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using std::string;
template <typename DeviceType, typename T>
class InstanceNormOp
: public framework::OperatorWithKernel<DeviceType,
InstanceNormParam<DeviceType>,
InstanceNormKernel<DeviceType, T>> {
public:
InstanceNormOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs, framework::Scope *scope)
: framework::OperatorWithKernel<DeviceType, InstanceNormParam<DeviceType>,
InstanceNormKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const override;
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -27,8 +27,8 @@ bool Pad2DKernel<CPU, float>::Init(Pad2DParam<CPU> *param) {
template <>
void Pad2DKernel<CPU, float>::Compute(const Pad2DParam<CPU> &param) {
const auto *input = param.input_;
auto *output = param.output_;
const auto *input = param.InputX();
auto *output = param.Out();
const auto &paddings = param.paddings_;
// if (param.mode_ == "constant" && param.pad_value_ == 0) {
math::PadFunctor<CPU, float> pad;
......
......@@ -2026,10 +2026,10 @@ __kernel void conv_7x7(__private const int global_size_dim0,
out_nh >= global_size_dim2) {
return;
}
const filter_n0 = 4 * out_c + 0;
const filter_n1 = 4 * out_c + 1;
const filter_n2 = 4 * out_c + 2;
const filter_n3 = 4 * out_c + 3;
const int filter_n0 = 4 * out_c + 0;
const int filter_n1 = 4 * out_c + 1;
const int filter_n2 = 4 * out_c + 2;
const int filter_n3 = 4 * out_c + 3;
int2 stride_xy;
stride_xy.x = stride;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cl_common.h"
__kernel void conv_transpose(__private const int input_c_block,
__private const int input_width,/* of one block */
__private const int input_height,/* of one block */
__private const int output_width,
__private const int output_height,
__read_only image2d_t input_image,
__read_only image2d_t filter,
__write_only image2d_t output_image) {
const int out_c = get_global_id(0);
const int in_w = get_global_id(1);
const int in_nh = get_global_id(2);
const int n = in_nh / input_height;
const int h = in_nh % input_height;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
half4 input1, input2, input3, input4;
half4 output1 = 0.0f, output2 = 0.0f, output3 = 0.0f, output4 = 0.0f;
half4 w = 0.0f;
int2 pos_in;
for (int i = 0; i < input_c_block; i += 1) {
pos_in = (int2)(mad24(i, input_width, in_w), in_nh);
input1 = select(read_imageh(input_image, sampler,
(int2)(pos_in.x, pos_in.y)),
(half4)(0.0f),
(ushort4)((in_w < 0 || h < 0 || in_w >= input_width || h >= input_height) << 15));
input2 = select(read_imageh(input_image, sampler,
(int2)(pos_in.x + 1, pos_in.y)),
(half4)(0.0f),
(ushort4)((in_w + 1 < 0 || h < 0 || in_w + 1 >= input_width || h >= input_height) << 15));
input3 = select(read_imageh(input_image, sampler,
(int2)(pos_in.x, pos_in.y + 1)),
(half4)(0.0f),
(ushort4)((in_w < 0 || h + 1 < 0 || in_w >= input_width || h + 1 >= input_height) << 15));
input4 = select(read_imageh(input_image, sampler,
(int2)(pos_in.x + 1, pos_in.y + 1)),
(half4)(0.0f),
(ushort4)((in_w + 1 < 0 || h + 1 < 0 || in_w + 1 >= input_width || h + 1 >= input_height) << 15));
int wx = i * 3;
int wy = out_c * 4 * 3;
w = read_imageh(filter, sampler, (int2)(wx, wy));
output4.x += dot(input4, w);
w = read_imageh(filter, sampler, (int2)(wx + 1, wy));
output3.x += dot(input3, w);
w = read_imageh(filter, sampler, (int2)(wx + 2, wy));
output4.x += dot(input3, w);
w = read_imageh(filter, sampler, (int2)(wx, wy + 1));
output2.x += dot(input2, w);
w = read_imageh(filter, sampler, (int2)(wx + 1, wy + 1));
output1.x += dot(input1, w);
w = read_imageh(filter, sampler, (int2)(wx + 2, wy + 1));
output2.x += dot(input1, w);
w = read_imageh(filter, sampler, (int2)(wx, wy + 2));
output4.x += dot(input2, w);
w = read_imageh(filter, sampler, (int2)(wx + 1, wy + 2));
output3.x += dot(input1, w);
w = read_imageh(filter, sampler, (int2)(wx + 2, wy + 2));
output4.x += dot(input1, w);
wy = (out_c * 4 + 1) * 3;
w = read_imageh(filter, sampler, (int2)(wx, wy));
output4.y += dot(input4, w);
w = read_imageh(filter, sampler, (int2)(wx + 1, wy));
output3.y += dot(input3, w);
w = read_imageh(filter, sampler, (int2)(wx + 2, wy));
output4.y += dot(input3, w);
w = read_imageh(filter, sampler, (int2)(wx, wy + 1));
output2.y += dot(input2, w);
w = read_imageh(filter, sampler, (int2)(wx + 1, wy + 1));
output1.y += dot(input1, w);
w = read_imageh(filter, sampler, (int2)(wx + 2, wy + 1));
output2.y += dot(input1, w);
w = read_imageh(filter, sampler, (int2)(wx, wy + 2));
output4.y += dot(input2, w);
w = read_imageh(filter, sampler, (int2)(wx + 1, wy + 2));
output3.y += dot(input1, w);
w = read_imageh(filter, sampler, (int2)(wx + 2, wy + 2));
output4.y += dot(input1, w);
wy = (out_c * 4 + 2) * 3;
w = read_imageh(filter, sampler, (int2)(wx, wy));
output4.z += dot(input4, w);
w = read_imageh(filter, sampler, (int2)(wx + 1, wy));
output3.z += dot(input3, w);
w = read_imageh(filter, sampler, (int2)(wx + 2, wy));
output4.z += dot(input3, w);
w = read_imageh(filter, sampler, (int2)(wx, wy + 1));
output2.z += dot(input2, w);
w = read_imageh(filter, sampler, (int2)(wx + 1, wy + 1));
output1.z += dot(input1, w);
w = read_imageh(filter, sampler, (int2)(wx + 2, wy + 1));
output2.z += dot(input1, w);
w = read_imageh(filter, sampler, (int2)(wx, wy + 2));
output4.z += dot(input2, w);
w = read_imageh(filter, sampler, (int2)(wx + 1, wy + 2));
output3.z += dot(input1, w);
w = read_imageh(filter, sampler, (int2)(wx + 2, wy + 2));
output4.z += dot(input1, w);
wy = (out_c * 4 + 3) * 3;
w = read_imageh(filter, sampler, (int2)(wx, wy));
output4.w += dot(input4, w);
w = read_imageh(filter, sampler, (int2)(wx + 1, wy));
output3.w += dot(input3, w);
w = read_imageh(filter, sampler, (int2)(wx + 2, wy));
output4.w += dot(input3, w);
w = read_imageh(filter, sampler, (int2)(wx, wy + 1));
output2.w += dot(input2, w);
w = read_imageh(filter, sampler, (int2)(wx + 1, wy + 1));
output1.w += dot(input1, w);
w = read_imageh(filter, sampler, (int2)(wx + 2, wy + 1));
output2.w += dot(input1, w);
w = read_imageh(filter, sampler, (int2)(wx, wy + 2));
output4.w += dot(input2, w);
w = read_imageh(filter, sampler, (int2)(wx + 1, wy + 2));
output3.w += dot(input1, w);
w = read_imageh(filter, sampler, (int2)(wx + 2, wy + 2));
output4.w += dot(input1, w);
}
int2 pos_out = (int2)(out_c * output_width + 2 * in_w, n * output_height + 2 * h);
write_imageh(output_image, pos_out, output1);
write_imageh(output_image, (int2)(pos_out.x + 1, pos_out.y), output2);
write_imageh(output_image, (int2)(pos_out.x, pos_out.y + 1), output3);
write_imageh(output_image, (int2)(pos_out.x + 1, pos_out.y + 1), output4);
}
\ No newline at end of file
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void instancenorm(__private const int in_width,
__private const int in_height,
__private const int in_c_group,
__private const int local_work_size_x,
__private const int local_work_size_y,
__private const float epsilon,
__read_only image2d_t input,
__write_only image2d_t output) {
const int out_cn = get_global_id(0);
const int n = out_cn / in_c_group;
const int c = out_cn % in_c_group;
const int w = get_local_id(1);
const int h = get_local_id(2);
const int local_id = w * local_work_size_y + h;
const int local_total_size = local_work_size_x * local_work_size_y;
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__local float4 shared_mem[256];
float4 sum = 0.0f;
for (int xIndex = w; xIndex < in_width; xIndex += local_work_size_x) {
for (int yIndex = h; yIndex < in_height; yIndex += local_work_size_y) {
sum += read_imagef(input, sampler, (int2)(mad24(c, in_width, xIndex), mad24(n, in_height, yIndex)));
}
}
shared_mem[local_id] = sum;
barrier(CLK_LOCAL_MEM_FENCE);
sum = 0.0f;
if (local_id < 32) {
for (int i = local_id + 32; i < local_total_size; i += 32) {
sum += shared_mem[i];
}
}
shared_mem[local_id] += sum;
barrier(CLK_LOCAL_MEM_FENCE);
sum = 0.0f;
if (local_id == 0) {
int top = min(32, local_total_size);
for (int i = 0; i < top; i += 1) {
sum += shared_mem[i];
}
shared_mem[0] = sum / (in_width * in_height);
}
barrier(CLK_LOCAL_MEM_FENCE);
const float4 mean_val = shared_mem[0];
barrier(CLK_LOCAL_MEM_FENCE);
sum = 0.0f;
for (int xIndex = w; xIndex < in_width; xIndex += local_work_size_x) {
for (int yIndex = h; yIndex < in_height; yIndex += local_work_size_y) {
sum += pow(read_imagef(input, sampler, (int2)(mad24(c, in_width, xIndex), mad24(n, in_height, yIndex))) - mean_val, 2);
}
}
shared_mem[local_id] = sum;
barrier(CLK_LOCAL_MEM_FENCE);
sum = 0.0f;
if (local_id < 32) {
for (int i = local_id + 32; i < local_total_size; i += 32) {
sum += shared_mem[i];
}
}
shared_mem[local_id] += sum;
barrier(CLK_LOCAL_MEM_FENCE);
sum = 0.0f;
if (local_id == 0) {
int top = min(32, local_total_size);
for (int i = 0; i < top; i += 1) {
sum += shared_mem[i];
}
shared_mem[0] = sum / (in_width * in_height);
}
barrier(CLK_LOCAL_MEM_FENCE);
const float4 sigma = sqrt(shared_mem[0] + (float4)(epsilon));
float4 s = 1 / sigma;
for (int xIndex = w; xIndex < in_width; xIndex += local_work_size_x) {
for (int yIndex = h; yIndex < in_height; yIndex += local_work_size_y) {
int2 intout_pos = (int2)(mad24(c, in_width, xIndex), mad24(n, in_height, yIndex));
float4 in_val = read_imagef(input, sampler, intout_pos);
write_imageh(output, intout_pos, convert_half4((in_val - mean_val) * s));
}
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void pad2d(
__private const int in_height, __private const int in_width,
__private const int out_height, __private const int out_width,
__private const int pad_top, __private const int pad_bottom,
__private const int pad_left, __private const int pad_right,
__private const int mode, __private const float pad_value,
__read_only image2d_t input, __write_only image2d_t output) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const int out_n = out_nh / out_height;
const int out_h = out_nh % out_height;
int2 output_pos = (int2)(mad24(out_c, out_width, out_w), out_nh);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int x = out_w - pad_left;
int y = out_h - pad_top;
if (mode == 0) {
if (x < 0 || y < 0 || x >= in_width || y >= in_height) {
write_imageh(output, output_pos, (half4)(pad_value));
} else {
write_imageh(output, output_pos, read_imageh(input, sampler, (int2)(out_c * in_width + x, out_n * in_height + y)));
}
} else if (mode == 1) {
x = abs(x);
y = abs(y);
x = x < in_width ? x : 2 * in_width - 2 - x;
y = y < in_height ? y : 2 * in_height - 2 - y;
write_imageh(output, output_pos, read_imageh(input, sampler, (int2)(out_c * in_width + x, out_n * in_height + y)));
} else if (mode == 2) {
x = x > 0 ? x : 0;
x = x < in_width ? x : in_width - 1;
y = y > 0 ? y : 0;
y = y < in_height ? y : in_height - 1;
write_imageh(output, output_pos, read_imageh(input, sampler, (int2)(out_c * in_width + x, out_n * in_height + y)));
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void tanh_kernel(__read_only image2d_t input,
__write_only image2d_t output){
const int x = get_global_id(0);
const int y = get_global_id(1);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
half4 in = read_imageh(input, sampler, (int2)(x, y));
write_imageh(output, (int2)(x, y), tanh(in));
}
......@@ -85,7 +85,15 @@ bool ConvKernel<GPU_CL, float>::Init(ConvParam<GPU_CL> *param) {
this->cl_helper_.AddKernel("conv_3x3", conv_kernel_file);
// }
DLOG << "conv 3x3";
} else if (param->Filter()->dims()[2] == 7 &&
param->Filter()->dims()[3] == 7) {
param->ExecMode() = ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW7x7_FLOAT;
param->Filter()->InitCLImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_7x7", conv_kernel_file);
// }
DLOG << "conv 7x7";
} else {
PADDLE_MOBILE_THROW_EXCEPTION(" not support ");
}
......@@ -102,6 +110,7 @@ void ConvKernel<GPU_CL, float>::Compute(const ConvParam<GPU_CL> &param) {
case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW1x1_FLOAT:
case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW3x3_FLOAT:
case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3_FLOAT:
case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW7x7_FLOAT:
ConvAddBnRelu(&this->cl_helper_, param);
break;
case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3S1_FLOAT:
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CONV_TRANSPOSE_OP
#include "operators/kernel/conv_transpose_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvTransposeKernel<GPU_CL, float>::Init(
ConvTransposeParam<GPU_CL>* param) {
param->Filter()->InitConv2dTransposeFilterCLImage(
cl_helper_.CLContext(), cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_transpose", "conv_transpose.cl");
return true;
}
template <>
void ConvTransposeKernel<GPU_CL, float>::Compute(
const ConvTransposeParam<GPU_CL>& param) {
auto kernel = this->cl_helper_.KernelAt(0);
const auto* input = param.Input();
auto* output = param.Output();
auto* filter = param.Filter();
const int n = input->dims()[0];
const int input_c = input->dims()[1];
const int input_c_block = (input_c + 3) / 4;
const int input_width = input->dims()[3];
const int input_height = input->dims()[2];
const int output_c = output->dims()[1];
const int output_c_block = (output_c + 3) / 4;
const int output_width = output->dims()[3];
const int output_height = output->dims()[2];
auto inputImage = input->GetCLImage();
auto outputImage = output->GetCLImage();
auto filterImage = filter->GetCLImage();
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(int), &input_c_block);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(int), &input_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &input_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(int), &output_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(int), &output_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &inputImage);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(cl_mem), &filterImage);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(cl_mem), &outputImage);
CL_CHECK_ERRORS(status);
const size_t work_size[3] = {(size_t)output_c_block, (size_t)input_width,
(size_t)(n * input_height)};
DLOG << "conv transpose " << input_c_block << input_width << input_height
<< output_width << output_height << work_size[0] << work_size[1]
<< work_size[2];
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
work_size, NULL, 0, NULL, NULL);
}
template class ConvTransposeKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef INSTANCENORM_OP
#include "operators/kernel/instancenorm_kernel.h"
#include <cmath>
namespace paddle_mobile {
namespace operators {
template <>
bool InstanceNormKernel<GPU_CL, float>::Init(InstanceNormParam<GPU_CL> *param) {
this->cl_helper_.AddKernel("instancenorm", "instancenorm_kernel.cl");
return true;
}
template <>
void InstanceNormKernel<GPU_CL, float>::Compute(
const InstanceNormParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto &dims = param.Out()->dims();
const int n = dims[0];
const int c_group = (dims[1] + 3) / 4;
const int h = dims[2];
const int w = dims[3];
auto epsilon = param.Epsilon();
auto input = param.InputX()->GetCLImage();
auto out = param.Out()->GetCLImage();
DLOG << "Epsilon: " << epsilon;
auto local_work_size_info = this->cl_helper_.LocalWorkSizeInfo();
DLOG << local_work_size_info.max_work_group_size;
DLOG << local_work_size_info.max_work_item_size0;
DLOG << local_work_size_info.max_work_item_size1;
DLOG << local_work_size_info.max_work_item_size2;
const int max_work_group_size =
std::min(256, static_cast<int>(local_work_size_info.max_work_group_size));
int local_work_size1 = 1;
int local_work_size2 = 1;
for (int i = 1; i <= local_work_size_info.max_work_item_size1 && i <= w;
i++) {
for (int j = 1; j <= local_work_size_info.max_work_item_size2 && j <= h;
j++) {
if (i * j <= max_work_group_size) {
if (i * j > local_work_size1 * local_work_size2) {
local_work_size1 = i;
local_work_size2 = j;
}
}
}
}
const size_t work_size[3] = {(size_t)(n * c_group), (size_t)local_work_size1,
(size_t)local_work_size2};
const size_t local_work_size[3] = {(size_t)1, (size_t)local_work_size1,
(size_t)local_work_size2};
DLOG << "work_size" << work_size[0] << " " << work_size[1] << " "
<< work_size[2];
DLOG << "local_work_size" << local_work_size[0] << " " << local_work_size[1]
<< " " << local_work_size[2];
cl_int status;
clSetKernelArg(kernel, 0, sizeof(cl_int), &w);
CL_CHECK_ERRORS(status);
clSetKernelArg(kernel, 1, sizeof(cl_int), &h);
CL_CHECK_ERRORS(status);
clSetKernelArg(kernel, 2, sizeof(cl_int), &c_group);
CL_CHECK_ERRORS(status);
clSetKernelArg(kernel, 3, sizeof(cl_int), &local_work_size1);
CL_CHECK_ERRORS(status);
clSetKernelArg(kernel, 4, sizeof(cl_int), &local_work_size2);
CL_CHECK_ERRORS(status);
clSetKernelArg(kernel, 5, sizeof(cl_float), &epsilon);
CL_CHECK_ERRORS(status);
clSetKernelArg(kernel, 6, sizeof(cl_mem), &input);
CL_CHECK_ERRORS(status);
clSetKernelArg(kernel, 7, sizeof(cl_mem), &out);
CL_CHECK_ERRORS(status);
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
work_size, local_work_size, 0, NULL, NULL);
}
template class InstanceNormKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PAD2D_OP
#include "operators/kernel/pad2d_kernel.h"
#include "framework/cl/cl_tensor.h"
namespace paddle_mobile {
namespace operators {
template <>
bool Pad2DKernel<GPU_CL, float>::Init(Pad2DParam<GPU_CL> *param) {
DLOG << "Init pad2d";
this->cl_helper_.AddKernel("pad2d", "pad2d_kernel.cl");
return true;
}
template <>
void Pad2DKernel<GPU_CL, float>::Compute(const Pad2DParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*(param.Out()));
cl_int status;
auto output = param.Out();
auto input = param.InputX();
auto output_image = output->GetCLImage();
auto input_image = input->GetCLImage();
const int out_H = output->dims()[2];
const int out_W = output->dims()[3];
const int input_H = input->dims()[2];
const int input_W = input->dims()[3];
const auto &paddings = param.paddings_;
const int pad_top = paddings[0];
const int pad_bottom = paddings[1];
const int pad_left = paddings[2];
const int pad_right = paddings[3];
const float pad_value = param.pad_value_;
const auto &modeStr = param.mode_;
int mode = 0;
if (modeStr == "reflect") {
mode = 1;
} else if (modeStr == "edge") {
mode = 2;
}
DLOG << "input_H: " << input_H;
status = clSetKernelArg(kernel, 0, sizeof(cl_int), &input_H);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(cl_int), &input_W);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(cl_int), &out_H);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_int), &out_W);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_int), &pad_top);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(cl_int), &pad_bottom);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(cl_int), &pad_left);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(cl_int), &pad_right);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 8, sizeof(cl_int), &mode);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 9, sizeof(cl_float), &pad_value);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 10, sizeof(cl_mem), &input_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 11, sizeof(cl_mem), &output_image);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}
template class Pad2DKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif // PAD2D_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TANH_OP
#include "operators/kernel/activation_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool TanhKernel<GPU_CL, float>::Init(TanhParam<GPU_CL>* param) {
this->cl_helper_.AddKernel("tanh_kernel", "tanh_kernel.cl");
return true;
}
template <>
void TanhKernel<GPU_CL, float>::Compute(const TanhParam<GPU_CL>& param) {
auto kernel = this->cl_helper_.KernelAt(0);
const auto* input = param.InputX();
auto* output = param.Out();
auto default_work_size = this->cl_helper_.DefaultWorkSize(*output);
auto inputImage = input->GetCLImage();
auto outputImage = output->GetCLImage();
clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputImage);
clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage);
const size_t work_size[2] = {input->ImageWidth(), input->ImageHeight()};
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2, NULL,
work_size, NULL, 0, NULL, NULL);
}
template class TanhKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -17,7 +17,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool Pad2DKernel<FPGA, float>::Init(Pad2DParam<FPGA> *param) {
Tensor *output = param->output_;
Tensor *output = param->Out();
fpga::format_fp16_ofm(output);
return true;
}
......@@ -40,8 +40,8 @@ void pad2dFunc(const framework::Tensor *input, framework::Tensor *output) {
}
template <>
void Pad2DKernel<FPGA, float>::Compute(const Pad2DParam<FPGA> &param) {
auto in_x = param.input_;
auto out = param.output_;
auto in_x = param.InputX();
auto out = param.Out();
fpga::fpga_invalidate((void *)in_x->data<half>(), // NOLINT
in_x->numel() * sizeof(half));
pad2dFunc(in_x, out);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef INSTANCENORM_OP
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class InstanceNormKernel
: public framework::OpKernelBase<DeviceType,
InstanceNormParam<DeviceType>> {
public:
void Compute(const InstanceNormParam<DeviceType> &param);
bool Init(InstanceNormParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -24,27 +24,27 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <typename Dtype>
class Pad2DParam : public OpParam {
public:
Pad2DParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
input_ = OpParam::GetVarValue<framework::LoDTensor>("X", inputs, *scope);
output_ =
OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, *scope);
paddings_ = OpParam::GetAttr<std::vector<int>>("paddings", attrs);
pad_value_ = OpParam::GetAttr<float>("pad_value", attrs);
mode_ = OpParam::GetStringAttr("mode", attrs);
}
public:
framework::LoDTensor *input_;
framework::LoDTensor *output_;
std::vector<int> paddings_;
float pad_value_;
std::string mode_;
};
// template <typename Dtype>
// class Pad2DParam : public OpParam {
// public:
// Pad2DParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
// const AttributeMap &attrs, Scope *scope)
// : OpParam(inputs, outputs, attrs, scope) {
// input_ = OpParam::GetVarValue<framework::LoDTensor>("X", inputs, *scope);
// output_ =
// OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, *scope);
// paddings_ = OpParam::GetAttr<std::vector<int>>("paddings", attrs);
// pad_value_ = OpParam::GetAttr<float>("pad_value", attrs);
// mode_ = OpParam::GetStringAttr("mode", attrs);
// }
//
// public:
// framework::LoDTensor *input_;
// framework::LoDTensor *output_;
// std::vector<int> paddings_;
// float pad_value_;
// std::string mode_;
//};
DECLARE_KERNEL(Pad2D, Pad2DParam);
......
......@@ -894,6 +894,35 @@ class BatchNormParam : public OpParam {
};
#endif
#ifdef INSTANCENORM_OP
template <typename Dtype>
class InstanceNormParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
InstanceNormParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
}
const GType *InputX() const { return input_x_; }
GType *Out() const { return out_; }
const float &Epsilon() const { return epsilon_; }
private:
GType *input_x_;
GType *out_;
float epsilon_;
};
#endif
#ifdef POOL_OP
template <typename Dtype>
class PoolParam : public OpParam {
......@@ -2472,12 +2501,16 @@ class ConvTransposeParam : public OpParam {
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
if (HasAttr("output_size", attrs)) {
output_size_ = GetAttr<vector<int>>("output_size", attrs);
DLOG << "conv transpose output size: " << output_size_;
}
groups = GetAttr<int>("groups", attrs);
}
const GType *Input() const { return input_; }
const GType *Filter() const { return filter_; }
GType *Filter() const { return filter_; }
GType *Output() const { return output_; }
......@@ -2487,6 +2520,8 @@ class ConvTransposeParam : public OpParam {
const vector<int> &Dilations() const { return dilations_; }
const vector<int> &OutputSize() const { return output_size_; }
const int &Groups() const { return groups; }
enum ExecMode {
......@@ -2505,6 +2540,7 @@ class ConvTransposeParam : public OpParam {
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
vector<int> output_size_;
int groups;
mutable enum ExecMode exec_mode_;
......@@ -3471,23 +3507,31 @@ class IncrementParam : public OpParam {
#endif // INCREMENT_OP
#ifdef PAD2D_OP
template <typename Dtype>
class Pad2dParam : public OpParam {
class Pad2DParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
Pad2dParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
Pad2DParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
paddings_ = OpParam::GetAttr<std::vector<int>>("paddings", attrs);
pad_value_ = OpParam::GetAttr<float>("pad_value", attrs);
mode_ = OpParam::GetStringAttr("mode", attrs);
DLOG << "mode" << mode_;
}
const RType *InputX() const { return input_x_; }
RType *Out() const { return out_; }
const GType *InputX() const { return input_x_; }
GType *Out() const { return out_; }
std::vector<int> paddings_;
float pad_value_;
std::string mode_;
private:
RType *input_x_;
RType *out_;
GType *input_x_;
GType *out_;
};
#endif
#ifdef EXP_OP
......
......@@ -20,14 +20,14 @@ namespace operators {
template <typename Dtype, typename T>
void Pad2DOp<Dtype, T>::InferShape() const {
auto input_dims = this->param_.input_->dims();
auto input_dims = this->param_.InputX()->dims();
const auto &paddings = this->param_.paddings_;
PADDLE_MOBILE_ENFORCE(paddings.size() == 4,
"Size of paddings should be equal to 4.");
input_dims[2] += paddings[0] + paddings[1];
input_dims[3] += paddings[2] + paddings[3];
this->param_.output_->Resize(input_dims);
this->param_.Out()->Resize(input_dims);
}
} // namespace operators
......@@ -40,5 +40,7 @@ REGISTER_OPERATOR_CPU(pad2d, ops::Pad2DOp);
#if defined(PADDLE_MOBILE_FPGA) || defined(PADDLE_MOBILE_FPGA_KD)
REGISTER_OPERATOR_FPGA(pad2d, ops::Pad2DOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(pad2d, ops::Pad2DOp);
#endif
#endif // PAD2D_OP
......@@ -282,6 +282,7 @@ if(NOT FOUND_MATCH)
message("--default--")
set(NORM_OP ON)
set(BATCHNORM_OP ON)
set(INSTANCENORM_OP ON)
set(CONV_TRANSPOSE_OP ON)
set(BOXCODER_OP ON)
set(CONCAT_OP ON)
......@@ -409,6 +410,9 @@ endif()
if (BATCHNORM_OP)
add_definitions(-DBATCHNORM_OP)
endif()
if (INSTANCENORM_OP)
add_definitions(-DINSTANCENORM_OP)
endif()
if (BOXCODER_OP)
add_definitions(-DBOXCODER_OP)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册