提交 eb9a623a 编写于 作者: O Oleg Rybakov 提交者: A. Unique TensorFlower

Update get_config with parameters used for initialization.

PiperOrigin-RevId: 485624879
上级 b3db5efe
......@@ -826,6 +826,12 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
"is_short_seq": self._is_short_seq,
"begin_kernel": self._begin_kernel,
"scale": self._scale,
"scale_by_length": self._scale_by_length,
"use_causal_windowed": self.use_causal_windowed,
"causal_chunk_length": self.causal_chunk_length,
"causal_window_length": self.causal_window_length,
"causal_window_decay": self.causal_window_decay,
"causal_padding": self.causal_padding,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册