未验证 提交 02e9453f 编写于 作者: Q QingshuChen 提交者: GitHub

add xpu resnet_unit (#44297)

* add xpu resnet_unit
*test=kunlun

* tmp
*test=kunlun
上级 74412dfe
...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so") ...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so")
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
set(XPU_BASE_URL_WITHOUT_DATE set(XPU_BASE_URL_WITHOUT_DATE
"https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220712") set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220718")
else() else()
set(XPU_BASE_URL "${XPU_BASE_URL}") set(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
...@@ -19,7 +19,7 @@ endif() ...@@ -19,7 +19,7 @@ endif()
if(NOT DEFINED XPU_XDNN_BASE_URL) if(NOT DEFINED XPU_XDNN_BASE_URL)
set(XPU_XDNN_BASE_URL_WITHOUT_DATE set(XPU_XDNN_BASE_URL_WITHOUT_DATE
"https://klx-sdk-release-public.su.bcebos.com/xdnn/dev") "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220712") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220718")
else() else()
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif() endif()
......
...@@ -35,6 +35,7 @@ op_library(fusion_lstm_op) ...@@ -35,6 +35,7 @@ op_library(fusion_lstm_op)
if(WITH_XPU) if(WITH_XPU)
op_library(resnet_basic_block_op) op_library(resnet_basic_block_op)
op_library(resnet_unit_op)
endif() endif()
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
......
...@@ -159,22 +159,28 @@ class ResNetUnitOp : public framework::OperatorWithKernel { ...@@ -159,22 +159,28 @@ class ResNetUnitOp : public framework::OperatorWithKernel {
bn_param_dims, bn_param_dims,
bn_param_dims.size())); bn_param_dims.size()));
auto data_format = ctx->Attrs().Get<std::string>("data_format"); auto data_format = ctx->Attrs().Get<std::string>("data_format");
PADDLE_ENFORCE_EQ( bool is_nchw = (data_format == "NCHW");
data_format,
"NHWC",
platform::errors::InvalidArgument("The data format must equal to NHWC. "
"But received: the data format "
"= [%s]",
data_format));
// Calculate the dims of outputs // Calculate the dims of outputs
int batch = x_dims[0]; int batch = x_dims[0];
int output_channel = w_dims[0]; int output_channel = w_dims[0];
int filter_size = w_dims[2]; int filter_size = w_dims[2];
int stride = ctx->Attrs().Get<int>("stride"); int stride = ctx->Attrs().Get<int>("stride");
int padding = ctx->Attrs().Get<int>("padding"); int padding = ctx->Attrs().Get<int>("padding");
int out_h = (x_dims[1] + padding * 2 - filter_size) / stride + 1; std::vector<int> out_shape;
int out_w = (x_dims[2] + padding * 2 - filter_size) / stride + 1; out_shape.push_back(batch);
std::vector<int> out_shape = {batch, out_h, out_w, output_channel}; if (is_nchw) {
int out_h = (x_dims[2] + padding * 2 - filter_size) / stride + 1;
int out_w = (x_dims[3] + padding * 2 - filter_size) / stride + 1;
out_shape.push_back(output_channel);
out_shape.push_back(out_h);
out_shape.push_back(out_w);
} else {
int out_h = (x_dims[1] + padding * 2 - filter_size) / stride + 1;
int out_w = (x_dims[2] + padding * 2 - filter_size) / stride + 1;
out_shape.push_back(out_h);
out_shape.push_back(out_w);
out_shape.push_back(output_channel);
}
auto y_dims = phi::make_ddim(out_shape); auto y_dims = phi::make_ddim(out_shape);
auto bitmask_dims = GetBitmaskDims(out_shape); auto bitmask_dims = GetBitmaskDims(out_shape);
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class ResNetUnitXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
PADDLE_ENFORCE_EQ(
platform::is_xpu_place(place),
true,
platform::errors::PreconditionNotMet("It must use XPUPlace."));
bool is_nchw = (ctx.Attr<std::string>("data_format") == "NCHW");
// input x
const Tensor *input_x = ctx.Input<Tensor>("X");
const Tensor *filter_x = ctx.Input<Tensor>("FilterX");
const Tensor *scale_x = ctx.Input<Tensor>("ScaleX");
const Tensor *bias_x = ctx.Input<Tensor>("BiasX");
// output x
Tensor *conv_out_x = ctx.Output<Tensor>("ConvX");
Tensor *saved_mean_x = ctx.Output<Tensor>("SavedMeanX");
Tensor *saved_invstd_x = ctx.Output<Tensor>("SavedInvstdX");
Tensor *running_mean_x = ctx.Output<Tensor>("RunningMeanX");
Tensor *running_var_x = ctx.Output<Tensor>("RunningVarX");
Tensor *output = ctx.Output<Tensor>("Y");
// attrs
int padding = ctx.Attr<int>("padding");
int stride = ctx.Attr<int>("stride");
int stride_z = ctx.Attr<int>("stride_z");
int dilation = ctx.Attr<int>("dilation");
int group = ctx.Attr<int>("group");
float eps = ctx.Attr<float>("epsilon");
float momentum = ctx.Attr<float>("momentum");
bool has_shortcut = ctx.Attr<bool>("has_shortcut");
bool fuse_add = ctx.Attr<bool>("fuse_add");
bool use_global_stats = ctx.Attr<bool>("use_global_stats");
bool is_test = ctx.Attr<bool>("is_test");
bool is_train = !is_test && !use_global_stats;
std::string act_type = ctx.Attr<std::string>("act_type");
auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
std::vector<const T *> x_list = {input_x->data<T>()};
std::vector<const T *> w_list = {filter_x->data<T>()};
std::vector<T *> conv_y_list = {conv_out_x->mutable_data<T>(place)};
std::vector<std::vector<int>> x_shape_list = {
phi::vectorize<int>(input_x->dims())};
auto filter_x_shape = phi::vectorize<int>(filter_x->dims());
std::vector<int> ksize = {filter_x_shape[2], filter_x_shape[3]};
if (!is_nchw) {
ksize[0] = filter_x_shape[1];
ksize[1] = filter_x_shape[2];
}
std::vector<int> strides = {stride, stride};
std::vector<std::vector<int>> ksize_list = {ksize};
std::vector<std::vector<int>> stride_list = {strides};
std::vector<int> paddings = {padding, padding};
std::vector<int> dilations = {dilation, dilation};
std::vector<const float *> scale_list = {scale_x->data<float>()};
std::vector<const float *> bias_list = {bias_x->data<float>()};
std::vector<float *> batch_mean_list = {
saved_mean_x->mutable_data<float>(place)};
std::vector<float *> batch_invstd_list = {
saved_invstd_x->mutable_data<float>(place)};
std::vector<float *> global_mean_list = {
running_mean_x->mutable_data<float>(place)};
std::vector<float *> global_var_list = {
running_var_x->mutable_data<float>(place)};
std::vector<const float *> x_maxlist = {nullptr};
std::vector<const float *> w_maxlist = {nullptr};
if (has_shortcut) {
// input z
const Tensor *input_z = ctx.Input<Tensor>("Z");
const Tensor *filter_z = ctx.Input<Tensor>("FilterZ");
const Tensor *scale_z = ctx.Input<Tensor>("ScaleZ");
const Tensor *bias_z = ctx.Input<Tensor>("BiasZ");
Tensor *conv_out_z = ctx.Output<Tensor>("ConvZ");
Tensor *saved_mean_z = ctx.Output<Tensor>("SavedMeanZ");
Tensor *saved_invstd_z = ctx.Output<Tensor>("SavedInvstdZ");
Tensor *running_mean_z = ctx.Output<Tensor>("RunningMeanZ");
Tensor *running_var_z = ctx.Output<Tensor>("RunningVarZ");
x_list.push_back(input_z->data<T>());
w_list.push_back(filter_z->data<T>());
conv_y_list.push_back(conv_out_z->mutable_data<T>(place));
x_shape_list.push_back(phi::vectorize<int>(input_z->dims()));
auto filter_z_shape = phi::vectorize<int>(filter_z->dims());
std::vector<int> ksize_z = {filter_z_shape[2], filter_z_shape[3]};
if (!is_nchw) {
ksize_z[0] = filter_z_shape[1];
ksize_z[1] = filter_z_shape[2];
}
ksize_list.push_back(ksize_z);
stride_list.push_back({stride_z, stride_z});
scale_list.push_back(scale_z->data<float>());
bias_list.push_back(bias_z->data<float>());
batch_mean_list.push_back(saved_mean_z->mutable_data<float>(place));
batch_invstd_list.push_back(saved_invstd_z->mutable_data<float>(place));
global_mean_list.push_back(running_mean_z->mutable_data<float>(place));
global_var_list.push_back(running_var_z->mutable_data<float>(place));
x_maxlist.push_back(nullptr);
w_maxlist.push_back(nullptr);
} else {
if (fuse_add) {
const Tensor *input_z = ctx.Input<Tensor>("Z");
auto input_z_shape = phi::vectorize<int>(input_z->dims());
x_list.push_back(input_z->data<T>());
x_shape_list.push_back(input_z_shape);
x_maxlist.push_back(nullptr);
}
}
int r = xpu::resnet_unit_fusion<T, T, T, int16_t>(
dev_ctx.x_context(),
x_list,
w_list,
conv_y_list,
output->mutable_data<T>(place),
x_shape_list,
filter_x_shape[0],
ksize_list,
stride_list,
paddings,
dilations,
group,
eps,
momentum,
x_maxlist,
w_maxlist,
scale_list,
bias_list,
batch_mean_list,
batch_invstd_list,
global_mean_list,
global_var_list,
xpu::Activation_t::RELU,
is_nchw,
has_shortcut,
fuse_add,
is_train);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "resnet_unit_fusion");
}
};
template <typename T>
class ResNetUnitGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
PADDLE_ENFORCE_EQ(
platform::is_xpu_place(place),
true,
platform::errors::PreconditionNotMet("It must use XPUPlace."));
bool is_nchw = (ctx.Attr<std::string>("data_format") == "NCHW");
const Tensor *y_grad = ctx.Input<Tensor>(framework::GradVarName("Y"));
const Tensor *x = ctx.Input<Tensor>("X");
const Tensor *filter_x = ctx.Input<Tensor>("FilterX");
const Tensor *scale_x = ctx.Input<Tensor>("ScaleX");
const Tensor *saved_mean_x = ctx.Input<Tensor>("SavedMeanX");
const Tensor *saved_invstd_x = ctx.Input<Tensor>("SavedInvstdX");
const Tensor *conv_out_x = ctx.Input<Tensor>("ConvX");
const Tensor *output = ctx.Input<Tensor>("Y");
Tensor *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
Tensor *filter_x_grad =
ctx.Output<Tensor>(framework::GradVarName("FilterX"));
Tensor *scale_x_grad = ctx.Output<Tensor>(framework::GradVarName("ScaleX"));
Tensor *bias_x_grad = ctx.Output<Tensor>(framework::GradVarName("BiasX"));
int padding = ctx.Attr<int>("padding");
int stride = ctx.Attr<int>("stride");
int stride_z = ctx.Attr<int>("stride_z");
int dilation = ctx.Attr<int>("dilation");
int group = ctx.Attr<int>("group");
float eps = ctx.Attr<float>("epsilon");
bool has_shortcut = ctx.Attr<bool>("has_shortcut");
bool fuse_add = ctx.Attr<bool>("fuse_add");
std::string act_type = ctx.Attr<std::string>("act_type");
auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
std::vector<const T *> x_list = {x->data<T>()};
std::vector<const T *> w_list = {filter_x->data<T>()};
std::vector<const T *> conv_y_list = {conv_out_x->data<T>()};
std::vector<T *> dx_list = {x_grad->mutable_data<T>(place)};
std::vector<T *> dw_list = {filter_x_grad->mutable_data<T>(place)};
std::vector<std::vector<int>> x_shape_list = {
phi::vectorize<int>(x->dims())};
auto filter_x_shape = phi::vectorize<int>(filter_x->dims());
std::vector<int> x_ksize = {filter_x_shape[2], filter_x_shape[3]};
if (!is_nchw) {
x_ksize[0] = filter_x_shape[1];
x_ksize[1] = filter_x_shape[2];
}
std::vector<std::vector<int>> ksize_list = {x_ksize};
std::vector<std::vector<int>> stride_list = {{stride, stride}};
std::vector<int> paddings = {padding, padding};
std::vector<int> dilations = {dilation, dilation};
std::vector<const float *> x_maxlist = {nullptr};
std::vector<const float *> w_maxlist = {nullptr};
std::vector<const float *> scale_list = {scale_x->data<float>()};
std::vector<const float *> batch_mean_list = {saved_mean_x->data<float>()};
std::vector<const float *> batch_invstd_list = {
saved_invstd_x->data<float>()};
std::vector<float *> dscale_list = {
scale_x_grad->mutable_data<float>(place)};
std::vector<float *> dbias_list = {bias_x_grad->mutable_data<float>(place)};
if (has_shortcut) {
// X Z
// | |
// NormConv NormConv
// | |
// BNStatsFinalize BNStatsFinalize
// \ /
// ScaleBiasAddRelu
// |
// Y
const Tensor *z = ctx.Input<Tensor>("Z");
const Tensor *filter_z = ctx.Input<Tensor>("FilterZ");
const Tensor *scale_z = ctx.Input<Tensor>("ScaleZ");
const Tensor *saved_mean_z = ctx.Input<Tensor>("SavedMeanZ");
const Tensor *saved_invstd_z = ctx.Input<Tensor>("SavedInvstdZ");
const Tensor *conv_out_z = ctx.Input<Tensor>("ConvZ");
Tensor *z_grad = ctx.Output<Tensor>(framework::GradVarName("Z"));
Tensor *filter_z_grad =
ctx.Output<Tensor>(framework::GradVarName("FilterZ"));
Tensor *scale_z_grad =
ctx.Output<Tensor>(framework::GradVarName("ScaleZ"));
Tensor *bias_z_grad = ctx.Output<Tensor>(framework::GradVarName("BiasZ"));
x_list.push_back(z->data<T>());
w_list.push_back(filter_z->data<T>());
conv_y_list.push_back(conv_out_z->data<T>());
dx_list.push_back(z_grad->mutable_data<T>(place));
dw_list.push_back(filter_z_grad->mutable_data<T>(place));
x_shape_list.push_back(phi::vectorize<int>(z->dims()));
auto filter_z_shape = phi::vectorize<int>(filter_z->dims());
std::vector<int> ksize_z = {filter_z_shape[2], filter_z_shape[3]};
if (!is_nchw) {
ksize_z[0] = filter_z_shape[1];
ksize_z[1] = filter_z_shape[2];
}
ksize_list.push_back(ksize_z);
stride_list.push_back({stride_z, stride_z});
x_maxlist.push_back(nullptr);
w_maxlist.push_back(nullptr);
scale_list.push_back(scale_z->data<float>());
batch_mean_list.push_back(saved_mean_z->data<float>());
batch_invstd_list.push_back(saved_invstd_z->data<float>());
dscale_list.push_back(scale_z_grad->mutable_data<float>(place));
dbias_list.push_back(bias_z_grad->mutable_data<float>(place));
} else {
if (fuse_add) {
auto z_grad = ctx.Output<Tensor>(framework::GradVarName("Z"));
dx_list.push_back(z_grad->mutable_data<T>(place));
}
}
int r =
xpu::resnet_unit_grad_fusion<T, T, T, int16_t>(dev_ctx.x_context(),
x_list,
w_list,
y_grad->data<T>(),
output->data<T>(),
conv_y_list,
dx_list,
dw_list,
x_shape_list,
filter_x_shape[0],
ksize_list,
stride_list,
paddings,
dilations,
group,
x_maxlist,
w_maxlist,
scale_list,
batch_mean_list,
batch_invstd_list,
dscale_list,
dbias_list,
xpu::Activation_t::RELU,
eps,
is_nchw,
has_shortcut,
fuse_add);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "resnet_unit_grad_fusion");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(resnet_unit, ops::ResNetUnitXPUKernel<float>);
REGISTER_OP_XPU_KERNEL(resnet_unit_grad, ops::ResNetUnitGradXPUKernel<float>);
...@@ -374,6 +374,9 @@ XPUOpMap& get_kl2_ops() { ...@@ -374,6 +374,9 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"resnet_unit", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"resnet_unit_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
...@@ -87,7 +87,9 @@ xpu_test_device_type_white_list = ['xpu1_float64'] ...@@ -87,7 +87,9 @@ xpu_test_device_type_white_list = ['xpu1_float64']
xpu_test_op_type_white_list = [ xpu_test_op_type_white_list = [
'dropout_float16', 'dropout_float16',
'dropout_grad_float16', 'dropout_grad_float16',
"grad_add_float32" # no api for grad_add, skip "grad_add_float32", # no api for grad_add, skip
"resnet_unit",
"resnet_unit_grad"
] ]
xpu_test_device_op_white_list = [] xpu_test_device_op_white_list = []
xpu_test_device_op_type_white_list = [] xpu_test_device_op_type_white_list = []
......
...@@ -170,7 +170,7 @@ class ResNetUnit(Layer): ...@@ -170,7 +170,7 @@ class ResNetUnit(Layer):
self._is_test = is_test self._is_test = is_test
# check format # check format
valid_format = {'NHWC'} valid_format = {'NHWC', 'NCHW'}
if data_format not in valid_format: if data_format not in valid_format:
raise ValueError( raise ValueError(
"conv_format must be one of {}, but got conv_format='{}'". "conv_format must be one of {}, but got conv_format='{}'".
...@@ -181,11 +181,25 @@ class ResNetUnit(Layer): ...@@ -181,11 +181,25 @@ class ResNetUnit(Layer):
std = (2.0 / filter_elem_num)**0.5 std = (2.0 / filter_elem_num)**0.5
return I.Normal(0.0, std) return I.Normal(0.0, std)
is_nchw = (data_format == 'NCHW')
# initial filter # initial filter
bn_param_dtype = fluid.core.VarDesc.VarType.FP32 bn_param_dtype = fluid.core.VarDesc.VarType.FP32
bn_param_shape = [1, 1, 1, num_filters] if not is_nchw:
filter_x_shape = [num_filters, filter_size, filter_size, num_channels_x] bn_param_shape = [1, 1, 1, num_filters]
filter_z_shape = [num_filters, filter_size, filter_size, num_channels_z] filter_x_shape = [
num_filters, filter_size, filter_size, num_channels_x
]
filter_z_shape = [
num_filters, filter_size, filter_size, num_channels_z
]
else:
bn_param_shape = [1, num_filters, 1, 1]
filter_x_shape = [
num_filters, num_channels_x, filter_size, filter_size
]
filter_z_shape = [
num_filters, num_channels_z, filter_size, filter_size
]
self.filter_x = self.create_parameter( self.filter_x = self.create_parameter(
shape=filter_x_shape, shape=filter_x_shape,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册