提交 ae517dba 编写于 作者: R Raman Sarokin 提交者: TensorFlower Gardener

LSTM converted to generic GPUOperation.

PiperOrigin-RevId: 328189510
Change-Id: Ic2791b36123d374fa1d3521e66b18dd7b82e5c4a
上级 84258bce
......@@ -24,33 +24,14 @@ limitations under the License.
namespace tflite {
namespace gpu {
namespace cl {
LSTM::LSTM(const OperationDef& definition, const DeviceInfo& device_info)
: GPUOperation(definition) {
code_ = GetLSTMCode(definition_, device_info);
}
LSTM::LSTM(LSTM&& kernel) : GPUOperation(std::move(kernel)) {}
LSTM& LSTM::operator=(LSTM&& kernel) {
if (this != &kernel) {
GPUOperation::operator=(std::move(kernel));
}
return *this;
}
std::string LSTM::GetLSTMCode(const OperationDef& op_def,
const DeviceInfo& device_info) {
AddSrcTensor("intermediate", op_def.src_tensors[0]);
AddSrcTensor("prev_state", op_def.src_tensors[1]);
AddDstTensor("new_state", op_def.dst_tensors[0]);
AddDstTensor("activation", op_def.dst_tensors[1]);
namespace {
std::string GetLSTMCode(const OperationDef& op_def,
const DeviceInfo& device_info) {
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += "$0) {\n";
c += " int B = get_global_id(0);\n";
c += " int Z = get_global_id(1);\n";
c += " int Z = get_global_id(2);\n";
c += " if (Z >= args.activation.Slices() || B >= args.activation.Batch()) "
"return;\n";
c += " FLT4 prev_st = args.prev_state.Read(0, 0, Z, B);\n";
......@@ -105,15 +86,18 @@ std::string LSTM::GetLSTMCode(const OperationDef& op_def,
return c;
}
int3 LSTM::GetGridSize() const {
const int grid_x = dst_[0]->Batch();
const int grid_y = dst_[0]->Slices();
const int grid_z = 1;
return int3(grid_x, grid_y, grid_z);
}
} // namespace
LSTM CreateLSTM(const OperationDef& definition, const DeviceInfo& device_info) {
return LSTM(definition, device_info);
GPUOperation CreateLSTM(const OperationDef& definition,
const DeviceInfo& device_info) {
GPUOperation op(definition);
op.AddSrcTensor("intermediate", definition.src_tensors[0]);
op.AddSrcTensor("prev_state", definition.src_tensors[1]);
op.AddDstTensor("new_state", definition.dst_tensors[0]);
op.AddDstTensor("activation", definition.dst_tensors[1]);
op.code_ = GetLSTMCode(definition, device_info);
op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
return op;
}
} // namespace cl
......
......@@ -25,23 +25,8 @@ namespace tflite {
namespace gpu {
namespace cl {
class LSTM : public GPUOperation {
public:
LSTM(const OperationDef& definition, const DeviceInfo& device_info);
int3 GetGridSize() const override;
// Move only
LSTM(LSTM&& kernel);
LSTM& operator=(LSTM&& kernel);
LSTM(const LSTM&) = delete;
LSTM& operator=(const LSTM&) = delete;
private:
std::string GetLSTMCode(const OperationDef& op_def,
const DeviceInfo& device_info);
};
LSTM CreateLSTM(const OperationDef& definition, const DeviceInfo& device_info);
GPUOperation CreateLSTM(const OperationDef& definition,
const DeviceInfo& device_info);
} // namespace cl
} // namespace gpu
......
......@@ -67,7 +67,7 @@ TEST_F(OpenCLOperationTest, LSTM) {
op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC});
TensorFloat32 new_state;
TensorFloat32 new_activ;
LSTM operation = CreateLSTM(op_def, env_.GetDevicePtr()->info_);
GPUOperation operation = CreateLSTM(op_def, env_.GetDevicePtr()->info_);
ASSERT_OK(ExecuteGPUOperation(
{src_tensor, prev_state}, creation_context_, &operation,
{BHWC(1, 1, 1, 4), BHWC(1, 1, 1, 4)}, {&new_state, &new_activ}));
......
......@@ -246,7 +246,7 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info,
return absl::OkStatus();
}
case OperationType::LSTM: {
SelectLSTM(op_def, device_info, gpu_op);
*gpu_op = SelectLSTM(op_def, device_info);
return absl::OkStatus();
}
case OperationType::MAX_UNPOOLING_2D: {
......
......@@ -45,10 +45,9 @@ namespace tflite {
namespace gpu {
namespace cl {
void SelectLSTM(const OperationDef& op_def, const DeviceInfo& device_info,
std::unique_ptr<GPUOperation>* ptr) {
LSTM operation = CreateLSTM(op_def, device_info);
*ptr = absl::make_unique<LSTM>(std::move(operation));
std::unique_ptr<GPUOperation> SelectLSTM(const OperationDef& op_def,
const DeviceInfo& device_info) {
return absl::make_unique<GPUOperation>(CreateLSTM(op_def, device_info));
}
std::unique_ptr<GPUOperation> SelectReLU(const ReLUAttributes& attr,
......
......@@ -28,8 +28,8 @@ namespace tflite {
namespace gpu {
namespace cl {
void SelectLSTM(const OperationDef& op_def, const DeviceInfo& device_info,
std::unique_ptr<GPUOperation>* ptr);
std::unique_ptr<GPUOperation> SelectLSTM(const OperationDef& op_def,
const DeviceInfo& device_info);
std::unique_ptr<GPUOperation> SelectReLU(const ReLUAttributes& attr,
const OperationDef& op_def);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册