diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h index a1f307e0e14b8848fc0c5fdd1095fb66211b49e5..65ac570bc7aa07a1a06e9deffcf797d6ef5d2519 100644 --- a/paddle/fluid/imperative/infer_shape_context.h +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -285,10 +285,12 @@ class DygraphInferShapeContext : public framework::InferShapeContext { it, var_base_map_out_->end(), platform::errors::NotFound("can not find [%s] in output", name)); - PADDLE_ENFORCE_EQ(it->second.size(), dims.size(), - platform::errors::PreconditionNotMet( - "dim size [%d] is not match output var number [%d]", - dims.size(), it->second.size())); + PADDLE_ENFORCE_EQ(dims.size(), it->second.size(), + platform::errors::InvalidArgument( + "The number of dims is expected to be equal to the " + "number of Outputs(%s). But receieved: the number of " + "dims = %d, the number of Outputs(%s) = %d.", + name, dims.size(), name, it->second.size())); for (size_t i = 0; i < dims.size(); ++i) { if (it->second[i]) { diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 3106a20520f38aa7453aaa3f031b526261cb9281..57a5db88c53a393e80204ef94acc9fbc3472e2f7 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -30,10 +30,10 @@ limitations under the License. */ namespace paddle { namespace operators { -void ConvOp::InferShape(framework::InferShapeContext* ctx) const { +std::vector ConvOp::ComputeOutputShape( + framework::InferShapeContext* ctx) const { OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv"); OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "Conv"); - OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "Conv"); auto in_dims = ctx->GetInputDim("Input"); auto filter_dims = ctx->GetInputDim("Filter"); @@ -54,30 +54,30 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE_EQ( in_dims.size() == 4 || in_dims.size() == 5, true, platform::errors::InvalidArgument( - "The input of Op(conv) should be 4-D or 5-D Tensor. But " - "received: %u-D Tensor, the shape of input is [%s].", + "The input of Op(Conv) should be a 4-D or 5-D Tensor. But " + "received: input's dimension is %u, input's shape is [%s].", in_dims.size(), in_dims)); PADDLE_ENFORCE_EQ( in_dims.size(), filter_dims.size(), platform::errors::InvalidArgument( - "The input's dimension size and filter's dimension size of " - "Op(conv) should be equal. But received: the shape of input is [%s], " - "the dimension size of input is [%d], the shape of filter is [%s], " - "the dimension size of filter is [%d].", + "The input's dimension and filter's dimension of " + "Op(Conv) should be equal. But received: the input's shape is [%s], " + "the input's dimension is %d; the filter's shape is [%s], " + "the filter's dimension is %d.", in_dims, in_dims.size(), filter_dims, filter_dims.size())); int in_sub_stride_size = in_dims.size() - strides.size(); PADDLE_ENFORCE_EQ( in_dims.size(), strides.size() + 2U, platform::errors::InvalidArgument( - "The dimension size of input minus the size of " - "Attr(stride) must be euqal to 2 for Op(conv)." - "But received: the dimension size of input minus the size " - "of Attr(stride) is [%d], the " - "input's dimension size is [%d], the shape of input " - "is [%s], the Attr(stride)'s size is [%d].", - in_sub_stride_size, in_dims.size(), in_dims, strides.size())); + "The difference of input's dimension and Attr(strides)'s " + "length must be euqal to 2 for Op(Conv). " + "But received: input's dimension is %d, input's shape is [%s]; " + "Attr(stride)'s length is %d, Attr(stride) is [%s]; " + "difference of input's dimention and Attr(strides)'s length = %u.", + in_dims.size(), in_dims, strides.size(), + framework::make_ddim(strides), in_sub_stride_size)); const auto input_channels = channel_last ? in_dims[in_dims.size() - 1] : in_dims[1]; @@ -85,31 +85,31 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE_EQ( input_channels, filter_dims[1] * groups, platform::errors::InvalidArgument( - "The number of input channels should be equal to filter channels * " - "groups for Op(conv). But received: the input's channels is [%d], " - "the shape of input is [%s], the filter's channel is [%d], the shape " - "of filter is [%s], the groups is [%d], the data_format is %s. The " - "error may come from wrong data_format setting.", + "The number of input's channels should be equal to filter's channels " + "* groups for Op(Conv). But received: the input's channels is %d, " + "the input's shape is [%s]; the filter's channels is %d, the " + "filter's shape is [%s]; the groups is %d, the data_format is %s. " + "The error may come from wrong data_format setting.", input_channels, in_dims, filter_dims[1], filter_dims, groups, data_format)); PADDLE_ENFORCE_EQ( filter_dims[0] % groups, 0, platform::errors::InvalidArgument( - "The number of output channels of Op(conv) should be divided " - "by groups. But received: the output channels is [%d], the shape " - "of filter is [%s] (the first dimension of filter is output " - "channel), the groups is [%d].", + "The number of output's channels (filter's first dimension) of " + "Op(Conv) should be divided by groups. But received: " + "the output channels is %d, the filter's shape is [%s], " + "the groups is %d.", filter_dims[0], filter_dims, groups)); framework::DDim in_data_dims; - framework::DDim filter_data_dims; if (channel_last) { in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); } else { in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); } - filter_data_dims = framework::slice_ddim(filter_dims, 2, filter_dims.size()); + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); std::vector ksize = framework::vectorize(filter_data_dims); UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, @@ -133,8 +133,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { output_shape.push_back(filter_dims[0]); } - ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); - ctx->ShareLoD("Input", "Output"); + return output_shape; } framework::OpKernelType ConvOp::GetExpectedKernelType( diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index ec9adafaa0dd35a8264741946be7256248596048..85ef42548ab1f471f85f9e8272a676a2bbbf05d2 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -41,10 +41,13 @@ inline int ConvOutputSize(int input_size, int filter_size, int dilation, int output_size = (input_size + 2 * padding - dkernel) / stride + 1; PADDLE_ENFORCE_GT( output_size, 0, - "Due to the settings of padding(%d), filter_size(%d), dilation(%d) and " - "stride(%d), the output size is less than 0, please check " - "again. Input_size:%d", - padding, filter_size, dilation, stride, input_size); + platform::errors::InvalidArgument( + "The output's size is expected to be greater than 0. " + "But recieved: output's size is %d. The output's size is computed by " + "((input_size + 2 * padding - (dilation * (filter_size - 1) + 1)) / " + "stride + 1), where input_size is %d, padding is %d, " + "filter_size is %d, dilation is %d, stride is %d.", + output_size, input_size, padding, filter_size, dilation, stride)); return output_size; } @@ -53,13 +56,16 @@ inline int ConvOutputSize(int input_size, int filter_size, int dilation, int padding_1, int padding_2, int stride) { const int dkernel = dilation * (filter_size - 1) + 1; int output_size = (input_size + padding_1 + padding_2 - dkernel) / stride + 1; - PADDLE_ENFORCE_GT(output_size, 0, - "Due to the settings of padding(%d, %d), filter_size(%d), " - "dilation(%d) and " - "stride(%d), the output size is less than 0, please check " - "again. Input_size:%d", - padding_1, padding_2, filter_size, dilation, stride, - input_size); + PADDLE_ENFORCE_GT( + output_size, 0, + platform::errors::InvalidArgument( + "The output's size is expected to be greater than 0. " + "But recieved: output's size is %d. The output's size is computed by " + "((input_size + padding_1 + padding_2 - (dilation * (filter_size - " + "1) + 1)) / stride + 1), where input_size is %d, padding is " + "(%d, %d), filter_size is %d, dilation is %d, stride is %d.", + output_size, input_size, padding_1, padding_2, filter_size, dilation, + stride)); return output_size; } @@ -81,7 +87,13 @@ inline void UpdatePaddingAndDilation(std::vector* paddings, } else { PADDLE_ENFORCE_EQ( data_dims.size() * 2, paddings->size(), - "Paddings size should be the same or twice as the input data size."); + platform::errors::InvalidArgument( + "Attribute padding's size should be the same or twice as the " + "input's dimension. " + "But recieved: padding's size is %d, padding is [%s]; input's " + "dimension is %d, input's shape is [%s].", + paddings->size(), framework::make_ddim(*paddings), data_dims.size(), + data_dims)); } // when padding_algorithm is "VALID" or "SAME" @@ -252,9 +264,18 @@ class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { class ConvOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override; + void InferShape(framework::InferShapeContext* ctx) const override { + std::vector output_shape = ComputeOutputShape(ctx); + + OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "Conv"); + ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); + ctx->ShareLoD("Input", "Output"); + } protected: + std::vector ComputeOutputShape( + framework::InferShapeContext* ctx) const; + framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 9794ddc672ab3cc1c99257e03870c56caab0582d..f1a96f6f6ec6a146b7e81d64f975c10c1d96e4a2 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -30,7 +30,7 @@ class FillConstantOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_GE( shape[i], 0, platform::errors::InvalidArgument( - "Each value of attribute 'shape' is expected to be greater " + "Each value of attribute 'shape' is expected to be no less " "than 0. But recieved: shape[%u] = %d; shape = [%s].", i, shape[i], framework::make_ddim(shape))); } diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cc b/paddle/fluid/operators/fused/conv_fusion_op.cc index b60ae9127c9355a477ed84b4f8852876ba3f24a9..6b94f4ea5bdd2ffc1cdf7684568fe1a7b2928278 100644 --- a/paddle/fluid/operators/fused/conv_fusion_op.cc +++ b/paddle/fluid/operators/fused/conv_fusion_op.cc @@ -65,108 +65,62 @@ class Conv2DFusionOp : public operators::ConvOp { protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, - "Input(Input) of ConvOp should not be null."); - PADDLE_ENFORCE_EQ(ctx->HasInput("Filter"), true, - "Input(Filter) of ConvOp should not be null."); - auto in_dims = ctx->GetInputDim("Input"); - auto filter_dims = ctx->GetInputDim("Filter"); - - std::vector strides = ctx->Attrs().Get>("strides"); - std::vector paddings = ctx->Attrs().Get>("paddings"); - std::vector dilations = - ctx->Attrs().Get>("dilations"); - std::string padding_algorithm = - ctx->Attrs().Get("padding_algorithm"); - int groups = ctx->Attrs().Get("groups"); - - framework::DDim in_data_dims; - in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); - - PADDLE_ENFORCE_EQ( - in_dims.size() == 4 || in_dims.size() == 5, true, - "ShapeError: Conv_fusion input should be 4-D or 5-D tensor. But " - "received: %u-D Tensor," - "the shape of Conv_fusion input is [%s]", - in_dims.size(), in_dims); - - PADDLE_ENFORCE_EQ(in_dims.size(), filter_dims.size(), - "ShapeError: Conv_fusion input dimension and filter " - "dimension should be the " - "equal." - "But received: the shape of Conv_fusion input is [%s], " - "input dimension of Conv_fusion " - "input is [%d]," - "the shape of filter is [%s], the filter dimension of " - "Conv_fusion is [%d]", - in_dims, in_dims.size(), filter_dims, filter_dims.size()); - - int in_sub_stride_size = in_dims.size() - strides.size(); - PADDLE_ENFORCE_EQ( - in_dims.size() - strides.size() == 2U, true, - "ShapeError: the dimension of input minus the dimension of " - "stride must be euqal to 2." - "But received: the dimension of input minus the dimension " - "of stride is [%d], the" - "input dimension of Conv_fusion is [%d], the shape of Conv_fusion " - "input " - "is [%s], the stride" - "dimension of Conv_fusion is [%d]", - in_sub_stride_size, in_dims.size(), in_dims, strides.size()); - - const auto input_channels = in_dims[1]; + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv2DFusion"); + OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "Conv2DFusion"); + auto in_dims = ctx->GetInputDim("Input"); PADDLE_ENFORCE_EQ( - input_channels, filter_dims[1] * groups, - "ShapeError: The number of input channels should be equal to filter " - "channels * groups. But received: the input channels is [%d], the shape" - "of input is [%s], the filter channel is [%d], the shape of filter is " - "[%s]," - "the groups is [%d]", - in_dims[1], in_dims, filter_dims[1], filter_dims, groups); - PADDLE_ENFORCE_EQ( - filter_dims[0] % groups, 0, - "ShapeError: The number of output channels should be divided by groups." - "But received: the output channels is [%d], the shape of filter is [%s]" - "(the first dimension of filter is output channel), the groups is [%d]", - filter_dims[0], filter_dims, groups); - - framework::DDim filter_data_dims = - framework::slice_ddim(filter_dims, 2, filter_dims.size()); - std::vector ksize = framework::vectorize(filter_data_dims); - UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, - in_data_dims, strides, ksize); - - std::vector output_shape({in_dims[0]}); - output_shape.push_back(filter_dims[0]); - - for (int i = 0; i < in_data_dims.size(); ++i) { - if ((!ctx->IsRuntime()) && - (in_data_dims[i] <= 0 || filter_dims[i + 2] <= 0)) { - output_shape.push_back(-1); - } else { - output_shape.push_back( - ConvOutputSize(in_data_dims[i], filter_dims[i + 2], dilations[i], - paddings[2 * i], paddings[2 * i + 1], strides[i])); - } - } - - PADDLE_ENFORCE_EQ(ctx->HasOutput("Output"), true, - "Output(Output) of ConvOp should not be null."); + in_dims.size(), 4U, + platform::errors::InvalidArgument( + "The input's dimension of Operator(Conv2DFusion) is expected " + "to be 4. But received: input's dimension = %u, shape = [%s].", + in_dims.size(), in_dims)); + + // In some case, attribute data_format is "AnyLayout". + std::string data_format = ctx->Attrs().Get("data_format"); + PADDLE_ENFORCE_NE( + data_format, "NHWC", + platform::errors::PermissionDenied( + "Operator(Conv2DFusion) only supports data format of " + "channel first (NCHW) now. But recieved: data_format = '%s'.", + data_format)); + + std::vector output_shape = ComputeOutputShape(ctx); ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); + ctx->ShareLoD("Input", "Output"); - std::vector channels = + std::vector split_channels = ctx->Attrs().Get>("split_channels"); - if (channels.size()) { - PADDLE_ENFORCE_EQ(ctx->HasOutputs("Outputs"), true, - "Output(Outputs) of ConvOp should not be null."); - std::vector oshapes; - oshapes.reserve(channels.size()); - for (size_t i = 0; i < channels.size(); ++i) { - oshapes.push_back( - {output_shape[0], channels[i], output_shape[2], output_shape[3]}); + if (split_channels.size()) { + OP_INOUT_CHECK(ctx->HasOutputs("Outputs"), "Output", "Outputs", + "Conv2DFusion"); + PADDLE_ENFORCE_EQ( + ctx->Outputs("Outputs").size(), split_channels.size(), + platform::errors::InvalidArgument( + "The number of Output(Outputs) of operator 'Conv2DFusion' is " + "expected to be equal to the length of Attr(split_channels). But " + "reiceved: the number of Output(Outputs) = %u; the length of " + "Attr(split_channels) = %u, the content = [%s].", + ctx->Outputs("Outputs").size(), split_channels.size(), + framework::make_ddim(split_channels))); + + int split_channels_sum = 0; + std::vector output_shapes(split_channels.size()); + for (size_t i = 0; i < split_channels.size(); ++i) { + split_channels_sum += split_channels[i]; + output_shapes[i] = + framework::make_ddim({output_shape[0], split_channels[i], + output_shape[2], output_shape[3]}); } - ctx->SetOutputsDim("Outputs", oshapes); + PADDLE_ENFORCE_EQ( + split_channels_sum, output_shape[1], + platform::errors::InvalidArgument( + "The sum of Attr(split_channels) is expected to be equal to the " + "total output channels. But recieved: the sum of " + "Attr(split_channels) = %d, the total output channels = %d.", + split_channels_sum, output_shape[1])); + + ctx->SetOutputsDim("Outputs", output_shapes); } } }; diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cu b/paddle/fluid/operators/fused/conv_fusion_op.cu index c4844300eaa67791d0c5a9edef408835e2648556..f2068d3aadf77594deb409a130e770a745b96741 100644 --- a/paddle/fluid/operators/fused/conv_fusion_op.cu +++ b/paddle/fluid/operators/fused/conv_fusion_op.cu @@ -46,7 +46,6 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { auto* input = ctx.Input("Input"); auto* filter = ctx.Input("Filter"); auto* bias = ctx.Input("Bias"); - PADDLE_ENFORCE_NOT_NULL(bias, "The bias should not be null."); auto* residual = ctx.Input("ResidualData"); auto* output = ctx.Output("Output"); output->mutable_data(ctx.GetPlace()); @@ -61,28 +60,25 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { bool exhaustive_search = FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); - // const T* input_data = input->data(); const T* filter_data = filter->data(); const T* bias_data = bias->data(); - // T* output_data = output->mutable_data(ctx.GetPlace()); const std::string padding_algorithm = ctx.Attr("padding_algorithm"); - const std::string data_format = ctx.Attr("data_format"); Tensor transformed_input_channel(input->type()); Tensor transformed_output(output->type()); - T* output_data = nullptr; - transformed_input_channel = *input; transformed_output = *output; - output_data = transformed_output.data(); + T* output_data = transformed_output.data(); + const T* residual_data = residual ? residual->data() : output_data; + // update padding and dilation auto in_dims = transformed_input_channel.dims(); auto filter_dims = filter->dims(); - framework::DDim in_data_dims; - in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + framework::DDim in_data_dims = + framework::slice_ddim(in_dims, 2, in_dims.size()); framework::DDim filter_data_dims = framework::slice_ddim(filter_dims, 2, filter_dims.size()); @@ -134,7 +130,10 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { &transformed_input); } break; default: - PADDLE_THROW("ConvOp only support tensors with 4 or 5 dimensions."); + PADDLE_THROW(platform::errors::PermissionDenied( + "Operator Conv2DFusion expects Input to be a 4-D or 5-D Tensor. " + "But recieved the actual dimension = %d, shape = [%s].", + rank, transformed_input_channel.dims())); } } else { @@ -168,7 +167,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { conv_desc.descriptor(padding_common, strides, dilations); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnSetConvolutionGroupCount(cudnn_conv_desc, - groups)); + groups), + platform::errors::External( + "Call of cudnnSetConvolutionGroupCount(cudnn_conv_desc, groups) " + "failed, where cudnn_conv_desc is configured: padding = [%s], " + "strides = [%s], dilations = [%s]; groups = %d", + framework::make_ddim(padding_common), framework::make_ddim(strides), + framework::make_ddim(dilations), groups)); cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( layout, framework::vectorize(transformed_input.dims())); @@ -199,8 +204,15 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { auto handle = dev_ctx.cudnn_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( - cudnn_conv_desc, CUDNN_DEFAULT_MATH)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnSetConvolutionMathType(cudnn_conv_desc, + CUDNN_DEFAULT_MATH), + platform::errors::External( + "Call of cudnnSetConvolutionMathType(cudnn_conv_desc, " + "CUDNN_DEFAULT_MATH) failed, where cudnn_conv_desc is configured: " + "padding = %d, strides = %d, dilations = %d.", + framework::make_ddim(padding_common), framework::make_ddim(strides), + framework::make_ddim(dilations))); auto x_dims = framework::vectorize(transformed_input.dims()); auto f_dims = framework::vectorize(filter->dims()); @@ -209,7 +221,9 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { platform::dynload::cudnnGetConvolutionForwardAlgorithm( handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &algo)); + workspace_size_limit, &algo), + platform::errors::External( + "Call of cudnnGetConvolutionForwardAlgorithm failed.")); VLOG(3) << "cuDNN forward algo " << algo; } else { std::function search_func = @@ -223,7 +237,9 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { handle, cudnn_input_desc, input_data, cudnn_filter_desc, filter_data, cudnn_conv_desc, cudnn_output_desc, output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count, - fwd_perf_stat.data(), cudnn_workspace, workspace_size_limit)); + fwd_perf_stat.data(), cudnn_workspace, workspace_size_limit), + platform::errors::External( + "Call of cudnnFindConvolutionForwardAlgorithmEx failed.")); }; workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit); VLOG(3) << "Perf result: (algo: stat, time, memory)"; @@ -257,9 +273,16 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, - cudnn_output_desc, algo, &workspace_size_in_bytes)); - PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit, - "workspace_size to be allocated exceeds the limit"); + cudnn_output_desc, algo, &workspace_size_in_bytes), + platform::errors::External( + "Call of cudnnGetConvolutionForwardWorkspaceSize failed.")); + PADDLE_ENFORCE_LE( + workspace_size_in_bytes, workspace_size_limit, + platform::errors::InvalidArgument( + "The actual workspace size to be allocated for cuDNN is expected " + "to be less than the limit. But recieved: the actual workspace " + "size = %d, limit = %d.", + workspace_size_in_bytes, workspace_size_limit)); if ((activation == "identity") && (!residual)) { // Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is @@ -269,15 +292,20 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { // ------------- cudnn conv forward and bias add --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; auto cudnn_func = [&](void* cudnn_workspace) { - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnConvolutionForward( - handle, &alpha, cudnn_input_desc, input_data, cudnn_filter_desc, - filter_data, cudnn_conv_desc, algo, cudnn_workspace, - workspace_size_in_bytes, &beta, cudnn_output_desc, output_data)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnConvolutionForward( + handle, &alpha, cudnn_input_desc, input_data, cudnn_filter_desc, + filter_data, cudnn_conv_desc, algo, cudnn_workspace, + workspace_size_in_bytes, &beta, cudnn_output_desc, output_data), + platform::errors::External( + "Call of cudnnConvolutionForward failed.")); }; workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnAddTensor( - handle, &alpha, cudnn_bias_desc, bias_data, &alpha, cudnn_output_desc, - output_data)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnAddTensor(handle, &alpha, cudnn_bias_desc, + bias_data, &alpha, + cudnn_output_desc, output_data), + platform::errors::External("Call of cudnnAddTensor failed.")); } else { if (activation == "identity") { algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; @@ -292,7 +320,9 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { cudnn_filter_desc, filter_data, cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes, &alpha2, cudnn_output_desc, residual_data, cudnn_bias_desc, bias_data, - cudnn_act_desc, cudnn_output_desc, output_data)); + cudnn_act_desc, cudnn_output_desc, output_data), + platform::errors::External( + "Call of cudnnConvolutionBiasActivationForward failed.")); }; workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); } @@ -314,7 +344,10 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { } } else { // TODO(qingiqng): do copy when batch size large than 1 - PADDLE_THROW("Batch size greater than 1 is Unsupported"); + PADDLE_THROW(platform::errors::Unimplemented( + "Input with batch size greater than 1 is unsupported. The recieved " + "batch size is %d, Input's shape is [%s].", + x_dims[0], framework::make_ddim(x_dims))); } } } diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py index 94ea62d793212064e0fbabaf54dc50d7f25dd6ad..dd1e69f74b3c321aa044b8aa3655b2de70ea87d5 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py @@ -24,14 +24,14 @@ from test_conv2d_op import conv2d_forward_naive def create_test_padding_SAME_class(parent): - class TestPaddingSMAECase(parent): + class TestPaddingSAMECase(parent): def init_paddings(self): self.pad = [0, 0] self.padding_algorithm = "SAME" cls_name = "{0}_{1}".format(parent.__name__, "PaddingSAMEOp") - TestPaddingSMAECase.__name__ = cls_name - globals()[cls_name] = TestPaddingSMAECase + TestPaddingSAMECase.__name__ = cls_name + globals()[cls_name] = TestPaddingSAMECase def create_test_padding_VALID_class(parent): @@ -52,16 +52,15 @@ class TestConv2dFusionOp(OpTest): self.data_format = "NCHW" self.dtype = np.float32 self.activation = 'relu' - self.add_bias = True self.add_residual_data = True - self.channels = None + self.split_channels = None self.outputs = None self.padding_algorithm = "EXIPLICIT" self.init_group() self.init_dilation() self.init_test_case() - self.init_bias_residual() + self.init_residual() self.init_activation() self.init_paddings() self.set_search_method() @@ -74,6 +73,7 @@ class TestConv2dFusionOp(OpTest): input = np.random.random(self.input_size).astype(self.dtype) filter = np.random.random(self.filter_size).astype(self.dtype) + bias = np.random.random(self.filter_size[0]).astype(self.dtype) self.output, _, _, _, _ = conv2d_forward_naive( input, filter, self.groups, conv2d_param, self.padding_algorithm, @@ -83,7 +83,8 @@ class TestConv2dFusionOp(OpTest): self.inputs = { 'Input': OpTest.np_dtype_to_fluid_dtype(input), - 'Filter': OpTest.np_dtype_to_fluid_dtype(filter) + 'Filter': OpTest.np_dtype_to_fluid_dtype(filter), + 'Bias': OpTest.np_dtype_to_fluid_dtype(bias) } if self.add_residual_data: @@ -93,10 +94,8 @@ class TestConv2dFusionOp(OpTest): residual_data) self.output += residual_data - if self.add_bias: - bias = np.random.random(self.filter_size[0]).astype(self.dtype) - self.inputs['Bias'] = OpTest.np_dtype_to_fluid_dtype(bias) - self.output = self.output + bias.reshape((1, bias.size, 1, 1)) + # Add bias + self.output = self.output + bias.reshape((1, bias.size, 1, 1)) assert self.activation in ['relu', 'identity'] if self.activation == 'relu': @@ -110,9 +109,11 @@ class TestConv2dFusionOp(OpTest): 'data_format': self.data_format, 'exhaustive_search': self.exhaustive_search, 'activation': self.activation, - 'split_channels': self.channels, 'padding_algorithm': self.padding_algorithm } + if self.split_channels is not None: + self.attrs['split_channels'] = self.split_channels + self.outputs = {'Output': self.output} self.set_outputs() @@ -124,8 +125,6 @@ class TestConv2dFusionOp(OpTest): if self.has_cuda(): place = core.CUDAPlace(0) self.check_output_with_place(place, atol=1e-5) - else: - pass def init_test_case(self): self.pad = [0, 0] @@ -141,8 +140,7 @@ class TestConv2dFusionOp(OpTest): def init_group(self): self.groups = 1 - def init_bias_residual(self): - self.add_bias = True + def init_residual(self): self.add_residual_data = True def init_activation(self): @@ -160,7 +158,7 @@ class TestConv2dFusionOp(OpTest): class TestWithoutResidual(TestConv2dFusionOp): - def init_bias_residual(self): + def init_residual(self): self.add_residual_data = False @@ -209,7 +207,7 @@ class TestMultipleOutputs(TestConv2dFusionOp): assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups self.filter_size = [126, f_c, 3, 3] - self.channels = [84, 42] + self.split_channels = [84, 42] def set_outputs(self): out1 = self.output[:, 0:84, :, :]