提交 03991a96 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4932 Add CropAndResize for old backend.

Merge pull request !4932 from liuxiao93/Add-CropAndResize-for-old-backend
......@@ -189,6 +189,7 @@ constexpr const char kNameRange[] = "Range";
constexpr const char kNameSquareSumAll[] = "SquareSumAll";
constexpr const char kNameAscendQuant[] = "Quant";
constexpr const char kNameAscendDequant[] = "Dequant";
constexpr const char kNameCropAndResize[] = "CropAndResize";
constexpr const char kNameReverseSequence[] = "ReverseSequence";
constexpr const char kNameEditDistance[] = "EditDistance";
constexpr const char kNameCase[] = "Case";
......
......@@ -45,4 +45,12 @@ ATTR_MAP(ResizeBilinearV2D) = {
{"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())}};
OUTPUT_MAP(ResizeBilinearV2D) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ResizeBilinearV2D, kNameResizeBilinear, ADPT_DESC(ResizeBilinearV2D))
// CropAndResize
INPUT_MAP(CropAndResize) = {
{1, INPUT_DESC(x)}, {2, INPUT_DESC(boxes)}, {3, INPUT_DESC(box_index)}, {4, INPUT_DESC(crop_size)}};
ATTR_MAP(CropAndResize) = {{"extrapolation_value", ATTR_DESC(extrapolation_value, AnyTraits<float>())},
{"method", ATTR_DESC(method, AnyTraits<std::string>())}};
OUTPUT_MAP(CropAndResize) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(CropAndResize, kNameCropAndResize, ADPT_DESC(CropAndResize))
} // namespace mindspore::transform
......@@ -34,5 +34,8 @@ DECLARE_OP_USE_OUTPUT(ResizeBilinearV2D)
DECLARE_OP_ADAPTER(ResizeBilinearV2Grad)
DECLARE_OP_USE_OUTPUT(ResizeBilinearV2Grad)
DECLARE_OP_ADAPTER(CropAndResize)
DECLARE_OP_USE_OUTPUT(CropAndResize)
} // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_IMAGE_OPS_DECLARE_H_
......@@ -14,6 +14,7 @@
# ============================================================================
"""image_ops"""
from ... import context
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
......@@ -84,6 +85,7 @@ class CropAndResize(PrimitiveWithInfer):
self.method = method
validator.check_value_type("extrapolation_value", extrapolation_value, [float], self.name)
self.extrapolation_value = extrapolation_value
self.is_ge = context.get_context("enable_ge")
def __infer__(self, x, boxes, box_index, crop_size):
# get shape
......@@ -124,6 +126,9 @@ class CropAndResize(PrimitiveWithInfer):
crop_height = crop_size_value[0]
crop_width = crop_size_value[1]
depth = x_shape[3]
return {'shape': (num_boxes, crop_height, crop_width, depth),
out_shape = (num_boxes, crop_height, crop_width, depth)
if self.is_ge:
out_shape = (num_boxes, x_shape[1], crop_height, crop_width)
return {'shape': out_shape,
'dtype': mstype.float32,
'value': None}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册