未验证 提交 276017bb 编写于 作者: Z zhangyikun02 提交者: GitHub

conv2d support FP16 on xpu and update unittest for conv2d, test=kunlun (#40395)

上级 1eb96eec
...@@ -19,14 +19,16 @@ namespace operators { ...@@ -19,14 +19,16 @@ namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class GemmConvXPUKernel : public framework::OpKernel<T> { class GemmConvXPUKernel : public framework::OpKernel<T> {
using XPUT = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor *input = context.Input<Tensor>("Input");
// The filter will be reshaped in the calculations, // The filter will be reshaped in the calculations,
// so here use an assignment operation, // so here use an assignment operation,
// that avoids modifying the variable in the Scope. // that avoids modifying the variable in the Scope.
Tensor filter = *context.Input<Tensor>("Filter"); Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output"); Tensor *output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
int groups = context.Attr<int>("groups"); int groups = context.Attr<int>("groups");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
...@@ -53,11 +55,16 @@ class GemmConvXPUKernel : public framework::OpKernel<T> { ...@@ -53,11 +55,16 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
const int img_h = static_cast<int>(input->dims()[2]); const int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]); const int img_w = static_cast<int>(input->dims()[3]);
const int f = static_cast<int>(filter.dims()[0]); const int f = static_cast<int>(filter.dims()[0]);
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::conv2d<float, float, float, int16_t>( const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>());
dev_ctx.x_context(), input->data<float>(), filter.data<float>(), const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>());
output->data<float>(), batch_size, img_c, img_h, img_w, f, ksize, XPUT *output_data = reinterpret_cast<XPUT *>(output->data<T>());
strides, paddings, dilations, groups, nullptr, nullptr, nullptr, true);
auto &dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>(
dev_ctx.x_context(), input_data, filter_data, output_data, batch_size,
img_c, img_h, img_w, f, ksize, strides, paddings, dilations, groups,
nullptr, nullptr, nullptr, true);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS, r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d %s]", platform::errors::External("XPU conv kernel return wrong value[%d %s]",
...@@ -67,14 +74,16 @@ class GemmConvXPUKernel : public framework::OpKernel<T> { ...@@ -67,14 +74,16 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class GemmConvGradXPUKernel : public framework::OpKernel<T> { class GemmConvGradXPUKernel : public framework::OpKernel<T> {
using XPUT = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor *input = context.Input<Tensor>("Input");
const Tensor* output_grad = const Tensor *output_grad =
context.Input<Tensor>(framework::GradVarName("Output")); context.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad = Tensor *input_grad =
context.Output<Tensor>(framework::GradVarName("Input")); context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad = Tensor *filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter")); context.Output<Tensor>(framework::GradVarName("Filter"));
// The filter and filter_grad will be reshaped in the calculations, // The filter and filter_grad will be reshaped in the calculations,
// so here use an assignment operation, // so here use an assignment operation,
...@@ -107,19 +116,27 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> { ...@@ -107,19 +116,27 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
const int img_h = static_cast<int>(input->dims()[2]); const int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]); const int img_w = static_cast<int>(input->dims()[3]);
const int f = static_cast<int>(filter.dims()[0]); const int f = static_cast<int>(filter.dims()[0]);
const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>());
const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>());
const XPUT *output_grad_data =
reinterpret_cast<const XPUT *>(output_grad->data<T>());
XPUT *input_grad_data = nullptr;
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
input_grad_data = reinterpret_cast<XPUT *>(input_grad->data<T>());
} }
XPUT *filter_grad_data = nullptr;
if (filter_grad) { if (filter_grad) {
filter_grad->mutable_data<T>(context.GetPlace()); filter_grad->mutable_data<T>(context.GetPlace());
filter_grad_data = reinterpret_cast<XPUT *>(filter_grad->data<T>());
} }
auto& dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::conv2d_grad<float, float, float, int16_t>( int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int16_t>(
dev_ctx.x_context(), input->data<T>(), filter.data<T>(), dev_ctx.x_context(), input_data, filter_data, output_grad_data,
output_grad->data<T>(), input_grad ? input_grad->data<T>() : nullptr, input_grad_data, filter_grad_data, batch_size, img_c, img_h, img_w, f,
filter_grad ? filter_grad->data<T>() : nullptr, batch_size, img_c, ksize, strides, paddings, dilations, groups, nullptr, nullptr, nullptr,
img_h, img_w, f, ksize, strides, paddings, dilations, groups, nullptr, nullptr, nullptr, true);
nullptr, nullptr, nullptr, nullptr, true);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS, r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d %s]", platform::errors::External("XPU conv kernel return wrong value[%d %s]",
...@@ -130,14 +147,22 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> { ...@@ -130,14 +147,22 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
depthwise_conv2d, conv2d, ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext,
REGISTER_OP_XPU_KERNEL( paddle::platform::float16>);
conv2d, ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
conv2d_grad, conv2d_grad,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
depthwise_conv2d,
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
depthwise_conv2d_grad, depthwise_conv2d_grad,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif #endif
...@@ -51,16 +51,20 @@ XPUOpMap& get_kl2_ops() { ...@@ -51,16 +51,20 @@ XPUOpMap& get_kl2_ops() {
{"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
{"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"conv2d_transpose_grad", {"conv2d_transpose_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"conv2d_transpose", {"conv2d_transpose",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d_grad", {"depthwise_conv2d_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"depthwise_conv2d", {"depthwise_conv2d",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"dropout_grad", {"dropout_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"dropout", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"dropout", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册