未验证 提交 fa657b3d 编写于 作者: L Leo Chen 提交者: GitHub

fix bug of prelu when rank not equal 4, test=develop (#25067)

* fix bug of prelu when rank not equal 4, test=develop

* fix prelu inference, test=develop

* fix api, test=develop

* fix shape when mode is chennel, test=develop

* remove debug code, test=develop

* add unittest, test=develop
上级 479c8834
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
// limitations under the License. // limitations under the License.
#include <stdio.h> #include <stdio.h>
#include <cassert> #include <cassert>
#include <vector> #include <vector>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
...@@ -55,24 +57,23 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs, ...@@ -55,24 +57,23 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs,
// const float *alpha = reinterpret_cast<const float *>(alpha_.get().values); // const float *alpha = reinterpret_cast<const float *>(alpha_.get().values);
const float *alpha = p_gpu_weight_; const float *alpha = p_gpu_weight_;
float *output = reinterpret_cast<float **>(outputs)[0]; float *output = reinterpret_cast<float **>(outputs)[0];
int numel = 1;
std::vector<int> input_shape;
input_shape.push_back(batch_size);
for (int i = 0; i < input_dims.nbDims; i++) { for (int i = 0; i < input_dims.nbDims; i++) {
input_shape.push_back(input_dims.d[i]); numel *= input_dims.d[i];
} }
if (mode_ == "channel") { if (mode_ == "channel") {
operators::math::PreluChannelWiseDirectCUDAFunctor<float> operators::math::PreluChannelWiseDirectCUDAFunctor<float>
prelu_channel_wise; prelu_channel_wise;
prelu_channel_wise(stream, input, alpha, output, input_shape); prelu_channel_wise(stream, input, alpha, output, input_dims.d[0],
input_dims.d[1], numel);
} else if (mode_ == "element") { } else if (mode_ == "element") {
operators::math::PreluElementWiseDirectCUDAFunctor<float> operators::math::PreluElementWiseDirectCUDAFunctor<float>
prelu_element_wise; prelu_element_wise;
prelu_element_wise(stream, input, alpha, output, input_shape); prelu_element_wise(stream, input, alpha, output, input_dims.d[0], numel);
} else { } else {
operators::math::PreluScalarDirectCUDAFunctor<float> prelu_scalar; operators::math::PreluScalarDirectCUDAFunctor<float> prelu_scalar;
prelu_scalar(stream, input, alpha, output, input_shape); prelu_scalar(stream, input, alpha, output, numel);
} }
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
...@@ -133,23 +134,23 @@ int PReluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, ...@@ -133,23 +134,23 @@ int PReluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
const float *alpha = p_gpu_weight_; const float *alpha = p_gpu_weight_;
const float *input = static_cast<const float *>(inputs[0]); const float *input = static_cast<const float *>(inputs[0]);
float *output = static_cast<float *>(outputs[0]); float *output = static_cast<float *>(outputs[0]);
int numel = 1;
std::vector<int> input_shape;
for (int i = 0; i < input_dims.nbDims; i++) { for (int i = 0; i < input_dims.nbDims; i++) {
input_shape.push_back(input_dims.d[i]); numel *= input_dims.d[i];
} }
if (mode_ == "channel") { if (mode_ == "channel") {
operators::math::PreluChannelWiseDirectCUDAFunctor<float> operators::math::PreluChannelWiseDirectCUDAFunctor<float>
prelu_channel_wise; prelu_channel_wise;
prelu_channel_wise(stream, input, alpha, output, input_shape); prelu_channel_wise(stream, input, alpha, output, input_dims.d[0],
input_dims.d[1], numel);
} else if (mode_ == "element") { } else if (mode_ == "element") {
operators::math::PreluElementWiseDirectCUDAFunctor<float> operators::math::PreluElementWiseDirectCUDAFunctor<float>
prelu_element_wise; prelu_element_wise;
prelu_element_wise(stream, input, alpha, output, input_shape); prelu_element_wise(stream, input, alpha, output, input_dims.d[0], numel);
} else { } else {
operators::math::PreluScalarDirectCUDAFunctor<float> prelu_scalar; operators::math::PreluScalarDirectCUDAFunctor<float> prelu_scalar;
prelu_scalar(stream, input, alpha, output, input_shape); prelu_scalar(stream, input, alpha, output, numel);
} }
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
......
...@@ -21,8 +21,8 @@ namespace math { ...@@ -21,8 +21,8 @@ namespace math {
#define CUDA_NUM_THREADS 1024 #define CUDA_NUM_THREADS 1024
// CUDA: grid stride looping // CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \ #define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x) i += blockDim.x * gridDim.x)
inline static int PADDLE_GET_BLOCKS(const int N) { inline static int PADDLE_GET_BLOCKS(const int N) {
...@@ -33,7 +33,6 @@ template <typename T> ...@@ -33,7 +33,6 @@ template <typename T>
__global__ void PReluChannelWiseKernel(const T *input, const T *alpha, __global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
T *output, size_t channel_num, T *output, size_t channel_num,
size_t plane_size, size_t numel) { size_t plane_size, size_t numel) {
size_t index;
CUDA_KERNEL_LOOP(index, numel) { CUDA_KERNEL_LOOP(index, numel) {
size_t temp = index / plane_size; size_t temp = index / plane_size;
size_t channel_index = temp % channel_num; size_t channel_index = temp % channel_num;
...@@ -47,7 +46,6 @@ template <typename T> ...@@ -47,7 +46,6 @@ template <typename T>
__global__ void PReluElementWiseKernel(const T *input, const T *alpha, __global__ void PReluElementWiseKernel(const T *input, const T *alpha,
T *output, size_t spatial_size, T *output, size_t spatial_size,
size_t numel) { size_t numel) {
size_t index;
CUDA_KERNEL_LOOP(index, numel) { CUDA_KERNEL_LOOP(index, numel) {
size_t element_index = index % spatial_size; size_t element_index = index % spatial_size;
T scale = alpha[element_index]; T scale = alpha[element_index];
...@@ -60,7 +58,6 @@ template <typename T> ...@@ -60,7 +58,6 @@ template <typename T>
__global__ void PReluScalarKernel(const T *input, const T *alpha, T *output, __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output,
size_t numel) { size_t numel) {
T scale = alpha[0]; T scale = alpha[0];
size_t index;
CUDA_KERNEL_LOOP(index, numel) { CUDA_KERNEL_LOOP(index, numel) {
T x = input[index]; T x = input[index];
output[index] = (x > 0) ? x : scale * x; output[index] = (x > 0) ? x : scale * x;
...@@ -70,34 +67,27 @@ __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output, ...@@ -70,34 +67,27 @@ __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output,
template <typename T> template <typename T>
void PreluChannelWiseDirectCUDAFunctor<T>::operator()( void PreluChannelWiseDirectCUDAFunctor<T>::operator()(
cudaStream_t stream, const T *input, const T *alpha, T *output, cudaStream_t stream, const T *input, const T *alpha, T *output,
std::vector<int> input_shape) { size_t batch_size, size_t channel, size_t numel) {
size_t plane_size = input_shape[2] * input_shape[3];
size_t spatial_size = input_shape[1] * plane_size;
size_t numel = input_shape[0] * spatial_size;
PReluChannelWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, PReluChannelWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
stream>>>(input, alpha, output, input_shape[1], stream>>>(input, alpha, output, channel,
plane_size, numel); numel / batch_size / channel, numel);
} }
template <typename T> template <typename T>
void PreluElementWiseDirectCUDAFunctor<T>::operator()( void PreluElementWiseDirectCUDAFunctor<T>::operator()(cudaStream_t stream,
cudaStream_t stream, const T *input, const T *alpha, T *output, const T *input,
std::vector<int> input_shape) { const T *alpha, T *output,
size_t plane_size = input_shape[2] * input_shape[3]; size_t batch_size,
size_t spatial_size = input_shape[1] * plane_size; size_t numel) {
size_t numel = input_shape[0] * spatial_size;
PReluElementWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, PReluElementWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
stream>>>(input, alpha, output, spatial_size, numel); stream>>>(input, alpha, output, numel / batch_size,
numel);
} }
template <typename T> template <typename T>
void PreluScalarDirectCUDAFunctor<T>::operator()(cudaStream_t stream, void PreluScalarDirectCUDAFunctor<T>::operator()(cudaStream_t stream,
const T *input, const T *alpha, const T *input, const T *alpha,
T *output, T *output, size_t numel) {
std::vector<int> input_shape) {
size_t plane_size = input_shape[2] * input_shape[3];
size_t spatial_size = input_shape[1] * plane_size;
size_t numel = input_shape[0] * spatial_size;
PReluScalarKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>( PReluScalarKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, numel); input, alpha, output, numel);
} }
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
...@@ -26,21 +27,21 @@ template <typename T> ...@@ -26,21 +27,21 @@ template <typename T>
class PreluChannelWiseDirectCUDAFunctor { class PreluChannelWiseDirectCUDAFunctor {
public: public:
void operator()(cudaStream_t stream, const T *input, const T *alpha, void operator()(cudaStream_t stream, const T *input, const T *alpha,
T *output, std::vector<int> input_shape); T *output, size_t batch_size, size_t channel, size_t numel);
}; };
template <typename T> template <typename T>
class PreluElementWiseDirectCUDAFunctor { class PreluElementWiseDirectCUDAFunctor {
public: public:
void operator()(cudaStream_t stream, const T *input, const T *alpha, void operator()(cudaStream_t stream, const T *input, const T *alpha,
T *output, std::vector<int> input_shape); T *output, size_t batch_size, size_t numel);
}; };
template <typename T> template <typename T>
class PreluScalarDirectCUDAFunctor { class PreluScalarDirectCUDAFunctor {
public: public:
void operator()(cudaStream_t stream, const T *input, const T *alpha, void operator()(cudaStream_t stream, const T *input, const T *alpha,
T *output, std::vector<int> input_shape); T *output, size_t numel);
}; };
#endif #endif
......
...@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and ...@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/prelu_op.h" #include "paddle/fluid/operators/prelu_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -43,10 +44,23 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -43,10 +44,23 @@ class PReluOp : public framework::OperatorWithKernel {
"equal to the number of channels of input(x). But " "equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[1]: %d", "recevied alpha's size: %d, x_dim[1]: %d",
product(ctx->GetInputDim("Alpha")), x_dim[1])); product(ctx->GetInputDim("Alpha")), x_dim[1]));
auto x_rank = x_dim.size();
PADDLE_ENFORCE_GE(x_rank, 2,
platform::errors::InvalidArgument(
"For mode 'channel', rank of input X must be "
"equal or larger than 2. But recevied X's "
"rank: %d",
x_rank));
} else if (mode == "element") { } else if (mode == "element") {
auto alpha_dim = ctx->GetInputDim("Alpha"); auto alpha_dim = ctx->GetInputDim("Alpha");
auto alpha_rank = alpha_dim.size(); auto alpha_rank = alpha_dim.size();
auto x_rank = x_dim.size(); auto x_rank = x_dim.size();
PADDLE_ENFORCE_GE(x_rank, 1,
platform::errors::InvalidArgument(
"For mode 'element', rank of input X must be "
"equal or larger than 2. But recevied X's "
"rank: %d",
x_rank));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
alpha_rank, x_rank, alpha_rank, x_rank,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
......
...@@ -11,6 +11,7 @@ limitations under the License. */ ...@@ -11,6 +11,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/prelu.h" #include "paddle/fluid/operators/math/prelu.h"
#include "paddle/fluid/operators/prelu_op.h" #include "paddle/fluid/operators/prelu_op.h"
...@@ -49,20 +50,22 @@ class CUDAPReluKernel : public framework::OpKernel<T> { ...@@ -49,20 +50,22 @@ class CUDAPReluKernel : public framework::OpKernel<T> {
int numel = x->numel(); int numel = x->numel();
auto dim = x->dims(); auto dim = x->dims();
std::vector<int> input_shape = framework::vectorize<int>(dim);
VLOG(4) << "dim[0]:" << dim[0] << ", dim[1]:" << dim[1]
<< ", numel:" << numel;
if (mode == "channel") { if (mode == "channel") {
math::PreluChannelWiseDirectCUDAFunctor<T> prelu_channel_wise; math::PreluChannelWiseDirectCUDAFunctor<T> prelu_channel_wise;
prelu_channel_wise(context.cuda_device_context().stream(), x_ptr, prelu_channel_wise(context.cuda_device_context().stream(), x_ptr,
alpha_ptr, o_ptr, input_shape); alpha_ptr, o_ptr, dim[0], dim[1], numel);
} else if (mode == "element") { } else if (mode == "element") {
math::PreluElementWiseDirectCUDAFunctor<T> prelu_element_wise; math::PreluElementWiseDirectCUDAFunctor<T> prelu_element_wise;
prelu_element_wise(context.cuda_device_context().stream(), x_ptr, prelu_element_wise(context.cuda_device_context().stream(), x_ptr,
alpha_ptr, o_ptr, input_shape); alpha_ptr, o_ptr, dim[0], numel);
} else { } else {
math::PreluScalarDirectCUDAFunctor<T> prelu_scalar; math::PreluScalarDirectCUDAFunctor<T> prelu_scalar;
prelu_scalar(context.cuda_device_context().stream(), x_ptr, alpha_ptr, prelu_scalar(context.cuda_device_context().stream(), x_ptr, alpha_ptr,
o_ptr, input_shape); o_ptr, numel);
} }
} }
}; };
...@@ -75,7 +78,6 @@ __global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr, ...@@ -75,7 +78,6 @@ __global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr,
size_t channel_num, size_t plane_size, size_t channel_num, size_t plane_size,
size_t spatial_size, size_t numel, size_t spatial_size, size_t numel,
PRELU_MODE mode) { PRELU_MODE mode) {
size_t index;
CUDA_KERNEL_LOOP(index, numel) { CUDA_KERNEL_LOOP(index, numel) {
T scale; T scale;
if (mode == Element) { if (mode == Element) {
...@@ -99,14 +101,18 @@ template <typename T> ...@@ -99,14 +101,18 @@ template <typename T>
class PreluOpGradFunctor { class PreluOpGradFunctor {
public: public:
void operator()(cudaStream_t stream, const T* x, const T* alpha, const T* dy, void operator()(cudaStream_t stream, const T* x, const T* alpha, const T* dy,
T* dx, T* dalpha, std::vector<int> input_shape, T* dx, T* dalpha, const framework::DDim& input_dims,
PRELU_MODE mode) { PRELU_MODE mode) {
size_t plane_size = input_shape[2] * input_shape[3]; size_t numel = 1;
size_t spatial_size = plane_size * input_shape[1]; for (size_t i = 0; i < input_dims.size(); ++i) {
size_t numel = spatial_size * input_shape[0]; numel *= input_dims[i];
}
size_t plane_size = numel / input_dims[0] / input_dims[1];
size_t spatial_size = numel / input_dims[0];
PReluOpGradKernel< PReluOpGradKernel<
T><<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>( T><<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>(
x, alpha, dy, dx, dalpha, input_shape[1], plane_size, spatial_size, x, alpha, dy, dx, dalpha, input_dims[1], plane_size, spatial_size,
numel, mode); numel, mode);
} }
}; };
...@@ -161,13 +167,13 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> { ...@@ -161,13 +167,13 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
m = Scalar; m = Scalar;
} }
PreluOpGradFunctor<T> prelu_grad; PreluOpGradFunctor<T> prelu_grad;
prelu_grad(stream, x_ptr, alpha_ptr, dy_ptr, dx_ptr, dalpha_tmp_ptr, prelu_grad(stream, x_ptr, alpha_ptr, dy_ptr, dx_ptr, dalpha_tmp_ptr, dim,
input_shape, m); m);
if (dalpha_tmp_ptr == nullptr) return; if (dalpha_tmp_ptr == nullptr) return;
std::vector<int> reduce_dims; std::vector<int> reduce_dims;
for (size_t i = 0; i < input_shape.size(); i++) { for (size_t i = 0; i < dim.size(); i++) {
if (mode == "channel" && i == 1) continue; if (mode == "channel" && i == 1) continue;
if (mode == "element" && i != 0) continue; if (mode == "element" && i != 0) continue;
reduce_dims.push_back(i); reduce_dims.push_back(i);
......
...@@ -2297,7 +2297,10 @@ class PRelu(layers.Layer): ...@@ -2297,7 +2297,10 @@ class PRelu(layers.Layer):
assert isinstance( assert isinstance(
channel, channel,
int), "channel argument is required when mode is 'channel'." int), "channel argument is required when mode is 'channel'."
self._alpha_shape = [1, channel, 1, 1] #NOTE(zhiqiu): The _alpha_shape should be [1, channel] + [1] * len(input_shape[2:]), not [1, channel, 1, 1].
# However, the suffix 1 in the list is useless, since the tensor is viewed as one demension array during kernel calculation.
# And, input_shape is not required when mode is 'channel', so it is simplified.
self._alpha_shape = [1, channel]
elif mode == 'element': elif mode == 'element':
assert isinstance(input_shape, ( assert isinstance(input_shape, (
list, tuple list, tuple
......
...@@ -9626,10 +9626,20 @@ def prelu(x, mode, param_attr=None, name=None): ...@@ -9626,10 +9626,20 @@ def prelu(x, mode, param_attr=None, name=None):
if mode not in ['all', 'channel', 'element']: if mode not in ['all', 'channel', 'element']:
raise ValueError('mode should be one of all, channel, element.') raise ValueError('mode should be one of all, channel, element.')
alpha_shape = [1] alpha_shape = [1]
# NOTE(): The input of this API should be ``N,C,...`` format,
# which means x.shape[0] is batch_size and x.shape[0] is channel.
if mode == 'channel': if mode == 'channel':
alpha_shape = [1, x.shape[1], 1, 1] assert len(
x.shape
) >= 2, "The size of input shape should be equal or larger than 2 in prelu() when mode is 'channel'"
#NOTE(zhiqiu): The alpha_shape should be [1, channel] + [1] * len(x.shape[2:]).
# To be consistent with Prelu, it is simplified.
alpha_shape = [1, x.shape[1]]
elif mode == 'element': elif mode == 'element':
alpha_shape = [1, x.shape[1], x.shape[2], x.shape[3]] assert len(
x.shape
) >= 1, "The size of input shape should be equal or larger than 1 in prelu() when mode is 'element'"
alpha_shape = [1] + list(x.shape)[1:]
dtype = helper.input_dtype(input_param_name='x') dtype = helper.input_dtype(input_param_name='x')
alpha = helper.create_parameter( alpha = helper.create_parameter(
attr=helper.param_attr, attr=helper.param_attr,
......
...@@ -51,12 +51,18 @@ class PReluTest(OpTest): ...@@ -51,12 +51,18 @@ class PReluTest(OpTest):
if self.attrs == {'mode': "all"}: if self.attrs == {'mode': "all"}:
alpha_np = np.random.uniform(-1, -0.5, (1)) alpha_np = np.random.uniform(-1, -0.5, (1))
elif self.attrs == {'mode': "channel"}: elif self.attrs == {'mode': "channel"}:
alpha_np = np.random.uniform(-1, -0.5, (1, x_np.shape[1], 1, 1)) alpha_np = np.random.uniform(-1, -0.5, [1, self.x_shape[1]])
else: else:
alpha_np = np.random.uniform(-1, -0.5, \ alpha_np = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:])
(1, x_np.shape[1], x_np.shape[2], x_np.shape[3]))
self.inputs = {'X': x_np, 'Alpha': alpha_np} self.inputs = {'X': x_np, 'Alpha': alpha_np}
# NOTE(zhiqu): reshape inputs['Alpha'] from [1, 100] to [1, 100, 1, 1] since np operands could not be broadcast together with shapes (2,100,3,4) (1,100)
if self.attrs == {'mode': "channel"}:
self.inputs['Alpha'] = np.reshape(
self.inputs['Alpha'],
[1, self.x_shape[1]] + [1] * len(self.x_shape[2:]))
out_np = np.maximum(self.inputs['X'], 0.) out_np = np.maximum(self.inputs['X'], 0.)
out_np = out_np + np.minimum(self.inputs['X'], out_np = out_np + np.minimum(self.inputs['X'],
0.) * self.inputs['Alpha'] 0.) * self.inputs['Alpha']
...@@ -64,7 +70,7 @@ class PReluTest(OpTest): ...@@ -64,7 +70,7 @@ class PReluTest(OpTest):
self.outputs = {'Out': out_np} self.outputs = {'Out': out_np}
def init_input_shape(self): def init_input_shape(self):
self.x_shape = (2, 100, 3, 4) self.x_shape = [2, 100, 3, 4]
def init_attr(self): def init_attr(self):
self.attrs = {'mode': "channel"} self.attrs = {'mode': "channel"}
...@@ -81,7 +87,7 @@ class PReluTest(OpTest): ...@@ -81,7 +87,7 @@ class PReluTest(OpTest):
) )
class TestModeAll(PReluTest): class TestModeAll(PReluTest):
def init_input_shape(self): def init_input_shape(self):
self.x_shape = (2, 3, 4, 5) self.x_shape = [2, 3, 4, 5]
def init_attr(self): def init_attr(self):
self.attrs = {'mode': "all"} self.attrs = {'mode': "all"}
...@@ -89,7 +95,61 @@ class TestModeAll(PReluTest): ...@@ -89,7 +95,61 @@ class TestModeAll(PReluTest):
class TestModeElt(PReluTest): class TestModeElt(PReluTest):
def init_input_shape(self): def init_input_shape(self):
self.x_shape = (3, 2, 5, 10) self.x_shape = [3, 2, 5, 10]
def init_attr(self):
self.attrs = {'mode': "element"}
@skip_check_grad_ci(
reason="[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode"
)
class TestModeAllRank3(PReluTest):
def init_input_shape(self):
self.x_shape = [1, 200, 3]
def init_attr(self):
self.attrs = {'mode': "all"}
@skip_check_grad_ci(
reason="[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode"
)
class TestModeAllRank6(PReluTest):
def init_input_shape(self):
self.x_shape = [1, 2, 3, 4, 5, 6]
def init_attr(self):
self.attrs = {'mode': "all"}
class TestModeChannelRank3(PReluTest):
def init_input_shape(self):
self.x_shape = [1, 200, 3]
def init_attr(self):
self.attrs = {'mode': "channel"}
class TestModeChannelRank6(PReluTest):
def init_input_shape(self):
self.x_shape = [1, 100, 2, 2, 2, 2]
def init_attr(self):
self.attrs = {'mode': "channel"}
class TestModeElementRank3(PReluTest):
def init_input_shape(self):
self.x_shape = [3, 10, 10]
def init_attr(self):
self.attrs = {'mode': "element"}
class TestModeElementRank6(PReluTest):
def init_input_shape(self):
self.x_shape = [3, 2, 2, 4, 5, 2]
def init_attr(self): def init_attr(self):
self.attrs = {'mode': "element"} self.attrs = {'mode': "element"}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册