提交 1c12a84e 编写于 作者: R Raman Sarokin 提交者: TensorFlower Gardener

LSTM converted to new style.

PiperOrigin-RevId: 317746681
Change-Id: I8471f352e2f842d40d71c2733564ab346d4eb69c
上级 13deeb09
......@@ -26,39 +26,34 @@ namespace gpu {
namespace cl {
namespace {
std::string GetLSTMCode(const OperationDef& op_def, const CLDevice& device) {
const WHSBPoint state_size{"1", "1", "state_size.z", "state_size.w"};
const WHSBPoint src_size{"1", "1", "src_size.z", "src_size.w"};
TensorCodeGenerator intermediate("src_data", src_size, op_def.src_tensors[0]);
TensorCodeGenerator prev_state("prev_state", state_size,
op_def.src_tensors[1]);
TensorCodeGenerator activation("dst_data", state_size, op_def.dst_tensors[0]);
TensorCodeGenerator new_state("new_state", state_size, op_def.dst_tensors[1]);
std::string GetLSTMCode(const OperationDef& op_def, const CLDevice& device,
Arguments* args) {
args->AddObjectRef(
"intermediate", AccessType::READ,
absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
args->AddObjectRef(
"prev_state", AccessType::READ,
absl::make_unique<TensorDescriptor>(op_def.src_tensors[1]));
args->AddObjectRef(
"new_state", AccessType::WRITE,
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
args->AddObjectRef(
"activation", AccessType::WRITE,
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[1]));
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += intermediate.GetDeclaration(AccessType::READ) + ",\n";
c += prev_state.GetDeclaration(AccessType::READ) + ",\n";
c += new_state.GetDeclaration(AccessType::WRITE) + ",\n";
c += activation.GetDeclaration(AccessType::WRITE) + ",\n";
c += " int4 src_size, \n";
c += " int4 state_size, \n";
c += " int BATCH_SIZE \n";
c += ") {\n";
c += "$0) {\n";
c += " int B = get_global_id(0);\n";
c += " int Z = get_global_id(1);\n";
c += " if (Z >= state_size.z || B >= state_size.w) return;\n";
c += " FLT4 prev_st = " + prev_state.ReadWHSB("0", "0", "Z", "B") + ";\n";
c += " FLT4 r0 = " + intermediate.ReadWHSB("0", "0", "Z", "B") + ";\n";
c += " FLT4 r1 = " +
intermediate.ReadWHSB("0", "0", "Z + state_size.z", "B") + ";\n";
c += " FLT4 r2 = " +
intermediate.ReadWHSB("0", "0", "Z + state_size.z * 2", "B") + ";\n";
c += " FLT4 r3 = " +
intermediate.ReadWHSB("0", "0", "Z + state_size.z * 3", "B") + ";\n";
c += " if (Z >= args.activation.Slices() || B >= args.activation.Batch()) "
"return;\n";
c += " FLT4 prev_st = args.prev_state.Read(0, 0, Z, B);\n";
c += " FLT4 r0 = args.intermediate.Read(0, 0, Z, B);\n";
c += " int state_stride = args.activation.Slices();\n";
c += " FLT4 r1 = args.intermediate.Read(0, 0, Z + state_stride, B);\n";
c += " FLT4 r2 = args.intermediate.Read(0, 0, Z + state_stride * 2, B);\n";
c += " FLT4 r3 = args.intermediate.Read(0, 0, Z + state_stride * 3, B);\n";
if (op_def.precision != CalculationsPrecision::F32 && device.IsAdreno()) {
c += " FLT4 input_gate;\n";
c += " FLT4 new_input;\n";
......@@ -97,9 +92,9 @@ std::string GetLSTMCode(const OperationDef& op_def, const CLDevice& device) {
"* r3));\n";
}
c += " FLT4 new_st = input_gate * new_input + forget_gate * prev_st;\n";
c += " FLT4 activation = output_gate * tanh(new_st);\n";
c += " " + activation.WriteWHSB("activation", "0", "0", "Z", "B");
c += " " + new_state.WriteWHSB("new_st", "0", "0", "Z", "B");
c += " FLT4 act_value = output_gate * tanh(new_st);\n";
c += " args.activation.Write(act_value, 0, 0, Z, B);\n";
c += " args.new_state.Write(new_st, 0, 0, Z, B);\n";
c += "}\n";
return c;
}
......@@ -122,22 +117,20 @@ LSTM& LSTM::operator=(LSTM&& kernel) {
}
absl::Status LSTM::Compile(const CreationContext& creation_context) {
const auto code = GetLSTMCode(definition_, *creation_context.device);
std::string code = GetLSTMCode(definition_, *creation_context.device, &args_);
RETURN_IF_ERROR(
args_.TransformToCLCode(creation_context.device->GetInfo(), {}, &code));
return creation_context.cache->GetOrCreateCLKernel(
code, "main_function", *creation_context.context,
*creation_context.device, &kernel_);
}
absl::Status LSTM::BindArguments() {
kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr()));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[1]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Batch()));
return absl::OkStatus();
RETURN_IF_ERROR(args_.SetObjectRef("intermediate", src_[0]));
RETURN_IF_ERROR(args_.SetObjectRef("prev_state", src_[1]));
RETURN_IF_ERROR(args_.SetObjectRef("new_state", dst_[0]));
RETURN_IF_ERROR(args_.SetObjectRef("activation", dst_[1]));
return args_.Bind(kernel_.kernel());
}
int3 LSTM::GetGridSize() const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册