提交 ae517dba 编写于 作者: R Raman Sarokin 提交者: TensorFlower Gardener

LSTM converted to generic GPUOperation.

PiperOrigin-RevId: 328189510
Change-Id: Ic2791b36123d374fa1d3521e66b18dd7b82e5c4a
上级 84258bce
...@@ -24,33 +24,14 @@ limitations under the License. ...@@ -24,33 +24,14 @@ limitations under the License.
namespace tflite { namespace tflite {
namespace gpu { namespace gpu {
namespace cl { namespace cl {
namespace {
LSTM::LSTM(const OperationDef& definition, const DeviceInfo& device_info) std::string GetLSTMCode(const OperationDef& op_def,
: GPUOperation(definition) { const DeviceInfo& device_info) {
code_ = GetLSTMCode(definition_, device_info);
}
LSTM::LSTM(LSTM&& kernel) : GPUOperation(std::move(kernel)) {}
LSTM& LSTM::operator=(LSTM&& kernel) {
if (this != &kernel) {
GPUOperation::operator=(std::move(kernel));
}
return *this;
}
std::string LSTM::GetLSTMCode(const OperationDef& op_def,
const DeviceInfo& device_info) {
AddSrcTensor("intermediate", op_def.src_tensors[0]);
AddSrcTensor("prev_state", op_def.src_tensors[1]);
AddDstTensor("new_state", op_def.dst_tensors[0]);
AddDstTensor("activation", op_def.dst_tensors[1]);
std::string c = GetCommonDefines(op_def.precision); std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n"; c += "__kernel void main_function(\n";
c += "$0) {\n"; c += "$0) {\n";
c += " int B = get_global_id(0);\n"; c += " int B = get_global_id(0);\n";
c += " int Z = get_global_id(1);\n"; c += " int Z = get_global_id(2);\n";
c += " if (Z >= args.activation.Slices() || B >= args.activation.Batch()) " c += " if (Z >= args.activation.Slices() || B >= args.activation.Batch()) "
"return;\n"; "return;\n";
c += " FLT4 prev_st = args.prev_state.Read(0, 0, Z, B);\n"; c += " FLT4 prev_st = args.prev_state.Read(0, 0, Z, B);\n";
...@@ -105,15 +86,18 @@ std::string LSTM::GetLSTMCode(const OperationDef& op_def, ...@@ -105,15 +86,18 @@ std::string LSTM::GetLSTMCode(const OperationDef& op_def,
return c; return c;
} }
int3 LSTM::GetGridSize() const { } // namespace
const int grid_x = dst_[0]->Batch();
const int grid_y = dst_[0]->Slices();
const int grid_z = 1;
return int3(grid_x, grid_y, grid_z);
}
LSTM CreateLSTM(const OperationDef& definition, const DeviceInfo& device_info) { GPUOperation CreateLSTM(const OperationDef& definition,
return LSTM(definition, device_info); const DeviceInfo& device_info) {
GPUOperation op(definition);
op.AddSrcTensor("intermediate", definition.src_tensors[0]);
op.AddSrcTensor("prev_state", definition.src_tensors[1]);
op.AddDstTensor("new_state", definition.dst_tensors[0]);
op.AddDstTensor("activation", definition.dst_tensors[1]);
op.code_ = GetLSTMCode(definition, device_info);
op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
return op;
} }
} // namespace cl } // namespace cl
......
...@@ -25,23 +25,8 @@ namespace tflite { ...@@ -25,23 +25,8 @@ namespace tflite {
namespace gpu { namespace gpu {
namespace cl { namespace cl {
class LSTM : public GPUOperation { GPUOperation CreateLSTM(const OperationDef& definition,
public: const DeviceInfo& device_info);
LSTM(const OperationDef& definition, const DeviceInfo& device_info);
int3 GetGridSize() const override;
// Move only
LSTM(LSTM&& kernel);
LSTM& operator=(LSTM&& kernel);
LSTM(const LSTM&) = delete;
LSTM& operator=(const LSTM&) = delete;
private:
std::string GetLSTMCode(const OperationDef& op_def,
const DeviceInfo& device_info);
};
LSTM CreateLSTM(const OperationDef& definition, const DeviceInfo& device_info);
} // namespace cl } // namespace cl
} // namespace gpu } // namespace gpu
......
...@@ -67,7 +67,7 @@ TEST_F(OpenCLOperationTest, LSTM) { ...@@ -67,7 +67,7 @@ TEST_F(OpenCLOperationTest, LSTM) {
op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC});
TensorFloat32 new_state; TensorFloat32 new_state;
TensorFloat32 new_activ; TensorFloat32 new_activ;
LSTM operation = CreateLSTM(op_def, env_.GetDevicePtr()->info_); GPUOperation operation = CreateLSTM(op_def, env_.GetDevicePtr()->info_);
ASSERT_OK(ExecuteGPUOperation( ASSERT_OK(ExecuteGPUOperation(
{src_tensor, prev_state}, creation_context_, &operation, {src_tensor, prev_state}, creation_context_, &operation,
{BHWC(1, 1, 1, 4), BHWC(1, 1, 1, 4)}, {&new_state, &new_activ})); {BHWC(1, 1, 1, 4), BHWC(1, 1, 1, 4)}, {&new_state, &new_activ}));
......
...@@ -246,7 +246,7 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info, ...@@ -246,7 +246,7 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info,
return absl::OkStatus(); return absl::OkStatus();
} }
case OperationType::LSTM: { case OperationType::LSTM: {
SelectLSTM(op_def, device_info, gpu_op); *gpu_op = SelectLSTM(op_def, device_info);
return absl::OkStatus(); return absl::OkStatus();
} }
case OperationType::MAX_UNPOOLING_2D: { case OperationType::MAX_UNPOOLING_2D: {
......
...@@ -45,10 +45,9 @@ namespace tflite { ...@@ -45,10 +45,9 @@ namespace tflite {
namespace gpu { namespace gpu {
namespace cl { namespace cl {
void SelectLSTM(const OperationDef& op_def, const DeviceInfo& device_info, std::unique_ptr<GPUOperation> SelectLSTM(const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr) { const DeviceInfo& device_info) {
LSTM operation = CreateLSTM(op_def, device_info); return absl::make_unique<GPUOperation>(CreateLSTM(op_def, device_info));
*ptr = absl::make_unique<LSTM>(std::move(operation));
} }
std::unique_ptr<GPUOperation> SelectReLU(const ReLUAttributes& attr, std::unique_ptr<GPUOperation> SelectReLU(const ReLUAttributes& attr,
......
...@@ -28,8 +28,8 @@ namespace tflite { ...@@ -28,8 +28,8 @@ namespace tflite {
namespace gpu { namespace gpu {
namespace cl { namespace cl {
void SelectLSTM(const OperationDef& op_def, const DeviceInfo& device_info, std::unique_ptr<GPUOperation> SelectLSTM(const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr); const DeviceInfo& device_info);
std::unique_ptr<GPUOperation> SelectReLU(const ReLUAttributes& attr, std::unique_ptr<GPUOperation> SelectReLU(const ReLUAttributes& attr,
const OperationDef& op_def); const OperationDef& op_def);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册