未验证 提交 8d0b0cb4 编写于 作者: Y Yiqun Liu 提交者: GitHub

Op(conv2d_fusion) error message enhancement. (#23596)

上级 21eca836
...@@ -285,10 +285,12 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -285,10 +285,12 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
it, var_base_map_out_->end(), it, var_base_map_out_->end(),
platform::errors::NotFound("can not find [%s] in output", name)); platform::errors::NotFound("can not find [%s] in output", name));
PADDLE_ENFORCE_EQ(it->second.size(), dims.size(), PADDLE_ENFORCE_EQ(dims.size(), it->second.size(),
platform::errors::PreconditionNotMet( platform::errors::InvalidArgument(
"dim size [%d] is not match output var number [%d]", "The number of dims is expected to be equal to the "
dims.size(), it->second.size())); "number of Outputs(%s). But receieved: the number of "
"dims = %d, the number of Outputs(%s) = %d.",
name, dims.size(), name, it->second.size()));
for (size_t i = 0; i < dims.size(); ++i) { for (size_t i = 0; i < dims.size(); ++i) {
if (it->second[i]) { if (it->second[i]) {
......
...@@ -30,10 +30,10 @@ limitations under the License. */ ...@@ -30,10 +30,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
void ConvOp::InferShape(framework::InferShapeContext* ctx) const { std::vector<int64_t> ConvOp::ComputeOutputShape(
framework::InferShapeContext* ctx) const {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv"); OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv");
OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "Conv"); OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "Conv");
OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "Conv");
auto in_dims = ctx->GetInputDim("Input"); auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter"); auto filter_dims = ctx->GetInputDim("Filter");
...@@ -54,30 +54,30 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -54,30 +54,30 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_dims.size() == 4 || in_dims.size() == 5, true, in_dims.size() == 4 || in_dims.size() == 5, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The input of Op(conv) should be 4-D or 5-D Tensor. But " "The input of Op(Conv) should be a 4-D or 5-D Tensor. But "
"received: %u-D Tensor, the shape of input is [%s].", "received: input's dimension is %u, input's shape is [%s].",
in_dims.size(), in_dims)); in_dims.size(), in_dims));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_dims.size(), filter_dims.size(), in_dims.size(), filter_dims.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The input's dimension size and filter's dimension size of " "The input's dimension and filter's dimension of "
"Op(conv) should be equal. But received: the shape of input is [%s], " "Op(Conv) should be equal. But received: the input's shape is [%s], "
"the dimension size of input is [%d], the shape of filter is [%s], " "the input's dimension is %d; the filter's shape is [%s], "
"the dimension size of filter is [%d].", "the filter's dimension is %d.",
in_dims, in_dims.size(), filter_dims, filter_dims.size())); in_dims, in_dims.size(), filter_dims, filter_dims.size()));
int in_sub_stride_size = in_dims.size() - strides.size(); int in_sub_stride_size = in_dims.size() - strides.size();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_dims.size(), strides.size() + 2U, in_dims.size(), strides.size() + 2U,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The dimension size of input minus the size of " "The difference of input's dimension and Attr(strides)'s "
"Attr(stride) must be euqal to 2 for Op(conv)." "length must be euqal to 2 for Op(Conv). "
"But received: the dimension size of input minus the size " "But received: input's dimension is %d, input's shape is [%s]; "
"of Attr(stride) is [%d], the " "Attr(stride)'s length is %d, Attr(stride) is [%s]; "
"input's dimension size is [%d], the shape of input " "difference of input's dimention and Attr(strides)'s length = %u.",
"is [%s], the Attr(stride)'s size is [%d].", in_dims.size(), in_dims, strides.size(),
in_sub_stride_size, in_dims.size(), in_dims, strides.size())); framework::make_ddim(strides), in_sub_stride_size));
const auto input_channels = const auto input_channels =
channel_last ? in_dims[in_dims.size() - 1] : in_dims[1]; channel_last ? in_dims[in_dims.size() - 1] : in_dims[1];
...@@ -85,31 +85,31 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -85,31 +85,31 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_channels, filter_dims[1] * groups, input_channels, filter_dims[1] * groups,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The number of input channels should be equal to filter channels * " "The number of input's channels should be equal to filter's channels "
"groups for Op(conv). But received: the input's channels is [%d], " "* groups for Op(Conv). But received: the input's channels is %d, "
"the shape of input is [%s], the filter's channel is [%d], the shape " "the input's shape is [%s]; the filter's channels is %d, the "
"of filter is [%s], the groups is [%d], the data_format is %s. The " "filter's shape is [%s]; the groups is %d, the data_format is %s. "
"error may come from wrong data_format setting.", "The error may come from wrong data_format setting.",
input_channels, in_dims, filter_dims[1], filter_dims, groups, input_channels, in_dims, filter_dims[1], filter_dims, groups,
data_format)); data_format));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
filter_dims[0] % groups, 0, filter_dims[0] % groups, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The number of output channels of Op(conv) should be divided " "The number of output's channels (filter's first dimension) of "
"by groups. But received: the output channels is [%d], the shape " "Op(Conv) should be divided by groups. But received: "
"of filter is [%s] (the first dimension of filter is output " "the output channels is %d, the filter's shape is [%s], "
"channel), the groups is [%d].", "the groups is %d.",
filter_dims[0], filter_dims, groups)); filter_dims[0], filter_dims, groups));
framework::DDim in_data_dims; framework::DDim in_data_dims;
framework::DDim filter_data_dims;
if (channel_last) { if (channel_last) {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
} else { } else {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} }
filter_data_dims = framework::slice_ddim(filter_dims, 2, filter_dims.size()); framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims); std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
...@@ -133,8 +133,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -133,8 +133,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
output_shape.push_back(filter_dims[0]); output_shape.push_back(filter_dims[0]);
} }
ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); return output_shape;
ctx->ShareLoD("Input", "Output");
} }
framework::OpKernelType ConvOp::GetExpectedKernelType( framework::OpKernelType ConvOp::GetExpectedKernelType(
......
...@@ -41,10 +41,13 @@ inline int ConvOutputSize(int input_size, int filter_size, int dilation, ...@@ -41,10 +41,13 @@ inline int ConvOutputSize(int input_size, int filter_size, int dilation,
int output_size = (input_size + 2 * padding - dkernel) / stride + 1; int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
output_size, 0, output_size, 0,
"Due to the settings of padding(%d), filter_size(%d), dilation(%d) and " platform::errors::InvalidArgument(
"stride(%d), the output size is less than 0, please check " "The output's size is expected to be greater than 0. "
"again. Input_size:%d", "But recieved: output's size is %d. The output's size is computed by "
padding, filter_size, dilation, stride, input_size); "((input_size + 2 * padding - (dilation * (filter_size - 1) + 1)) / "
"stride + 1), where input_size is %d, padding is %d, "
"filter_size is %d, dilation is %d, stride is %d.",
output_size, input_size, padding, filter_size, dilation, stride));
return output_size; return output_size;
} }
...@@ -53,13 +56,16 @@ inline int ConvOutputSize(int input_size, int filter_size, int dilation, ...@@ -53,13 +56,16 @@ inline int ConvOutputSize(int input_size, int filter_size, int dilation,
int padding_1, int padding_2, int stride) { int padding_1, int padding_2, int stride) {
const int dkernel = dilation * (filter_size - 1) + 1; const int dkernel = dilation * (filter_size - 1) + 1;
int output_size = (input_size + padding_1 + padding_2 - dkernel) / stride + 1; int output_size = (input_size + padding_1 + padding_2 - dkernel) / stride + 1;
PADDLE_ENFORCE_GT(output_size, 0, PADDLE_ENFORCE_GT(
"Due to the settings of padding(%d, %d), filter_size(%d), " output_size, 0,
"dilation(%d) and " platform::errors::InvalidArgument(
"stride(%d), the output size is less than 0, please check " "The output's size is expected to be greater than 0. "
"again. Input_size:%d", "But recieved: output's size is %d. The output's size is computed by "
padding_1, padding_2, filter_size, dilation, stride, "((input_size + padding_1 + padding_2 - (dilation * (filter_size - "
input_size); "1) + 1)) / stride + 1), where input_size is %d, padding is "
"(%d, %d), filter_size is %d, dilation is %d, stride is %d.",
output_size, input_size, padding_1, padding_2, filter_size, dilation,
stride));
return output_size; return output_size;
} }
...@@ -81,7 +87,13 @@ inline void UpdatePaddingAndDilation(std::vector<T>* paddings, ...@@ -81,7 +87,13 @@ inline void UpdatePaddingAndDilation(std::vector<T>* paddings,
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
data_dims.size() * 2, paddings->size(), data_dims.size() * 2, paddings->size(),
"Paddings size should be the same or twice as the input data size."); platform::errors::InvalidArgument(
"Attribute padding's size should be the same or twice as the "
"input's dimension. "
"But recieved: padding's size is %d, padding is [%s]; input's "
"dimension is %d, input's shape is [%s].",
paddings->size(), framework::make_ddim(*paddings), data_dims.size(),
data_dims));
} }
// when padding_algorithm is "VALID" or "SAME" // when padding_algorithm is "VALID" or "SAME"
...@@ -252,9 +264,18 @@ class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { ...@@ -252,9 +264,18 @@ class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
class ConvOp : public framework::OperatorWithKernel { class ConvOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override {
std::vector<int64_t> output_shape = ComputeOutputShape(ctx);
OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "Conv");
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
ctx->ShareLoD("Input", "Output");
}
protected: protected:
std::vector<int64_t> ComputeOutputShape(
framework::InferShapeContext* ctx) const;
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
......
...@@ -30,7 +30,7 @@ class FillConstantOp : public framework::OperatorWithKernel { ...@@ -30,7 +30,7 @@ class FillConstantOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
shape[i], 0, shape[i], 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Each value of attribute 'shape' is expected to be greater " "Each value of attribute 'shape' is expected to be no less "
"than 0. But recieved: shape[%u] = %d; shape = [%s].", "than 0. But recieved: shape[%u] = %d; shape = [%s].",
i, shape[i], framework::make_ddim(shape))); i, shape[i], framework::make_ddim(shape)));
} }
......
...@@ -65,108 +65,62 @@ class Conv2DFusionOp : public operators::ConvOp { ...@@ -65,108 +65,62 @@ class Conv2DFusionOp : public operators::ConvOp {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv2DFusion");
"Input(Input) of ConvOp should not be null."); OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "Conv2DFusion");
PADDLE_ENFORCE_EQ(ctx->HasInput("Filter"), true,
"Input(Filter) of ConvOp should not be null.");
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> dilations =
ctx->Attrs().Get<std::vector<int>>("dilations");
std::string padding_algorithm =
ctx->Attrs().Get<std::string>("padding_algorithm");
int groups = ctx->Attrs().Get<int>("groups");
framework::DDim in_data_dims;
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
PADDLE_ENFORCE_EQ(
in_dims.size() == 4 || in_dims.size() == 5, true,
"ShapeError: Conv_fusion input should be 4-D or 5-D tensor. But "
"received: %u-D Tensor,"
"the shape of Conv_fusion input is [%s]",
in_dims.size(), in_dims);
PADDLE_ENFORCE_EQ(in_dims.size(), filter_dims.size(),
"ShapeError: Conv_fusion input dimension and filter "
"dimension should be the "
"equal."
"But received: the shape of Conv_fusion input is [%s], "
"input dimension of Conv_fusion "
"input is [%d],"
"the shape of filter is [%s], the filter dimension of "
"Conv_fusion is [%d]",
in_dims, in_dims.size(), filter_dims, filter_dims.size());
int in_sub_stride_size = in_dims.size() - strides.size();
PADDLE_ENFORCE_EQ(
in_dims.size() - strides.size() == 2U, true,
"ShapeError: the dimension of input minus the dimension of "
"stride must be euqal to 2."
"But received: the dimension of input minus the dimension "
"of stride is [%d], the"
"input dimension of Conv_fusion is [%d], the shape of Conv_fusion "
"input "
"is [%s], the stride"
"dimension of Conv_fusion is [%d]",
in_sub_stride_size, in_dims.size(), in_dims, strides.size());
const auto input_channels = in_dims[1];
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_channels, filter_dims[1] * groups, in_dims.size(), 4U,
"ShapeError: The number of input channels should be equal to filter " platform::errors::InvalidArgument(
"channels * groups. But received: the input channels is [%d], the shape" "The input's dimension of Operator(Conv2DFusion) is expected "
"of input is [%s], the filter channel is [%d], the shape of filter is " "to be 4. But received: input's dimension = %u, shape = [%s].",
"[%s]," in_dims.size(), in_dims));
"the groups is [%d]",
in_dims[1], in_dims, filter_dims[1], filter_dims, groups); // In some case, attribute data_format is "AnyLayout".
PADDLE_ENFORCE_EQ( std::string data_format = ctx->Attrs().Get<std::string>("data_format");
filter_dims[0] % groups, 0, PADDLE_ENFORCE_NE(
"ShapeError: The number of output channels should be divided by groups." data_format, "NHWC",
"But received: the output channels is [%d], the shape of filter is [%s]" platform::errors::PermissionDenied(
"(the first dimension of filter is output channel), the groups is [%d]", "Operator(Conv2DFusion) only supports data format of "
filter_dims[0], filter_dims, groups); "channel first (NCHW) now. But recieved: data_format = '%s'.",
data_format));
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size()); std::vector<int64_t> output_shape = ComputeOutputShape(ctx);
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
std::vector<int64_t> output_shape({in_dims[0]});
output_shape.push_back(filter_dims[0]);
for (int i = 0; i < in_data_dims.size(); ++i) {
if ((!ctx->IsRuntime()) &&
(in_data_dims[i] <= 0 || filter_dims[i + 2] <= 0)) {
output_shape.push_back(-1);
} else {
output_shape.push_back(
ConvOutputSize(in_data_dims[i], filter_dims[i + 2], dilations[i],
paddings[2 * i], paddings[2 * i + 1], strides[i]));
}
}
PADDLE_ENFORCE_EQ(ctx->HasOutput("Output"), true,
"Output(Output) of ConvOp should not be null.");
ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
ctx->ShareLoD("Input", "Output");
std::vector<int> channels = std::vector<int> split_channels =
ctx->Attrs().Get<std::vector<int>>("split_channels"); ctx->Attrs().Get<std::vector<int>>("split_channels");
if (channels.size()) { if (split_channels.size()) {
PADDLE_ENFORCE_EQ(ctx->HasOutputs("Outputs"), true, OP_INOUT_CHECK(ctx->HasOutputs("Outputs"), "Output", "Outputs",
"Output(Outputs) of ConvOp should not be null."); "Conv2DFusion");
std::vector<framework::DDim> oshapes; PADDLE_ENFORCE_EQ(
oshapes.reserve(channels.size()); ctx->Outputs("Outputs").size(), split_channels.size(),
for (size_t i = 0; i < channels.size(); ++i) { platform::errors::InvalidArgument(
oshapes.push_back( "The number of Output(Outputs) of operator 'Conv2DFusion' is "
{output_shape[0], channels[i], output_shape[2], output_shape[3]}); "expected to be equal to the length of Attr(split_channels). But "
"reiceved: the number of Output(Outputs) = %u; the length of "
"Attr(split_channels) = %u, the content = [%s].",
ctx->Outputs("Outputs").size(), split_channels.size(),
framework::make_ddim(split_channels)));
int split_channels_sum = 0;
std::vector<framework::DDim> output_shapes(split_channels.size());
for (size_t i = 0; i < split_channels.size(); ++i) {
split_channels_sum += split_channels[i];
output_shapes[i] =
framework::make_ddim({output_shape[0], split_channels[i],
output_shape[2], output_shape[3]});
} }
ctx->SetOutputsDim("Outputs", oshapes); PADDLE_ENFORCE_EQ(
split_channels_sum, output_shape[1],
platform::errors::InvalidArgument(
"The sum of Attr(split_channels) is expected to be equal to the "
"total output channels. But recieved: the sum of "
"Attr(split_channels) = %d, the total output channels = %d.",
split_channels_sum, output_shape[1]));
ctx->SetOutputsDim("Outputs", output_shapes);
} }
} }
}; };
......
...@@ -46,7 +46,6 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -46,7 +46,6 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("Input"); auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter"); auto* filter = ctx.Input<Tensor>("Filter");
auto* bias = ctx.Input<Tensor>("Bias"); auto* bias = ctx.Input<Tensor>("Bias");
PADDLE_ENFORCE_NOT_NULL(bias, "The bias should not be null.");
auto* residual = ctx.Input<Tensor>("ResidualData"); auto* residual = ctx.Input<Tensor>("ResidualData");
auto* output = ctx.Output<Tensor>("Output"); auto* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
...@@ -61,28 +60,25 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -61,28 +60,25 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
bool exhaustive_search = bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search"); FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
// const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>(); const T* filter_data = filter->data<T>();
const T* bias_data = bias->data<T>(); const T* bias_data = bias->data<T>();
// T* output_data = output->mutable_data<T>(ctx.GetPlace());
const std::string padding_algorithm = const std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm"); ctx.Attr<std::string>("padding_algorithm");
const std::string data_format = ctx.Attr<std::string>("data_format");
Tensor transformed_input_channel(input->type()); Tensor transformed_input_channel(input->type());
Tensor transformed_output(output->type()); Tensor transformed_output(output->type());
T* output_data = nullptr;
transformed_input_channel = *input; transformed_input_channel = *input;
transformed_output = *output; transformed_output = *output;
output_data = transformed_output.data<T>(); T* output_data = transformed_output.data<T>();
const T* residual_data = residual ? residual->data<T>() : output_data; const T* residual_data = residual ? residual->data<T>() : output_data;
// update padding and dilation // update padding and dilation
auto in_dims = transformed_input_channel.dims(); auto in_dims = transformed_input_channel.dims();
auto filter_dims = filter->dims(); auto filter_dims = filter->dims();
framework::DDim in_data_dims; framework::DDim in_data_dims =
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); framework::slice_ddim(in_dims, 2, in_dims.size());
framework::DDim filter_data_dims = framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size()); framework::slice_ddim(filter_dims, 2, filter_dims.size());
...@@ -134,7 +130,10 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -134,7 +130,10 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
&transformed_input); &transformed_input);
} break; } break;
default: default:
PADDLE_THROW("ConvOp only support tensors with 4 or 5 dimensions."); PADDLE_THROW(platform::errors::PermissionDenied(
"Operator Conv2DFusion expects Input to be a 4-D or 5-D Tensor. "
"But recieved the actual dimension = %d, shape = [%s].",
rank, transformed_input_channel.dims()));
} }
} else { } else {
...@@ -168,7 +167,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -168,7 +167,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
conv_desc.descriptor<T>(padding_common, strides, dilations); conv_desc.descriptor<T>(padding_common, strides, dilations);
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionGroupCount(cudnn_conv_desc, platform::dynload::cudnnSetConvolutionGroupCount(cudnn_conv_desc,
groups)); groups),
platform::errors::External(
"Call of cudnnSetConvolutionGroupCount(cudnn_conv_desc, groups) "
"failed, where cudnn_conv_desc is configured: padding = [%s], "
"strides = [%s], dilations = [%s]; groups = %d",
framework::make_ddim(padding_common), framework::make_ddim(strides),
framework::make_ddim(dilations), groups));
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>( cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize<int>(transformed_input.dims())); layout, framework::vectorize<int>(transformed_input.dims()));
...@@ -199,8 +204,15 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -199,8 +204,15 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( PADDLE_ENFORCE_CUDA_SUCCESS(
cudnn_conv_desc, CUDNN_DEFAULT_MATH)); platform::dynload::cudnnSetConvolutionMathType(cudnn_conv_desc,
CUDNN_DEFAULT_MATH),
platform::errors::External(
"Call of cudnnSetConvolutionMathType(cudnn_conv_desc, "
"CUDNN_DEFAULT_MATH) failed, where cudnn_conv_desc is configured: "
"padding = %d, strides = %d, dilations = %d.",
framework::make_ddim(padding_common), framework::make_ddim(strides),
framework::make_ddim(dilations)));
auto x_dims = framework::vectorize(transformed_input.dims()); auto x_dims = framework::vectorize(transformed_input.dims());
auto f_dims = framework::vectorize(filter->dims()); auto f_dims = framework::vectorize(filter->dims());
...@@ -209,7 +221,9 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -209,7 +221,9 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
platform::dynload::cudnnGetConvolutionForwardAlgorithm( platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo)); workspace_size_limit, &algo),
platform::errors::External(
"Call of cudnnGetConvolutionForwardAlgorithm failed."));
VLOG(3) << "cuDNN forward algo " << algo; VLOG(3) << "cuDNN forward algo " << algo;
} else { } else {
std::function<cudnnConvolutionFwdAlgo_t()> search_func = std::function<cudnnConvolutionFwdAlgo_t()> search_func =
...@@ -223,7 +237,9 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -223,7 +237,9 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
handle, cudnn_input_desc, input_data, cudnn_filter_desc, handle, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, cudnn_output_desc, output_data, filter_data, cudnn_conv_desc, cudnn_output_desc, output_data,
kNUM_CUDNN_FWD_ALGS, &returned_algo_count, kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
fwd_perf_stat.data(), cudnn_workspace, workspace_size_limit)); fwd_perf_stat.data(), cudnn_workspace, workspace_size_limit),
platform::errors::External(
"Call of cudnnFindConvolutionForwardAlgorithmEx failed."));
}; };
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit); workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit);
VLOG(3) << "Perf result: (algo: stat, time, memory)"; VLOG(3) << "Perf result: (algo: stat, time, memory)";
...@@ -257,9 +273,16 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -257,9 +273,16 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes)); cudnn_output_desc, algo, &workspace_size_in_bytes),
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit, platform::errors::External(
"workspace_size to be allocated exceeds the limit"); "Call of cudnnGetConvolutionForwardWorkspaceSize failed."));
PADDLE_ENFORCE_LE(
workspace_size_in_bytes, workspace_size_limit,
platform::errors::InvalidArgument(
"The actual workspace size to be allocated for cuDNN is expected "
"to be less than the limit. But recieved: the actual workspace "
"size = %d, limit = %d.",
workspace_size_in_bytes, workspace_size_limit));
if ((activation == "identity") && (!residual)) { if ((activation == "identity") && (!residual)) {
// Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is // Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is
...@@ -269,15 +292,20 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -269,15 +292,20 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
// ------------- cudnn conv forward and bias add --------------------- // ------------- cudnn conv forward and bias add ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f; ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto cudnn_func = [&](void* cudnn_workspace) { auto cudnn_func = [&](void* cudnn_workspace) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnConvolutionForward( PADDLE_ENFORCE_CUDA_SUCCESS(
handle, &alpha, cudnn_input_desc, input_data, cudnn_filter_desc, platform::dynload::cudnnConvolutionForward(
filter_data, cudnn_conv_desc, algo, cudnn_workspace, handle, &alpha, cudnn_input_desc, input_data, cudnn_filter_desc,
workspace_size_in_bytes, &beta, cudnn_output_desc, output_data)); filter_data, cudnn_conv_desc, algo, cudnn_workspace,
workspace_size_in_bytes, &beta, cudnn_output_desc, output_data),
platform::errors::External(
"Call of cudnnConvolutionForward failed."));
}; };
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnAddTensor( PADDLE_ENFORCE_CUDA_SUCCESS(
handle, &alpha, cudnn_bias_desc, bias_data, &alpha, cudnn_output_desc, platform::dynload::cudnnAddTensor(handle, &alpha, cudnn_bias_desc,
output_data)); bias_data, &alpha,
cudnn_output_desc, output_data),
platform::errors::External("Call of cudnnAddTensor failed."));
} else { } else {
if (activation == "identity") { if (activation == "identity") {
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
...@@ -292,7 +320,9 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -292,7 +320,9 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
cudnn_filter_desc, filter_data, cudnn_conv_desc, algo, cudnn_filter_desc, filter_data, cudnn_conv_desc, algo,
cudnn_workspace, workspace_size_in_bytes, &alpha2, cudnn_workspace, workspace_size_in_bytes, &alpha2,
cudnn_output_desc, residual_data, cudnn_bias_desc, bias_data, cudnn_output_desc, residual_data, cudnn_bias_desc, bias_data,
cudnn_act_desc, cudnn_output_desc, output_data)); cudnn_act_desc, cudnn_output_desc, output_data),
platform::errors::External(
"Call of cudnnConvolutionBiasActivationForward failed."));
}; };
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
...@@ -314,7 +344,10 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -314,7 +344,10 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
} }
} else { } else {
// TODO(qingiqng): do copy when batch size large than 1 // TODO(qingiqng): do copy when batch size large than 1
PADDLE_THROW("Batch size greater than 1 is Unsupported"); PADDLE_THROW(platform::errors::Unimplemented(
"Input with batch size greater than 1 is unsupported. The recieved "
"batch size is %d, Input's shape is [%s].",
x_dims[0], framework::make_ddim(x_dims)));
} }
} }
} }
......
...@@ -24,14 +24,14 @@ from test_conv2d_op import conv2d_forward_naive ...@@ -24,14 +24,14 @@ from test_conv2d_op import conv2d_forward_naive
def create_test_padding_SAME_class(parent): def create_test_padding_SAME_class(parent):
class TestPaddingSMAECase(parent): class TestPaddingSAMECase(parent):
def init_paddings(self): def init_paddings(self):
self.pad = [0, 0] self.pad = [0, 0]
self.padding_algorithm = "SAME" self.padding_algorithm = "SAME"
cls_name = "{0}_{1}".format(parent.__name__, "PaddingSAMEOp") cls_name = "{0}_{1}".format(parent.__name__, "PaddingSAMEOp")
TestPaddingSMAECase.__name__ = cls_name TestPaddingSAMECase.__name__ = cls_name
globals()[cls_name] = TestPaddingSMAECase globals()[cls_name] = TestPaddingSAMECase
def create_test_padding_VALID_class(parent): def create_test_padding_VALID_class(parent):
...@@ -52,16 +52,15 @@ class TestConv2dFusionOp(OpTest): ...@@ -52,16 +52,15 @@ class TestConv2dFusionOp(OpTest):
self.data_format = "NCHW" self.data_format = "NCHW"
self.dtype = np.float32 self.dtype = np.float32
self.activation = 'relu' self.activation = 'relu'
self.add_bias = True
self.add_residual_data = True self.add_residual_data = True
self.channels = None self.split_channels = None
self.outputs = None self.outputs = None
self.padding_algorithm = "EXIPLICIT" self.padding_algorithm = "EXIPLICIT"
self.init_group() self.init_group()
self.init_dilation() self.init_dilation()
self.init_test_case() self.init_test_case()
self.init_bias_residual() self.init_residual()
self.init_activation() self.init_activation()
self.init_paddings() self.init_paddings()
self.set_search_method() self.set_search_method()
...@@ -74,6 +73,7 @@ class TestConv2dFusionOp(OpTest): ...@@ -74,6 +73,7 @@ class TestConv2dFusionOp(OpTest):
input = np.random.random(self.input_size).astype(self.dtype) input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype) filter = np.random.random(self.filter_size).astype(self.dtype)
bias = np.random.random(self.filter_size[0]).astype(self.dtype)
self.output, _, _, _, _ = conv2d_forward_naive( self.output, _, _, _, _ = conv2d_forward_naive(
input, filter, self.groups, conv2d_param, self.padding_algorithm, input, filter, self.groups, conv2d_param, self.padding_algorithm,
...@@ -83,7 +83,8 @@ class TestConv2dFusionOp(OpTest): ...@@ -83,7 +83,8 @@ class TestConv2dFusionOp(OpTest):
self.inputs = { self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input), 'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter) 'Filter': OpTest.np_dtype_to_fluid_dtype(filter),
'Bias': OpTest.np_dtype_to_fluid_dtype(bias)
} }
if self.add_residual_data: if self.add_residual_data:
...@@ -93,10 +94,8 @@ class TestConv2dFusionOp(OpTest): ...@@ -93,10 +94,8 @@ class TestConv2dFusionOp(OpTest):
residual_data) residual_data)
self.output += residual_data self.output += residual_data
if self.add_bias: # Add bias
bias = np.random.random(self.filter_size[0]).astype(self.dtype) self.output = self.output + bias.reshape((1, bias.size, 1, 1))
self.inputs['Bias'] = OpTest.np_dtype_to_fluid_dtype(bias)
self.output = self.output + bias.reshape((1, bias.size, 1, 1))
assert self.activation in ['relu', 'identity'] assert self.activation in ['relu', 'identity']
if self.activation == 'relu': if self.activation == 'relu':
...@@ -110,9 +109,11 @@ class TestConv2dFusionOp(OpTest): ...@@ -110,9 +109,11 @@ class TestConv2dFusionOp(OpTest):
'data_format': self.data_format, 'data_format': self.data_format,
'exhaustive_search': self.exhaustive_search, 'exhaustive_search': self.exhaustive_search,
'activation': self.activation, 'activation': self.activation,
'split_channels': self.channels,
'padding_algorithm': self.padding_algorithm 'padding_algorithm': self.padding_algorithm
} }
if self.split_channels is not None:
self.attrs['split_channels'] = self.split_channels
self.outputs = {'Output': self.output} self.outputs = {'Output': self.output}
self.set_outputs() self.set_outputs()
...@@ -124,8 +125,6 @@ class TestConv2dFusionOp(OpTest): ...@@ -124,8 +125,6 @@ class TestConv2dFusionOp(OpTest):
if self.has_cuda(): if self.has_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5) self.check_output_with_place(place, atol=1e-5)
else:
pass
def init_test_case(self): def init_test_case(self):
self.pad = [0, 0] self.pad = [0, 0]
...@@ -141,8 +140,7 @@ class TestConv2dFusionOp(OpTest): ...@@ -141,8 +140,7 @@ class TestConv2dFusionOp(OpTest):
def init_group(self): def init_group(self):
self.groups = 1 self.groups = 1
def init_bias_residual(self): def init_residual(self):
self.add_bias = True
self.add_residual_data = True self.add_residual_data = True
def init_activation(self): def init_activation(self):
...@@ -160,7 +158,7 @@ class TestConv2dFusionOp(OpTest): ...@@ -160,7 +158,7 @@ class TestConv2dFusionOp(OpTest):
class TestWithoutResidual(TestConv2dFusionOp): class TestWithoutResidual(TestConv2dFusionOp):
def init_bias_residual(self): def init_residual(self):
self.add_residual_data = False self.add_residual_data = False
...@@ -209,7 +207,7 @@ class TestMultipleOutputs(TestConv2dFusionOp): ...@@ -209,7 +207,7 @@ class TestMultipleOutputs(TestConv2dFusionOp):
assert np.mod(self.input_size[1], self.groups) == 0 assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups f_c = self.input_size[1] // self.groups
self.filter_size = [126, f_c, 3, 3] self.filter_size = [126, f_c, 3, 3]
self.channels = [84, 42] self.split_channels = [84, 42]
def set_outputs(self): def set_outputs(self):
out1 = self.output[:, 0:84, :, :] out1 = self.output[:, 0:84, :, :]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册