未验证 提交 51fcceb2 编写于 作者: P pangengzheng 提交者: GitHub

support add(x_float32, bfloa16_) or add(x_float32, y_float16) (#54611)

* support add(x_float32, bfloa16_) or add(x_float32, y_float16)

* polisg
上级 ac7f09a9
......@@ -258,20 +258,6 @@ PD_REGISTER_KERNEL(subtract,
complex128,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(add,
KPS,
ALL_LAYOUT,
phi::AddKernel,
float,
double,
int16_t,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
complex64,
complex128) {}
PD_REGISTER_KERNEL(multiply,
KPS,
ALL_LAYOUT,
......
......@@ -41,6 +41,22 @@ struct InverseAddFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const { return b + a; }
};
// Float32Bfloat16Add
template <typename T>
struct Float32Bfloat16AddFunctor {
inline HOSTDEVICE T operator()(const T x, const phi::bfloat16 y) {
return x + static_cast<T>(y);
}
};
// Float32Float16Add
template <typename T>
struct Float32Float16AddFunctor {
inline HOSTDEVICE T operator()(const T x, const phi::float16 y) {
return x + static_cast<T>(y);
}
};
// Subtract
template <typename T>
struct SubtractFunctor {
......
......@@ -24,6 +24,47 @@ namespace phi {
DEFINE_CUDA_ELEMENTWISE_OP(Add)
template <typename T, typename Context>
void Float32Bfloat16OrFloat16AddCudaFunctor(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
inputs.reserve(2);
std::vector<DenseTensor*> outputs;
outputs.reserve(1);
inputs.emplace_back(&x);
inputs.emplace_back(&y);
outputs.emplace_back(out);
if (y.dtype() == phi::DataType::BFLOAT16) {
funcs::ElementwiseKernel<T>(
dev_ctx, inputs, &outputs, funcs::Float32Bfloat16AddFunctor<T>());
} else if (y.dtype() == phi::DataType::FLOAT16) {
funcs::ElementwiseKernel<T>(
dev_ctx, inputs, &outputs, funcs::Float32Float16AddFunctor<T>());
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupport x dtype:%s, y dtype:%s for add(x, y) operation",
phi::DataTypeToString(x.type()),
phi::DataTypeToString(y.type())));
}
}
template <typename T, typename Context>
void AddKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
if (x.dtype() == phi::DataType::FLOAT32 &&
(y.dtype() == phi::DataType::BFLOAT16 ||
y.dtype() == phi::DataType::FLOAT16)) {
using Type = DataTypeToCppType<phi::DataType::FLOAT32>::type;
Float32Bfloat16OrFloat16AddCudaFunctor<Type, Context>(dev_ctx, x, y, out);
} else {
AddRawKernel<T, Context>(dev_ctx, x, y, -1, out);
}
}
template <typename T, typename Context>
void GradAddKernel(const Context& dev_ctx,
const DenseTensor& x,
......@@ -43,6 +84,20 @@ using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(add,
KPS,
ALL_LAYOUT,
phi::AddKernel,
float,
double,
int16_t,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
complex64,
complex128) {}
PD_REGISTER_KERNEL(add_raw,
KPS,
ALL_LAYOUT,
......
......@@ -62,7 +62,7 @@ class MixPrecisionLayer(nn.Layer):
name="main_grad@" + param.name,
)
else:
param.main_grad.add_(tmp_grad.cast(paddle.float32))
param.main_grad.add_(tmp_grad)
tmp_grad._clear_data()
return None
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import unittest
import warnings
......@@ -872,6 +873,50 @@ class TestTensorAddAPIWarnings(unittest.TestCase):
os.environ['FLAGS_print_extra_attrs'] = "0"
class TestTensorFloa32Bfloat16OrFloat16Add(unittest.TestCase):
def _floa32_bfloat16_or_float16_add(self, y_dtype):
paddle.disable_static()
test_num = 5
val_range = 10000
shapes = []
for i in range(test_num):
shape = [np.random.randint(val_range), np.random.randint(val_range)]
shapes.append(shape)
for i, shape in enumerate(shapes):
x = paddle.randn(list(shape), dtype=paddle.float32)
x_copy = copy.deepcopy(x)
y = paddle.randn(list(shape), dtype=y_dtype)
x.add_(y)
x_copy.add_(paddle.cast(y, paddle.float32))
np.testing.assert_equal(x.numpy(), x_copy.numpy())
del x, x_copy
@unittest.skipIf(
not core.is_compiled_with_cuda()
or core.cudnn_version() < 8100
or paddle.device.cuda.get_device_capability()[0] < 8,
"only support compiled with CUDA and cudnn version need larger than 8.1.0 and device's compute capability is at least 8.0",
)
class TestTensorFloa32Bfloat16Add(TestTensorFloa32Bfloat16OrFloat16Add):
def test_floa32_bfloat16_add(self):
place = core.CUDAPlace(0)
with fluid.dygraph.base.guard(place=place):
self._floa32_bfloat16_or_float16_add(y_dtype=paddle.bfloat16)
@unittest.skipIf(
not core.is_compiled_with_cuda() or core.cudnn_version() < 8100,
"only support compiled with CUDA and cudnn version need larger than 8.1.0",
)
class TestTensorFloa32Float16Add(TestTensorFloa32Bfloat16OrFloat16Add):
def test_floa32_float16_add(self):
place = core.CUDAPlace(0)
with fluid.dygraph.base.guard(place=place):
self._floa32_bfloat16_or_float16_add(y_dtype=paddle.float16)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册