未验证 提交 b3232936 编写于 作者: P pangengzheng 提交者: GitHub

support add(x_float32, bfloa16_) or add(x_float32, y_float16) (#54415)

* support add(x_float32, bfloa16_) or add(x_float32, y_float16)

* polish

* fix test
上级 06b8fbb0
......@@ -44,6 +44,22 @@ struct InverseAddFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const { return b + a; }
};
// Float32Bfloat16Add
template <typename T>
struct Float32Bfloat16AddFunctor {
inline HOSTDEVICE T operator()(const T x, const phi::bfloat16 y) {
return x + static_cast<T>(y);
}
};
// Float32Float16Add
template <typename T>
struct Float32Float16AddFunctor {
inline HOSTDEVICE T operator()(const T x, const phi::float16 y) {
return x + static_cast<T>(y);
}
};
// Subtract
template <typename T>
struct SubtractFunctor {
......
......@@ -41,12 +41,45 @@ void AddCudaFunctor(const Context& dev_ctx,
dev_ctx, inputs, &outputs, funcs::AddFunctor<T>(), axis);
}
template <typename T, typename Context>
void Float32Bfloat16OrFloat16AddCudaFunctor(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
inputs.reserve(2);
std::vector<DenseTensor*> outputs;
outputs.reserve(1);
inputs.emplace_back(&x);
inputs.emplace_back(&y);
outputs.emplace_back(out);
if (y.dtype() == phi::DataType::BFLOAT16) {
funcs::ElementwiseKernel<T>(
dev_ctx, inputs, &outputs, funcs::Float32Bfloat16AddFunctor<T>());
} else if (y.dtype() == phi::DataType::FLOAT16) {
funcs::ElementwiseKernel<T>(
dev_ctx, inputs, &outputs, funcs::Float32Float16AddFunctor<T>());
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupport x dtype:%s, y dtype:%s for add(x, y) operation",
phi::DataTypeToString(x.type()),
phi::DataTypeToString(y.type())));
}
}
template <typename T, typename Context>
void AddKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
AddCudaFunctor<T, Context>(dev_ctx, x, y, -1, out);
if (x.dtype() == phi::DataType::FLOAT32 &&
(y.dtype() == phi::DataType::BFLOAT16 ||
y.dtype() == phi::DataType::FLOAT16)) {
using Type = DataTypeToCppType<phi::DataType::FLOAT32>::type;
Float32Bfloat16OrFloat16AddCudaFunctor<Type, Context>(dev_ctx, x, y, out);
} else {
AddCudaFunctor<T, Context>(dev_ctx, x, y, -1, out);
}
}
template <typename T, typename Context>
......
......@@ -59,7 +59,7 @@ class MixPrecisionLayer(nn.Layer):
name="main_grad@" + param.name,
)
else:
param.main_grad.add_(tmp_grad.cast(paddle.float32))
param.main_grad.add_(tmp_grad)
tmp_grad._clear_data()
return None
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import unittest
import warnings
......@@ -734,10 +735,10 @@ class TestComplexElementwiseAddOp(OpTest):
class TestRealComplexElementwiseAddOp(TestComplexElementwiseAddOp):
def init_input_output(self):
self.x = np.random.random(self.shape).astype(self.dtype)
self.y = np.random.random(self.shape).astype(
self.x = np.random.random(self.shape).astype(
self.dtype
) + 1j * np.random.random(self.shape).astype(self.dtype)
self.y = np.random.random(self.shape).astype(self.dtype)
self.out = self.x + self.y
......@@ -848,6 +849,50 @@ class TestTensorAddAPIWarnings(unittest.TestCase):
os.environ['FLAGS_print_extra_attrs'] = "0"
class TestTensorFloa32Bfloat16OrFloat16Add(unittest.TestCase):
def _floa32_bfloat16_or_float16_add(self, y_dtype):
paddle.disable_static()
test_num = 5
val_range = 10000
shapes = []
for i in range(test_num):
shape = [np.random.randint(val_range), np.random.randint(val_range)]
shapes.append(shape)
for i, shape in enumerate(shapes):
x = paddle.randn(list(shape), dtype=paddle.float32)
x_copy = copy.deepcopy(x)
y = paddle.randn(list(shape), dtype=y_dtype)
x.add_(y)
x_copy.add_(paddle.cast(y, paddle.float32))
np.testing.assert_equal(x.numpy(), x_copy.numpy())
del x, x_copy
@unittest.skipIf(
not core.is_compiled_with_cuda()
or core.cudnn_version() < 8100
or paddle.device.cuda.get_device_capability()[0] < 8,
"only support compiled with CUDA and cudnn version need larger than 8.1.0 and device's compute capability is at least 8.0",
)
class TestTensorFloa32Bfloat16Add(TestTensorFloa32Bfloat16OrFloat16Add):
def test_floa32_bfloat16_add(self):
place = core.CUDAPlace(0)
with fluid.dygraph.base.guard(place=place):
self._floa32_bfloat16_or_float16_add(y_dtype=paddle.bfloat16)
@unittest.skipIf(
not core.is_compiled_with_cuda() or core.cudnn_version() < 8100,
"only support compiled with CUDA and cudnn version need larger than 8.1.0",
)
class TestTensorFloa32Float16Add(TestTensorFloa32Bfloat16OrFloat16Add):
def test_floa32_float16_add(self):
place = core.CUDAPlace(0)
with fluid.dygraph.base.guard(place=place):
self._floa32_bfloat16_or_float16_add(y_dtype=paddle.float16)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册