未验证 提交 c6272b6a 编写于 作者: Z zhangkaihuo 提交者: GitHub

Some Ops support fp16 (#44295)

* sparse support amp

* EagerAmpAutoCasts support sparse
上级 dc5a0420
...@@ -39,6 +39,27 @@ static inline bool NeedCast(const paddle::experimental::Tensor& tensor, ...@@ -39,6 +39,27 @@ static inline bool NeedCast(const paddle::experimental::Tensor& tensor,
return false; return false;
} }
inline paddle::experimental::Tensor Cast(
const paddle::experimental::Tensor& input,
const paddle::experimental::DataType& dst_dtype,
const bool trace_backward = true) {
if (input.is_sparse_coo_tensor() || input.is_sparse_csr_tensor()) {
if (trace_backward) {
return sparse::cast_final_state_dygraph_function(
input, paddle::experimental::DataType::UNDEFINED, dst_dtype);
} else {
return paddle::experimental::sparse::cast(
input, paddle::experimental::DataType::UNDEFINED, dst_dtype);
}
} else {
if (trace_backward) {
return cast_final_state_dygraph_function(input, dst_dtype);
} else {
return paddle::experimental::cast(input, dst_dtype);
}
}
}
inline std::vector<paddle::experimental::Tensor> EagerAmpAutoCasts( inline std::vector<paddle::experimental::Tensor> EagerAmpAutoCasts(
const std::string& inputs_name, const std::string& inputs_name,
const std::vector<paddle::experimental::Tensor>& inputs, const std::vector<paddle::experimental::Tensor>& inputs,
...@@ -51,13 +72,7 @@ inline std::vector<paddle::experimental::Tensor> EagerAmpAutoCasts( ...@@ -51,13 +72,7 @@ inline std::vector<paddle::experimental::Tensor> EagerAmpAutoCasts(
std::vector<paddle::experimental::Tensor> inputs_casted; std::vector<paddle::experimental::Tensor> inputs_casted;
for (auto& input : inputs) { for (auto& input : inputs) {
if (NeedCast(input, dst_dtype)) { if (NeedCast(input, dst_dtype)) {
if (trace_backward) { inputs_casted.emplace_back(std::move(Cast(input, dst_dtype)));
inputs_casted.emplace_back(
std::move(cast_final_state_dygraph_function(input, dst_dtype)));
} else {
inputs_casted.emplace_back(
std::move(paddle::experimental::cast(input, dst_dtype)));
}
} else { } else {
inputs_casted.emplace_back(input); inputs_casted.emplace_back(input);
} }
...@@ -92,11 +107,7 @@ inline paddle::experimental::Tensor EagerAmpAutoCast( ...@@ -92,11 +107,7 @@ inline paddle::experimental::Tensor EagerAmpAutoCast(
} }
} }
if (NeedCast(input, dst_dtype)) { if (NeedCast(input, dst_dtype)) {
if (trace_backward) { return Cast(input, dst_dtype, trace_backward);
return cast_final_state_dygraph_function(input, dst_dtype);
} else {
return paddle::experimental::cast(input, dst_dtype);
}
} }
return input; return input;
} }
......
...@@ -503,5 +503,10 @@ void Pad3dGradKernel(const Context& dev_ctx, ...@@ -503,5 +503,10 @@ void Pad3dGradKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(pad3d_grad,
pad3d_grad, GPU, ALL_LAYOUT, phi::Pad3dGradKernel, float, double) {} GPU,
ALL_LAYOUT,
phi::Pad3dGradKernel,
float,
double,
phi::dtype::float16) {}
...@@ -97,6 +97,7 @@ PD_REGISTER_KERNEL(empty_like_coo, ...@@ -97,6 +97,7 @@ PD_REGISTER_KERNEL(empty_like_coo,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::sparse::EmptyLikeCooKernel, phi::sparse::EmptyLikeCooKernel,
phi::dtype::float16,
float, float,
double, double,
int8_t, int8_t,
...@@ -112,6 +113,7 @@ PD_REGISTER_KERNEL(empty_like_csr, ...@@ -112,6 +113,7 @@ PD_REGISTER_KERNEL(empty_like_csr,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::sparse::EmptyLikeCsrKernel, phi::sparse::EmptyLikeCsrKernel,
phi::dtype::float16,
float, float,
double, double,
int8_t, int8_t,
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
GPU, \ GPU, \
ALL_LAYOUT, \ ALL_LAYOUT, \
phi::sparse::prefix##CooGradKernel, \ phi::sparse::prefix##CooGradKernel, \
phi::dtype::float16, \
float, \ float, \
double) { \ double) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \ kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \
...@@ -32,6 +33,7 @@ ...@@ -32,6 +33,7 @@
GPU, \ GPU, \
ALL_LAYOUT, \ ALL_LAYOUT, \
phi::sparse::prefix##CsrGradKernel, \ phi::sparse::prefix##CsrGradKernel, \
phi::dtype::float16, \
float, \ float, \
double) { \ double) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \ kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
...@@ -56,6 +58,7 @@ PD_REGISTER_KERNEL(cast_coo_grad, ...@@ -56,6 +58,7 @@ PD_REGISTER_KERNEL(cast_coo_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::sparse::CastCooGradKernel, phi::sparse::CastCooGradKernel,
phi::dtype::float16,
float, float,
double, double,
int8_t, int8_t,
...@@ -69,6 +72,7 @@ PD_REGISTER_KERNEL(cast_csr_grad, ...@@ -69,6 +72,7 @@ PD_REGISTER_KERNEL(cast_csr_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::sparse::CastCsrGradKernel, phi::sparse::CastCsrGradKernel,
phi::dtype::float16,
float, float,
double, double,
int8_t, int8_t,
......
...@@ -67,6 +67,7 @@ void DivCsrScalarKernel(const Context& dev_ctx, ...@@ -67,6 +67,7 @@ void DivCsrScalarKernel(const Context& dev_ctx,
GPU, \ GPU, \
ALL_LAYOUT, \ ALL_LAYOUT, \
phi::sparse::prefix##CooKernel, \ phi::sparse::prefix##CooKernel, \
phi::dtype::float16, \
float, \ float, \
double) { \ double) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \ kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \
...@@ -76,6 +77,7 @@ void DivCsrScalarKernel(const Context& dev_ctx, ...@@ -76,6 +77,7 @@ void DivCsrScalarKernel(const Context& dev_ctx,
GPU, \ GPU, \
ALL_LAYOUT, \ ALL_LAYOUT, \
phi::sparse::prefix##CsrKernel, \ phi::sparse::prefix##CsrKernel, \
phi::dtype::float16, \
float, \ float, \
double) { \ double) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \ kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
...@@ -119,6 +121,7 @@ PD_REGISTER_KERNEL(cast_coo, ...@@ -119,6 +121,7 @@ PD_REGISTER_KERNEL(cast_coo,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::sparse::CastCooKernel, phi::sparse::CastCooKernel,
phi::dtype::float16,
float, float,
double, double,
int8_t, int8_t,
...@@ -132,6 +135,7 @@ PD_REGISTER_KERNEL(cast_csr, ...@@ -132,6 +135,7 @@ PD_REGISTER_KERNEL(cast_csr,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::sparse::CastCsrKernel, phi::sparse::CastCsrKernel,
phi::dtype::float16,
float, float,
double, double,
int8_t, int8_t,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册