未验证 提交 3ee2b237 编写于 作者: C cyberslack_lee 提交者: GitHub

【Hackathon4 No58】fix exponential and pad (#51300)

上级 351ccb63
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
#include <random> #include <random>
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/generator.h" #include "paddle/phi/core/generator.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/distribution_helper.h"
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#pragma once #pragma once
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/exponential_kernel.h" #include "paddle/phi/kernels/exponential_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/distribution_helper.h"
...@@ -25,12 +25,19 @@ void ExponentialKernel(const Context &dev_ctx, ...@@ -25,12 +25,19 @@ void ExponentialKernel(const Context &dev_ctx,
const DenseTensor &x, const DenseTensor &x,
float lambda, float lambda,
DenseTensor *out) { DenseTensor *out) {
phi::funcs::uniform_distribution<T> dist; using MT = typename kps::details::MPTypeTrait<T>::Type;
phi::funcs::exponential_transform<T> trans(lambda); phi::funcs::uniform_distribution<MT> dist;
phi::funcs::exponential_transform<MT> trans(lambda);
phi::funcs::distribution_and_transform<T>(dev_ctx, out, dist, trans); phi::funcs::distribution_and_transform<T>(dev_ctx, out, dist, trans);
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(exponential,
exponential, GPU, ALL_LAYOUT, phi::ExponentialKernel, float, double) {} GPU,
ALL_LAYOUT,
phi::ExponentialKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -25,5 +25,6 @@ PD_REGISTER_KERNEL(pad_grad, ...@@ -25,5 +25,6 @@ PD_REGISTER_KERNEL(pad_grad,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -28,5 +28,6 @@ PD_REGISTER_KERNEL(pad, ...@@ -28,5 +28,6 @@ PD_REGISTER_KERNEL(pad,
int, int,
int64_t, int64_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi { namespace phi {
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi { namespace phi {
......
...@@ -15,9 +15,14 @@ ...@@ -15,9 +15,14 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import (
OpTest,
convert_float_to_uint16,
convert_uint16_to_float,
)
import paddle import paddle
from paddle.fluid import core
class TestExponentialOp1(OpTest): class TestExponentialOp1(OpTest):
...@@ -344,6 +349,104 @@ class TestExponentialAPI(unittest.TestCase): ...@@ -344,6 +349,104 @@ class TestExponentialAPI(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
class TestExponentialFP16Op(OpTest):
def setUp(self):
paddle.enable_static()
self.op_type = "exponential"
self.python_api = paddle.tensor.exponential_
self.config()
self.attrs = {"lambda": self.lam}
self.inputs = {'X': np.empty([1024, 1024], dtype=self.dtype)}
self.outputs = {'Out': np.ones([1024, 1024], dtype=self.dtype)}
def config(self):
self.lam = 0.5
self.dtype = np.float16
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
hist1, _ = np.histogram(outs[0], range=(0, 5))
hist1 = hist1.astype(np.float16)
hist1 = hist1 / float(outs[0].size)
data_np = np.random.exponential(1.0 / self.lam, [1024, 1024])
hist2, _ = np.histogram(data_np, range=(0, 5))
hist2 = hist2.astype(np.float16)
hist2 = hist2 / float(data_np.size)
np.testing.assert_allclose(hist1, hist2, rtol=0.05)
def test_check_grad_normal(self):
self.check_grad(
['X'],
'Out',
in_place=True,
user_defined_grads=[np.zeros([1024, 1024], dtype=self.dtype)],
user_defined_grad_outputs=[
np.random.rand(1024, 1024).astype(self.dtype)
],
check_dygraph=False, # inplace can not call paddle.grad
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and not support the bfloat16",
)
class TestExponentialBP16Op(OpTest):
def setUp(self):
paddle.enable_static()
self.op_type = "exponential"
self.python_api = paddle.tensor.exponential_
self.config()
x = np.empty([1024, 1024]).astype('float32')
out = np.ones([1024, 1024]).astype('float32')
self.inputs = {'X': convert_float_to_uint16(x)}
self.attrs = {"lambda": self.lam}
self.outputs = {'Out': convert_float_to_uint16(out)}
def config(self):
self.lam = 0.5
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place_customized(
checker=self.verify_output, place=place
)
def verify_output(self, outs):
outs = convert_uint16_to_float(outs)
self.assertEqual(outs[0].shape, (1024, 1024))
hist1, _ = np.histogram(outs[0], range=(-3, 5))
hist1 = hist1.astype("float32")
hist1 = hist1 / float(outs[0].size)
data_np = np.random.exponential(1.0 / self.lam, [1024, 1024])
hist2, _ = np.histogram(data_np, range=(-3, 5))
hist2 = hist2.astype("float32")
hist2 = hist2 / float(data_np.size)
np.testing.assert_allclose(hist1, hist2, rtol=0.05)
def test_check_grad_normal(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['X'],
'Out',
in_place=True,
user_defined_grads=[np.zeros([1024, 1024], dtype=self.dtype)],
user_defined_grad_outputs=[
np.random.rand(1024, 1024).astype(self.dtype)
],
check_dygraph=False, # inplace can not call paddle.grad
)
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -16,7 +16,7 @@ import os ...@@ -16,7 +16,7 @@ import os
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
from test_attribute_var import UnittestBase from test_attribute_var import UnittestBase
import paddle import paddle
...@@ -96,7 +96,7 @@ def create_test_fp16(parent): ...@@ -96,7 +96,7 @@ def create_test_fp16(parent):
return np.float16 return np.float16
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', max_relative_error=0.3) self.check_grad(['X'], 'Out')
cls_name = "{}_{}".format(parent.__name__, "Fp16") cls_name = "{}_{}".format(parent.__name__, "Fp16")
TestPadFp16.__name__ = cls_name TestPadFp16.__name__ = cls_name
...@@ -202,6 +202,41 @@ class TestPaddingValueTensor3(unittest.TestCase): ...@@ -202,6 +202,41 @@ class TestPaddingValueTensor3(unittest.TestCase):
np.testing.assert_allclose(pd_out, np_out) np.testing.assert_allclose(pd_out, np_out)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and not support the bfloat16",
)
class TestPadBP16Op(OpTest):
def setUp(self):
self.initTestCase()
self.dtype = np.uint16
self.op_type = "pad"
self.python_api = pad_wrapper
x = np.random.random(self.shape).astype(np.float32)
self.attrs = {}
self.attrs['paddings'] = np.array(self.paddings).flatten()
self.attrs['pad_value'] = self.pad_value
out = np.pad(
x, self.paddings, mode='constant', constant_values=self.pad_value
)
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(out)}
def initTestCase(self):
self.shape = (16, 16)
self.paddings = [(0, 1), (2, 3)]
self.pad_value = 0.0
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -1154,7 +1154,9 @@ def exponential_(x, lam=1.0, name=None): ...@@ -1154,7 +1154,9 @@ def exponential_(x, lam=1.0, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.exponential_(x, lam) return _C_ops.exponential_(x, lam)
else: else:
check_variable_and_dtype(x, "x", ["float32", "float64"], "exponential") check_variable_and_dtype(
x, "x", ["float16", "float32", "float64", "uint16"], "exponential"
)
helper = LayerHelper("exponential", **locals()) helper = LayerHelper("exponential", **locals())
helper.append_op( helper.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册