未验证 提交 a4bb38cb 编写于 作者: X xiongkun 提交者: GitHub

[EinsumOp] Make EinsumOp support bfloat16. (#43085)

* change einsum_v2 as default and add new flags: FLAG_einsum_opt=1|0

* make EInsumOP support bf16

* add unittest for BF16

* add condition for test_BF16

* fix bugs

* fix
上级 0ae8a2d6
...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
...@@ -73,6 +74,7 @@ struct EigenBroadcastGrad<Eigen::DefaultDevice, T, Rank> { ...@@ -73,6 +74,7 @@ struct EigenBroadcastGrad<Eigen::DefaultDevice, T, Rank> {
template struct FUNCTOR<Eigen::DefaultDevice, T, 6> template struct FUNCTOR<Eigen::DefaultDevice, T, 6>
INSTANTIATION(EigenBroadcast, bool); INSTANTIATION(EigenBroadcast, bool);
INSTANTIATION(EigenBroadcast, dtype::float16); INSTANTIATION(EigenBroadcast, dtype::float16);
INSTANTIATION(EigenBroadcast, dtype::bfloat16);
INSTANTIATION(EigenBroadcast, float); INSTANTIATION(EigenBroadcast, float);
INSTANTIATION(EigenBroadcast, double); INSTANTIATION(EigenBroadcast, double);
INSTANTIATION(EigenBroadcast, int); INSTANTIATION(EigenBroadcast, int);
......
...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
...@@ -73,6 +74,7 @@ struct EigenBroadcastGrad<Eigen::GpuDevice, T, Rank> { ...@@ -73,6 +74,7 @@ struct EigenBroadcastGrad<Eigen::GpuDevice, T, Rank> {
template struct FUNCTOR<Eigen::GpuDevice, T, 6> template struct FUNCTOR<Eigen::GpuDevice, T, 6>
INSTANTIATION(EigenBroadcast, bool); INSTANTIATION(EigenBroadcast, bool);
INSTANTIATION(EigenBroadcast, dtype::float16); INSTANTIATION(EigenBroadcast, dtype::float16);
INSTANTIATION(EigenBroadcast, dtype::bfloat16);
INSTANTIATION(EigenBroadcast, float); INSTANTIATION(EigenBroadcast, float);
INSTANTIATION(EigenBroadcast, double); INSTANTIATION(EigenBroadcast, double);
INSTANTIATION(EigenBroadcast, int); INSTANTIATION(EigenBroadcast, int);
......
...@@ -24,4 +24,5 @@ PD_REGISTER_KERNEL(einsum_grad, ...@@ -24,4 +24,5 @@ PD_REGISTER_KERNEL(einsum_grad,
phi::EinsumGradKernel, phi::EinsumGradKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(tile, ...@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(tile,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -197,20 +197,24 @@ void EinsumGradKernel(const Context& dev_ctx, ...@@ -197,20 +197,24 @@ void EinsumGradKernel(const Context& dev_ctx,
// release the cache tensor dTC to save memory right now. they are useless // release the cache tensor dTC to save memory right now. they are useless
// now. // now.
cache.clear(); cache.clear();
*(x_grad[0]) = PerformTileAndReduction<T, Context>(dev_ctx, if (x_grad[0]) {
labeltype, *(x_grad[0]) = PerformTileAndReduction<T, Context>(dev_ctx,
labelshape, labeltype,
broadcast_dims, labelshape,
ellipsis_dims[0], broadcast_dims,
ops[0], ellipsis_dims[0],
dA); ops[0],
*(x_grad[1]) = PerformTileAndReduction<T, Context>(dev_ctx, dA);
labeltype, }
labelshape, if (x_grad[1]) {
broadcast_dims, *(x_grad[1]) = PerformTileAndReduction<T, Context>(dev_ctx,
ellipsis_dims[1], labeltype,
ops[1], labelshape,
dB); broadcast_dims,
ellipsis_dims[1],
ops[1],
dB);
}
} }
} }
} // namespace phi } // namespace phi
...@@ -478,5 +478,23 @@ class TestStaticGraphShape(unittest.TestCase): ...@@ -478,5 +478,23 @@ class TestStaticGraphShape(unittest.TestCase):
self.assertEqual(C.shape, (-1, 384)) self.assertEqual(C.shape, (-1, 384))
class TestBF16(unittest.TestCase):
"""
EinsumOp support bfloat16 type, add unittest here for the correctness.
"""
def test_shape(self):
cuda_major = paddle.version.cuda().split('.')[0].strip()
if paddle.is_compiled_with_cuda() and int(cuda_major) >= 11:
""" MatmulKernel support bfloat16 only if cuda_major > 11.0.
"""
A = paddle.to_tensor(np.array([1.0, 2.0])).astype(paddle.bfloat16)
A = A.cuda()
B = paddle.to_tensor(np.array([2.0, 3.0])).astype(paddle.bfloat16)
B = B.cuda()
C = paddle.einsum('i,i->', A, B)
self.assertEqual(C.item(), 8.0)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册