未验证 提交 4e62af80 编写于 作者: C cc 提交者: GitHub

Add FP16 PRelu (#35532)

上级 afd1b372
develop 1.8.5 2.4.1 Ligoml-patch-1 ZHUI-patch-1 add_kylinv10 add_some_yaml_config bugfix-eval-frame-leakgae cherry-pick-fix-customOP-random-fail cherry_undefined_var cp_2.4_fix_numpy delete_disable_iterable_dataset_unittest delete_fix_undefined_var delete_revert-36057-dev/read_flags_in_ut dingjiaweiww-patch-1 disable_iterable_dataset_unittest dy2static enable_eager_model_test final_state_gen_python_c final_state_intermediate fix-numpy-issue fix-run-program-grad-node-mem fix_check fix_concat_slice fix_custom_device_copy_sync fix_dlpack_for fix_newexe_gc fix_op_flops fix_rnn_docs fix_tensor_type fix_undefined_var fix_var_stop_gradient_error hack_event incuabte/new_frl incubate/frl_train_eval incubate/infrt incubate/new_frl incubate/new_frl_rc incubate/stride inplace_addto layer_norm make_flag_adding_easier matmul_double_grad move_embedding_to_phi move_histogram_to_pten move_sgd_to_phi move_slice_to_pten move_temporal_shift_to_phi move_yolo_box_to_phi npu_fix_alloc operator_opt pass-compile-eval-frame preln_ernie prv-md-even-more prv-onednn-2.5 prv-reshape-mkldnn-ut2 pten_tensor_refactor release-deleted/2.5 release-rc/2.5 release/2.2 release/2.3 release/2.3-fc-ernie-fix release/2.4 release/2.5 release/llm_2.5 revert-36057-dev/read_flags_in_ut revert-36201-refine_fast_threaded_ssa_graph_executor revert-36985-add_license revert-37318-refactor_dygraph_to_eager revert-37926-eager_coreops_500 revert-37956-revert-37727-pylayer_support_tuple revert-38100-mingdong revert-38301-allocation_rearrange_pr revert-38703-numpy_bf16_package_reupload revert-38732-remove_useless_header_in_elementwise_mul_grad revert-38959-Reduce_Grad revert-39143-adjust_empty revert-39227-move_trace_op_to_pten revert-39268-dev/remove_concat_fluid_kernel revert-40170-support_partial_grad revert-41056-revert-40727-move_some_activaion_to_phi revert-41065-revert-40993-mv_ele_floordiv_pow revert-41068-revert-40790-phi_new revert-41944-smaller_inference_api_test revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator revert-43155-fix_ut_tempfile revert-43882-revert-41944-smaller_inference_api_test revert-45808-phi/simplify_size_op revert-46827-deform_comment revert-47325-remove_cudnn_hardcode revert-47645-add_npu_storage_dims revert-48815-set_free_when_no_cache_hit_default_value_true revert-49499-test_ninja_on_ci revert-49654-prim_api_gen revert-49673-modify_get_single_cov revert-49763-fix_static_composite_gen revert-50158-fix_found_inf_bug_for_custom_optimizer revert-50188-refine_optimizer_create_accumulators revert-50335-fix_optminizer_set_auxiliary_var_bug revert-51676-flag_delete revert-51850-fix_softmaxce_dev revert-52175-dev_peak_memory revert-52186-deve revert-52523-test_py38 revert-52912-develop revert-53248-set_cmake_policy revert-54029-fix_windows_compile_bug revert-54068-support_translating_op_attribute revert-54214-modify_cmake_dependencies revert-54370-offline_pslib revert-54391-fix_cmake_md5error revert-54411-fix_cpp17_compile revert-54466-offline_pslib revert-54480-cmake-rocksdb revert-55568-fix_BF16_bug1 revert-56328-new_ir_support_vector_type_place_transfer revert-56366-fix_openssl_bug revert-56545-revert-56366-fix_openssl_bug revert-56620-fix_new_ir_ocr_bug revert-56925-check_inputs_grad_semantic revert-57005-refine_stride_flag sd_conv_linear_autocast semi-auto/rule-base support-0D-sort support_weight_transpose test_for_Filtetfiles zhiqiu-patch-1 v2.5.1 v2.5.0 v2.5.0-rc1 v2.5.0-rc0 v2.4.2 v2.4.1 v2.4.0 v2.4.0-rc0 v2.3.2 v2.3.1 v2.3.0 v2.3.0-rc0 v2.2.2 v2.2.1 v2.2.0 v2.2.0-rc0 v2.2.0-bak0
无相关合并请求
......@@ -33,7 +33,8 @@ __global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
size_t channel_index = temp % channel_num;
T scale = alpha[channel_index];
T x = input[index];
output[index] = (x > 0) ? x : scale * x;
T zero = static_cast<T>(0);
output[index] = (x > zero) ? x : scale * x;
}
}
......@@ -45,7 +46,8 @@ __global__ void PReluElementWiseKernel(const T *input, const T *alpha,
size_t element_index = index % spatial_size;
T scale = alpha[element_index];
T x = input[index];
output[index] = (x > 0) ? x : scale * x;
T zero = static_cast<T>(0);
output[index] = (x > zero) ? x : scale * x;
}
}
......@@ -55,7 +57,8 @@ __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output,
T scale = alpha[0];
CUDA_KERNEL_LOOP(index, numel) {
T x = input[index];
output[index] = (x > 0) ? x : scale * x;
T zero = static_cast<T>(0);
output[index] = (x > zero) ? x : scale * x;
}
}
......@@ -88,12 +91,15 @@ void PreluScalarDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
}
template class PreluChannelWiseDirectCUDAFunctor<float>;
template class PreluChannelWiseDirectCUDAFunctor<paddle::platform::float16>;
template class PreluChannelWiseDirectCUDAFunctor<double>;
template class PreluElementWiseDirectCUDAFunctor<float>;
template class PreluElementWiseDirectCUDAFunctor<paddle::platform::float16>;
template class PreluElementWiseDirectCUDAFunctor<double>;
template class PreluScalarDirectCUDAFunctor<float>;
template class PreluScalarDirectCUDAFunctor<paddle::platform::float16>;
template class PreluScalarDirectCUDAFunctor<double>;
} // namespace math
......
......@@ -87,8 +87,9 @@ __global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr,
}
T x = x_ptr[index];
T dy = dy_ptr[index];
if (dx_ptr != nullptr) dx_ptr[index] = (x > 0) ? dy : scale * dy;
if (dalpha_ptr != nullptr) dalpha_ptr[index] = (x > 0) ? 0 : x * dy;
T zero = static_cast<T>(0);
if (dx_ptr != nullptr) dx_ptr[index] = (x > zero) ? dy : scale * dy;
if (dalpha_ptr != nullptr) dalpha_ptr[index] = (x > zero) ? zero : x * dy;
}
}
......@@ -112,9 +113,11 @@ class PreluOpGradFunctor {
}
};
template <typename T>
struct IdentityFunctor {
HOSTDEVICE inline T operator()(const T& x) const { return x; }
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const {
return x;
}
};
template <typename DeviceContext, typename T>
......@@ -174,9 +177,9 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
reduce_dims.push_back(i);
}
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
TensorReduce<T, T, cub::Sum, IdentityFunctor>(
dalpha_tmp, dalpha, reduce_dims, static_cast<T>(0), cub::Sum(),
IdentityFunctor<T>(), stream);
IdentityFunctor(), stream);
}
};
......@@ -184,10 +187,14 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
prelu, ops::CUDAPReluKernel<paddle::platform::CUDADeviceContext, float>,
ops::CUDAPReluKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::CUDAPReluKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
prelu_grad,
ops::CUDAPReluGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::CUDAPReluGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::CUDAPReluGradKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -153,11 +153,12 @@ class TestNNPReluAPI(unittest.TestCase):
class PReluTest(OpTest):
def setUp(self):
self.init_dtype()
self.init_input_shape()
self.init_attr()
self.op_type = "prelu"
x_np = np.random.uniform(-1, 1, self.x_shape)
x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype)
# Since zero point in prelu is not differentiable, avoid randomize
# zero.
x_np[np.abs(x_np) < 0.005] = 0.02
......@@ -168,6 +169,7 @@ class PReluTest(OpTest):
alpha_np = np.random.uniform(-1, -0.5, [1, self.x_shape[1], 1, 1])
else:
alpha_np = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:])
alpha_np = alpha_np.astype(self.dtype)
self.inputs = {'X': x_np, 'Alpha': alpha_np}
......@@ -184,6 +186,9 @@ class PReluTest(OpTest):
assert out_np is not self.inputs['X']
self.outputs = {'Out': out_np}
def init_dtype(self):
self.dtype = np.float64
def init_input_shape(self):
self.x_shape = [2, 100, 3, 4]
......@@ -270,6 +275,44 @@ class TestModeElementRank6(PReluTest):
self.attrs = {'mode': "element"}
def create_test_fp16_class(parent,
check_grad=True,
atol=1e-3,
max_relative_error=0.05):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestPReluFp16Case(parent):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=atol)
def test_check_grad(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and check_grad:
self.check_grad_with_place(
place, ['X', 'Alpha'],
'Out',
max_relative_error=max_relative_error)
cls_name = "{0}_{1}".format(parent.__name__, "Fp16Op")
TestPReluFp16Case.__name__ = cls_name
globals()[cls_name] = TestPReluFp16Case
create_test_fp16_class(TestModeElt)
create_test_fp16_class(TestModeAllRank3)
create_test_fp16_class(TestModeAllRank6)
create_test_fp16_class(TestModeChannelRank3)
create_test_fp16_class(TestModeChannelRank6)
create_test_fp16_class(TestModeElementRank3)
create_test_fp16_class(TestModeElementRank6)
def prelu_t(x, mode, param_attr=None, name=None):
helper = fluid.layer_helper.LayerHelper('prelu', **locals())
alpha_shape = [1, x.shape[1], 1, 1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部