未验证 提交 b79c6a9b 编写于 作者: Z zhangbo9674 提交者: GitHub

add cast_grad phi kernel (#40798)

* add cast_grad phi kernel

* refie unittest

* refien unittest

* refine unittest

* refine include header path

* refien xpu cast unittest

* refine code
上级 56cd3407
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void CastGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cast_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/cast_impl.h"
namespace phi {
template <typename T, typename Context>
void CastGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
PD_VISIT_ALL_TYPES(x.dtype(), "CastKernelImpl", ([&] {
CastKernelImpl<T, data_t>(dev_ctx, out_grad, x_grad);
}));
}
} // namespace phi
PD_REGISTER_KERNEL(cast_grad,
CPU,
ALL_LAYOUT,
phi::CastGradKernel,
float,
double,
int,
int64_t,
int16_t,
bool,
uint8_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/transform.h"
namespace phi {
template <typename InT, typename OutT>
struct CastOpTransformFunctor {
HOSTDEVICE OutT operator()(InT in) const { return static_cast<OutT>(in); }
};
template <typename InT, typename OutT>
void CastKernelImpl(const CPUContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
auto* in_begin = x.data<InT>();
auto numel = x.numel();
auto* in_end = in_begin + numel;
auto* out_begin = dev_ctx.Alloc<OutT>(out);
paddle::platform::Transform<CPUContext> trans;
trans(dev_ctx,
in_begin,
in_end,
out_begin,
CastOpTransformFunctor<InT, OutT>());
}
} // namespace phi
...@@ -13,39 +13,12 @@ ...@@ -13,39 +13,12 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/cpu/cast_impl.h"
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/transform.h"
namespace phi { namespace phi {
template <typename InT, typename OutT>
struct CastOpTransformFunctor {
HOSTDEVICE OutT operator()(InT in) const { return static_cast<OutT>(in); }
};
template <typename InT, typename OutT>
void CastKernelImpl(const CPUContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
auto* in_begin = x.data<InT>();
auto numel = x.numel();
auto* in_end = in_begin + numel;
auto* out_begin = dev_ctx.Alloc<OutT>(out);
paddle::platform::Transform<CPUContext> trans;
trans(dev_ctx,
in_begin,
in_end,
out_begin,
CastOpTransformFunctor<InT, OutT>());
}
template <typename T, typename Context> template <typename T, typename Context>
void CastKernel(const Context& dev_ctx, void CastKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_grad_kernel.h"
#include "paddle/phi/kernels/gpu/cast_impl.h"
namespace phi {
template <typename T, typename Context>
void CastGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
PD_VISIT_ALL_TYPES(x.dtype(), "CastCUDAKernelImpl", ([&] {
CastCUDAKernelImpl<T, data_t>(dev_ctx, out_grad, x_grad);
}));
}
} // namespace phi
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PD_REGISTER_KERNEL(cast_grad, \
GPU, \
ALL_LAYOUT, \
phi::CastGradKernel, \
float, \
double, \
int, \
int64_t, \
int16_t, \
bool, \
uint8_t, \
phi::dtype::float16, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>, \
##__VA_ARGS__) { \
kernel->OutputAt(0).SetDataType( \
paddle::experimental::DataType::UNDEFINED); \
}
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast_grad, phi::dtype::bfloat16)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
namespace phi {
template <typename InT, typename OutT>
struct CastFuctor {
__device__ __forceinline__ OutT operator()(const InT x) const {
return static_cast<OutT>(x);
}
};
template <typename InT, typename OutT>
void CastCUDAKernelImpl(const GPUContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x);
outputs.emplace_back(out);
dev_ctx.Alloc<OutT>(out);
phi::funcs::ElementwiseKernel<OutT>(
dev_ctx, inputs, &outputs, CastFuctor<InT, OutT>());
}
} // namespace phi
...@@ -12,42 +12,12 @@ ...@@ -12,42 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/gpu/cast_impl.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device/gpu/gpu_helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace phi { namespace phi {
template <typename InT, typename OutT>
struct CastFuctor {
__device__ __forceinline__ OutT operator()(const InT x) const {
return static_cast<OutT>(x);
}
};
template <typename InT, typename OutT>
void CastCUDAKernelImpl(const GPUContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x);
outputs.emplace_back(out);
dev_ctx.Alloc<OutT>(out);
phi::funcs::ElementwiseKernel<OutT>(
dev_ctx, inputs, &outputs, CastFuctor<InT, OutT>());
}
template <typename T, typename Context> template <typename T, typename Context>
void CastKernel(const Context& dev_ctx, void CastKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
......
...@@ -52,6 +52,7 @@ class TestCastOpFp16ToFp32(OpTest): ...@@ -52,6 +52,7 @@ class TestCastOpFp16ToFp32(OpTest):
'out_dtype': int(core.VarDesc.VarType.FP32) 'out_dtype': int(core.VarDesc.VarType.FP32)
} }
self.op_type = 'cast' self.op_type = 'cast'
self.__class__.no_need_check_grad = True
def test_check_output(self): def test_check_output(self):
self.check_output(atol=1e-3) self.check_output(atol=1e-3)
...@@ -67,6 +68,7 @@ class TestCastOpFp32ToFp16(OpTest): ...@@ -67,6 +68,7 @@ class TestCastOpFp32ToFp16(OpTest):
'out_dtype': int(core.VarDesc.VarType.FP16) 'out_dtype': int(core.VarDesc.VarType.FP16)
} }
self.op_type = 'cast' self.op_type = 'cast'
self.__class__.no_need_check_grad = True
def test_check_output(self): def test_check_output(self):
self.check_output(atol=1e-3) self.check_output(atol=1e-3)
...@@ -82,6 +84,7 @@ class TestCastOpBf16ToFp32(OpTest): ...@@ -82,6 +84,7 @@ class TestCastOpBf16ToFp32(OpTest):
'out_dtype': int(core.VarDesc.VarType.FP32) 'out_dtype': int(core.VarDesc.VarType.FP32)
} }
self.op_type = 'cast' self.op_type = 'cast'
self.__class__.no_need_check_grad = True
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -97,6 +100,7 @@ class TestCastOpFp32ToBf16(OpTest): ...@@ -97,6 +100,7 @@ class TestCastOpFp32ToBf16(OpTest):
'out_dtype': int(core.VarDesc.VarType.BF16) 'out_dtype': int(core.VarDesc.VarType.BF16)
} }
self.op_type = 'cast' self.op_type = 'cast'
self.__class__.no_need_check_grad = True
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
...@@ -76,7 +76,8 @@ NO_FP64_CHECK_GRAD_OP_LIST = [ ...@@ -76,7 +76,8 @@ NO_FP64_CHECK_GRAD_OP_LIST = [
'trilinear_interp_v2', \ 'trilinear_interp_v2', \
'var_conv_2d', \ 'var_conv_2d', \
'warpctc', \ 'warpctc', \
'bilateral_slice' 'bilateral_slice', \
'cast'
] ]
NO_FP16_CHECK_GRAD_OP_LIST = [ NO_FP16_CHECK_GRAD_OP_LIST = [
......
...@@ -44,6 +44,7 @@ def create_test_class(in_typename, out_typename): ...@@ -44,6 +44,7 @@ def create_test_class(in_typename, out_typename):
'out_dtype': typeid_dict[out_typename], 'out_dtype': typeid_dict[out_typename],
} }
self.op_type = 'cast' self.op_type = 'cast'
self.__class__.no_need_check_grad = True
def test_check_output(self): def test_check_output(self):
if paddle.is_compiled_with_xpu(): if paddle.is_compiled_with_xpu():
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
func : cast func : cast
param : [x, out_dtype] param : [x, out_dtype]
data_type : x data_type : x
backward : cast_grad
- api : concat - api : concat
......
...@@ -307,6 +307,16 @@ ...@@ -307,6 +307,16 @@
kernel : kernel :
func : mv_grad func : mv_grad
- backward_api : cast_grad
forward : cast (Tensor x, DataType out_dtype) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : cast_grad
data_type : out_grad
# =================================== sep0 # =================================== sep0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册