提交 2604aced 编写于 作者: W wangnan39@huawei.com

extend conv stride and dilation to 2d

上级 268d358a
...@@ -35,8 +35,22 @@ void Conv2dCPUKernel::InitKernel(const CNodePtr &kernel_node) { ...@@ -35,8 +35,22 @@ void Conv2dCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape);
int kernel_size = SizeToInt(weight_shape[3]); int kernel_size = SizeToInt(weight_shape[3]);
int stride = AnfAlgo::GetNodeAttr<int>(kernel_node, STRIDE); auto stride_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, STRIDE);
int dilation = AnfAlgo::GetNodeAttr<int>(kernel_node, DILATION); auto dilation_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, DILATION);
if (stride_ori.size() != 4 || stride_ori[2] != stride_ori[3]) {
MS_LOG(EXCEPTION) << "conv2d only support equal stride, and stride must be 4d!";
}
if (stride_ori[0] != 1 || stride_ori[1] != 1) {
MS_LOG(EXCEPTION) << "conv2d stride only support 1 in N axis and C axis!";
}
if (dilation_ori.size() != 4 || dilation_ori[2] != 1 || dilation_ori[3] != 1) {
MS_LOG(EXCEPTION) << "conv2d dilation only support 1, and dilation must be 4d!";
}
if (dilation_ori[0] != 1 || dilation_ori[1] != 1) {
MS_LOG(EXCEPTION) << "conv2d dilation only support 1 in N axis and C axis!";
}
int stride = stride_ori[2];
int dilation = dilation_ori[2];
dnnl::memory::dims strides{stride, stride}; dnnl::memory::dims strides{stride, stride};
dnnl::memory::dims dilates{dilation - 1, dilation - 1}; dnnl::memory::dims dilates{dilation - 1, dilation - 1};
......
...@@ -35,8 +35,19 @@ void Conv2dGradFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) { ...@@ -35,8 +35,19 @@ void Conv2dGradFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape);
int kernel_size = SizeToInt(weight_shape[3]); int kernel_size = SizeToInt(weight_shape[3]);
int stride = AnfAlgo::GetNodeAttr<int>(kernel_node, STRIDE); auto stride_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, STRIDE);
int dilation = AnfAlgo::GetNodeAttr<int>(kernel_node, DILATION); auto dilation_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, DILATION);
if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) {
MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel only support equal stride, and stride must be 2d!";
}
if (dilation_ori.size() != 4 || dilation_ori[2] != 1 || dilation_ori[3] != 1) {
MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel dilation only support 1, and dilation must be 4d!";
}
if (dilation_ori[0] != 1 || dilation_ori[1] != 1) {
MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel dilation only support 1 in N axis and C axis!";
}
int stride = stride_ori[0];
int dilation = dilation_ori[2];
dnnl::memory::dims strides{stride, stride}; dnnl::memory::dims strides{stride, stride};
dnnl::memory::dims dilates{dilation - 1, dilation - 1}; dnnl::memory::dims dilates{dilation - 1, dilation - 1};
......
...@@ -35,8 +35,19 @@ void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) { ...@@ -35,8 +35,19 @@ void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape);
int kernel_size = SizeToInt(weight_shape[3]); int kernel_size = SizeToInt(weight_shape[3]);
int stride = AnfAlgo::GetNodeAttr<int>(kernel_node, STRIDE); auto stride_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, STRIDE);
int dilation = AnfAlgo::GetNodeAttr<int>(kernel_node, DILATION); auto dilation_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, DILATION);
if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) {
MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel only support equal stride, and stride must be 2d!";
}
if (dilation_ori.size() != 4 || dilation_ori[2] != 1 || dilation_ori[3] != 1) {
MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel dilation only support 1, and dilation must be 4d!";
}
if (dilation_ori[0] != 1 || dilation_ori[1] != 1) {
MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel dilation only support 1 in N axis and C axis!";
}
int stride = stride_ori[0];
int dilation = dilation_ori[2];
dnnl::memory::dims strides{stride, stride}; dnnl::memory::dims strides{stride, stride};
dnnl::memory::dims dilates{dilation - 1, dilation - 1}; dnnl::memory::dims dilates{dilation - 1, dilation - 1};
std::vector<int> int_padding_l; std::vector<int> int_padding_l;
......
...@@ -113,9 +113,24 @@ class Conv2dGpuFwdKernel : public GpuKernel { ...@@ -113,9 +113,24 @@ class Conv2dGpuFwdKernel : public GpuKernel {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed");
pad_height_ = GetAttr<int>(kernel_node, "pad"); pad_height_ = GetAttr<int>(kernel_node, "pad");
pad_width_ = pad_height_; pad_width_ = pad_height_;
stride_ = GetAttr<int>(kernel_node, "stride");
dilation_ = GetAttr<int>(kernel_node, "dilation");
pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode"); pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode");
auto stride_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride");
auto dilation_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "dilation");
if (stride_ori.size() != 4 || stride_ori[2] != stride_ori[3]) {
MS_LOG(EXCEPTION) << "conv2d only support equal stride, and stride must be 4d!";
}
if (stride_ori[0] != 1 || stride_ori[1] != 1) {
MS_LOG(EXCEPTION) << "conv2d stride only support 1 in N axis and C axis!";
}
if (dilation_ori.size() != 4 || dilation_ori[2] != dilation_ori[3]) {
MS_LOG(EXCEPTION) << "conv2d only support equal dilation, and dilation must be 4d!";
}
if (dilation_ori[0] != 1 || dilation_ori[1] != 1) {
MS_LOG(EXCEPTION) << "conv2d dilation only support 1 in N axis and C axis!";
}
stride_ = stride_ori[2];
dilation_ = dilation_ori[2];
cudnnTensorDescriptor_t input_descriptor_real = nullptr; cudnnTensorDescriptor_t input_descriptor_real = nullptr;
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) {
SetPad(in_shape, kernel_node); SetPad(in_shape, kernel_node);
......
...@@ -116,9 +116,20 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { ...@@ -116,9 +116,20 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
pad_height_ = GetAttr<int>(kernel_node, "pad"); pad_height_ = GetAttr<int>(kernel_node, "pad");
pad_width_ = pad_height_; pad_width_ = pad_height_;
stride_ = GetAttr<int>(kernel_node, "stride");
dilation_ = GetAttr<int>(kernel_node, "dilation");
pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode"); pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode");
auto stride_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride");
auto dilation_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "dilation");
if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) {
MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel only support equal stride, and stride must be 2d!";
}
if (dilation_ori.size() != 4 || dilation_ori[2] != dilation_ori[3]) {
MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel only support equal dilation, and dilation must be 4d!";
}
if (dilation_ori[0] != 1 || dilation_ori[1] != 1) {
MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel dilation only support 1 in N axis and C axis!";
}
stride_ = stride_ori[0];
dilation_ = dilation_ori[2];
cudnnTensorDescriptor_t x_desc_real = nullptr; cudnnTensorDescriptor_t x_desc_real = nullptr;
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) {
SetPad(in_shape, kernel_node); SetPad(in_shape, kernel_node);
......
...@@ -117,9 +117,20 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { ...@@ -117,9 +117,20 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
pad_height_ = GetAttr<int>(kernel_node, "pad"); pad_height_ = GetAttr<int>(kernel_node, "pad");
pad_width_ = pad_height_; pad_width_ = pad_height_;
stride_ = GetAttr<int>(kernel_node, "stride");
dilation_ = GetAttr<int>(kernel_node, "dilation");
pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode"); pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode");
auto stride_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride");
auto dilation_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "dilation");
if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) {
MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel only support equal stride, and stride must be 2d!";
}
if (dilation_ori.size() != 4 || dilation_ori[2] != dilation_ori[3]) {
MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel only support equal dilation, and dilation must be 4d!";
}
if (dilation_ori[0] != 1 || dilation_ori[1] != 1) {
MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel dilation only support 1 in N axis and C axis!";
}
stride_ = stride_ori[0];
dilation_ = dilation_ori[2];
cudnnTensorDescriptor_t dx_desc_real = nullptr; cudnnTensorDescriptor_t dx_desc_real = nullptr;
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) {
SetPad(input_shape, kernel_node); SetPad(input_shape, kernel_node);
......
...@@ -148,9 +148,6 @@ void TbeAdapter::InputOrderPass(const std::string &op_name, std::vector<std::vec ...@@ -148,9 +148,6 @@ void TbeAdapter::InputOrderPass(const std::string &op_name, std::vector<std::vec
} }
std::map<std::string, FAttrsPass> TbeAdapter::build_json_attr_pass_map_ = { std::map<std::string, FAttrsPass> TbeAdapter::build_json_attr_pass_map_ = {
{"Conv2D", TbeAdapter::Conv2DAttrJsonPass},
{"Conv2DBackpropFilter", TbeAdapter::Conv2DBackpropFilterAttrJsonPass},
{"Conv2DBackpropInput", TbeAdapter::Conv2DBackpropInputAttrJsonPass},
{"MaximumGrad", TbeAdapter::MaximumGradAttrJsonPass}, {"MaximumGrad", TbeAdapter::MaximumGradAttrJsonPass},
{"MinimumGrad", TbeAdapter::MinimumGradAttrJsonPass}, {"MinimumGrad", TbeAdapter::MinimumGradAttrJsonPass},
{"Cast", TbeAdapter::CastAttrJsonPass}}; {"Cast", TbeAdapter::CastAttrJsonPass}};
...@@ -168,135 +165,6 @@ bool TbeAdapter::RunAttrPass(const mindspore::AnfNodePtr &anf_node, ...@@ -168,135 +165,6 @@ bool TbeAdapter::RunAttrPass(const mindspore::AnfNodePtr &anf_node,
return false; return false;
} }
void TbeAdapter::Conv2DAttrJsonPass(const mindspore::AnfNodePtr &anf_node,
const std::vector<std::shared_ptr<mindspore::kernel::OpAttr>> &op_info_attrs,
nlohmann::json *attrs_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(attrs_json);
auto attr_num = op_info_attrs.size();
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
for (size_t i = 0; i < attr_num; i++) {
nlohmann::json attr_obj;
MS_EXCEPTION_IF_NULL(op_info_attrs[i]);
std::string attr_name = op_info_attrs[i]->name();
std::vector<int> attr_value;
if (primitive->GetAttr(attr_name) != nullptr) {
auto value = primitive->GetAttr(attr_name);
int data = GetValue<int>(value);
size_t list_int_size = 0;
if (attr_name == "stride") {
list_int_size = 4;
} else if (attr_name == "dilation") {
list_int_size = 4;
} else if (attr_name == "pad") {
value = primitive->GetAttr("pad_list");
attr_value = GetValue<std::vector<int>>(value);
}
for (size_t j = 0; j < list_int_size; j++) {
attr_value.push_back(data);
}
attr_obj["value"] = attr_value;
} else {
attr_obj["value"] = 0;
}
attr_obj["name"] = attr_name;
attr_obj["valid"] = true;
(*attrs_json).push_back(attr_obj);
}
MS_LOG(INFO) << "Conv2DAttrPass done.";
}
void TbeAdapter::Conv2DBackpropFilterAttrJsonPass(
const mindspore::AnfNodePtr &anf_node, const std::vector<std::shared_ptr<mindspore::kernel::OpAttr>> &op_info_attrs,
nlohmann::json *attrs_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(attrs_json);
auto attr_num = op_info_attrs.size();
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
for (size_t i = 0; i < attr_num; i++) {
nlohmann::json attr_obj;
MS_EXCEPTION_IF_NULL(op_info_attrs[i]);
std::string attr_name = op_info_attrs[i]->name();
if (primitive->GetAttr(attr_name) != nullptr) {
auto value = primitive->GetAttr(attr_name);
if (attr_name == "pad_mode") {
std::string attr_value = GetValue<std::string>(value);
(void)transform(attr_value.begin(), attr_value.end(), attr_value.begin(), ::toupper);
attr_obj["value"] = attr_value;
} else if (attr_name == "filter_sizes") {
std::vector<int> attr_value = GetValue<std::vector<int>>(value);
attr_obj["value"] = attr_value;
} else {
std::vector<int> attr_value;
int data = GetValue<int>(value);
size_t list_int_size = 0;
if (attr_name == "stride") {
list_int_size = 2;
} else if (attr_name == "dilation") {
list_int_size = 4;
}
for (size_t j = 0; j < list_int_size; j++) {
attr_value.push_back(data);
}
attr_obj["value"] = attr_value;
}
attr_obj["valid"] = true;
} else {
attr_obj["valid"] = false;
}
attr_obj["name"] = attr_name;
attrs_json->push_back(attr_obj);
}
MS_LOG(INFO) << "Conv2DBackpropFilterAttrJsonPass done.";
}
void TbeAdapter::Conv2DBackpropInputAttrJsonPass(
const mindspore::AnfNodePtr &anf_node, const std::vector<std::shared_ptr<mindspore::kernel::OpAttr>> &op_info_attrs,
nlohmann::json *attrs_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(attrs_json);
auto attr_num = op_info_attrs.size();
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
for (size_t i = 0; i < attr_num; i++) {
nlohmann::json attr_obj;
MS_EXCEPTION_IF_NULL(op_info_attrs[i]);
std::string attr_name = op_info_attrs[i]->name();
if (primitive->GetAttr(attr_name) != nullptr) {
auto value = primitive->GetAttr(attr_name);
if (attr_name == "pad_mode") {
std::string attr_value = GetValue<std::string>(value);
(void)transform(attr_value.begin(), attr_value.end(), attr_value.begin(), ::toupper);
attr_obj["value"] = attr_value;
} else if (attr_name == "input_sizes") {
std::vector<int> attr_value = GetValue<std::vector<int>>(value);
attr_obj["value"] = attr_value;
} else {
std::vector<int> attr_value;
int data = GetValue<int>(value);
size_t list_int_size = 0;
if (attr_name == "stride") {
list_int_size = 2;
} else if (attr_name == "dilation") {
list_int_size = 4;
}
for (size_t j = 0; j < list_int_size; j++) {
attr_value.push_back(data);
}
attr_obj["value"] = attr_value;
}
attr_obj["valid"] = true;
} else {
attr_obj["valid"] = false;
}
attr_obj["name"] = attr_name;
attrs_json->push_back(attr_obj);
}
MS_LOG(INFO) << "Conv2DBackpropInputAttrJsonPass done.";
}
void TbeAdapter::MaximumGradAttrJsonPass(const mindspore::AnfNodePtr &anf_node, void TbeAdapter::MaximumGradAttrJsonPass(const mindspore::AnfNodePtr &anf_node,
const std::vector<std::shared_ptr<mindspore::kernel::OpAttr>> &op_info_attrs, const std::vector<std::shared_ptr<mindspore::kernel::OpAttr>> &op_info_attrs,
nlohmann::json *attrs_json) { nlohmann::json *attrs_json) {
......
...@@ -179,7 +179,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(Squeeze, Squeeze, ...@@ -179,7 +179,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(Squeeze, Squeeze,
OPERATOR_ONNX_CONVERT_DEFINE( OPERATOR_ONNX_CONVERT_DEFINE(
Conv2D, Conv, Conv2D, Conv,
OpNameInfo() OpNameInfo()
.Attr("dilation", "dilations", onnx::AttributeProto_AttributeType_INTS, SetAttrValueToProto<Int32Imm, 2>) .Attr("dilation", "dilations", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)
.Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int32Imm>) .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int32Imm>)
.Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto) .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto)
.Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING,
...@@ -197,8 +197,7 @@ OPERATOR_ONNX_CONVERT_DEFINE( ...@@ -197,8 +197,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(
prim); prim);
} }
}) })
.Attr("stride", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrValueToProto<Int32Imm, 2>)) .Attr("stride", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
OPERATOR_ONNX_CONVERT_DEFINE(BiasAdd, Add, OpNameInfo()) OPERATOR_ONNX_CONVERT_DEFINE(BiasAdd, Add, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(MatMul, Gemm, OPERATOR_ONNX_CONVERT_DEFINE(MatMul, Gemm,
OpNameInfo() OpNameInfo()
......
...@@ -754,9 +754,9 @@ OUTPUT_MAP(MaxPoolGradWithArgmax) = {{0, OUTPUT_DESC(y)}}; ...@@ -754,9 +754,9 @@ OUTPUT_MAP(MaxPoolGradWithArgmax) = {{0, OUTPUT_DESC(y)}};
// Conv2D // Conv2D
INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}};
ATTR_MAP(Conv2D) = { ATTR_MAP(Conv2D) = {
{"stride", ATTR_DESC(strides, "pad", AnyTraits<std::vector<int64_t>>())}, {"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, {"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"dilation", ATTR_DESC(dilations, "pad", AnyTraits<std::vector<int64_t>>())}, {"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
}; };
OUTPUT_MAP(Conv2D) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Conv2D) = {{0, OUTPUT_DESC(y)}};
...@@ -766,8 +766,8 @@ INPUT_ATTR_MAP(Conv2DBackpropInputD) = { ...@@ -766,8 +766,8 @@ INPUT_ATTR_MAP(Conv2DBackpropInputD) = {
{3, ATTR_DESC(input_sizes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; {3, ATTR_DESC(input_sizes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(Conv2DBackpropInputD) = { ATTR_MAP(Conv2DBackpropInputD) = {
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, {"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"stride", ATTR_DESC(strides, "strides", AnyTraits<std::vector<int64_t>>())}, {"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"dilation", ATTR_DESC(dilations, "pad", AnyTraits<std::vector<int64_t>>())}, {"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
}; };
OUTPUT_MAP(Conv2DBackpropInputD) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Conv2DBackpropInputD) = {{0, OUTPUT_DESC(y)}};
...@@ -777,17 +777,17 @@ INPUT_ATTR_MAP(Conv2DBackpropFilterD) = { ...@@ -777,17 +777,17 @@ INPUT_ATTR_MAP(Conv2DBackpropFilterD) = {
{3, ATTR_DESC(filter_sizes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; {3, ATTR_DESC(filter_sizes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(Conv2DBackpropFilterD) = { ATTR_MAP(Conv2DBackpropFilterD) = {
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, {"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"stride", ATTR_DESC(strides, "strides", AnyTraits<std::vector<int64_t>>())}, {"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"dilation", ATTR_DESC(dilations, "pad", AnyTraits<std::vector<int64_t>>())}, {"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
}; };
OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}};
// DepthwiseConv2D // DepthwiseConv2D
INPUT_MAP(DepthwiseConv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; INPUT_MAP(DepthwiseConv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}};
ATTR_MAP(DepthwiseConv2D) = { ATTR_MAP(DepthwiseConv2D) = {
{"stride", ATTR_DESC(strides, "pad", AnyTraits<std::vector<int64_t>>())}, {"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"pads", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, {"pads", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"dilation", ATTR_DESC(dilations, "pad", AnyTraits<std::vector<int64_t>>())}, {"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}, {"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())},
}; };
OUTPUT_MAP(DepthwiseConv2D) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(DepthwiseConv2D) = {{0, OUTPUT_DESC(y)}};
...@@ -797,9 +797,9 @@ INPUT_MAP(DepthwiseConv2DBackpropInputD) = {{2, INPUT_DESC(filter)}, {3, INPUT_D ...@@ -797,9 +797,9 @@ INPUT_MAP(DepthwiseConv2DBackpropInputD) = {{2, INPUT_DESC(filter)}, {3, INPUT_D
INPUT_ATTR_MAP(DepthwiseConv2DBackpropInputD) = { INPUT_ATTR_MAP(DepthwiseConv2DBackpropInputD) = {
{1, ATTR_DESC(input_size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; {1, ATTR_DESC(input_size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(DepthwiseConv2DBackpropInputD) = { ATTR_MAP(DepthwiseConv2DBackpropInputD) = {
{"stride", ATTR_DESC(strides, "pad", AnyTraits<std::vector<int64_t>>())}, {"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"pads", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, {"pads", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"dilation", ATTR_DESC(dilations, "pad", AnyTraits<std::vector<int64_t>>())}, {"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
}; };
OUTPUT_MAP(DepthwiseConv2DBackpropInputD) = {{0, OUTPUT_DESC(input_grad)}}; OUTPUT_MAP(DepthwiseConv2DBackpropInputD) = {{0, OUTPUT_DESC(input_grad)}};
...@@ -808,9 +808,9 @@ INPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{1, INPUT_DESC(input)}, {3, INPUT_D ...@@ -808,9 +808,9 @@ INPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{1, INPUT_DESC(input)}, {3, INPUT_D
INPUT_ATTR_MAP(DepthwiseConv2DBackpropFilterD) = { INPUT_ATTR_MAP(DepthwiseConv2DBackpropFilterD) = {
{2, ATTR_DESC(filter_size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; {2, ATTR_DESC(filter_size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(DepthwiseConv2DBackpropFilterD) = { ATTR_MAP(DepthwiseConv2DBackpropFilterD) = {
{"stride", ATTR_DESC(strides, "pad", AnyTraits<std::vector<int64_t>>())}, {"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"pads", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, {"pads", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"dilation", ATTR_DESC(dilations, "pad", AnyTraits<std::vector<int64_t>>())}, {"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
}; };
OUTPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{0, OUTPUT_DESC(filter_grad)}}; OUTPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{0, OUTPUT_DESC(filter_grad)}};
......
...@@ -17,7 +17,7 @@ from mindspore import log as logger ...@@ -17,7 +17,7 @@ from mindspore import log as logger
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative, check_int from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from ..cell import Cell from ..cell import Cell
...@@ -42,17 +42,23 @@ class _Conv(Cell): ...@@ -42,17 +42,23 @@ class _Conv(Cell):
self.in_channels = check_int_positive(in_channels) self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = check_int_positive(out_channels)
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.stride = check_int_positive(stride) self.stride = stride
self.pad_mode = pad_mode self.pad_mode = pad_mode
self.padding = check_int_non_negative(padding) self.padding = check_int_non_negative(padding)
self.dilation = check_int(dilation) self.dilation = dilation
self.group = check_int_positive(group) self.group = check_int_positive(group)
self.has_bias = has_bias self.has_bias = has_bias
if (not isinstance(kernel_size, tuple)) or len(kernel_size) != 2 or \ if (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \
(not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \
kernel_size[0] < 1 or kernel_size[1] < 1: kernel_size[0] < 1 or kernel_size[1] < 1:
raise ValueError("Attr 'kernel_size' of 'Conv2D' Op passed " raise ValueError("Attr 'kernel_size' of 'Conv2D' Op passed "
+ str(self.kernel_size) + ", should be a int or tuple and equal to or greater than 1.") + str(self.kernel_size) + ", should be a int or tuple and equal to or greater than 1.")
if (not isinstance(stride[0], int)) or (not isinstance(stride[1], int)) or stride[0] < 1 or stride[1] < 1:
raise ValueError("Attr 'stride' of 'Conv2D' Op passed "
+ str(self.stride) + ", should be a int or tuple and equal to or greater than 1.")
if (not isinstance(dilation[0], int)) or (not isinstance(dilation[1], int)) or \
dilation[0] < 1 or dilation[1] < 1:
raise ValueError("Attr 'dilation' of 'Conv2D' Op passed "
+ str(self.dilation) + ", should equal to or greater than 1.")
if in_channels % group != 0: if in_channels % group != 0:
raise ValueError("Attr 'in_channels' of 'Conv2D' Op must be divisible by " raise ValueError("Attr 'in_channels' of 'Conv2D' Op must be divisible by "
"attr 'group' of 'Conv2D' Op.") "attr 'group' of 'Conv2D' Op.")
...@@ -107,12 +113,13 @@ class Conv2d(_Conv): ...@@ -107,12 +113,13 @@ class Conv2d(_Conv):
Args: Args:
in_channels (int): The number of input channel :math:`C_{in}`. in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`. out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple]): The data type is int or tuple with 2 integers. Specifies the height kernel_size (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the height
and width of the 2D convolution window. Single int means the value if for both height and width of and width of the 2D convolution window. Single int means the value if for both height and width of
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
width of the kernel. width of the kernel.
stride (int): Specifies stride for all spatial dimensions with the same value. Value of stride should be stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
greater or equal to 1 but bounded by the height and width of the input. Default: 1. the height and width of movement are both strides, or a tuple of two int numbers that
represent height and width of movement respectively. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are pad_mode (str): Specifies padding mode. The optional values are
"same", "valid", "pad". Default: "same". "same", "valid", "pad". Default: "same".
...@@ -130,9 +137,11 @@ class Conv2d(_Conv): ...@@ -130,9 +137,11 @@ class Conv2d(_Conv):
Tensor borders. `padding` should be greater than or equal to 0. Tensor borders. `padding` should be greater than or equal to 0.
padding (int): Implicit paddings on both sides of the input. Default: 0. padding (int): Implicit paddings on both sides of the input. Default: 0.
dilation (int): Specifying the dilation rate to use for dilated convolution. If set to be :math:`k > 1`, dilation (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the dilation rate
there will be :math:`k - 1` pixels skipped for each sampling location. Its value should be greater to use for dilated convolution. If set to be :math:`k > 1`, there will
or equal to 1 and bounded by the height and width of the input. Default: 1. be :math:`k - 1` pixels skipped for each sampling location. Its value should
be greater or equal to 1 and bounded by the height and width of the
input. Default: 1.
group (int): Split filter into groups, `in_ channels` and `out_channels` should be group (int): Split filter into groups, `in_ channels` and `out_channels` should be
divisible by the number of groups. Default: 1. divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
...@@ -172,6 +181,8 @@ class Conv2d(_Conv): ...@@ -172,6 +181,8 @@ class Conv2d(_Conv):
weight_init='normal', weight_init='normal',
bias_init='zeros'): bias_init='zeros'):
kernel_size = twice(kernel_size) kernel_size = twice(kernel_size)
stride = twice(stride)
dilation = twice(dilation)
super(Conv2d, self).__init__( super(Conv2d, self).__init__(
in_channels, in_channels,
out_channels, out_channels,
...@@ -241,7 +252,9 @@ class Conv2dTranspose(_Conv): ...@@ -241,7 +252,9 @@ class Conv2dTranspose(_Conv):
and width of the 2D convolution window. Single int means the value is for both height and width of and width of the 2D convolution window. Single int means the value is for both height and width of
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
width of the kernel. width of the kernel.
stride (int): Specifies the same value for all spatial dimensions. Default: 1. stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
the height and width of movement are both strides, or a tuple of two int numbers that
represent height and width of movement respectively. Default: 1.
pad_mode (str): Select the mode of the pad. The optional values are pad_mode (str): Select the mode of the pad. The optional values are
"pad", "same", "valid". Default: "same". "pad", "same", "valid". Default: "same".
...@@ -251,8 +264,11 @@ class Conv2dTranspose(_Conv): ...@@ -251,8 +264,11 @@ class Conv2dTranspose(_Conv):
- valid: Adopted the way of discarding. - valid: Adopted the way of discarding.
padding (int): Implicit paddings on both sides of the input. Default: 0. padding (int): Implicit paddings on both sides of the input. Default: 0.
dilation (int): Specifies the dilation rate to use for dilated dilation (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the dilation rate
convolution. Default: 1. to use for dilated convolution. If set to be :math:`k > 1`, there will
be :math:`k - 1` pixels skipped for each sampling location. Its value should
be greater or equal to 1 and bounded by the height and width of the
input. Default: 1.
group (int): Split filter into groups, `in_channels` and `out_channels` should be group (int): Split filter into groups, `in_channels` and `out_channels` should be
divisible by the number of groups. Default: 1. divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
...@@ -290,6 +306,8 @@ class Conv2dTranspose(_Conv): ...@@ -290,6 +306,8 @@ class Conv2dTranspose(_Conv):
weight_init='normal', weight_init='normal',
bias_init='zeros'): bias_init='zeros'):
kernel_size = twice(kernel_size) kernel_size = twice(kernel_size)
stride = twice(stride)
dilation = twice(dilation)
# out_channels and in_channels swap. # out_channels and in_channels swap.
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel, # cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel,
# then Conv2dTranspose's out_channel refers to Conv2DBackpropInput's in_channel. # then Conv2dTranspose's out_channel refers to Conv2DBackpropInput's in_channel.
...@@ -333,26 +351,26 @@ class Conv2dTranspose(_Conv): ...@@ -333,26 +351,26 @@ class Conv2dTranspose(_Conv):
self.conv2d_transpose.set_strategy(strategy) self.conv2d_transpose.set_strategy(strategy)
return self return self
def _deconv_output_length(self, input_length, filter_size): def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size):
"""Calculate the width and height of output.""" """Calculate the width and height of output."""
length = 0 length = 0
if self.is_valid: if self.is_valid:
if filter_size - self.stride > 0: if filter_size - stride_size > 0:
length = input_length * self.stride + filter_size - self.stride length = input_length * stride_size + filter_size - stride_size
else: else:
length = input_length * self.stride length = input_length * stride_size
elif self.is_same: elif self.is_same:
length = input_length * self.stride length = input_length * stride_size
elif self.is_pad: elif self.is_pad:
length = input_length * self.stride - 2 * self.padding + filter_size + \ length = input_length * stride_size - 2 * self.padding + filter_size + \
(filter_size - 1) * (self.dilation - 1) - self.stride (filter_size - 1) * (dilation_size - 1) - stride_size
return length return length
def construct(self, x): def construct(self, x):
n, _, h, w = self.shape(x) n, _, h, w = self.shape(x)
h_out = self._deconv_output_length(h, self.kernel_size[0]) h_out = self._deconv_output_length(h, self.kernel_size[0], self.stride[0], self.dilation[0])
w_out = self._deconv_output_length(w, self.kernel_size[1]) w_out = self._deconv_output_length(w, self.kernel_size[1], self.stride[1], self.dilation[1])
if self.has_bias: if self.has_bias:
return self.bias_add(self.conv2d_transpose(x, self.weight, (n, self.out_channels, h_out, w_out)), return self.bias_add(self.conv2d_transpose(x, self.weight, (n, self.out_channels, h_out, w_out)),
self.bias) self.bias)
......
...@@ -34,7 +34,7 @@ from mindspore.ops.op_info_register import op_info_register ...@@ -34,7 +34,7 @@ from mindspore.ops.op_info_register import op_info_register
"value": "all" "value": "all"
}, },
{ {
"name": "pad", "name": "pad_list",
"param_type": "required", "param_type": "required",
"type": "listInt", "type": "listInt",
"value": "all" "value": "all"
......
...@@ -119,8 +119,8 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): ...@@ -119,8 +119,8 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
pad (int): The pad value to fill. Default: 0. pad (int): The pad value to fill. Default: 0.
mode (int): 0 Math convolutiuon, 1 cross-correlation convolution , mode (int): 0 Math convolutiuon, 1 cross-correlation convolution ,
2 deconvolution, 3 depthwise convolution. Default: 1. 2 deconvolution, 3 depthwise convolution. Default: 1.
stride (int): The stride to apply conv filter. Default: 1. stride (tuple): The stride to apply conv filter. Default: (1, 1).
dilation (int): Specifies the dilation rate to use for dilated convolution. Default: 1. dilation (tuple): Specifies the dilation rate to use for dilated convolution. Default: (1, 1, 1, 1).
group (int): Splits input into groups. Default: 1. group (int): Splits input into groups. Default: 1.
Returns: Returns:
...@@ -135,8 +135,8 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): ...@@ -135,8 +135,8 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
pad=0, pad=0,
pad_list=(0, 0, 0, 0), pad_list=(0, 0, 0, 0),
mode=1, mode=1,
stride=1, stride=(1, 1),
dilation=1, dilation=(1, 1, 1, 1),
group=1): group=1):
"""init Convolution""" """init Convolution"""
self.init_prim_io_names(inputs=['out_backprop', 'input', 'filter_sizes'], outputs=['output']) self.init_prim_io_names(inputs=['out_backprop', 'input', 'filter_sizes'], outputs=['output'])
...@@ -146,7 +146,9 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): ...@@ -146,7 +146,9 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
pad_mode = pad_mode.upper() pad_mode = pad_mode.upper()
self.add_prim_attr('pad_mode', pad_mode) self.add_prim_attr('pad_mode', pad_mode)
self.pad = pad self.pad = pad
self.stride = stride if isinstance(stride, tuple) and len(stride) == 4:
self.stride = (stride[2], stride[3])
self.add_prim_attr('stride', self.stride)
self.dilation = dilation self.dilation = dilation
self.group = group self.group = group
self.add_prim_attr('data_format', "NCHW") self.add_prim_attr('data_format', "NCHW")
......
此差异已折叠。
...@@ -35,8 +35,8 @@ class Net4(nn.Cell): ...@@ -35,8 +35,8 @@ class Net4(nn.Cell):
pad_mode="valid", pad_mode="valid",
pad=0, pad=0,
mode=1, mode=1,
stride=1, stride=(1, 1),
dilation=1, dilation=(1, 1, 1, 1),
group=1) group=1)
self.w = Parameter(initializer(Tensor(np.array([[[[1, 0, -1], [1, 0, -1], [1, 0, -1]]]]).astype(np.float32)), [1, 1, 3, 3]), name='w') self.w = Parameter(initializer(Tensor(np.array([[[[1, 0, -1], [1, 0, -1], [1, 0, -1]]]]).astype(np.float32)), [1, 1, 3, 3]), name='w')
self.x = Parameter(initializer(Tensor(np.array([[[ self.x = Parameter(initializer(Tensor(np.array([[[
......
...@@ -35,8 +35,8 @@ class Conv2dFilter(nn.Cell): ...@@ -35,8 +35,8 @@ class Conv2dFilter(nn.Cell):
pad_mode="valid", pad_mode="valid",
pad=0, pad=0,
mode=1, mode=1,
stride=1, stride=(1, 1),
dilation=1, dilation=(1, 1, 1, 1),
group=1) group=1)
self.get_shape = P.Shape() self.get_shape = P.Shape()
......
...@@ -21,17 +21,17 @@ from mindspore.common.tensor import Tensor ...@@ -21,17 +21,17 @@ from mindspore.common.tensor import Tensor
def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1): def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1):
"""Rearranges an image to row vector""" """Rearranges an image to row vector"""
batch_num, channel, height, width = img.shape batch_num, channel, height, width = img.shape
out_h = (height + 2*pad - filter_h - (filter_h - 1) * (dilation - 1))//stride + 1 out_h = (height + 2*pad - filter_h - (filter_h - 1) * (dilation[2] - 1))//stride[2] + 1
out_w = (width + 2*pad - filter_w - (filter_w - 1) * (dilation - 1))//stride + 1 out_w = (width + 2*pad - filter_w - (filter_w - 1) * (dilation[3] - 1))//stride[3] + 1
img = np.pad(img, [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant') img = np.pad(img, [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype) col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype)
for y in range(filter_h): for y in range(filter_h):
y_max = y + stride*out_h y_max = y + stride[2]*out_h
for x in range(filter_w): for x in range(filter_w):
x_max = x + stride*out_w x_max = x + stride[2]*out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride] col[:, :, y, x, :, :] = img[:, :, y:y_max:stride[2], x:x_max:stride[2]]
col = col.transpose(0, 4, 5, 1, 2, 3).reshape(batch_num*out_h*out_w, -1) col = col.transpose(0, 4, 5, 1, 2, 3).reshape(batch_num*out_h*out_w, -1)
return col return col
...@@ -42,8 +42,8 @@ def conv2d(x, weight, bias=None, stride=1, pad=0, ...@@ -42,8 +42,8 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
"""Convolution 2D""" """Convolution 2D"""
batch_num, _, x_h, x_w = x.shape batch_num, _, x_h, x_w = x.shape
filter_num, _, filter_h, filter_w = weight.shape filter_num, _, filter_h, filter_w = weight.shape
out_h = 1 + int((x_h + 2 * pad - filter_h - (filter_h - 1) * (dilation - 1)) / stride) out_h = 1 + int((x_h + 2 * pad - filter_h - (filter_h - 1) * (dilation[2] - 1)) / stride[2])
out_w = 1 + int((x_w + 2 * pad - filter_w - (filter_w - 1) * (dilation - 1)) / stride) out_w = 1 + int((x_w + 2 * pad - filter_w - (filter_w - 1) * (dilation[3] - 1)) / stride[3])
col = im2col(x, filter_h, filter_w, stride, pad, dilation) col = im2col(x, filter_h, filter_w, stride, pad, dilation)
col_w = np.reshape(weight, (filter_num, -1)).T col_w = np.reshape(weight, (filter_num, -1)).T
out = np.dot(col, col_w) out = np.dot(col, col_w)
......
...@@ -155,23 +155,35 @@ def batch_norm_grad(dy, x, scale, save_mean, save_inv_variance): ...@@ -155,23 +155,35 @@ def batch_norm_grad(dy, x, scale, save_mean, save_inv_variance):
def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0): def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
"""Rearranges a row vector to an image.""" """Rearranges a row vector to an image."""
validator.check_integer("stride", stride, 0, Rel.GT) if isinstance(stride, int):
stride_h = stride
stride_w = stride
elif isinstance(stride, tuple) and len(stride) == 2:
stride_h = stride[0]
stride_w = stride[1]
elif isinstance(stride, tuple) and len(stride) == 3:
stride_h = stride[2]
stride_w = stride[3]
else:
raise ValueError(f"The \'stride\' should be an int number or "
f"a tuple of two or four int numbers, but got {stride}")
batch_num, channel, height, width = input_shape batch_num, channel, height, width = input_shape
out_h = (height + 2*pad - filter_h)//stride + 1 out_h = (height + 2*pad - filter_h)//stride_h + 1
out_w = (width + 2*pad - filter_w)//stride + 1 out_w = (width + 2*pad - filter_w)//stride_w + 1
col = col.reshape(batch_num, out_h, out_w, channel, filter_h, filter_w) \ col = col.reshape(batch_num, out_h, out_w, channel, filter_h, filter_w) \
.transpose(0, 3, 4, 5, 1, 2) .transpose(0, 3, 4, 5, 1, 2)
img = np.zeros((batch_num, img = np.zeros((batch_num,
channel, channel,
height + 2*pad + stride - 1, height + 2*pad + stride_h - 1,
width + 2*pad + stride - 1)) \ width + 2*pad + stride_w - 1)) \
.astype(col.dtype) .astype(col.dtype)
for y in range(filter_h): for y in range(filter_h):
y_max = y + stride*out_h y_max = y + stride_h*out_h
for x in range(filter_w): for x in range(filter_w):
x_max = x + stride*out_w x_max = x + stride_h*out_w
img[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :] img[:, :, y:y_max:stride_h, x:x_max:stride_h] += col[:, :, y, x, :, :]
return img[:, :, pad:height + pad, pad:width + pad] return img[:, :, pad:height + pad, pad:width + pad]
...@@ -205,11 +217,35 @@ def conv2d(x, weight, bias=None, stride=1, pad=0, ...@@ -205,11 +217,35 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
dilation=1, groups=1, padding_mode='zeros'): dilation=1, groups=1, padding_mode='zeros'):
"""Convolution 2D.""" """Convolution 2D."""
# pylint: disable=unused-argument # pylint: disable=unused-argument
validator.check_integer("stride", stride, 0, Rel.GT) validator.check_type('stride', stride, (int, tuple))
if isinstance(stride, int):
stride = (stride, stride)
elif len(stride) == 4:
stride = (stride[2], stride[3])
if len(stride) != 2 or (not isinstance(stride[0], int)) or \
(not isinstance(stride[1], int)) or \
stride[0] < 1 or stride[1] < 1:
raise ValueError(f"The \'stride\' of \'conv2d\' should be an positive int number or "
f"a tuple of two positive int numbers, but got {stride}")
stride_h = stride[0]
stride_w = stride[1]
validator.check_type('dilation', dilation, (int, tuple))
if isinstance(dilation, int):
dilation = (dilation, dilation)
elif len(dilation) == 4:
dilation = (dilation[2], dilation[3])
if len(dilation) != 2 or (not isinstance(dilation[0], int)) or \
(not isinstance(dilation[1], int)) or \
dilation[0] < 1 or dilation[1] < 1:
raise ValueError(f"The \'dilation\' of \'conv2d\' should be an positive int number or "
f"a tuple of two positive int numbers, but got {dilation}")
dilation_h = dilation[0]
dilation_w = dilation[1]
batch_num, _, x_h, x_w = x.shape batch_num, _, x_h, x_w = x.shape
filter_num, _, filter_h, filter_w = weight.shape filter_num, _, filter_h, filter_w = weight.shape
out_h = 1 + int((x_h + 2 * pad - filter_h - (filter_h - 1) * (dilation - 1)) / stride) out_h = 1 + int((x_h + 2 * pad - filter_h - (filter_h - 1) * (dilation_h - 1)) / stride_h)
out_w = 1 + int((x_w + 2 * pad - filter_w - (filter_w - 1) * (dilation - 1)) / stride) out_w = 1 + int((x_w + 2 * pad - filter_w - (filter_w - 1) * (dilation_w - 1)) / stride_w)
col = im2col(x, filter_h, filter_w, stride, pad, dilation) col = im2col(x, filter_h, filter_w, stride, pad, dilation)
col_w = np.reshape(weight, (filter_num, -1)).T col_w = np.reshape(weight, (filter_num, -1)).T
out = np.dot(col, col_w) out = np.dot(col, col_w)
...@@ -286,19 +322,43 @@ def flatten_grad(dout, x): ...@@ -286,19 +322,43 @@ def flatten_grad(dout, x):
def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1): def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1):
"""Rearranges an image to row vector.""" """Rearranges an image to row vector."""
validator.check_integer("stride", stride, 0, Rel.GT) if isinstance(stride, int):
stride_h = stride
stride_w = stride
elif isinstance(stride, tuple) and len(stride) == 2:
stride_h = stride[0]
stride_w = stride[1]
elif isinstance(stride, tuple) and len(stride) == 3:
stride_h = stride[2]
stride_w = stride[3]
else:
raise ValueError(f"The \'stride\' should be an int number or "
f"a tuple of two or four int numbers, but got {stride}")
if isinstance(dilation, int):
dilation_h = dilation
dilation_w = dilation
elif isinstance(dilation, tuple) and len(dilation) == 2:
dilation_h = dilation[0]
dilation_w = dilation[1]
elif isinstance(dilation, tuple) and len(dilation) == 3:
dilation_h = dilation[2]
dilation_w = dilation[3]
else:
raise ValueError(f"The \'dilation\' should be an int number or "
f"a tuple of two or four int numbers, but got {dilation}")
batch_num, channel, height, width = img.shape batch_num, channel, height, width = img.shape
out_h = (height + 2*pad - filter_h- (filter_h - 1) * (dilation - 1))//stride + 1 out_h = (height + 2*pad - filter_h- (filter_h - 1) * (dilation_h - 1))//stride_h + 1
out_w = (width + 2*pad - filter_w- (filter_w - 1) * (dilation - 1))//stride + 1 out_w = (width + 2*pad - filter_w- (filter_w - 1) * (dilation_w - 1))//stride_w + 1
img = np.pad(img, [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant') img = np.pad(img, [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype) col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype)
for y in range(filter_h): for y in range(filter_h):
y_max = y + stride*out_h y_max = y + stride_h*out_h
for x in range(filter_w): for x in range(filter_w):
x_max = x + stride*out_w x_max = x + stride_h*out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride] col[:, :, y, x, :, :] = img[:, :, y:y_max:stride_h, x:x_max:stride_h]
col = col.transpose(0, 4, 5, 1, 2, 3).reshape(batch_num*out_h*out_w, -1) col = col.transpose(0, 4, 5, 1, 2, 3).reshape(batch_num*out_h*out_w, -1)
return col return col
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册