未验证 提交 3f2a665a 编写于 作者: G Guoxia Wang 提交者: GitHub

support data_format='NHWC' for prelu channel mode (#37019)

* support data_format='NHWC' for prelu channel mode
上级 0c82e3a0
...@@ -34,6 +34,11 @@ class PReluOpConverter : public OpConverter { ...@@ -34,6 +34,11 @@ class PReluOpConverter : public OpConverter {
auto* input = engine_->GetITensor(op_desc.Input("X")[0]); auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
// Get attrs // Get attrs
std::string mode = BOOST_GET_CONST(std::string, op_desc.GetAttr("mode")); std::string mode = BOOST_GET_CONST(std::string, op_desc.GetAttr("mode"));
std::string data_format = "NCHW";
if (op_desc.HasAttr("data_format")) {
data_format =
BOOST_GET_CONST(std::string, op_desc.GetAttr("data_format"));
}
auto* alpha_var = scope.FindVar(op_desc.Input("Alpha")[0]); auto* alpha_var = scope.FindVar(op_desc.Input("Alpha")[0]);
auto* alpha_tensor = alpha_var->GetMutable<framework::LoDTensor>(); auto* alpha_tensor = alpha_var->GetMutable<framework::LoDTensor>();
...@@ -47,7 +52,7 @@ class PReluOpConverter : public OpConverter { ...@@ -47,7 +52,7 @@ class PReluOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
plugin::PReluPluginDynamic* plugin = new plugin::PReluPluginDynamic( plugin::PReluPluginDynamic* plugin = new plugin::PReluPluginDynamic(
alpha_data, alpha_tensor_temp->numel(), mode); alpha_data, alpha_tensor_temp->numel(), mode, data_format);
layer = engine_->AddDynamicPlugin(&input, input_num, plugin); layer = engine_->AddDynamicPlugin(&input, input_num, plugin);
} else { } else {
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
...@@ -74,8 +79,8 @@ class PReluOpConverter : public OpConverter { ...@@ -74,8 +79,8 @@ class PReluOpConverter : public OpConverter {
layer = TRT_ENGINE_ADD_LAYER(engine_, ParametricReLU, *input, layer = TRT_ENGINE_ADD_LAYER(engine_, ParametricReLU, *input,
*alpha_layer_output); *alpha_layer_output);
#else #else
plugin::PReluPlugin* plugin = plugin::PReluPlugin* plugin = new plugin::PReluPlugin(
new plugin::PReluPlugin(alpha_data, alpha_tensor_temp->numel(), mode); alpha_data, alpha_tensor_temp->numel(), mode, data_format);
layer = engine_->AddPlugin(&input, input_num, plugin); layer = engine_->AddPlugin(&input, input_num, plugin);
#endif #endif
} }
......
...@@ -69,10 +69,11 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs, ...@@ -69,10 +69,11 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs,
} }
if (mode_ == "channel") { if (mode_ == "channel") {
bool channel_last = data_format_ == "NHWC";
operators::math::PreluChannelWiseDirectCUDAFunctor<float> operators::math::PreluChannelWiseDirectCUDAFunctor<float>
prelu_channel_wise; prelu_channel_wise;
prelu_channel_wise(stream, input, alpha, output, input_dims.d[0], prelu_channel_wise(stream, input, alpha, output, input_dims.d[0],
input_dims.d[1], numel); input_dims.d[1], channel_last, numel);
} else if (mode_ == "element") { } else if (mode_ == "element") {
operators::math::PreluElementWiseDirectCUDAFunctor<float> operators::math::PreluElementWiseDirectCUDAFunctor<float>
prelu_element_wise; prelu_element_wise;
...@@ -168,10 +169,11 @@ int PReluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, ...@@ -168,10 +169,11 @@ int PReluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
} }
if (mode_ == "channel") { if (mode_ == "channel") {
bool channel_last = data_format_ == "NHWC";
operators::math::PreluChannelWiseDirectCUDAFunctor<float> operators::math::PreluChannelWiseDirectCUDAFunctor<float>
prelu_channel_wise; prelu_channel_wise;
prelu_channel_wise(stream, input, alpha, output, input_dims.d[0], prelu_channel_wise(stream, input, alpha, output, input_dims.d[0],
input_dims.d[1], numel); input_dims.d[1], channel_last, numel);
} else if (mode_ == "element") { } else if (mode_ == "element") {
operators::math::PreluElementWiseDirectCUDAFunctor<float> operators::math::PreluElementWiseDirectCUDAFunctor<float>
prelu_element_wise; prelu_element_wise;
......
...@@ -32,11 +32,12 @@ class PReluPlugin : public PluginTensorRT { ...@@ -32,11 +32,12 @@ class PReluPlugin : public PluginTensorRT {
std::vector<float> weight_; std::vector<float> weight_;
float* p_gpu_weight_; float* p_gpu_weight_;
std::string mode_; std::string mode_;
std::string data_format_;
public: public:
size_t getSerializationSize() const TRT_NOEXCEPT override { size_t getSerializationSize() const TRT_NOEXCEPT override {
return getBaseSerializationSize() + SerializedSize(mode_.c_str()) + return getBaseSerializationSize() + SerializedSize(mode_.c_str()) +
SerializedSize(weight_); SerializedSize(data_format_.c_str()) + SerializedSize(weight_);
} }
// TRT will call this func when we need to serialize the configuration of // TRT will call this func when we need to serialize the configuration of
...@@ -46,11 +47,12 @@ class PReluPlugin : public PluginTensorRT { ...@@ -46,11 +47,12 @@ class PReluPlugin : public PluginTensorRT {
serializeBase(buffer); serializeBase(buffer);
SerializeValue(&buffer, weight_); SerializeValue(&buffer, weight_);
SerializeValue(&buffer, mode_.c_str()); SerializeValue(&buffer, mode_.c_str());
SerializeValue(&buffer, data_format_.c_str());
} }
PReluPlugin(const float* weight, const int weight_num, PReluPlugin(const float* weight, const int weight_num,
std::string const& mode) std::string const& mode, std::string const& data_format)
: mode_(mode) { : mode_(mode), data_format_(data_format) {
weight_.resize(weight_num); weight_.resize(weight_num);
std::copy(weight, weight + weight_num, weight_.data()); std::copy(weight, weight + weight_num, weight_.data());
} }
...@@ -63,13 +65,17 @@ class PReluPlugin : public PluginTensorRT { ...@@ -63,13 +65,17 @@ class PReluPlugin : public PluginTensorRT {
const char* prelu_mode; const char* prelu_mode;
DeserializeValue(&serialData, &serialLength, &prelu_mode); DeserializeValue(&serialData, &serialLength, &prelu_mode);
mode_ = std::string(prelu_mode); mode_ = std::string(prelu_mode);
const char* prelu_data_format;
DeserializeValue(&serialData, &serialLength, &prelu_data_format);
data_format_ = std::string(prelu_data_format);
} }
~PReluPlugin() {} ~PReluPlugin() {}
int initialize() TRT_NOEXCEPT override; int initialize() TRT_NOEXCEPT override;
void terminate() TRT_NOEXCEPT override; void terminate() TRT_NOEXCEPT override;
PReluPlugin* clone() const TRT_NOEXCEPT override { PReluPlugin* clone() const TRT_NOEXCEPT override {
auto* ptr = new PReluPlugin(weight_.data(), weight_.size(), mode_); auto* ptr =
new PReluPlugin(weight_.data(), weight_.size(), mode_, data_format_);
ptr->p_gpu_weight_ = p_gpu_weight_; ptr->p_gpu_weight_ = p_gpu_weight_;
return ptr; return ptr;
} }
...@@ -108,8 +114,8 @@ REGISTER_TRT_PLUGIN_V2(PReluPluginCreator); ...@@ -108,8 +114,8 @@ REGISTER_TRT_PLUGIN_V2(PReluPluginCreator);
class PReluPluginDynamic : public DynamicPluginTensorRT { class PReluPluginDynamic : public DynamicPluginTensorRT {
public: public:
PReluPluginDynamic(const float* weight, const int weight_num, PReluPluginDynamic(const float* weight, const int weight_num,
std::string const& mode) std::string const& mode, std::string const& data_format)
: mode_(mode) { : mode_(mode), data_format_(data_format) {
weight_.resize(weight_num); weight_.resize(weight_num);
std::copy(weight, weight + weight_num, weight_.data()); std::copy(weight, weight + weight_num, weight_.data());
} }
...@@ -117,7 +123,8 @@ class PReluPluginDynamic : public DynamicPluginTensorRT { ...@@ -117,7 +123,8 @@ class PReluPluginDynamic : public DynamicPluginTensorRT {
PReluPluginDynamic(void const* serialData, size_t serialLength); PReluPluginDynamic(void const* serialData, size_t serialLength);
~PReluPluginDynamic() {} ~PReluPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
auto ptr = new PReluPluginDynamic(weight_.data(), weight_.size(), mode_); auto ptr = new PReluPluginDynamic(weight_.data(), weight_.size(), mode_,
data_format_);
ptr->p_gpu_weight_ = p_gpu_weight_; ptr->p_gpu_weight_ = p_gpu_weight_;
return ptr; return ptr;
} }
...@@ -167,6 +174,7 @@ class PReluPluginDynamic : public DynamicPluginTensorRT { ...@@ -167,6 +174,7 @@ class PReluPluginDynamic : public DynamicPluginTensorRT {
std::vector<float> weight_; std::vector<float> weight_;
float* p_gpu_weight_; float* p_gpu_weight_;
std::string mode_; std::string mode_;
std::string data_format_;
}; };
#endif #endif
......
...@@ -25,7 +25,7 @@ inline static int PADDLE_GET_BLOCKS(const int N) { ...@@ -25,7 +25,7 @@ inline static int PADDLE_GET_BLOCKS(const int N) {
} }
template <typename T> template <typename T>
__global__ void PReluChannelWiseKernel(const T *input, const T *alpha, __global__ void PReluChannelFirstWiseKernel(const T *input, const T *alpha,
T *output, size_t channel_num, T *output, size_t channel_num,
size_t plane_size, size_t numel) { size_t plane_size, size_t numel) {
CUDA_KERNEL_LOOP(index, numel) { CUDA_KERNEL_LOOP(index, numel) {
...@@ -38,6 +38,19 @@ __global__ void PReluChannelWiseKernel(const T *input, const T *alpha, ...@@ -38,6 +38,19 @@ __global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
} }
} }
template <typename T>
__global__ void PReluChannelLastWiseKernel(const T *input, const T *alpha,
T *output, size_t channel_num,
size_t numel) {
CUDA_KERNEL_LOOP(index, numel) {
size_t channel_index = index % channel_num;
T scale = alpha[channel_index];
T x = input[index];
T zero = static_cast<T>(0);
output[index] = (x > zero) ? x : scale * x;
}
}
template <typename T> template <typename T>
__global__ void PReluElementWiseKernel(const T *input, const T *alpha, __global__ void PReluElementWiseKernel(const T *input, const T *alpha,
T *output, size_t spatial_size, T *output, size_t spatial_size,
...@@ -65,10 +78,16 @@ __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output, ...@@ -65,10 +78,16 @@ __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output,
template <typename T> template <typename T>
void PreluChannelWiseDirectCUDAFunctor<T>::operator()( void PreluChannelWiseDirectCUDAFunctor<T>::operator()(
gpuStream_t stream, const T *input, const T *alpha, T *output, gpuStream_t stream, const T *input, const T *alpha, T *output,
size_t batch_size, size_t channel, size_t numel) { size_t batch_size, size_t channel, bool channel_last, size_t numel) {
PReluChannelWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, if (channel_last) {
PReluChannelLastWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
stream>>>(input, alpha, output, channel, stream>>>(input, alpha, output, channel,
numel / batch_size / channel, numel); numel);
} else {
PReluChannelFirstWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
stream>>>(
input, alpha, output, channel, numel / batch_size / channel, numel);
}
} }
template <typename T> template <typename T>
......
...@@ -31,7 +31,8 @@ template <typename T> ...@@ -31,7 +31,8 @@ template <typename T>
class PreluChannelWiseDirectCUDAFunctor { class PreluChannelWiseDirectCUDAFunctor {
public: public:
void operator()(gpuStream_t stream, const T *input, const T *alpha, T *output, void operator()(gpuStream_t stream, const T *input, const T *alpha, T *output,
size_t batch_size, size_t channel, size_t numel); size_t batch_size, size_t channel, bool channel_last,
size_t numel);
}; };
template <typename T> template <typename T>
......
...@@ -34,7 +34,7 @@ class PReluMKLDNNHandler ...@@ -34,7 +34,7 @@ class PReluMKLDNNHandler
const dnnl::engine engine, platform::Place cpu_place, const dnnl::engine engine, platform::Place cpu_place,
const Tensor* x, const Tensor* weights, const Tensor* x, const Tensor* weights,
const std::string& uniq_name, const std::string& mode, const std::string& uniq_name, const std::string& mode,
bool is_test = false) const std::string& data_format, bool is_test = false)
: platform::MKLDNNHandlerT<T, dnnl::prelu_forward, dnnl::prelu_backward>( : platform::MKLDNNHandlerT<T, dnnl::prelu_forward, dnnl::prelu_backward>(
dev_ctx, engine, cpu_place, dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
...@@ -49,9 +49,14 @@ class PReluMKLDNNHandler ...@@ -49,9 +49,14 @@ class PReluMKLDNNHandler
if (weights->dims().size() != x->dims().size()) { if (weights->dims().size() != x->dims().size()) {
auto new_weights_dims = std::vector<int64_t>(x->dims().size(), 1); auto new_weights_dims = std::vector<int64_t>(x->dims().size(), 1);
if (mode == "channel") { if (mode == "channel") {
if (data_format == "NHWC") {
new_weights_dims[x->dims().size() - 1] =
*std::max_element(weights_dims.begin(), weights_dims.end());
} else {
new_weights_dims[1] = new_weights_dims[1] =
*std::max_element(weights_dims.begin(), weights_dims.end()); *std::max_element(weights_dims.begin(), weights_dims.end());
} }
}
weights_dims = std::move(new_weights_dims); weights_dims = std::move(new_weights_dims);
} }
auto weights_md = memory::desc(weights_dims, MKLDNNGetDataType<T>(), auto weights_md = memory::desc(weights_dims, MKLDNNGetDataType<T>(),
...@@ -110,9 +115,11 @@ class PReluMKLDNNKernel : public framework::OpKernel<T> { ...@@ -110,9 +115,11 @@ class PReluMKLDNNKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
const auto mode = ctx.Attr<std::string>("mode"); const auto mode = ctx.Attr<std::string>("mode");
const auto data_format = ctx.Attr<std::string>("data_format");
PReluMKLDNNHandler<T> handler(dev_ctx, onednn_engine, ctx.GetPlace(), x, PReluMKLDNNHandler<T> handler(dev_ctx, onednn_engine, ctx.GetPlace(), x,
alpha, ctx.InputName("X"), mode, is_test); alpha, ctx.InputName("X"), mode, data_format,
is_test);
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto weights_memory_p = auto weights_memory_p =
...@@ -149,9 +156,11 @@ class PReluGradMKLDNNKernel : public framework::OpKernel<T> { ...@@ -149,9 +156,11 @@ class PReluGradMKLDNNKernel : public framework::OpKernel<T> {
auto* alpha = ctx.Input<Tensor>("Alpha"); auto* alpha = ctx.Input<Tensor>("Alpha");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
const auto mode = ctx.Attr<std::string>("mode"); const auto mode = ctx.Attr<std::string>("mode");
const auto data_format = ctx.Attr<std::string>("data_format");
PReluMKLDNNHandler<T> handler(dev_ctx, onednn_engine, ctx.GetPlace(), x, PReluMKLDNNHandler<T> handler(dev_ctx, onednn_engine, ctx.GetPlace(), x,
alpha, framework::GradVarName("X"), mode); alpha, framework::GradVarName("X"), mode,
data_format);
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto weights_memory_p = auto weights_memory_p =
......
...@@ -38,12 +38,6 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -38,12 +38,6 @@ class PReluOp : public framework::OperatorWithKernel {
"But recevied alpha's size: %d.", "But recevied alpha's size: %d.",
product(ctx->GetInputDim("Alpha")))); product(ctx->GetInputDim("Alpha"))));
} else if (mode == "channel") { } else if (mode == "channel") {
PADDLE_ENFORCE_EQ(product(ctx->GetInputDim("Alpha")), x_dim[1],
platform::errors::InvalidArgument(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[1]: %d",
product(ctx->GetInputDim("Alpha")), x_dim[1]));
auto x_rank = x_dim.size(); auto x_rank = x_dim.size();
PADDLE_ENFORCE_GE(x_rank, 2, PADDLE_ENFORCE_GE(x_rank, 2,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -51,6 +45,33 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -51,6 +45,33 @@ class PReluOp : public framework::OperatorWithKernel {
"equal or larger than 2. But recevied X's " "equal or larger than 2. But recevied X's "
"rank: %d", "rank: %d",
x_rank)); x_rank));
const std::string data_format_str =
ctx->Attrs().Get<std::string>("data_format");
PADDLE_ENFORCE_EQ(data_format_str == "NCHW" || data_format_str == "NHWC",
true,
platform::errors::InvalidArgument(
"For mode 'channel', data_format must be one of "
"NCHW and NHWC. But recevied data_format: %s",
data_format_str));
if (data_format_str == "NCHW") {
PADDLE_ENFORCE_EQ(
product(ctx->GetInputDim("Alpha")) == x_dim[1], true,
platform::errors::InvalidArgument(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[1]: %d",
product(ctx->GetInputDim("Alpha")), x_dim[1]));
} else {
PADDLE_ENFORCE_EQ(
product(ctx->GetInputDim("Alpha")) == x_dim[x_rank - 1], true,
platform::errors::InvalidArgument(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[%d]: %d",
product(ctx->GetInputDim("Alpha")), x_rank - 1,
x_dim[x_rank - 1]));
}
} else if (mode == "element") { } else if (mode == "element") {
auto alpha_dim = ctx->GetInputDim("Alpha"); auto alpha_dim = ctx->GetInputDim("Alpha");
auto alpha_rank = alpha_dim.size(); auto alpha_rank = alpha_dim.size();
...@@ -134,6 +155,9 @@ There are modes: ...@@ -134,6 +155,9 @@ There are modes:
)DOC"); )DOC");
AddAttr<std::string>("mode", "The mode for inputs to share weights.") AddAttr<std::string>("mode", "The mode for inputs to share weights.")
.SetDefault("all"); .SetDefault("all");
AddAttr<std::string>("data_format",
"Data format that specifies the layout of input")
.SetDefault("NCHW");
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false) .SetDefault(false)
......
...@@ -42,17 +42,22 @@ class CUDAPReluKernel : public framework::OpKernel<T> { ...@@ -42,17 +42,22 @@ class CUDAPReluKernel : public framework::OpKernel<T> {
const T* alpha_ptr = alpha->data<T>(); const T* alpha_ptr = alpha->data<T>();
auto& mode = context.Attr<std::string>("mode"); auto& mode = context.Attr<std::string>("mode");
auto& data_format = context.Attr<std::string>("data_format");
int numel = x->numel(); int numel = x->numel();
auto dim = x->dims(); auto dim = x->dims();
auto x_rank = dim.size();
VLOG(4) << "dim[0]:" << dim[0] << ", dim[1]:" << dim[1] VLOG(4) << "dim[0]:" << dim[0] << ", dim[1]:" << dim[1] << ", dim["
<< ", numel:" << numel; << x_rank - 1 << "]:" << dim[x_rank - 1] << ", numel:" << numel;
if (mode == "channel") { if (mode == "channel") {
bool channel_last = data_format == "NHWC";
size_t channel = channel_last ? dim[x_rank - 1] : dim[1];
math::PreluChannelWiseDirectCUDAFunctor<T> prelu_channel_wise; math::PreluChannelWiseDirectCUDAFunctor<T> prelu_channel_wise;
prelu_channel_wise(context.cuda_device_context().stream(), x_ptr, prelu_channel_wise(context.cuda_device_context().stream(), x_ptr,
alpha_ptr, o_ptr, dim[0], dim[1], numel); alpha_ptr, o_ptr, dim[0], channel, channel_last,
numel);
} else if (mode == "element") { } else if (mode == "element") {
math::PreluElementWiseDirectCUDAFunctor<T> prelu_element_wise; math::PreluElementWiseDirectCUDAFunctor<T> prelu_element_wise;
prelu_element_wise(context.cuda_device_context().stream(), x_ptr, prelu_element_wise(context.cuda_device_context().stream(), x_ptr,
...@@ -65,7 +70,7 @@ class CUDAPReluKernel : public framework::OpKernel<T> { ...@@ -65,7 +70,7 @@ class CUDAPReluKernel : public framework::OpKernel<T> {
} }
}; };
enum PRELU_MODE { Element, Channel, Scalar }; enum PRELU_MODE { Element, ChannelFirst, ChannelLast, Scalar };
template <typename T> template <typename T>
__global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr, __global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr,
...@@ -78,10 +83,13 @@ __global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr, ...@@ -78,10 +83,13 @@ __global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr,
if (mode == Element) { if (mode == Element) {
size_t element_index = index % spatial_size; size_t element_index = index % spatial_size;
scale = alpha_ptr[element_index]; scale = alpha_ptr[element_index];
} else if (mode == Channel) { } else if (mode == ChannelFirst) {
size_t temp = index / plane_size; size_t temp = index / plane_size;
size_t channel_index = temp % channel_num; size_t channel_index = temp % channel_num;
scale = alpha_ptr[channel_index]; scale = alpha_ptr[channel_index];
} else if (mode == ChannelLast) {
size_t channel_index = index % channel_num;
scale = alpha_ptr[channel_index];
} else { } else {
scale = alpha_ptr[0]; scale = alpha_ptr[0];
} }
...@@ -105,11 +113,13 @@ class PreluOpGradFunctor { ...@@ -105,11 +113,13 @@ class PreluOpGradFunctor {
} }
size_t plane_size = numel / input_dims[0] / input_dims[1]; size_t plane_size = numel / input_dims[0] / input_dims[1];
size_t spatial_size = numel / input_dims[0]; size_t spatial_size = numel / input_dims[0];
size_t channel =
mode == ChannelLast ? input_dims[input_dims.size() - 1] : input_dims[1];
PReluOpGradKernel< PReluOpGradKernel<
T><<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>( T><<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>(
x, alpha, dy, dx, dalpha, input_dims[1], plane_size, spatial_size, x, alpha, dy, dx, dalpha, channel, plane_size, spatial_size, numel,
numel, mode); mode);
} }
}; };
...@@ -140,9 +150,11 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> { ...@@ -140,9 +150,11 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
if (!dx && !dalpha) return; if (!dx && !dalpha) return;
auto& mode = context.Attr<std::string>("mode"); auto& mode = context.Attr<std::string>("mode");
auto& data_format = context.Attr<std::string>("data_format");
int numel = x->numel(); int numel = x->numel();
auto dim = x->dims(); auto dim = x->dims();
auto x_rank = dim.size();
std::vector<int> input_shape = framework::vectorize<int>(dim); std::vector<int> input_shape = framework::vectorize<int>(dim);
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
...@@ -157,10 +169,12 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> { ...@@ -157,10 +169,12 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
} }
PRELU_MODE m; PRELU_MODE m;
bool channel_last = false;
if (mode == "element") { if (mode == "element") {
m = Element; m = Element;
} else if (mode == "channel") { } else if (mode == "channel") {
m = Channel; channel_last = data_format == "NHWC";
m = channel_last ? ChannelLast : ChannelFirst;
} else { } else {
m = Scalar; m = Scalar;
} }
...@@ -172,7 +186,8 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> { ...@@ -172,7 +186,8 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
std::vector<int> reduce_dims; std::vector<int> reduce_dims;
for (size_t i = 0; i < dim.size(); i++) { for (size_t i = 0; i < dim.size(); i++) {
if (mode == "channel" && i == 1) continue; if (mode == "channel" && !channel_last && i == 1) continue;
if (mode == "channel" && channel_last && i == dim.size() - 1) continue;
if (mode == "element" && i != 0) continue; if (mode == "element" && i != 0) continue;
reduce_dims.push_back(i); reduce_dims.push_back(i);
} }
......
...@@ -33,12 +33,14 @@ class PReluKernel : public framework::OpKernel<T> { ...@@ -33,12 +33,14 @@ class PReluKernel : public framework::OpKernel<T> {
const T* alpha_ptr = alpha->data<T>(); const T* alpha_ptr = alpha->data<T>();
auto& mode = context.Attr<std::string>("mode"); auto& mode = context.Attr<std::string>("mode");
auto& data_format = context.Attr<std::string>("data_format");
int numel = x->numel(); int numel = x->numel();
auto dim = x->dims(); auto dim = x->dims();
int index = 0; int index = 0;
int i = 0; int i = 0;
if (mode == "channel") { if (mode == "channel") {
if (data_format == "NCHW") {
int temp = 1; int temp = 1;
for (int j = 2; j < dim.size(); j++) { for (int j = 2; j < dim.size(); j++) {
temp *= dim[j]; temp *= dim[j];
...@@ -47,6 +49,12 @@ class PReluKernel : public framework::OpKernel<T> { ...@@ -47,6 +49,12 @@ class PReluKernel : public framework::OpKernel<T> {
index = (i / temp) % dim[1]; index = (i / temp) % dim[1];
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i]; o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i];
} }
} else {
for (i = 0; i < numel; i++) {
index = i % dim[dim.size() - 1];
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i];
}
}
} else if (mode == "element") { } else if (mode == "element") {
int temp = 1; int temp = 1;
for (int j = 1; j < dim.size(); j++) { for (int j = 1; j < dim.size(); j++) {
...@@ -77,6 +85,7 @@ class PReluGradKernel : public framework::OpKernel<T> { ...@@ -77,6 +85,7 @@ class PReluGradKernel : public framework::OpKernel<T> {
const T* x_ptr = x->data<T>(); const T* x_ptr = x->data<T>();
const T* dout_ptr = dout->data<T>(); const T* dout_ptr = dout->data<T>();
std::string mode = context.Attr<std::string>("mode"); std::string mode = context.Attr<std::string>("mode");
auto& data_format = context.Attr<std::string>("data_format");
int numel = x->numel(); int numel = x->numel();
auto dim = x->dims(); auto dim = x->dims();
int index = 0; int index = 0;
...@@ -84,6 +93,7 @@ class PReluGradKernel : public framework::OpKernel<T> { ...@@ -84,6 +93,7 @@ class PReluGradKernel : public framework::OpKernel<T> {
if (dx) { if (dx) {
T* dx_ptr = dx->mutable_data<T>(context.GetPlace()); T* dx_ptr = dx->mutable_data<T>(context.GetPlace());
if (mode == "channel") { if (mode == "channel") {
if (data_format == "NCHW") {
int temp = 1; int temp = 1;
for (int j = 2; j < dim.size(); j++) { for (int j = 2; j < dim.size(); j++) {
temp *= dim[j]; temp *= dim[j];
...@@ -93,6 +103,13 @@ class PReluGradKernel : public framework::OpKernel<T> { ...@@ -93,6 +103,13 @@ class PReluGradKernel : public framework::OpKernel<T> {
dx_ptr[i] = dx_ptr[i] =
x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i]; x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i];
} }
} else {
for (i = 0; i < numel; i++) {
index = i % dim[dim.size() - 1];
dx_ptr[i] =
x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i];
}
}
} else if (mode == "element") { } else if (mode == "element") {
int temp = 1; int temp = 1;
for (int j = 1; j < dim.size(); j++) { for (int j = 1; j < dim.size(); j++) {
...@@ -116,6 +133,7 @@ class PReluGradKernel : public framework::OpKernel<T> { ...@@ -116,6 +133,7 @@ class PReluGradKernel : public framework::OpKernel<T> {
memset(dalpha_ptr, 0, sizeof(T) * dalpha->numel()); memset(dalpha_ptr, 0, sizeof(T) * dalpha->numel());
if (mode == "channel") { if (mode == "channel") {
if (data_format == "NCHW") {
int temp = 1; int temp = 1;
for (int j = 2; j < dim.size(); j++) { for (int j = 2; j < dim.size(); j++) {
temp *= dim[j]; temp *= dim[j];
...@@ -124,6 +142,12 @@ class PReluGradKernel : public framework::OpKernel<T> { ...@@ -124,6 +142,12 @@ class PReluGradKernel : public framework::OpKernel<T> {
index = (i / temp) % dim[1]; index = (i / temp) % dim[1];
dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
} }
} else {
for (i = 0; i < numel; i++) {
index = i % dim[dim.size() - 1];
dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
}
}
} else if (mode == "element") { } else if (mode == "element") {
int temp = 1; int temp = 1;
for (int j = 1; j < dim.size(); j++) { for (int j = 1; j < dim.size(); j++) {
......
...@@ -9791,7 +9791,7 @@ def swish(x, beta=1.0, name=None): ...@@ -9791,7 +9791,7 @@ def swish(x, beta=1.0, name=None):
@deprecated(since="2.0.0", update_to="paddle.static.nn.prelu") @deprecated(since="2.0.0", update_to="paddle.static.nn.prelu")
def prelu(x, mode, param_attr=None, name=None): def prelu(x, mode, param_attr=None, data_format="NCHW", name=None):
r""" r"""
prelu activation. prelu activation.
...@@ -9819,6 +9819,9 @@ def prelu(x, mode, param_attr=None, name=None): ...@@ -9819,6 +9819,9 @@ def prelu(x, mode, param_attr=None, name=None):
name (str, optional): Name for the operation (optional, default is None). \ name (str, optional): Name for the operation (optional, default is None). \
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
data_format(str, optional): Data format that specifies the layout of input.
It may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW".
Returns: Returns:
Tensor: A tensor with the same shape and data type as x. Tensor: A tensor with the same shape and data type as x.
...@@ -9839,17 +9842,32 @@ def prelu(x, mode, param_attr=None, name=None): ...@@ -9839,17 +9842,32 @@ def prelu(x, mode, param_attr=None, name=None):
helper = LayerHelper('prelu', **locals()) helper = LayerHelper('prelu', **locals())
if mode not in ['all', 'channel', 'element']: if mode not in ['all', 'channel', 'element']:
raise ValueError('mode should be one of all, channel, element.') raise ValueError('mode should be one of all, channel, element.')
alpha_shape = [1] alpha_shape = [1]
# NOTE(): The input of this API should be ``N,C,...`` format,
# which means x.shape[0] is batch_size and x.shape[0] is channel.
if mode == 'channel': if mode == 'channel':
true_data_format = [
'NC', 'NCL', 'NCHW', 'NCDHW', 'NLC', 'NHWC', 'NDHWC'
]
if data_format not in true_data_format:
raise ValueError(
"data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', "
"'NLC', 'NHWC', 'NDHWC' but receive {}".format(data_format))
data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC'
assert len( assert len(
x.shape x.shape
) >= 2, "The size of input shape should be equal or larger than 2 in prelu() when mode is 'channel'" ) >= 2, "The size of input shape should be equal or larger than 2 in prelu() when mode is 'channel'"
#NOTE(zhiqiu): The alpha_shape should be [1, channel] + [1] * len(x.shape[2:]). #NOTE(zhiqiu): The alpha_shape should be [1, channel] + [1] * len(x.shape[2:]).
# To be consistent with Prelu, it is simplified. # To be consistent with Prelu, it is simplified.
#NOTE(zhiqiu): Revert shape to [1, channel, 1, 1] for compatibility with saved model of old version. #NOTE(zhiqiu): Revert shape to [1, channel, 1, 1] for compatibility with saved model of old version.
#NOTE(GuoxiaWang): support NHWC data format
if data_format == 'NHWC':
alpha_shape = [1, 1, 1, x.shape[1]]
else:
alpha_shape = [1, x.shape[1], 1, 1] alpha_shape = [1, x.shape[1], 1, 1]
elif mode == 'element': elif mode == 'element':
assert len( assert len(
x.shape x.shape
...@@ -9867,7 +9885,8 @@ def prelu(x, mode, param_attr=None, name=None): ...@@ -9867,7 +9885,8 @@ def prelu(x, mode, param_attr=None, name=None):
type="prelu", type="prelu",
inputs={"X": x, inputs={"X": x,
'Alpha': alpha}, 'Alpha': alpha},
attrs={"mode": mode}, attrs={"mode": mode,
"data_format": data_format},
outputs={"Out": out}) outputs={"Out": out})
return out return out
......
...@@ -44,8 +44,12 @@ class TestMkldnnPreluOp(MkldnnAutoScanTest): ...@@ -44,8 +44,12 @@ class TestMkldnnPreluOp(MkldnnAutoScanTest):
if len(kwargs['in_shape']) <= 1: if len(kwargs['in_shape']) <= 1:
# not valid case, just return 0 # not valid case, just return 0
return np.zeros((1)).astype(np.float32) return np.zeros((1)).astype(np.float32)
if kwargs['data_format'] == 'NCHW':
return np.random.random(kwargs['in_shape'][1]).astype( return np.random.random(kwargs['in_shape'][1]).astype(
np.float32) np.float32)
else:
return np.random.random(kwargs['in_shape'][-1]).astype(
np.float32)
else: else:
if len(kwargs['in_shape']) <= 1: if len(kwargs['in_shape']) <= 1:
# not valid case, just return 0 # not valid case, just return 0
...@@ -57,7 +61,10 @@ class TestMkldnnPreluOp(MkldnnAutoScanTest): ...@@ -57,7 +61,10 @@ class TestMkldnnPreluOp(MkldnnAutoScanTest):
inputs={"X": ["input_data"], inputs={"X": ["input_data"],
"Alpha": ["alpha_weight"]}, "Alpha": ["alpha_weight"]},
outputs={"Out": ["output_data"]}, outputs={"Out": ["output_data"]},
attrs={"mode": kwargs['mode']}) attrs={
"mode": kwargs['mode'],
"data_format": kwargs['data_format']
})
program_config = ProgramConfig( program_config = ProgramConfig(
ops=[prelu_op], ops=[prelu_op],
...@@ -82,6 +89,7 @@ class TestMkldnnPreluOp(MkldnnAutoScanTest): ...@@ -82,6 +89,7 @@ class TestMkldnnPreluOp(MkldnnAutoScanTest):
@given( @given(
mode=st.sampled_from(['all', 'channel', 'element']), mode=st.sampled_from(['all', 'channel', 'element']),
data_format=st.sampled_from(['NCHW', 'NHWC']),
in_shape=st.lists( in_shape=st.lists(
st.integers( st.integers(
min_value=1, max_value=32), min_size=1, max_size=4)) min_value=1, max_value=32), min_size=1, max_size=4))
......
...@@ -39,7 +39,8 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest): ...@@ -39,7 +39,8 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
def generate_alpha(attrs: List[Dict[str, Any]], dim1, dim2, dim3): def generate_alpha(attrs: List[Dict[str, Any]], dim1, dim2, dim3):
if attrs[0]["mode"] == "all": if attrs[0]["mode"] == "all":
return np.random.random(size=(1)).astype(np.float32) return np.random.random(size=(1)).astype(np.float32)
elif attrs[0]["mode"] == "channel": elif attrs[0]["mode"] == "channel" and attrs[0][
"data_format"] == "NCHW":
shape = [1] shape = [1]
if dim1 != 0: if dim1 != 0:
shape.append(dim1) shape.append(dim1)
...@@ -48,6 +49,16 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest): ...@@ -48,6 +49,16 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
if dim3 != 0: if dim3 != 0:
shape.append(1) shape.append(1)
return np.random.random(size=shape).astype(np.float32) return np.random.random(size=shape).astype(np.float32)
elif attrs[0]["mode"] == "channel" and attrs[0][
"data_format"] == "NHWC":
shape = [1]
if dim1 != 0:
shape.append(1)
if dim2 != 0:
shape.append(1)
if dim3 != 0:
shape.append(dim3)
return np.random.random(size=shape).astype(np.float32)
elif attrs[0]["mode"] == "element": elif attrs[0]["mode"] == "element":
shape = [1] shape = [1]
if dim1 != 0: if dim1 != 0:
...@@ -72,9 +83,15 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest): ...@@ -72,9 +83,15 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
continue continue
for mode in ["all", "channel", "element"]: for mode in ["all", "channel", "element"]:
if mode == "channel" and dim1 == 0: for data_format in ['NCHW', 'NHWC']:
if mode == "channel" and dim1 == 0 and data_format == "NCHW":
continue continue
dics = [{"mode": mode}] if mode == "channel" and dim3 == 0 and data_format == "NHWC":
continue
dics = [{
"mode": mode,
"data_format": data_format
}]
ops_config = [{ ops_config = [{
"op_type": "prelu", "op_type": "prelu",
"op_inputs": { "op_inputs": {
...@@ -92,13 +109,15 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest): ...@@ -92,13 +109,15 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
ops=ops, ops=ops,
weights={ weights={
"alpha_weight": TensorConfig( "alpha_weight": TensorConfig(
data_gen=partial(generate_alpha, dics, data_gen=partial(generate_alpha,
dim1, dim2, dim3)) dics, dim1, dim2,
dim3))
}, },
inputs={ inputs={
"input_data": TensorConfig( "input_data": TensorConfig(
data_gen=partial(generate_input, batch, data_gen=partial(generate_input,
dim1, dim2, dim3)), batch, dim1, dim2,
dim3)),
}, },
outputs=["output_data"]) outputs=["output_data"])
......
...@@ -41,10 +41,11 @@ class TestLayerPrint(unittest.TestCase): ...@@ -41,10 +41,11 @@ class TestLayerPrint(unittest.TestCase):
self.assertEqual( self.assertEqual(
str(module), 'Hardtanh(min=-1.0, max=1.0, name=Hardtanh)') str(module), 'Hardtanh(min=-1.0, max=1.0, name=Hardtanh)')
module = nn.PReLU(1, 0.25, name="PReLU") module = nn.PReLU(1, 0.25, name="PReLU", data_format="NCHW")
self.assertEqual( self.assertEqual(
str(module), str(module),
'PReLU(num_parameters=1, init=0.25, dtype=float32, name=PReLU)') 'PReLU(num_parameters=1, data_format=NCHW, init=0.25, dtype=float32, name=PReLU)'
)
module = nn.ReLU() module = nn.ReLU()
self.assertEqual(str(module), 'ReLU()') self.assertEqual(str(module), 'ReLU()')
......
...@@ -163,10 +163,18 @@ class PReluTest(OpTest): ...@@ -163,10 +163,18 @@ class PReluTest(OpTest):
# zero. # zero.
x_np[np.abs(x_np) < 0.005] = 0.02 x_np[np.abs(x_np) < 0.005] = 0.02
if self.attrs == {'mode': "all"}: if self.attrs == {
'mode': "all",
"data_format": "NCHW"
} or self.attrs == {
'mode': "all",
"data_format": "NHWC"
}:
alpha_np = np.random.uniform(-1, -0.5, (1)) alpha_np = np.random.uniform(-1, -0.5, (1))
elif self.attrs == {'mode': "channel"}: elif self.attrs == {'mode': "channel", "data_format": "NCHW"}:
alpha_np = np.random.uniform(-1, -0.5, [1, self.x_shape[1], 1, 1]) alpha_np = np.random.uniform(-1, -0.5, [1, self.x_shape[1], 1, 1])
elif self.attrs == {'mode': "channel", "data_format": "NHWC"}:
alpha_np = np.random.uniform(-1, -0.5, [1, 1, 1, self.x_shape[-1]])
else: else:
alpha_np = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:]) alpha_np = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:])
alpha_np = alpha_np.astype(self.dtype) alpha_np = alpha_np.astype(self.dtype)
...@@ -176,11 +184,14 @@ class PReluTest(OpTest): ...@@ -176,11 +184,14 @@ class PReluTest(OpTest):
# NOTE(zhiqu): reshape inputs['Alpha'] from [1, 100, 1, 1] to [1, 100] + [1]*len(x.shape[2:]) # NOTE(zhiqu): reshape inputs['Alpha'] from [1, 100, 1, 1] to [1, 100] + [1]*len(x.shape[2:])
# since np operands could not be broadcast together with shapes (1,100,2,2,2,3) (1,100,1,1) # since np operands could not be broadcast together with shapes (1,100,2,2,2,3) (1,100,1,1)
reshaped_alpha = self.inputs['Alpha'] reshaped_alpha = self.inputs['Alpha']
if self.attrs == {'mode': "channel"}: if self.attrs == {'mode': "channel", "data_format": "NCHW"}:
reshaped_alpha = np.reshape( reshaped_alpha = np.reshape(
self.inputs['Alpha'], self.inputs['Alpha'],
[1, self.x_shape[1]] + [1] * len(self.x_shape[2:])) [1, self.x_shape[1]] + [1] * len(self.x_shape[2:]))
elif self.attrs == {'mode': "channel", "data_format": "NHWC"}:
reshaped_alpha = np.reshape(
self.inputs['Alpha'],
[1] + [1] * len(self.x_shape[1:-1]) + [self.x_shape[-1]])
out_np = np.maximum(self.inputs['X'], 0.) out_np = np.maximum(self.inputs['X'], 0.)
out_np = out_np + np.minimum(self.inputs['X'], 0.) * reshaped_alpha out_np = out_np + np.minimum(self.inputs['X'], 0.) * reshaped_alpha
assert out_np is not self.inputs['X'] assert out_np is not self.inputs['X']
...@@ -193,7 +204,7 @@ class PReluTest(OpTest): ...@@ -193,7 +204,7 @@ class PReluTest(OpTest):
self.x_shape = [2, 100, 3, 4] self.x_shape = [2, 100, 3, 4]
def init_attr(self): def init_attr(self):
self.attrs = {'mode': "channel"} self.attrs = {'mode': "channel", "data_format": "NCHW"}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -210,7 +221,18 @@ class TestModeAll(PReluTest): ...@@ -210,7 +221,18 @@ class TestModeAll(PReluTest):
self.x_shape = [2, 3, 4, 5] self.x_shape = [2, 3, 4, 5]
def init_attr(self): def init_attr(self):
self.attrs = {'mode': "all"} self.attrs = {'mode': "all", "data_format": "NCHW"}
@skip_check_grad_ci(
reason="[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode"
)
class TestModeAllNHWC(PReluTest):
def init_input_shape(self):
self.x_shape = [2, 3, 4, 50]
def init_attr(self):
self.attrs = {'mode': "all", "data_format": "NHWC"}
class TestModeElt(PReluTest): class TestModeElt(PReluTest):
...@@ -218,7 +240,15 @@ class TestModeElt(PReluTest): ...@@ -218,7 +240,15 @@ class TestModeElt(PReluTest):
self.x_shape = [3, 2, 5, 10] self.x_shape = [3, 2, 5, 10]
def init_attr(self): def init_attr(self):
self.attrs = {'mode': "element"} self.attrs = {'mode': "element", "data_format": "NCHW"}
class TestModeEltNHWC(PReluTest):
def init_input_shape(self):
self.x_shape = [3, 2, 5, 10]
def init_attr(self):
self.attrs = {'mode': "element", "data_format": "NHWC"}
@skip_check_grad_ci( @skip_check_grad_ci(
...@@ -229,7 +259,18 @@ class TestModeAllRank3(PReluTest): ...@@ -229,7 +259,18 @@ class TestModeAllRank3(PReluTest):
self.x_shape = [1, 200, 3] self.x_shape = [1, 200, 3]
def init_attr(self): def init_attr(self):
self.attrs = {'mode': "all"} self.attrs = {'mode': "all", "data_format": "NCHW"}
@skip_check_grad_ci(
reason="[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode"
)
class TestModeAllRank3NHWC(PReluTest):
def init_input_shape(self):
self.x_shape = [1, 200, 3]
def init_attr(self):
self.attrs = {'mode': "all", "data_format": "NHWC"}
@skip_check_grad_ci( @skip_check_grad_ci(
...@@ -240,7 +281,18 @@ class TestModeAllRank6(PReluTest): ...@@ -240,7 +281,18 @@ class TestModeAllRank6(PReluTest):
self.x_shape = [1, 2, 3, 4, 5, 6] self.x_shape = [1, 2, 3, 4, 5, 6]
def init_attr(self): def init_attr(self):
self.attrs = {'mode': "all"} self.attrs = {'mode': "all", "data_format": "NCHW"}
@skip_check_grad_ci(
reason="[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode"
)
class TestModeAllRank6NHWC(PReluTest):
def init_input_shape(self):
self.x_shape = [1, 2, 3, 4, 5, 6]
def init_attr(self):
self.attrs = {'mode': "all", "data_format": "NHWC"}
class TestModeChannelRank3(PReluTest): class TestModeChannelRank3(PReluTest):
...@@ -248,7 +300,15 @@ class TestModeChannelRank3(PReluTest): ...@@ -248,7 +300,15 @@ class TestModeChannelRank3(PReluTest):
self.x_shape = [1, 200, 3] self.x_shape = [1, 200, 3]
def init_attr(self): def init_attr(self):
self.attrs = {'mode': "channel"} self.attrs = {'mode': "channel", "data_format": "NCHW"}
class TestModeChannelRank3NHWC(PReluTest):
def init_input_shape(self):
self.x_shape = [1, 3, 100]
def init_attr(self):
self.attrs = {'mode': "channel", "data_format": "NHWC"}
class TestModeChannelRank6(PReluTest): class TestModeChannelRank6(PReluTest):
...@@ -256,7 +316,15 @@ class TestModeChannelRank6(PReluTest): ...@@ -256,7 +316,15 @@ class TestModeChannelRank6(PReluTest):
self.x_shape = [1, 100, 2, 2, 2, 2] self.x_shape = [1, 100, 2, 2, 2, 2]
def init_attr(self): def init_attr(self):
self.attrs = {'mode': "channel"} self.attrs = {'mode': "channel", "data_format": "NCHW"}
class TestModeChannelRank6NHWC(PReluTest):
def init_input_shape(self):
self.x_shape = [1, 2, 2, 2, 2, 100]
def init_attr(self):
self.attrs = {'mode': "channel", "data_format": "NHWC"}
class TestModeElementRank3(PReluTest): class TestModeElementRank3(PReluTest):
...@@ -264,7 +332,15 @@ class TestModeElementRank3(PReluTest): ...@@ -264,7 +332,15 @@ class TestModeElementRank3(PReluTest):
self.x_shape = [3, 10, 10] self.x_shape = [3, 10, 10]
def init_attr(self): def init_attr(self):
self.attrs = {'mode': "element"} self.attrs = {'mode': "element", "data_format": "NCHW"}
class TestModeElementRank3NHWC(PReluTest):
def init_input_shape(self):
self.x_shape = [3, 10, 10]
def init_attr(self):
self.attrs = {'mode': "element", "data_format": "NHWC"}
class TestModeElementRank6(PReluTest): class TestModeElementRank6(PReluTest):
...@@ -272,7 +348,15 @@ class TestModeElementRank6(PReluTest): ...@@ -272,7 +348,15 @@ class TestModeElementRank6(PReluTest):
self.x_shape = [3, 2, 2, 4, 5, 2] self.x_shape = [3, 2, 2, 4, 5, 2]
def init_attr(self): def init_attr(self):
self.attrs = {'mode': "element"} self.attrs = {'mode': "element", "data_format": "NCHW"}
class TestModeElementRank6NHWC(PReluTest):
def init_input_shape(self):
self.x_shape = [3, 2, 2, 4, 5, 2]
def init_attr(self):
self.attrs = {'mode': "element", "data_format": "NHWC"}
def create_test_fp16_class(parent, def create_test_fp16_class(parent,
...@@ -311,9 +395,16 @@ create_test_fp16_class(TestModeChannelRank3) ...@@ -311,9 +395,16 @@ create_test_fp16_class(TestModeChannelRank3)
create_test_fp16_class(TestModeChannelRank6) create_test_fp16_class(TestModeChannelRank6)
create_test_fp16_class(TestModeElementRank3) create_test_fp16_class(TestModeElementRank3)
create_test_fp16_class(TestModeElementRank6) create_test_fp16_class(TestModeElementRank6)
create_test_fp16_class(TestModeEltNHWC)
create_test_fp16_class(TestModeAllRank3NHWC)
create_test_fp16_class(TestModeAllRank6NHWC)
create_test_fp16_class(TestModeChannelRank3NHWC)
create_test_fp16_class(TestModeChannelRank6NHWC)
create_test_fp16_class(TestModeElementRank3NHWC)
create_test_fp16_class(TestModeElementRank6NHWC)
def prelu_t(x, mode, param_attr=None, name=None): def prelu_t(x, mode, param_attr=None, name=None, data_format='NCHW'):
helper = fluid.layer_helper.LayerHelper('prelu', **locals()) helper = fluid.layer_helper.LayerHelper('prelu', **locals())
alpha_shape = [1, x.shape[1], 1, 1] alpha_shape = [1, x.shape[1], 1, 1]
dtype = helper.input_dtype(input_param_name='x') dtype = helper.input_dtype(input_param_name='x')
...@@ -328,13 +419,19 @@ def prelu_t(x, mode, param_attr=None, name=None): ...@@ -328,13 +419,19 @@ def prelu_t(x, mode, param_attr=None, name=None):
type="prelu", type="prelu",
inputs={"X": x, inputs={"X": x,
'Alpha': alpha}, 'Alpha': alpha},
attrs={"mode": mode}, attrs={"mode": mode,
'data_format': data_format},
outputs={"Out": out}) outputs={"Out": out})
return out return out
# error message test if mode is not one of 'all', 'channel', 'element' # error message test if mode is not one of 'all', 'channel', 'element'
class TestModeError(unittest.TestCase): class TestModeError(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
) else paddle.CPUPlace()
self.x_np = np.ones([1, 2, 3, 4]).astype('float32')
def test_mode_error(self): def test_mode_error(self):
main_program = Program() main_program = Program()
with fluid.program_guard(main_program, Program()): with fluid.program_guard(main_program, Program()):
...@@ -344,6 +441,24 @@ class TestModeError(unittest.TestCase): ...@@ -344,6 +441,24 @@ class TestModeError(unittest.TestCase):
except Exception as e: except Exception as e:
assert (e.args[0].find('InvalidArgument') != -1) assert (e.args[0].find('InvalidArgument') != -1)
def test_data_format_error1(self):
main_program = Program()
with fluid.program_guard(main_program, Program()):
x = fluid.data(name='x', shape=[2, 3, 4, 5])
try:
y = prelu_t(x, 'channel', data_format='N')
except Exception as e:
assert (e.args[0].find('InvalidArgument') != -1)
def test_data_format_error2(self):
main_program = Program()
with fluid.program_guard(main_program, Program()):
x = fluid.data(name='x', shape=[2, 3, 4, 5])
try:
y = paddle.static.nn.prelu(x, 'channel', data_format='N')
except ValueError as e:
pass
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -442,7 +442,7 @@ def leaky_relu(x, negative_slope=0.01, name=None): ...@@ -442,7 +442,7 @@ def leaky_relu(x, negative_slope=0.01, name=None):
return out return out
def prelu(x, weight, name=None): def prelu(x, weight, data_format="NCHW", name=None):
""" """
prelu activation. prelu activation.
...@@ -456,6 +456,8 @@ def prelu(x, weight, name=None): ...@@ -456,6 +456,8 @@ def prelu(x, weight, name=None):
The weight shape is [1] or [in], where `in` is the input channel of ``x``. The weight shape is [1] or [in], where `in` is the input channel of ``x``.
name (str, optional): Name for the operation (optional, default is None). name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
data_format(str, optional): Data format that specifies the layout of input.
It may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW".
Returns: Returns:
A Tensor with the same data type and shape as ``x`` . A Tensor with the same data type and shape as ``x`` .
...@@ -490,19 +492,34 @@ def prelu(x, weight, name=None): ...@@ -490,19 +492,34 @@ def prelu(x, weight, name=None):
assert len(weight.shape assert len(weight.shape
) == 1, "The dim count of weight shape should be 1 in prelu()." ) == 1, "The dim count of weight shape should be 1 in prelu()."
# NOTE(): The input of this API should be ``N,C,...`` format,
# which means x.shape[0] is batch_size and x.shape[0] is channel.
mode = 'all' mode = 'all'
if weight.shape[0] > 1: if weight.shape[0] > 1:
true_data_format = [
'NC', 'NCL', 'NCHW', 'NCDHW', 'NLC', 'NHWC', 'NDHWC'
]
if data_format not in true_data_format:
raise ValueError(
"data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', "
"'NLC', 'NHWC', 'NDHWC' but receive {}".format(data_format))
data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC'
assert len( assert len(
x.shape x.shape
) > 1, "The dim count of x should be equal or larger than 2 in prelu() when weight shape is not [1]." ) > 1, "The dim count of x should be equal or larger than 2 in prelu() when weight shape is not [1]."
#NOTE(GuoxiaWang): support NHWC data format
if data_format == 'NHWC':
assert weight.shape[0] == x.shape[
-1], "The weight size should be equal to x input channel in prelu() when weight shape is not [1]."
else:
assert weight.shape[0] == x.shape[ assert weight.shape[0] == x.shape[
1], "The weight size should be equal to x input channel in prelu() when weight shape is not [1]." 1], "The weight size should be equal to x input channel in prelu() when weight shape is not [1]."
mode = 'channel' mode = 'channel'
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.prelu(x, weight, 'mode', mode) return _C_ops.prelu(x, weight, 'mode', mode, 'data_format', data_format)
helper = LayerHelper('prelu', **locals()) helper = LayerHelper('prelu', **locals())
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
...@@ -511,7 +528,8 @@ def prelu(x, weight, name=None): ...@@ -511,7 +528,8 @@ def prelu(x, weight, name=None):
inputs={"X": x, inputs={"X": x,
"Alpha": weight}, "Alpha": weight},
outputs={"Out": out}, outputs={"Out": out},
attrs={"mode": mode}) attrs={"mode": mode,
"data_format": data_format})
return out return out
......
...@@ -376,6 +376,8 @@ class PReLU(Layer): ...@@ -376,6 +376,8 @@ class PReLU(Layer):
Default is None. For more information, please refer to :ref:`api_paddle_ParamAttr`. Default is None. For more information, please refer to :ref:`api_paddle_ParamAttr`.
name (str, optional): Name for the operation (optional, default is None). name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
data_format(str, optional): Data format that specifies the layout of input.
It may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW".
Shape: Shape:
- input: Tensor with any shape. Default dtype is float32. - input: Tensor with any shape. Default dtype is float32.
...@@ -406,13 +408,18 @@ class PReLU(Layer): ...@@ -406,13 +408,18 @@ class PReLU(Layer):
# [ 6. , 7. , 8. , 9. ]]]] # [ 6. , 7. , 8. , 9. ]]]]
""" """
def __init__(self, num_parameters=1, init=0.25, weight_attr=None, def __init__(self,
num_parameters=1,
init=0.25,
weight_attr=None,
data_format="NCHW",
name=None): name=None):
super(PReLU, self).__init__() super(PReLU, self).__init__()
self._num_parameters = num_parameters self._num_parameters = num_parameters
self._init = init self._init = init
self._weight_attr = weight_attr self._weight_attr = weight_attr
self._name = name self._name = name
self._data_format = data_format
self._weight = self.create_parameter( self._weight = self.create_parameter(
attr=self._weight_attr, attr=self._weight_attr,
...@@ -422,12 +429,13 @@ class PReLU(Layer): ...@@ -422,12 +429,13 @@ class PReLU(Layer):
default_initializer=Constant(self._init)) default_initializer=Constant(self._init))
def forward(self, x): def forward(self, x):
return F.prelu(x, self._weight) return F.prelu(x, self._weight, data_format=self._data_format)
def extra_repr(self): def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else '' name_str = ', name={}'.format(self._name) if self._name else ''
return 'num_parameters={}, init={}, dtype={}{}'.format( return 'num_parameters={}, data_format={}, init={}, dtype={}{}'.format(
self._num_parameters, self._init, self._dtype, name_str) self._num_parameters, self._data_format, self._init, self._dtype,
name_str)
class ReLU(Layer): class ReLU(Layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册