未验证 提交 c6272b6a 编写于 作者: Z zhangkaihuo 提交者: GitHub

Some Ops support fp16 (#44295)

* sparse support amp

* EagerAmpAutoCasts support sparse
上级 dc5a0420
......@@ -39,6 +39,27 @@ static inline bool NeedCast(const paddle::experimental::Tensor& tensor,
return false;
}
inline paddle::experimental::Tensor Cast(
const paddle::experimental::Tensor& input,
const paddle::experimental::DataType& dst_dtype,
const bool trace_backward = true) {
if (input.is_sparse_coo_tensor() || input.is_sparse_csr_tensor()) {
if (trace_backward) {
return sparse::cast_final_state_dygraph_function(
input, paddle::experimental::DataType::UNDEFINED, dst_dtype);
} else {
return paddle::experimental::sparse::cast(
input, paddle::experimental::DataType::UNDEFINED, dst_dtype);
}
} else {
if (trace_backward) {
return cast_final_state_dygraph_function(input, dst_dtype);
} else {
return paddle::experimental::cast(input, dst_dtype);
}
}
}
inline std::vector<paddle::experimental::Tensor> EagerAmpAutoCasts(
const std::string& inputs_name,
const std::vector<paddle::experimental::Tensor>& inputs,
......@@ -51,13 +72,7 @@ inline std::vector<paddle::experimental::Tensor> EagerAmpAutoCasts(
std::vector<paddle::experimental::Tensor> inputs_casted;
for (auto& input : inputs) {
if (NeedCast(input, dst_dtype)) {
if (trace_backward) {
inputs_casted.emplace_back(
std::move(cast_final_state_dygraph_function(input, dst_dtype)));
} else {
inputs_casted.emplace_back(
std::move(paddle::experimental::cast(input, dst_dtype)));
}
inputs_casted.emplace_back(std::move(Cast(input, dst_dtype)));
} else {
inputs_casted.emplace_back(input);
}
......@@ -92,11 +107,7 @@ inline paddle::experimental::Tensor EagerAmpAutoCast(
}
}
if (NeedCast(input, dst_dtype)) {
if (trace_backward) {
return cast_final_state_dygraph_function(input, dst_dtype);
} else {
return paddle::experimental::cast(input, dst_dtype);
}
return Cast(input, dst_dtype, trace_backward);
}
return input;
}
......
......@@ -503,5 +503,10 @@ void Pad3dGradKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
pad3d_grad, GPU, ALL_LAYOUT, phi::Pad3dGradKernel, float, double) {}
PD_REGISTER_KERNEL(pad3d_grad,
GPU,
ALL_LAYOUT,
phi::Pad3dGradKernel,
float,
double,
phi::dtype::float16) {}
......@@ -97,6 +97,7 @@ PD_REGISTER_KERNEL(empty_like_coo,
GPU,
ALL_LAYOUT,
phi::sparse::EmptyLikeCooKernel,
phi::dtype::float16,
float,
double,
int8_t,
......@@ -112,6 +113,7 @@ PD_REGISTER_KERNEL(empty_like_csr,
GPU,
ALL_LAYOUT,
phi::sparse::EmptyLikeCsrKernel,
phi::dtype::float16,
float,
double,
int8_t,
......
......@@ -23,6 +23,7 @@
GPU, \
ALL_LAYOUT, \
phi::sparse::prefix##CooGradKernel, \
phi::dtype::float16, \
float, \
double) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \
......@@ -32,6 +33,7 @@
GPU, \
ALL_LAYOUT, \
phi::sparse::prefix##CsrGradKernel, \
phi::dtype::float16, \
float, \
double) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
......@@ -56,6 +58,7 @@ PD_REGISTER_KERNEL(cast_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::CastCooGradKernel,
phi::dtype::float16,
float,
double,
int8_t,
......@@ -69,6 +72,7 @@ PD_REGISTER_KERNEL(cast_csr_grad,
GPU,
ALL_LAYOUT,
phi::sparse::CastCsrGradKernel,
phi::dtype::float16,
float,
double,
int8_t,
......
......@@ -67,6 +67,7 @@ void DivCsrScalarKernel(const Context& dev_ctx,
GPU, \
ALL_LAYOUT, \
phi::sparse::prefix##CooKernel, \
phi::dtype::float16, \
float, \
double) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \
......@@ -76,6 +77,7 @@ void DivCsrScalarKernel(const Context& dev_ctx,
GPU, \
ALL_LAYOUT, \
phi::sparse::prefix##CsrKernel, \
phi::dtype::float16, \
float, \
double) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
......@@ -119,6 +121,7 @@ PD_REGISTER_KERNEL(cast_coo,
GPU,
ALL_LAYOUT,
phi::sparse::CastCooKernel,
phi::dtype::float16,
float,
double,
int8_t,
......@@ -132,6 +135,7 @@ PD_REGISTER_KERNEL(cast_csr,
GPU,
ALL_LAYOUT,
phi::sparse::CastCsrKernel,
phi::dtype::float16,
float,
double,
int8_t,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册