提交 b37b0d00 编写于 作者: R Raman Sarokin 提交者: TensorFlower Gardener

DepthwiseConvolution generation changed to use storage type properties instead...

DepthwiseConvolution generation changed to use storage type properties instead of specific storage types.

PiperOrigin-RevId: 339964553
Change-Id: I2cded9c306a40b136002c08e610daca2d75e1758
上级 50ec85cd
...@@ -86,13 +86,8 @@ std::string GenerateDepthwiseConvolutionCode( ...@@ -86,13 +86,8 @@ std::string GenerateDepthwiseConvolutionCode(
} }
op->AddDstTensor("dst_tensor", dst_desc); op->AddDstTensor("dst_tensor", dst_desc);
const auto src_tensor_type = op_def.src_tensors[0].storage_type;
std::string c = GetCommonDefines(op_def.precision); std::string c = GetCommonDefines(op_def.precision);
const bool manual_clamp = src_tensor_type == TensorStorageType::BUFFER ||
src_tensor_type == TensorStorageType::IMAGE_BUFFER;
c += "__kernel void main_function(\n"; c += "__kernel void main_function(\n";
c += "$0) {\n"; c += "$0) {\n";
c += " int X = get_global_id(0);\n"; c += " int X = get_global_id(0);\n";
...@@ -142,84 +137,91 @@ std::string GenerateDepthwiseConvolutionCode( ...@@ -142,84 +137,91 @@ std::string GenerateDepthwiseConvolutionCode(
std::string kernel_size_z = std::string kernel_size_z =
dynamic_weights ? "args.weights.Depth()" : "args.kernel_size_z"; dynamic_weights ? "args.weights.Depth()" : "args.kernel_size_z";
std::string flat_coords = "x_c, y_c"; auto generate_check = [&]() {
if (manual_clamp) { std::string check;
std::string check = "!outside_x && !outside_y"; const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH};
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { const std::vector<std::string> names{"outside_x", "outside_y", "outside_z"};
check += " && !outside_z"; for (int i = 0; i < axes.size(); ++i) {
flat_coords += ", z_c"; const auto& axis = axes[i];
c += " for (int kz = 0; kz < " + kernel_size_z + "; ++kz) {\n"; if (src_desc.HasAxis(axis) && !src_desc.SupportsZeroClamp(axis)) {
c += " int z_c = z_offseted + kz * args.dilation_z;\n"; if (!check.empty()) {
c += " bool outside_z = z_c < 0 || z_c >= args.src_tensor.Depth();\n"; check += " && ";
} }
c += " for (int ky = 0; ky < " + kernel_size_y + "; ++ky) {\n"; check += "!" + names[i];
c += " int y_c = y_offseted + ky * args.dilation_y;\n";
c += " bool outside_y = y_c < 0 || y_c >= args.src_tensor.Height();\n";
c += " for (int kx = 0; kx < " + kernel_size_x + "; ++kx) {\n";
const std::string dilation_x =
op_def.IsBatchSupported() ? "args.dilation_x * args.src_tensor.Batch()"
: "args.dilation_x";
c += " int x_c = x_offseted + kx * " + dilation_x + ";\n";
c += " bool outside_x = x_c < 0 || x_c >= args.src_tensor.Width();\n";
c += " if (" + check + ") {\n";
if (dynamic_weights) {
c += " FLT4 f = args.weights.Read(kx, ky, S);\n";
} else {
if (weights_are_buffer) {
c += " FLT4 f = args.weights.Read(fx_c);\n";
} else {
c += " FLT4 f = args.weights.Read(fx_c, S);\n";
} }
} }
c += GetSrcValue(channel_multiplier, flat_coords); return check;
c += " r += TO_ACCUM_TYPE(src_final * f);\n"; };
c += " };\n"; auto generate_coords = [&]() {
if (!dynamic_weights) { std::string check;
c += " fx_c++;\n"; const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH};
} const std::vector<std::string> names{"x_c", "y_c", "z_c"};
c += " }\n"; for (int i = 0; i < axes.size(); ++i) {
c += " }\n"; const auto& axis = axes[i];
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { if (src_desc.HasAxis(axis)) {
c += " }\n"; if (!check.empty()) {
} check += ", ";
} else { // Texture types with ZERO clamping }
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { check += names[i];
flat_coords += ", z_c";
c += " for (int kz = 0; kz < " + kernel_size_z + "; ++kz) {\n";
c += " int z_c = z_offseted + kz * args.dilation_z;\n";
if (src_tensor_type !=
TensorStorageType::TEXTURE_3D) { // Only TEXTURE_3D supports clamping
// in DEPTH dimension
c += " if (z_c < 0 || z_c >= args.src_tensor.Depth()) {\n";
c += " fx_c += args.kernel_size_y * args.kernel_size_x;\n";
c += " continue;\n";
c += " }\n";
} }
} }
return check;
};
const std::string check = generate_check();
const std::string coords = generate_coords();
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
c += " for (int kz = 0; kz < " + kernel_size_z + "; ++kz) {\n";
c += " int z_c = z_offseted + kz * args.dilation_z;\n";
if (!src_desc.SupportsZeroClamp(Axis::DEPTH)) {
c += " bool outside_z = z_c < 0 || z_c >= args.src_tensor.Depth();\n";
}
}
if (op_def.dst_tensors[0].HasAxis(Axis::HEIGHT)) {
c += " for (int ky = 0; ky < " + kernel_size_y + "; ++ky) {\n"; c += " for (int ky = 0; ky < " + kernel_size_y + "; ++ky) {\n";
c += " int y_c = y_offseted + ky * args.dilation_y;\n"; c += " int y_c = y_offseted + ky * args.dilation_y;\n";
c += " for (int kx = 0; kx < " + kernel_size_x + "; ++kx) {\n"; if (!src_desc.SupportsZeroClamp(Axis::HEIGHT)) {
c += " bool outside_y = y_c < 0 || y_c >= args.src_tensor.Height();\n";
}
}
if (op_def.dst_tensors[0].HasAxis(Axis::WIDTH)) {
c += " for (int kx = 0; kx < " + kernel_size_x + "; ++kx) {\n";
const std::string dilation_x = const std::string dilation_x =
op_def.IsBatchSupported() ? "args.dilation_x * args.src_tensor.Batch()" op_def.IsBatchSupported() ? "args.dilation_x * args.src_tensor.Batch()"
: "args.dilation_x"; : "args.dilation_x";
c += " int x_c = x_offseted + kx * " + dilation_x + ";\n"; c += " int x_c = x_offseted + kx * " + dilation_x + ";\n";
c += GetSrcValue(channel_multiplier, flat_coords); if (!src_desc.SupportsZeroClamp(Axis::WIDTH)) {
if (dynamic_weights) { c += " bool outside_x = x_c < 0 || x_c >= args.src_tensor.Width();\n";
c += " FLT4 f = args.weights.Read(kx, ky, S);\n"; }
}
if (!check.empty()) {
c += " if (" + check + ") {\n";
}
if (dynamic_weights) {
c += " FLT4 f = args.weights.Read(kx, ky, S);\n";
} else {
if (weights_are_buffer) {
c += " FLT4 f = args.weights.Read(fx_c);\n";
} else { } else {
if (weights_are_buffer) { c += " FLT4 f = args.weights.Read(fx_c, S);\n";
c += " FLT4 f = args.weights.Read(fx_c);\n";
} else {
c += " FLT4 f = args.weights.Read(fx_c, S);\n";
}
c += " fx_c++;\n";
} }
c += " r += TO_ACCUM_TYPE(src_final * f);\n"; }
c += GetSrcValue(channel_multiplier, coords);
c += " r += TO_ACCUM_TYPE(src_final * f);\n";
if (!check.empty()) {
c += " }\n"; c += " }\n";
}
if (!dynamic_weights) {
c += " fx_c++;\n";
}
if (op_def.dst_tensors[0].HasAxis(Axis::WIDTH)) {
c += " }\n";
}
if (op_def.dst_tensors[0].HasAxis(Axis::HEIGHT)) {
c += " }\n";
}
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
c += " }\n"; c += " }\n";
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
c += " }\n";
}
} }
c += " FLT4 res0 = TO_FLT4(r) + args.biases.Read(S);\n"; c += " FLT4 res0 = TO_FLT4(r) + args.biases.Read(S);\n";
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
...@@ -228,7 +230,6 @@ std::string GenerateDepthwiseConvolutionCode( ...@@ -228,7 +230,6 @@ std::string GenerateDepthwiseConvolutionCode(
c += " args.dst_tensor.Write(res0, X, Y, S);\n"; c += " args.dst_tensor.Write(res0, X, Y, S);\n";
} }
c += "}\n"; c += "}\n";
return c; return c;
} }
} // namespace } // namespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册