未验证 提交 3f2a665a 编写于 作者: G Guoxia Wang 提交者: GitHub

support data_format='NHWC' for prelu channel mode (#37019)

* support data_format='NHWC' for prelu channel mode
上级 0c82e3a0
......@@ -34,6 +34,11 @@ class PReluOpConverter : public OpConverter {
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
// Get attrs
std::string mode = BOOST_GET_CONST(std::string, op_desc.GetAttr("mode"));
std::string data_format = "NCHW";
if (op_desc.HasAttr("data_format")) {
data_format =
BOOST_GET_CONST(std::string, op_desc.GetAttr("data_format"));
}
auto* alpha_var = scope.FindVar(op_desc.Input("Alpha")[0]);
auto* alpha_tensor = alpha_var->GetMutable<framework::LoDTensor>();
......@@ -47,7 +52,7 @@ class PReluOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
plugin::PReluPluginDynamic* plugin = new plugin::PReluPluginDynamic(
alpha_data, alpha_tensor_temp->numel(), mode);
alpha_data, alpha_tensor_temp->numel(), mode, data_format);
layer = engine_->AddDynamicPlugin(&input, input_num, plugin);
} else {
#if IS_TRT_VERSION_GE(7000)
......@@ -74,8 +79,8 @@ class PReluOpConverter : public OpConverter {
layer = TRT_ENGINE_ADD_LAYER(engine_, ParametricReLU, *input,
*alpha_layer_output);
#else
plugin::PReluPlugin* plugin =
new plugin::PReluPlugin(alpha_data, alpha_tensor_temp->numel(), mode);
plugin::PReluPlugin* plugin = new plugin::PReluPlugin(
alpha_data, alpha_tensor_temp->numel(), mode, data_format);
layer = engine_->AddPlugin(&input, input_num, plugin);
#endif
}
......
......@@ -69,10 +69,11 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs,
}
if (mode_ == "channel") {
bool channel_last = data_format_ == "NHWC";
operators::math::PreluChannelWiseDirectCUDAFunctor<float>
prelu_channel_wise;
prelu_channel_wise(stream, input, alpha, output, input_dims.d[0],
input_dims.d[1], numel);
input_dims.d[1], channel_last, numel);
} else if (mode_ == "element") {
operators::math::PreluElementWiseDirectCUDAFunctor<float>
prelu_element_wise;
......@@ -168,10 +169,11 @@ int PReluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
}
if (mode_ == "channel") {
bool channel_last = data_format_ == "NHWC";
operators::math::PreluChannelWiseDirectCUDAFunctor<float>
prelu_channel_wise;
prelu_channel_wise(stream, input, alpha, output, input_dims.d[0],
input_dims.d[1], numel);
input_dims.d[1], channel_last, numel);
} else if (mode_ == "element") {
operators::math::PreluElementWiseDirectCUDAFunctor<float>
prelu_element_wise;
......
......@@ -32,11 +32,12 @@ class PReluPlugin : public PluginTensorRT {
std::vector<float> weight_;
float* p_gpu_weight_;
std::string mode_;
std::string data_format_;
public:
size_t getSerializationSize() const TRT_NOEXCEPT override {
return getBaseSerializationSize() + SerializedSize(mode_.c_str()) +
SerializedSize(weight_);
SerializedSize(data_format_.c_str()) + SerializedSize(weight_);
}
// TRT will call this func when we need to serialize the configuration of
......@@ -46,11 +47,12 @@ class PReluPlugin : public PluginTensorRT {
serializeBase(buffer);
SerializeValue(&buffer, weight_);
SerializeValue(&buffer, mode_.c_str());
SerializeValue(&buffer, data_format_.c_str());
}
PReluPlugin(const float* weight, const int weight_num,
std::string const& mode)
: mode_(mode) {
std::string const& mode, std::string const& data_format)
: mode_(mode), data_format_(data_format) {
weight_.resize(weight_num);
std::copy(weight, weight + weight_num, weight_.data());
}
......@@ -63,13 +65,17 @@ class PReluPlugin : public PluginTensorRT {
const char* prelu_mode;
DeserializeValue(&serialData, &serialLength, &prelu_mode);
mode_ = std::string(prelu_mode);
const char* prelu_data_format;
DeserializeValue(&serialData, &serialLength, &prelu_data_format);
data_format_ = std::string(prelu_data_format);
}
~PReluPlugin() {}
int initialize() TRT_NOEXCEPT override;
void terminate() TRT_NOEXCEPT override;
PReluPlugin* clone() const TRT_NOEXCEPT override {
auto* ptr = new PReluPlugin(weight_.data(), weight_.size(), mode_);
auto* ptr =
new PReluPlugin(weight_.data(), weight_.size(), mode_, data_format_);
ptr->p_gpu_weight_ = p_gpu_weight_;
return ptr;
}
......@@ -108,8 +114,8 @@ REGISTER_TRT_PLUGIN_V2(PReluPluginCreator);
class PReluPluginDynamic : public DynamicPluginTensorRT {
public:
PReluPluginDynamic(const float* weight, const int weight_num,
std::string const& mode)
: mode_(mode) {
std::string const& mode, std::string const& data_format)
: mode_(mode), data_format_(data_format) {
weight_.resize(weight_num);
std::copy(weight, weight + weight_num, weight_.data());
}
......@@ -117,7 +123,8 @@ class PReluPluginDynamic : public DynamicPluginTensorRT {
PReluPluginDynamic(void const* serialData, size_t serialLength);
~PReluPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
auto ptr = new PReluPluginDynamic(weight_.data(), weight_.size(), mode_);
auto ptr = new PReluPluginDynamic(weight_.data(), weight_.size(), mode_,
data_format_);
ptr->p_gpu_weight_ = p_gpu_weight_;
return ptr;
}
......@@ -167,6 +174,7 @@ class PReluPluginDynamic : public DynamicPluginTensorRT {
std::vector<float> weight_;
float* p_gpu_weight_;
std::string mode_;
std::string data_format_;
};
#endif
......
......@@ -25,7 +25,7 @@ inline static int PADDLE_GET_BLOCKS(const int N) {
}
template <typename T>
__global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
__global__ void PReluChannelFirstWiseKernel(const T *input, const T *alpha,
T *output, size_t channel_num,
size_t plane_size, size_t numel) {
CUDA_KERNEL_LOOP(index, numel) {
......@@ -38,6 +38,19 @@ __global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
}
}
template <typename T>
__global__ void PReluChannelLastWiseKernel(const T *input, const T *alpha,
T *output, size_t channel_num,
size_t numel) {
CUDA_KERNEL_LOOP(index, numel) {
size_t channel_index = index % channel_num;
T scale = alpha[channel_index];
T x = input[index];
T zero = static_cast<T>(0);
output[index] = (x > zero) ? x : scale * x;
}
}
template <typename T>
__global__ void PReluElementWiseKernel(const T *input, const T *alpha,
T *output, size_t spatial_size,
......@@ -65,10 +78,16 @@ __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output,
template <typename T>
void PreluChannelWiseDirectCUDAFunctor<T>::operator()(
gpuStream_t stream, const T *input, const T *alpha, T *output,
size_t batch_size, size_t channel, size_t numel) {
PReluChannelWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
size_t batch_size, size_t channel, bool channel_last, size_t numel) {
if (channel_last) {
PReluChannelLastWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
stream>>>(input, alpha, output, channel,
numel / batch_size / channel, numel);
numel);
} else {
PReluChannelFirstWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
stream>>>(
input, alpha, output, channel, numel / batch_size / channel, numel);
}
}
template <typename T>
......
......@@ -31,7 +31,8 @@ template <typename T>
class PreluChannelWiseDirectCUDAFunctor {
public:
void operator()(gpuStream_t stream, const T *input, const T *alpha, T *output,
size_t batch_size, size_t channel, size_t numel);
size_t batch_size, size_t channel, bool channel_last,
size_t numel);
};
template <typename T>
......
......@@ -34,7 +34,7 @@ class PReluMKLDNNHandler
const dnnl::engine engine, platform::Place cpu_place,
const Tensor* x, const Tensor* weights,
const std::string& uniq_name, const std::string& mode,
bool is_test = false)
const std::string& data_format, bool is_test = false)
: platform::MKLDNNHandlerT<T, dnnl::prelu_forward, dnnl::prelu_backward>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
......@@ -49,9 +49,14 @@ class PReluMKLDNNHandler
if (weights->dims().size() != x->dims().size()) {
auto new_weights_dims = std::vector<int64_t>(x->dims().size(), 1);
if (mode == "channel") {
if (data_format == "NHWC") {
new_weights_dims[x->dims().size() - 1] =
*std::max_element(weights_dims.begin(), weights_dims.end());
} else {
new_weights_dims[1] =
*std::max_element(weights_dims.begin(), weights_dims.end());
}
}
weights_dims = std::move(new_weights_dims);
}
auto weights_md = memory::desc(weights_dims, MKLDNNGetDataType<T>(),
......@@ -110,9 +115,11 @@ class PReluMKLDNNKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<Tensor>("Out");
const bool is_test = ctx.Attr<bool>("is_test");
const auto mode = ctx.Attr<std::string>("mode");
const auto data_format = ctx.Attr<std::string>("data_format");
PReluMKLDNNHandler<T> handler(dev_ctx, onednn_engine, ctx.GetPlace(), x,
alpha, ctx.InputName("X"), mode, is_test);
alpha, ctx.InputName("X"), mode, data_format,
is_test);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto weights_memory_p =
......@@ -149,9 +156,11 @@ class PReluGradMKLDNNKernel : public framework::OpKernel<T> {
auto* alpha = ctx.Input<Tensor>("Alpha");
const bool is_test = ctx.Attr<bool>("is_test");
const auto mode = ctx.Attr<std::string>("mode");
const auto data_format = ctx.Attr<std::string>("data_format");
PReluMKLDNNHandler<T> handler(dev_ctx, onednn_engine, ctx.GetPlace(), x,
alpha, framework::GradVarName("X"), mode);
alpha, framework::GradVarName("X"), mode,
data_format);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto weights_memory_p =
......
......@@ -38,12 +38,6 @@ class PReluOp : public framework::OperatorWithKernel {
"But recevied alpha's size: %d.",
product(ctx->GetInputDim("Alpha"))));
} else if (mode == "channel") {
PADDLE_ENFORCE_EQ(product(ctx->GetInputDim("Alpha")), x_dim[1],
platform::errors::InvalidArgument(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[1]: %d",
product(ctx->GetInputDim("Alpha")), x_dim[1]));
auto x_rank = x_dim.size();
PADDLE_ENFORCE_GE(x_rank, 2,
platform::errors::InvalidArgument(
......@@ -51,6 +45,33 @@ class PReluOp : public framework::OperatorWithKernel {
"equal or larger than 2. But recevied X's "
"rank: %d",
x_rank));
const std::string data_format_str =
ctx->Attrs().Get<std::string>("data_format");
PADDLE_ENFORCE_EQ(data_format_str == "NCHW" || data_format_str == "NHWC",
true,
platform::errors::InvalidArgument(
"For mode 'channel', data_format must be one of "
"NCHW and NHWC. But recevied data_format: %s",
data_format_str));
if (data_format_str == "NCHW") {
PADDLE_ENFORCE_EQ(
product(ctx->GetInputDim("Alpha")) == x_dim[1], true,
platform::errors::InvalidArgument(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[1]: %d",
product(ctx->GetInputDim("Alpha")), x_dim[1]));
} else {
PADDLE_ENFORCE_EQ(
product(ctx->GetInputDim("Alpha")) == x_dim[x_rank - 1], true,
platform::errors::InvalidArgument(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[%d]: %d",
product(ctx->GetInputDim("Alpha")), x_rank - 1,
x_dim[x_rank - 1]));
}
} else if (mode == "element") {
auto alpha_dim = ctx->GetInputDim("Alpha");
auto alpha_rank = alpha_dim.size();
......@@ -134,6 +155,9 @@ There are modes:
)DOC");
AddAttr<std::string>("mode", "The mode for inputs to share weights.")
.SetDefault("all");
AddAttr<std::string>("data_format",
"Data format that specifies the layout of input")
.SetDefault("NCHW");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
......
......@@ -42,17 +42,22 @@ class CUDAPReluKernel : public framework::OpKernel<T> {
const T* alpha_ptr = alpha->data<T>();
auto& mode = context.Attr<std::string>("mode");
auto& data_format = context.Attr<std::string>("data_format");
int numel = x->numel();
auto dim = x->dims();
auto x_rank = dim.size();
VLOG(4) << "dim[0]:" << dim[0] << ", dim[1]:" << dim[1]
<< ", numel:" << numel;
VLOG(4) << "dim[0]:" << dim[0] << ", dim[1]:" << dim[1] << ", dim["
<< x_rank - 1 << "]:" << dim[x_rank - 1] << ", numel:" << numel;
if (mode == "channel") {
bool channel_last = data_format == "NHWC";
size_t channel = channel_last ? dim[x_rank - 1] : dim[1];
math::PreluChannelWiseDirectCUDAFunctor<T> prelu_channel_wise;
prelu_channel_wise(context.cuda_device_context().stream(), x_ptr,
alpha_ptr, o_ptr, dim[0], dim[1], numel);
alpha_ptr, o_ptr, dim[0], channel, channel_last,
numel);
} else if (mode == "element") {
math::PreluElementWiseDirectCUDAFunctor<T> prelu_element_wise;
prelu_element_wise(context.cuda_device_context().stream(), x_ptr,
......@@ -65,7 +70,7 @@ class CUDAPReluKernel : public framework::OpKernel<T> {
}
};
enum PRELU_MODE { Element, Channel, Scalar };
enum PRELU_MODE { Element, ChannelFirst, ChannelLast, Scalar };
template <typename T>
__global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr,
......@@ -78,10 +83,13 @@ __global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr,
if (mode == Element) {
size_t element_index = index % spatial_size;
scale = alpha_ptr[element_index];
} else if (mode == Channel) {
} else if (mode == ChannelFirst) {
size_t temp = index / plane_size;
size_t channel_index = temp % channel_num;
scale = alpha_ptr[channel_index];
} else if (mode == ChannelLast) {
size_t channel_index = index % channel_num;
scale = alpha_ptr[channel_index];
} else {
scale = alpha_ptr[0];
}
......@@ -105,11 +113,13 @@ class PreluOpGradFunctor {
}
size_t plane_size = numel / input_dims[0] / input_dims[1];
size_t spatial_size = numel / input_dims[0];
size_t channel =
mode == ChannelLast ? input_dims[input_dims.size() - 1] : input_dims[1];
PReluOpGradKernel<
T><<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>(
x, alpha, dy, dx, dalpha, input_dims[1], plane_size, spatial_size,
numel, mode);
x, alpha, dy, dx, dalpha, channel, plane_size, spatial_size, numel,
mode);
}
};
......@@ -140,9 +150,11 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
if (!dx && !dalpha) return;
auto& mode = context.Attr<std::string>("mode");
auto& data_format = context.Attr<std::string>("data_format");
int numel = x->numel();
auto dim = x->dims();
auto x_rank = dim.size();
std::vector<int> input_shape = framework::vectorize<int>(dim);
auto stream = context.cuda_device_context().stream();
......@@ -157,10 +169,12 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
}
PRELU_MODE m;
bool channel_last = false;
if (mode == "element") {
m = Element;
} else if (mode == "channel") {
m = Channel;
channel_last = data_format == "NHWC";
m = channel_last ? ChannelLast : ChannelFirst;
} else {
m = Scalar;
}
......@@ -172,7 +186,8 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
std::vector<int> reduce_dims;
for (size_t i = 0; i < dim.size(); i++) {
if (mode == "channel" && i == 1) continue;
if (mode == "channel" && !channel_last && i == 1) continue;
if (mode == "channel" && channel_last && i == dim.size() - 1) continue;
if (mode == "element" && i != 0) continue;
reduce_dims.push_back(i);
}
......
......@@ -33,12 +33,14 @@ class PReluKernel : public framework::OpKernel<T> {
const T* alpha_ptr = alpha->data<T>();
auto& mode = context.Attr<std::string>("mode");
auto& data_format = context.Attr<std::string>("data_format");
int numel = x->numel();
auto dim = x->dims();
int index = 0;
int i = 0;
if (mode == "channel") {
if (data_format == "NCHW") {
int temp = 1;
for (int j = 2; j < dim.size(); j++) {
temp *= dim[j];
......@@ -47,6 +49,12 @@ class PReluKernel : public framework::OpKernel<T> {
index = (i / temp) % dim[1];
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
index = i % dim[dim.size() - 1];
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i];
}
}
} else if (mode == "element") {
int temp = 1;
for (int j = 1; j < dim.size(); j++) {
......@@ -77,6 +85,7 @@ class PReluGradKernel : public framework::OpKernel<T> {
const T* x_ptr = x->data<T>();
const T* dout_ptr = dout->data<T>();
std::string mode = context.Attr<std::string>("mode");
auto& data_format = context.Attr<std::string>("data_format");
int numel = x->numel();
auto dim = x->dims();
int index = 0;
......@@ -84,6 +93,7 @@ class PReluGradKernel : public framework::OpKernel<T> {
if (dx) {
T* dx_ptr = dx->mutable_data<T>(context.GetPlace());
if (mode == "channel") {
if (data_format == "NCHW") {
int temp = 1;
for (int j = 2; j < dim.size(); j++) {
temp *= dim[j];
......@@ -93,6 +103,13 @@ class PReluGradKernel : public framework::OpKernel<T> {
dx_ptr[i] =
x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
index = i % dim[dim.size() - 1];
dx_ptr[i] =
x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i];
}
}
} else if (mode == "element") {
int temp = 1;
for (int j = 1; j < dim.size(); j++) {
......@@ -116,6 +133,7 @@ class PReluGradKernel : public framework::OpKernel<T> {
memset(dalpha_ptr, 0, sizeof(T) * dalpha->numel());
if (mode == "channel") {
if (data_format == "NCHW") {
int temp = 1;
for (int j = 2; j < dim.size(); j++) {
temp *= dim[j];
......@@ -124,6 +142,12 @@ class PReluGradKernel : public framework::OpKernel<T> {
index = (i / temp) % dim[1];
dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
index = i % dim[dim.size() - 1];
dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
}
}
} else if (mode == "element") {
int temp = 1;
for (int j = 1; j < dim.size(); j++) {
......
......@@ -9791,7 +9791,7 @@ def swish(x, beta=1.0, name=None):
@deprecated(since="2.0.0", update_to="paddle.static.nn.prelu")
def prelu(x, mode, param_attr=None, name=None):
def prelu(x, mode, param_attr=None, data_format="NCHW", name=None):
r"""
prelu activation.
......@@ -9819,6 +9819,9 @@ def prelu(x, mode, param_attr=None, name=None):
name (str, optional): Name for the operation (optional, default is None). \
For more information, please refer to :ref:`api_guide_Name`.
data_format(str, optional): Data format that specifies the layout of input.
It may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW".
Returns:
Tensor: A tensor with the same shape and data type as x.
......@@ -9839,17 +9842,32 @@ def prelu(x, mode, param_attr=None, name=None):
helper = LayerHelper('prelu', **locals())
if mode not in ['all', 'channel', 'element']:
raise ValueError('mode should be one of all, channel, element.')
alpha_shape = [1]
# NOTE(): The input of this API should be ``N,C,...`` format,
# which means x.shape[0] is batch_size and x.shape[0] is channel.
if mode == 'channel':
true_data_format = [
'NC', 'NCL', 'NCHW', 'NCDHW', 'NLC', 'NHWC', 'NDHWC'
]
if data_format not in true_data_format:
raise ValueError(
"data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', "
"'NLC', 'NHWC', 'NDHWC' but receive {}".format(data_format))
data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC'
assert len(
x.shape
) >= 2, "The size of input shape should be equal or larger than 2 in prelu() when mode is 'channel'"
#NOTE(zhiqiu): The alpha_shape should be [1, channel] + [1] * len(x.shape[2:]).
# To be consistent with Prelu, it is simplified.
#NOTE(zhiqiu): Revert shape to [1, channel, 1, 1] for compatibility with saved model of old version.
#NOTE(GuoxiaWang): support NHWC data format
if data_format == 'NHWC':
alpha_shape = [1, 1, 1, x.shape[1]]
else:
alpha_shape = [1, x.shape[1], 1, 1]
elif mode == 'element':
assert len(
x.shape
......@@ -9867,7 +9885,8 @@ def prelu(x, mode, param_attr=None, name=None):
type="prelu",
inputs={"X": x,
'Alpha': alpha},
attrs={"mode": mode},
attrs={"mode": mode,
"data_format": data_format},
outputs={"Out": out})
return out
......
......@@ -44,8 +44,12 @@ class TestMkldnnPreluOp(MkldnnAutoScanTest):
if len(kwargs['in_shape']) <= 1:
# not valid case, just return 0
return np.zeros((1)).astype(np.float32)
if kwargs['data_format'] == 'NCHW':
return np.random.random(kwargs['in_shape'][1]).astype(
np.float32)
else:
return np.random.random(kwargs['in_shape'][-1]).astype(
np.float32)
else:
if len(kwargs['in_shape']) <= 1:
# not valid case, just return 0
......@@ -57,7 +61,10 @@ class TestMkldnnPreluOp(MkldnnAutoScanTest):
inputs={"X": ["input_data"],
"Alpha": ["alpha_weight"]},
outputs={"Out": ["output_data"]},
attrs={"mode": kwargs['mode']})
attrs={
"mode": kwargs['mode'],
"data_format": kwargs['data_format']
})
program_config = ProgramConfig(
ops=[prelu_op],
......@@ -82,6 +89,7 @@ class TestMkldnnPreluOp(MkldnnAutoScanTest):
@given(
mode=st.sampled_from(['all', 'channel', 'element']),
data_format=st.sampled_from(['NCHW', 'NHWC']),
in_shape=st.lists(
st.integers(
min_value=1, max_value=32), min_size=1, max_size=4))
......
......@@ -39,7 +39,8 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
def generate_alpha(attrs: List[Dict[str, Any]], dim1, dim2, dim3):
if attrs[0]["mode"] == "all":
return np.random.random(size=(1)).astype(np.float32)
elif attrs[0]["mode"] == "channel":
elif attrs[0]["mode"] == "channel" and attrs[0][
"data_format"] == "NCHW":
shape = [1]
if dim1 != 0:
shape.append(dim1)
......@@ -48,6 +49,16 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
if dim3 != 0:
shape.append(1)
return np.random.random(size=shape).astype(np.float32)
elif attrs[0]["mode"] == "channel" and attrs[0][
"data_format"] == "NHWC":
shape = [1]
if dim1 != 0:
shape.append(1)
if dim2 != 0:
shape.append(1)
if dim3 != 0:
shape.append(dim3)
return np.random.random(size=shape).astype(np.float32)
elif attrs[0]["mode"] == "element":
shape = [1]
if dim1 != 0:
......@@ -72,9 +83,15 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
continue
for mode in ["all", "channel", "element"]:
if mode == "channel" and dim1 == 0:
for data_format in ['NCHW', 'NHWC']:
if mode == "channel" and dim1 == 0 and data_format == "NCHW":
continue
dics = [{"mode": mode}]
if mode == "channel" and dim3 == 0 and data_format == "NHWC":
continue
dics = [{
"mode": mode,
"data_format": data_format
}]
ops_config = [{
"op_type": "prelu",
"op_inputs": {
......@@ -92,13 +109,15 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
ops=ops,
weights={
"alpha_weight": TensorConfig(
data_gen=partial(generate_alpha, dics,
dim1, dim2, dim3))
data_gen=partial(generate_alpha,
dics, dim1, dim2,
dim3))
},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input, batch,
dim1, dim2, dim3)),
data_gen=partial(generate_input,
batch, dim1, dim2,
dim3)),
},
outputs=["output_data"])
......
......@@ -41,10 +41,11 @@ class TestLayerPrint(unittest.TestCase):
self.assertEqual(
str(module), 'Hardtanh(min=-1.0, max=1.0, name=Hardtanh)')
module = nn.PReLU(1, 0.25, name="PReLU")
module = nn.PReLU(1, 0.25, name="PReLU", data_format="NCHW")
self.assertEqual(
str(module),
'PReLU(num_parameters=1, init=0.25, dtype=float32, name=PReLU)')
'PReLU(num_parameters=1, data_format=NCHW, init=0.25, dtype=float32, name=PReLU)'
)
module = nn.ReLU()
self.assertEqual(str(module), 'ReLU()')
......
......@@ -163,10 +163,18 @@ class PReluTest(OpTest):
# zero.
x_np[np.abs(x_np) < 0.005] = 0.02
if self.attrs == {'mode': "all"}:
if self.attrs == {
'mode': "all",
"data_format": "NCHW"
} or self.attrs == {
'mode': "all",
"data_format": "NHWC"
}:
alpha_np = np.random.uniform(-1, -0.5, (1))
elif self.attrs == {'mode': "channel"}:
elif self.attrs == {'mode': "channel", "data_format": "NCHW"}:
alpha_np = np.random.uniform(-1, -0.5, [1, self.x_shape[1], 1, 1])
elif self.attrs == {'mode': "channel", "data_format": "NHWC"}:
alpha_np = np.random.uniform(-1, -0.5, [1, 1, 1, self.x_shape[-1]])
else:
alpha_np = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:])
alpha_np = alpha_np.astype(self.dtype)
......@@ -176,11 +184,14 @@ class PReluTest(OpTest):
# NOTE(zhiqu): reshape inputs['Alpha'] from [1, 100, 1, 1] to [1, 100] + [1]*len(x.shape[2:])
# since np operands could not be broadcast together with shapes (1,100,2,2,2,3) (1,100,1,1)
reshaped_alpha = self.inputs['Alpha']
if self.attrs == {'mode': "channel"}:
if self.attrs == {'mode': "channel", "data_format": "NCHW"}:
reshaped_alpha = np.reshape(
self.inputs['Alpha'],
[1, self.x_shape[1]] + [1] * len(self.x_shape[2:]))
elif self.attrs == {'mode': "channel", "data_format": "NHWC"}:
reshaped_alpha = np.reshape(
self.inputs['Alpha'],
[1] + [1] * len(self.x_shape[1:-1]) + [self.x_shape[-1]])
out_np = np.maximum(self.inputs['X'], 0.)
out_np = out_np + np.minimum(self.inputs['X'], 0.) * reshaped_alpha
assert out_np is not self.inputs['X']
......@@ -193,7 +204,7 @@ class PReluTest(OpTest):
self.x_shape = [2, 100, 3, 4]
def init_attr(self):
self.attrs = {'mode': "channel"}
self.attrs = {'mode': "channel", "data_format": "NCHW"}
def test_check_output(self):
self.check_output()
......@@ -210,7 +221,18 @@ class TestModeAll(PReluTest):
self.x_shape = [2, 3, 4, 5]
def init_attr(self):
self.attrs = {'mode': "all"}
self.attrs = {'mode': "all", "data_format": "NCHW"}
@skip_check_grad_ci(
reason="[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode"
)
class TestModeAllNHWC(PReluTest):
def init_input_shape(self):
self.x_shape = [2, 3, 4, 50]
def init_attr(self):
self.attrs = {'mode': "all", "data_format": "NHWC"}
class TestModeElt(PReluTest):
......@@ -218,7 +240,15 @@ class TestModeElt(PReluTest):
self.x_shape = [3, 2, 5, 10]
def init_attr(self):
self.attrs = {'mode': "element"}
self.attrs = {'mode': "element", "data_format": "NCHW"}
class TestModeEltNHWC(PReluTest):
def init_input_shape(self):
self.x_shape = [3, 2, 5, 10]
def init_attr(self):
self.attrs = {'mode': "element", "data_format": "NHWC"}
@skip_check_grad_ci(
......@@ -229,7 +259,18 @@ class TestModeAllRank3(PReluTest):
self.x_shape = [1, 200, 3]
def init_attr(self):
self.attrs = {'mode': "all"}
self.attrs = {'mode': "all", "data_format": "NCHW"}
@skip_check_grad_ci(
reason="[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode"
)
class TestModeAllRank3NHWC(PReluTest):
def init_input_shape(self):
self.x_shape = [1, 200, 3]
def init_attr(self):
self.attrs = {'mode': "all", "data_format": "NHWC"}
@skip_check_grad_ci(
......@@ -240,7 +281,18 @@ class TestModeAllRank6(PReluTest):
self.x_shape = [1, 2, 3, 4, 5, 6]
def init_attr(self):
self.attrs = {'mode': "all"}
self.attrs = {'mode': "all", "data_format": "NCHW"}
@skip_check_grad_ci(
reason="[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode"
)
class TestModeAllRank6NHWC(PReluTest):
def init_input_shape(self):
self.x_shape = [1, 2, 3, 4, 5, 6]
def init_attr(self):
self.attrs = {'mode': "all", "data_format": "NHWC"}
class TestModeChannelRank3(PReluTest):
......@@ -248,7 +300,15 @@ class TestModeChannelRank3(PReluTest):
self.x_shape = [1, 200, 3]
def init_attr(self):
self.attrs = {'mode': "channel"}
self.attrs = {'mode': "channel", "data_format": "NCHW"}
class TestModeChannelRank3NHWC(PReluTest):
def init_input_shape(self):
self.x_shape = [1, 3, 100]
def init_attr(self):
self.attrs = {'mode': "channel", "data_format": "NHWC"}
class TestModeChannelRank6(PReluTest):
......@@ -256,7 +316,15 @@ class TestModeChannelRank6(PReluTest):
self.x_shape = [1, 100, 2, 2, 2, 2]
def init_attr(self):
self.attrs = {'mode': "channel"}
self.attrs = {'mode': "channel", "data_format": "NCHW"}
class TestModeChannelRank6NHWC(PReluTest):
def init_input_shape(self):
self.x_shape = [1, 2, 2, 2, 2, 100]
def init_attr(self):
self.attrs = {'mode': "channel", "data_format": "NHWC"}
class TestModeElementRank3(PReluTest):
......@@ -264,7 +332,15 @@ class TestModeElementRank3(PReluTest):
self.x_shape = [3, 10, 10]
def init_attr(self):
self.attrs = {'mode': "element"}
self.attrs = {'mode': "element", "data_format": "NCHW"}
class TestModeElementRank3NHWC(PReluTest):
def init_input_shape(self):
self.x_shape = [3, 10, 10]
def init_attr(self):
self.attrs = {'mode': "element", "data_format": "NHWC"}
class TestModeElementRank6(PReluTest):
......@@ -272,7 +348,15 @@ class TestModeElementRank6(PReluTest):
self.x_shape = [3, 2, 2, 4, 5, 2]
def init_attr(self):
self.attrs = {'mode': "element"}
self.attrs = {'mode': "element", "data_format": "NCHW"}
class TestModeElementRank6NHWC(PReluTest):
def init_input_shape(self):
self.x_shape = [3, 2, 2, 4, 5, 2]
def init_attr(self):
self.attrs = {'mode': "element", "data_format": "NHWC"}
def create_test_fp16_class(parent,
......@@ -311,9 +395,16 @@ create_test_fp16_class(TestModeChannelRank3)
create_test_fp16_class(TestModeChannelRank6)
create_test_fp16_class(TestModeElementRank3)
create_test_fp16_class(TestModeElementRank6)
create_test_fp16_class(TestModeEltNHWC)
create_test_fp16_class(TestModeAllRank3NHWC)
create_test_fp16_class(TestModeAllRank6NHWC)
create_test_fp16_class(TestModeChannelRank3NHWC)
create_test_fp16_class(TestModeChannelRank6NHWC)
create_test_fp16_class(TestModeElementRank3NHWC)
create_test_fp16_class(TestModeElementRank6NHWC)
def prelu_t(x, mode, param_attr=None, name=None):
def prelu_t(x, mode, param_attr=None, name=None, data_format='NCHW'):
helper = fluid.layer_helper.LayerHelper('prelu', **locals())
alpha_shape = [1, x.shape[1], 1, 1]
dtype = helper.input_dtype(input_param_name='x')
......@@ -328,13 +419,19 @@ def prelu_t(x, mode, param_attr=None, name=None):
type="prelu",
inputs={"X": x,
'Alpha': alpha},
attrs={"mode": mode},
attrs={"mode": mode,
'data_format': data_format},
outputs={"Out": out})
return out
# error message test if mode is not one of 'all', 'channel', 'element'
class TestModeError(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
) else paddle.CPUPlace()
self.x_np = np.ones([1, 2, 3, 4]).astype('float32')
def test_mode_error(self):
main_program = Program()
with fluid.program_guard(main_program, Program()):
......@@ -344,6 +441,24 @@ class TestModeError(unittest.TestCase):
except Exception as e:
assert (e.args[0].find('InvalidArgument') != -1)
def test_data_format_error1(self):
main_program = Program()
with fluid.program_guard(main_program, Program()):
x = fluid.data(name='x', shape=[2, 3, 4, 5])
try:
y = prelu_t(x, 'channel', data_format='N')
except Exception as e:
assert (e.args[0].find('InvalidArgument') != -1)
def test_data_format_error2(self):
main_program = Program()
with fluid.program_guard(main_program, Program()):
x = fluid.data(name='x', shape=[2, 3, 4, 5])
try:
y = paddle.static.nn.prelu(x, 'channel', data_format='N')
except ValueError as e:
pass
if __name__ == "__main__":
unittest.main()
......@@ -442,7 +442,7 @@ def leaky_relu(x, negative_slope=0.01, name=None):
return out
def prelu(x, weight, name=None):
def prelu(x, weight, data_format="NCHW", name=None):
"""
prelu activation.
......@@ -456,6 +456,8 @@ def prelu(x, weight, name=None):
The weight shape is [1] or [in], where `in` is the input channel of ``x``.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
data_format(str, optional): Data format that specifies the layout of input.
It may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW".
Returns:
A Tensor with the same data type and shape as ``x`` .
......@@ -490,19 +492,34 @@ def prelu(x, weight, name=None):
assert len(weight.shape
) == 1, "The dim count of weight shape should be 1 in prelu()."
# NOTE(): The input of this API should be ``N,C,...`` format,
# which means x.shape[0] is batch_size and x.shape[0] is channel.
mode = 'all'
if weight.shape[0] > 1:
true_data_format = [
'NC', 'NCL', 'NCHW', 'NCDHW', 'NLC', 'NHWC', 'NDHWC'
]
if data_format not in true_data_format:
raise ValueError(
"data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', "
"'NLC', 'NHWC', 'NDHWC' but receive {}".format(data_format))
data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC'
assert len(
x.shape
) > 1, "The dim count of x should be equal or larger than 2 in prelu() when weight shape is not [1]."
#NOTE(GuoxiaWang): support NHWC data format
if data_format == 'NHWC':
assert weight.shape[0] == x.shape[
-1], "The weight size should be equal to x input channel in prelu() when weight shape is not [1]."
else:
assert weight.shape[0] == x.shape[
1], "The weight size should be equal to x input channel in prelu() when weight shape is not [1]."
mode = 'channel'
if in_dygraph_mode():
return _C_ops.prelu(x, weight, 'mode', mode)
return _C_ops.prelu(x, weight, 'mode', mode, 'data_format', data_format)
helper = LayerHelper('prelu', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
......@@ -511,7 +528,8 @@ def prelu(x, weight, name=None):
inputs={"X": x,
"Alpha": weight},
outputs={"Out": out},
attrs={"mode": mode})
attrs={"mode": mode,
"data_format": data_format})
return out
......
......@@ -376,6 +376,8 @@ class PReLU(Layer):
Default is None. For more information, please refer to :ref:`api_paddle_ParamAttr`.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
data_format(str, optional): Data format that specifies the layout of input.
It may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW".
Shape:
- input: Tensor with any shape. Default dtype is float32.
......@@ -406,13 +408,18 @@ class PReLU(Layer):
# [ 6. , 7. , 8. , 9. ]]]]
"""
def __init__(self, num_parameters=1, init=0.25, weight_attr=None,
def __init__(self,
num_parameters=1,
init=0.25,
weight_attr=None,
data_format="NCHW",
name=None):
super(PReLU, self).__init__()
self._num_parameters = num_parameters
self._init = init
self._weight_attr = weight_attr
self._name = name
self._data_format = data_format
self._weight = self.create_parameter(
attr=self._weight_attr,
......@@ -422,12 +429,13 @@ class PReLU(Layer):
default_initializer=Constant(self._init))
def forward(self, x):
return F.prelu(x, self._weight)
return F.prelu(x, self._weight, data_format=self._data_format)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'num_parameters={}, init={}, dtype={}{}'.format(
self._num_parameters, self._init, self._dtype, name_str)
return 'num_parameters={}, data_format={}, init={}, dtype={}{}'.format(
self._num_parameters, self._data_format, self._init, self._dtype,
name_str)
class ReLU(Layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册