提交 1f2e88a9 编写于 作者: R Raman Sarokin 提交者: TensorFlower Gardener

ConvolutionTransposed3x3 generation changed to use storage type properties...

ConvolutionTransposed3x3 generation changed to use storage type properties instead of specific storage types.

PiperOrigin-RevId: 339967056
Change-Id: I5f6e192e21cbabf33bd152f91f4ab664b139fd14
上级 64edb2fb
...@@ -1356,6 +1356,7 @@ test_suite( ...@@ -1356,6 +1356,7 @@ test_suite(
"conv_buffer_1x1_test", "conv_buffer_1x1_test",
"conv_constants_test", "conv_constants_test",
"conv_powervr_test", "conv_powervr_test",
"convolution_transposed_3x3_test",
"convolution_transposed_3x3_thin_test", "convolution_transposed_3x3_thin_test",
"convolution_transposed_4x4_test", "convolution_transposed_4x4_test",
"convolution_transposed_test", "convolution_transposed_test",
......
...@@ -86,10 +86,6 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode( ...@@ -86,10 +86,6 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode(
args_.AddInt("padding_x"); args_.AddInt("padding_x");
args_.AddInt("padding_y"); args_.AddInt("padding_y");
const auto src_tensor_type = op_def.src_tensors[0].storage_type;
const bool manual_clamp = src_tensor_type == TensorStorageType::BUFFER ||
src_tensor_type == TensorStorageType::IMAGE_BUFFER;
const bool need_local_mem = const bool need_local_mem =
weights_upload_type == weights_upload_type ==
ConvolutionTransposed3x3::WeightsUploadType::LOCAL_MEM_BY_THREADS || ConvolutionTransposed3x3::WeightsUploadType::LOCAL_MEM_BY_THREADS ||
...@@ -170,26 +166,35 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode( ...@@ -170,26 +166,35 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode(
ConvolutionTransposed3x3::WeightsUploadType::LOCAL_MEM_BY_THREADS) { ConvolutionTransposed3x3::WeightsUploadType::LOCAL_MEM_BY_THREADS) {
c += " int local_id = (int)(get_local_id(1) * 8 + get_local_id(0));\n"; c += " int local_id = (int)(get_local_id(1) * 8 + get_local_id(0));\n";
} }
if (manual_clamp) { const std::string next_x = "SRC_X + " + pixel_stride;
const std::string next_x = "SRC_X + " + pixel_stride; if (!src_desc.SupportsZeroClamp(Axis::WIDTH)) {
c += " bool in_x0 = SRC_X >= 0 && SRC_X < args.src_tensor.Width();\n"; c += " bool in_x0 = SRC_X >= 0 && SRC_X < args.src_tensor.Width();\n";
c += " bool in_x1 = " + next_x + " >= 0 && " + next_x + c += " bool in_x1 = " + next_x + " >= 0 && " + next_x +
" < args.src_tensor.Width();\n"; " < args.src_tensor.Width();\n";
}
if (!src_desc.SupportsZeroClamp(Axis::HEIGHT)) {
c += " bool in_y0 = SRC_Y >= 0 && SRC_Y < args.src_tensor.Height();\n"; c += " bool in_y0 = SRC_Y >= 0 && SRC_Y < args.src_tensor.Height();\n";
c += " bool in_y1 = SRC_Y + 1 >= 0 && SRC_Y + 1 < " c += " bool in_y1 = SRC_Y + 1 >= 0 && SRC_Y + 1 < "
"args.src_tensor.Height();\n"; "args.src_tensor.Height();\n";
if (src_tensor_type == TensorStorageType::BUFFER) { }
c += " int xc0 = clamp(SRC_X, 0, args.src_tensor.Width() - 1);\n"; auto generate_check = [&](int x, int y) {
c += " int xc1 = clamp(" + next_x + std::string check;
", 0, args.src_tensor.Width() - 1);\n"; const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT};
c += " int yc0 = clamp(SRC_Y, 0, args.src_tensor.Height() - 1);\n"; const std::vector<std::string> names{"in_x" + std::to_string(x),
c += " int yc1 = clamp(SRC_Y + 1, 0, args.src_tensor.Height() - 1);\n"; "in_y" + std::to_string(y)};
c += " args.src_tensor.GetAddress(addr_0, xc0, yc0, 0);\n"; for (int i = 0; i < axes.size(); ++i) {
c += " args.src_tensor.GetAddress(addr_1, xc1, yc0, 0);\n"; const auto& axis = axes[i];
c += " args.src_tensor.GetAddress(addr_2, xc0, yc1, 0);\n"; if (src_desc.HasAxis(axis) && !src_desc.SupportsZeroClamp(axis)) {
c += " args.src_tensor.GetAddress(addr_3, xc1, yc1, 0);\n"; if (!check.empty()) {
c += " int dz = args.src_tensor.SliceStride();\n"; check += " && ";
} else { // TensorStorageType::IMAGE_BUFFER }
check += names[i];
}
}
return check;
};
if (src_desc.IsLinear()) {
if (src_desc.ReturnsZeroForNegOneRead()) {
c += " args.src_tensor.GetAddress(addr_0, SRC_X, SRC_Y, 0);\n"; c += " args.src_tensor.GetAddress(addr_0, SRC_X, SRC_Y, 0);\n";
c += " args.src_tensor.GetAddress(addr_1," + next_x + ", SRC_Y, 0);\n"; c += " args.src_tensor.GetAddress(addr_1," + next_x + ", SRC_Y, 0);\n";
c += " args.src_tensor.GetAddress(addr_2, SRC_X, SRC_Y + 1, 0);\n"; c += " args.src_tensor.GetAddress(addr_2, SRC_X, SRC_Y + 1, 0);\n";
...@@ -206,13 +211,24 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode( ...@@ -206,13 +211,24 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode(
"in_y1));\n"; "in_y1));\n";
c += " int dz_3 = select(0, args.src_tensor.SliceStride(), (in_x1 && " c += " int dz_3 = select(0, args.src_tensor.SliceStride(), (in_x1 && "
"in_y1));\n"; "in_y1));\n";
} else {
c += " int xc0 = clamp(SRC_X, 0, args.src_tensor.Width() - 1);\n";
c += " int xc1 = clamp(" + next_x +
", 0, args.src_tensor.Width() - 1);\n";
c += " int yc0 = clamp(SRC_Y, 0, args.src_tensor.Height() - 1);\n";
c += " int yc1 = clamp(SRC_Y + 1, 0, args.src_tensor.Height() - 1);\n";
c += " args.src_tensor.GetAddress(addr_0, xc0, yc0, 0);\n";
c += " args.src_tensor.GetAddress(addr_1, xc1, yc0, 0);\n";
c += " args.src_tensor.GetAddress(addr_2, xc0, yc1, 0);\n";
c += " args.src_tensor.GetAddress(addr_3, xc1, yc1, 0);\n";
c += " int dz = args.src_tensor.SliceStride();\n";
} }
} }
auto read_src = [&](int x, int y) { auto read_src = [&](int x, int y) {
if (manual_clamp) { if (src_desc.IsLinear()) {
const std::string id = std::to_string(y * 2 + x); const std::string id = std::to_string(y * 2 + x);
const std::string addr = "addr_" + std::to_string(y * 2 + x); const std::string addr = "addr_" + std::to_string(y * 2 + x);
if (src_tensor_type == TensorStorageType::IMAGE_BUFFER) { if (src_desc.ReturnsZeroForNegOneRead()) {
return "args.src_tensor.Read(" + addr + "); " + addr + " += dz_" + id + return "args.src_tensor.Read(" + addr + "); " + addr + " += dz_" + id +
";\n"; ";\n";
} else { } else {
...@@ -221,8 +237,13 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode( ...@@ -221,8 +237,13 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode(
addr + " += dz;\n"; addr + " += dz;\n";
} }
} else { } else {
std::string check = generate_check(x, y);
if (!check.empty()) {
check = " * (FLT)(" + check + ")";
}
return "args.src_tensor.Read(SRC_X + " + std::to_string(x) + "*" + return "args.src_tensor.Read(SRC_X + " + std::to_string(x) + "*" +
pixel_stride + ", SRC_Y + " + std::to_string(y) + ", s);\n"; pixel_stride + ", SRC_Y + " + std::to_string(y) + ", s)" + check +
";\n";
} }
}; };
const int padding_x_rem = abs(padding.x) % 2; const int padding_x_rem = abs(padding.x) % 2;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册