未验证 提交 170842cb 编写于 作者: Y Yibing Liu 提交者: GitHub

Some improvements to support bert mixed precision training (#15585)

* Some improvements to support bert mixed precision training

test=develop

* Revert the cast in layer_norm

test=develop
上级 16d54f7f
...@@ -114,4 +114,5 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -114,4 +114,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::GPUDropoutKernel<plat::CUDADeviceContext, double>); ops::GPUDropoutKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
dropout_grad, ops::DropoutGradKernel<plat::CUDADeviceContext, float>, dropout_grad, ops::DropoutGradKernel<plat::CUDADeviceContext, float>,
ops::DropoutGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::DropoutGradKernel<plat::CUDADeviceContext, double>); ops::DropoutGradKernel<plat::CUDADeviceContext, double>);
...@@ -60,11 +60,14 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -60,11 +60,14 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel<float>, REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel<float>,
ops::GatherOpCUDAKernel<double>, ops::GatherOpCUDAKernel<double>,
ops::GatherOpCUDAKernel<int64_t>, ops::GatherOpCUDAKernel<int64_t>,
ops::GatherOpCUDAKernel<int>); ops::GatherOpCUDAKernel<int>,
ops::GatherOpCUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>, REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>,
ops::GatherGradOpCUDAKernel<double>, ops::GatherGradOpCUDAKernel<double>,
ops::GatherGradOpCUDAKernel<int64_t>, ops::GatherGradOpCUDAKernel<int64_t>,
ops::GatherGradOpCUDAKernel<int>); ops::GatherGradOpCUDAKernel<int>,
ops::GatherGradOpCUDAKernel<plat::float16>);
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/lookup_table_op.h" #include "paddle/fluid/operators/lookup_table_op.h"
#include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -193,8 +194,11 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> { ...@@ -193,8 +194,11 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>, REGISTER_OP_CUDA_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>,
ops::LookupTableCUDAKernel<double>); ops::LookupTableCUDAKernel<double>,
ops::LookupTableCUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(lookup_table_grad, REGISTER_OP_CUDA_KERNEL(lookup_table_grad,
ops::LookupTableGradCUDAKernel<float>, ops::LookupTableGradCUDAKernel<float>,
ops::LookupTableGradCUDAKernel<double>); ops::LookupTableGradCUDAKernel<double>,
ops::LookupTableGradCUDAKernel<plat::float16>);
...@@ -330,6 +330,7 @@ class Reshape2GradOp : public framework::OperatorWithKernel { ...@@ -330,6 +330,7 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
...@@ -356,16 +357,20 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, ...@@ -356,16 +357,20 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel); int64_t, ops::ReshapeKernel, plat::float16,
ops::ReshapeKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int, double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel, plat::float16,
ops::ReshapeGradKernel); ops::ReshapeGradKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel); int64_t, ops::ReshapeKernel, plat::float16,
ops::ReshapeKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int, double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel, plat::float16,
ops::ReshapeGradKernel); ops::ReshapeGradKernel);
#endif #endif
...@@ -17,13 +17,16 @@ ...@@ -17,13 +17,16 @@
namespace plat = paddle::platform; namespace plat = paddle::platform;
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(stack, ops::StackKernel<plat::CUDADeviceContext, float>, REGISTER_OP_CUDA_KERNEL(
stack, ops::StackKernel<plat::CUDADeviceContext, float>,
ops::StackKernel<plat::CUDADeviceContext, double>, ops::StackKernel<plat::CUDADeviceContext, double>,
ops::StackKernel<plat::CUDADeviceContext, int>, ops::StackKernel<plat::CUDADeviceContext, int>,
ops::StackKernel<plat::CUDADeviceContext, int64_t>); ops::StackKernel<plat::CUDADeviceContext, int64_t>,
ops::StackKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(stack_grad, REGISTER_OP_CUDA_KERNEL(
ops::StackGradKernel<plat::CUDADeviceContext, float>, stack_grad, ops::StackGradKernel<plat::CUDADeviceContext, float>,
ops::StackGradKernel<plat::CUDADeviceContext, double>, ops::StackGradKernel<plat::CUDADeviceContext, double>,
ops::StackGradKernel<plat::CUDADeviceContext, int>, ops::StackGradKernel<plat::CUDADeviceContext, int>,
ops::StackGradKernel<plat::CUDADeviceContext, int64_t>); ops::StackGradKernel<plat::CUDADeviceContext, int64_t>,
ops::StackGradKernel<plat::CUDADeviceContext, plat::float16>);
...@@ -15,19 +15,27 @@ limitations under the License. */ ...@@ -15,19 +15,27 @@ limitations under the License. */
#include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/operators/transpose_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose, ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>, transpose, ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, double>); ops::TransposeKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose_grad, transpose_grad,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>, ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, double>); ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose2, transpose2,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>, ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, double>); ops::TransposeKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose2_grad, transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>, ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, double>); ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
...@@ -366,17 +366,40 @@ class TruncatedNormalInitializer(Initializer): ...@@ -366,17 +366,40 @@ class TruncatedNormalInitializer(Initializer):
# Initialization Ops should be prepended and not appended # Initialization Ops should be prepended and not appended
if self._seed == 0: if self._seed == 0:
self._seed = block.program.random_seed self._seed = block.program.random_seed
# to be compatible of fp16 initalizers
if var.dtype == VarDesc.VarType.FP16:
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(".".join(
['truncated_gaussian_random', 'tmp'])),
shape=var.shape,
dtype=out_dtype,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False)
else:
out_dtype = var.dtype
out_var = var
op = block._prepend_op( op = block._prepend_op(
type="truncated_gaussian_random", type="truncated_gaussian_random",
outputs={"Out": var}, outputs={"Out": out_var},
attrs={ attrs={
"shape": var.shape, "shape": var.shape,
"dtype": int(var.dtype), "dtype": out_dtype,
"mean": self._mean, "mean": self._mean,
"std": self._std_dev, "std": self._std_dev,
"seed": self._seed "seed": self._seed
}, },
stop_gradient=True) stop_gradient=True)
if var.dtype == VarDesc.VarType.FP16:
block.append_op(
type="cast",
inputs={"X": out_var},
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
var.op = op var.op = op
return op return op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册