未验证 提交 7aa4d879 编写于 作者: L Leo Chen 提交者: GitHub

add clip_by_norm fp16 kernel (#35446)

* add clip_by_norm fp16 kernel

* add ut
上级 28abd5d8
......@@ -13,8 +13,123 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/clip_by_norm_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename Tx, typename Ty = Tx>
struct SquareTransformer {
HOSTDEVICE explicit inline SquareTransformer(int n) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x) * static_cast<Ty>(x);
}
HOSTDEVICE inline Ty operator()(const Tx* x) const {
return static_cast<Ty>(x[0]) * static_cast<Ty>(x[0]);
}
};
template <typename Tx, typename Ty = Tx>
struct SquareSum {
using Transformer = SquareTransformer<Tx, Ty>;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return b + a;
}
};
template <>
class ClipByNormKernel<platform::CUDADeviceContext, platform::float16>
: public framework::OpKernel<platform::float16> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max_norm = context.Attr<float>("max_norm");
auto in_var = context.InputVar("X");
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
Tensor* output = nullptr;
const Tensor* input = nullptr;
if (in_var->IsType<framework::LoDTensor>()) {
input = context.Input<Tensor>("X");
output = context.Output<Tensor>("Out");
output->mutable_data<platform::float16>(context.GetPlace());
} else if (in_var->IsType<SelectedRows>()) {
auto* x = context.Input<SelectedRows>("X");
// merge ids in selected rows first
math::scatter::MergeAdd<platform::CUDADeviceContext, platform::float16>
merge_func;
SelectedRows* merged_input =
const_cast<framework::Scope&>(context.scope())
.Var()
->GetMutable<SelectedRows>();
merge_func(context.template device_context<platform::CUDADeviceContext>(),
*x, merged_input);
input = &(merged_input->value());
SelectedRows* output_selected_rows = context.Output<SelectedRows>("Out");
output_selected_rows->set_rows(merged_input->rows());
output_selected_rows->set_height(merged_input->height());
output = output_selected_rows->mutable_value();
output->Resize(merged_input->value().dims());
output->mutable_data<platform::float16>(context.GetPlace());
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid input variable type, only support LodTensor and "
"SelectedRows types, but got type is %s.",
framework::ToTypeName(in_var->Type())));
}
PADDLE_ENFORCE_NOT_NULL(input,
platform::errors::InvalidArgument(
"Input(X) of ClipByNormOp should not be null. "
"Please check if it is created correctly."));
std::vector<int> reduce_dims;
reduce_dims.resize(input->dims().size());
for (int i = 0; i < reduce_dims.size(); ++i) {
reduce_dims[i] = i;
}
Tensor tmp = context.AllocateTmpTensor<float, platform::CUDADeviceContext>(
{1}, dev_ctx);
TensorReduceFunctorImpl<platform::float16, float, SquareSum>(
*input, &tmp, reduce_dims, dev_ctx.stream());
auto tmp_eigen = EigenVector<float>::Flatten(tmp);
auto x_norm = tmp_eigen.sqrt();
auto x = EigenVector<platform::float16>::Flatten(*input);
auto out = EigenVector<platform::float16>::Flatten(*output);
auto& place =
*context.template device_context<platform::CUDADeviceContext>()
.eigen_device();
auto temp = (x_norm <= max_norm).template cast<float>();
auto epsilon =
((x_norm <= static_cast<float>(1e-30)).all().template cast<float>()) *
static_cast<float>(1e-6);
auto scaling =
(temp + (static_cast<float>(1) - temp) * max_norm / (x_norm + epsilon))
.template cast<platform::float16>();
Eigen::array<int, 1> one_dim{{1}};
Eigen::DSizes<int, 1> m_dsize(input->numel());
out.device(place) = x * scaling.reshape(one_dim).broadcast(m_dsize);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
clip_by_norm,
ops::ClipByNormKernel<paddle::platform::CUDADeviceContext, float>);
ops::ClipByNormKernel<paddle::platform::CUDADeviceContext, float>,
ops::ClipByNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
......@@ -25,8 +25,9 @@ import paddle.fluid.core as core
class TestClipByNormOp(OpTest):
def setUp(self):
self.max_relative_error = 0.006
self.init_dtype()
self.initTestCase()
input = np.random.random(self.shape).astype("float32")
input = np.random.random(self.shape).astype(self.dtype)
input[np.abs(input) < self.max_relative_error] = 0.5
self.op_type = "clip_by_norm"
self.inputs = {'X': input, }
......@@ -46,6 +47,9 @@ class TestClipByNormOp(OpTest):
self.shape = (100, )
self.max_norm = 1.0
def init_dtype(self):
self.dtype = np.float32
class TestCase1(TestClipByNormOp):
def initTestCase(self):
......@@ -65,6 +69,35 @@ class TestCase3(TestClipByNormOp):
self.max_norm = 1.0
class TestClipByNormOpFp16(TestClipByNormOp):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=0.001)
class TestClipByNormOpFp16Case1(TestClipByNormOpFp16):
def initTestCase(self):
self.shape = (100, )
self.max_norm = 1e20
class TestClipByNormOpFp16Case2(TestClipByNormOpFp16):
def initTestCase(self):
self.shape = (16, 16)
self.max_norm = 0.1
class TestClipByNormOpFp16Case3(TestClipByNormOpFp16):
def initTestCase(self):
self.shape = (4, 8, 16)
self.max_norm = 1.0
class TestClipByNormOpWithSelectedRows(unittest.TestCase):
def check_with_place(self, place):
self.config_test_case()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册