提交 5a397567 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4096 Add Conv2d C Primitive

Merge pull request !4096 from lianliguang/unify-primitive-and-generate-graph
......@@ -149,7 +149,9 @@ add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/utils util)
list(APPEND SUB_OBJECTS_SRC $<TARGET_OBJECTS:_mindspore_core_utils_obj>)
add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/ir ir)
list(APPEND SUB_OBJECTS_SRC $<TARGET_OBJECTS:_mindspore_ir_obj>)
add_dependencies(_mindspore_core_utils_obj _mindspore_base_obj _mindspore_ir_obj _mindspore_abstract_obj proto_input )
add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/c_ops c_ops)
list(APPEND SUB_OBJECTS_SRC $<TARGET_OBJECTS:_mindspore_c_ops_obj>)
add_dependencies(_mindspore_core_utils_obj _mindspore_base_obj _mindspore_ir_obj _mindspore_abstract_obj _mindspore_c_ops_obj proto_input)
set_property(SOURCE ${SUB_OBJECTS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ME)
add_library(mindspore STATIC ${SUB_OBJECTS_SRC})
......
file(GLOB_RECURSE _C_OPS_ALL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
add_library(_mindspore_c_ops_obj OBJECT ${_C_OPS_ALL_SRC_FILES})
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/conv2d.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace {
using PrimConv2dPtr = std::shared_ptr<Conv2d>;
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto conv_prim = primitive->cast<PrimConv2dPtr>();
MS_EXCEPTION_IF_NULL(conv_prim);
auto prim_name = conv_prim->name();
CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[0]->GetShapeTrack(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->GetShapeTrack(), prim_name);
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->GetGroup(), kEqual, "w_shape[1]",
w_shape[1], conv_prim->name());
auto out_channel = conv_prim->GetOutputChannel();
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name());
std::vector<int> temp_w;
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
CheckAndConvertUtils::Check("kernel_size", conv_prim->GetKernelSize(), kEqual, "w_shape[2:4]", temp_w,
conv_prim->name());
auto kernel_size_h = w_shape[2];
auto kernel_size_w = w_shape[3];
auto stride = conv_prim->GetStride();
auto dilation = conv_prim->GetDilation();
auto stride_h = stride[2];
auto stride_w = stride[3];
auto dilation_h = dilation[2];
auto dilation_w = dilation[3];
int h_out = -1;
int w_out = -1;
std::vector<int> pad_list(4, 0);
auto pad_mode = conv_prim->GetPadMode();
if (pad_mode == "valid") {
h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h);
w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w);
} else if (pad_mode == "same") {
h_out = ceil(x_shape[2] / stride_h);
w_out = ceil(x_shape[3] / stride_w);
auto pad_needed_h = std::max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]);
pad_list.emplace_back(floor(pad_needed_h / 2));
pad_list.emplace_back(pad_needed_h / 2);
auto pad_needed_w = std::max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]);
auto pad_left = floor(pad_needed_w / 2);
pad_list.emplace_back(pad_left);
pad_list.emplace_back(pad_needed_h - pad_left);
} else if (pad_mode == "pad") {
std::copy(conv_prim->GetPad().begin(), conv_prim->GetPad().end(), std::back_inserter(pad_list));
auto pad_top = conv_prim->GetPad()[0];
auto pad_bottom = conv_prim->GetPad()[1];
auto pad_right = conv_prim->GetPad()[2];
auto pad_left = conv_prim->GetPad()[3];
h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h;
w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w;
h_out = floor(h_out);
w_out = floor(w_out);
}
conv_prim->SetPadList(pad_list);
std::vector<int> out_shape = {x_shape[0], out_channel, h_out, w_out};
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_type = CheckAndConvertUtils::ConvertTypePtrToTypeId("x_dtype", input_args[0]->GetTypeTrack(), prim->name());
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32};
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->GetTypeTrack());
types.emplace("w", input_args[1]->GetTypeTrack());
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
if (x_type == kNumberTypeInt8) {
return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32));
}
return std::make_shared<TensorType>(TypeIdToType(x_type));
}
} // namespace
void Conv2d::Init(int out_channel, const std::vector<int> &kernel_size, int mode, const std::string &pad_mode,
const std::vector<int> &pad, const std::vector<int> &stride, const std::vector<int> &dilation,
int group) {
auto prim_name = this->name();
this->AddAttr("data_format", MakeValue("NCHW"));
this->AddAttr("offset_a", MakeValue(0));
this->SetKernelSize(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name));
this->SetStride(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), true, true));
this->SetDilation(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), true, true));
this->SetPadMode(CheckAndConvertUtils::CheckString(kPadMode, pad_mode, {"valid", "same", "pad"}, prim_name));
CheckAndConvertUtils::CheckInteger("pad size", pad.size(), kEqual, 4, prim_name);
if (pad_mode == "pad") {
for (auto item : pad) {
CheckAndConvertUtils::Check("pad item", item, kGreaterEqual, "zeros list", 0, prim_name);
}
} else {
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros list", {0, 0, 0, 0}, prim_name);
}
this->SetPad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true));
this->SetMode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 1, prim_name));
this->SetOutChannel(CheckAndConvertUtils::CheckInteger("out_channel", out_channel, kGreaterThan, 0, prim_name));
this->SetGroup(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name));
}
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
}
} // namespace mindspore
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_CONV2D_H
#define MINDSPORE_CORE_C_OPS_CONV2D_H
#include <map>
#include <vector>
#include <string>
#include "c_ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
class Conv2d : public PrimitiveC {
public:
Conv2d() : PrimitiveC(kConv2DName) { InitIOName({"x", "w"}, {"output"}); }
void Init(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid",
const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1},
const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1);
std::vector<int> GetKernelSize() const {
auto value_ptr = this->GetAttr(kKernelSize);
return GetValue<std::vector<int>>(value_ptr);
}
std::vector<int> GetStride() const {
auto value_ptr = GetAttr(kStride);
return GetValue<std::vector<int>>(value_ptr);
}
std::vector<int> GetDilation() const {
auto value_ptr = GetAttr(kDilation);
return GetValue<std::vector<int>>(value_ptr);
}
std::string GetPadMode() const {
auto value_ptr = this->GetAttr(kPadMode);
return GetValue<string>(value_ptr);
}
std::vector<int> GetPad() const {
auto value_ptr = this->GetAttr(kPad);
return GetValue<std::vector<int>>(value_ptr);
}
int GetMode() const {
auto value_ptr = this->GetAttr(kMode);
return GetValue<int>(value_ptr);
}
int GetGroup() const {
auto value_ptr = this->GetAttr(kGroup);
return GetValue<int>(value_ptr);
}
int GetOutputChannel() const {
auto value_ptr = this->GetAttr(kOutputChannel);
return GetValue<int>(value_ptr);
}
void SetKernelSize(const std::vector<int> &kernel_size) { this->AddAttr(kKernelSize, MakeValue(kernel_size)); }
void SetStride(const std::vector<int> &stride) { this->AddAttr(kStride, MakeValue(stride)); }
void SetDilation(const std::vector<int> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); }
void SetPadMode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); }
void SetPad(const std::vector<int> &pad) { this->AddAttr(kPad, MakeValue(pad)); }
void SetMode(int mode) { this->AddAttr(kMode, MakeValue(mode)); }
void SetGroup(int group) { this->AddAttr(kGroup, MakeValue(group)); }
void SetOutChannel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); }
void SetPadList(const std::vector<int> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); }
private:
inline static const string kKernelSize = "kernel_size";
inline static const string kStride = "stride";
inline static const string kDilation = "dilation";
inline static const string kPadMode = "pad_mode";
inline static const string kPad = "pad";
inline static const string kMode = "mode";
inline static const string kGroup = "group";
inline static const string kOutputChannel = "output channel";
inline static const string kPadList = "pad_list";
inline static const string kConv2DName = "Conv2D";
};
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_CONV2D_H
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H
#define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H
#include <string>
#include <vector>
#include "ir/primitive.h"
#include "ir/value.h"
namespace mindspore {
class PrimitiveC : public Primitive {
public:
explicit PrimitiveC(const std::string &name) : Primitive(name) { attrs_ = {}; }
protected:
void InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name) {
this->AddAttr("input_names", MakeValue(inputs_name));
this->AddAttr("output_names", MakeValue(outputs_name));
}
};
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H
......@@ -632,6 +632,19 @@ void FuncGraph::CheckOrder() {
MS_LOG(DEBUG) << "Check order okay.";
}
}
CNodePtr FuncGraph::NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &inputs) {
auto primitive_node = std::make_shared<ValueNode>(primitive);
std::vector<AnfNodePtr> input_node_list = {primitive_node};
std::copy(inputs.begin(), inputs.end(), std::back_inserter(input_node_list));
return NewCNode(input_node_list);
}
ParameterPtr FuncGraph::add_parameter(const tensor::MetaTensorPtr &meta_tensor) {
auto parameter = add_parameter();
parameter->set_default_param(MakeValue(meta_tensor));
parameter->set_abstract(meta_tensor->ToAbstract());
return parameter;
}
size_t NewFgSeenGeneration() {
static size_t fg_seen_generation = 0;
......
......@@ -170,7 +170,9 @@ class FuncGraph : public FuncGraphBase {
// create a cnode with given inputs, bound to this graph, and set to specific scope
CNodePtr NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope);
virtual CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs);
virtual ParameterPtr add_parameter(const tensor::MetaTensorPtr &meta_tensor);
// Functions for handling variable argument, keyword-only arguments and variable keyword argument
AnfNodePtr GetDefaultValueByName(const std::string &name);
void set_param_default_value(const std::string &name, const AnfNodePtr &node) {
......
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/check_convert_utils.h"
#include <utility>
#include "abstract/abstract_value.h"
namespace mindspore {
namespace {
const std::map<CompareEnum, std::function<bool(int, int)>> kCompareMap = {
{kEqual, [](int num1, int num2) -> bool { return num1 == num2; }},
{kNotEqual, [](int num1, int num2) -> bool { return num1 != num2; }},
{kLessThan, [](int num1, int num2) -> bool { return num1 < num2; }},
{kLessEqual, [](int num1, int num2) -> bool { return num1 <= num2; }},
{kGreaterThan, [](int num1, int num2) -> bool { return num1 > num2; }},
{kGreaterEqual, [](int num1, int num2) -> bool { return num1 >= num2; }}};
const std::map<CompareRange, std::function<bool(int, std::pair<int, int>)>> kCompareRangeMap = {
{kIncludeNeither,
[](int num1, std::pair<int, int> range) -> bool { return num1 > range.first && num1 < range.second; }},
{kIncludeLeft,
[](int num1, std::pair<int, int> range) -> bool { return num1 >= range.first && num1 < range.second; }},
{kIncludeRight,
[](int num1, std::pair<int, int> range) -> bool { return num1 > range.first && num1 <= range.second; }},
{kIncludeBoth,
[](int num1, std::pair<int, int> range) -> bool { return num1 >= range.first && num1 <= range.second; }}};
const std::map<CompareEnum, std::string> kCompareToString = {
{kEqual, "equal"}, {kNotEqual, "not equal"}, {kLessThan, "less than"},
{kLessEqual, "less eqaul"}, {kGreaterThan, "greater than"}, {kGreaterEqual, "greate equal"}};
const std::map<CompareRange, std::pair<std::string, std::string>> kCompareRangeToString = {
{kIncludeNeither, {"in (", ")"}},
{kIncludeLeft, {" in [", ")"}},
{kIncludeRight, {"in (", "]"}},
{kIncludeBoth, {"in [", "]"}}};
} // namespace
bool CheckAndConvertUtils::IsEqualVector(const std::vector<int> &vec_1, const std::vector<int> &vec_2) {
if (vec_1.size() != vec_2.size()) {
return false;
}
for (size_t index = 0; index < vec_1.size(); ++index) {
if (vec_1[index] != vec_2[index]) {
return false;
}
}
return true;
}
std::vector<int> CheckAndConvertUtils::CheckPositiveVector(const std::string &arg_name,
const std::vector<int> &arg_value,
const std::string &prim_name, bool allow_four,
bool ret_four) {
if (arg_value.size() == 2) {
return ret_four ? std::vector<int>{1, 1, arg_value[0], arg_value[1]} : arg_value;
} else if (arg_value.size() == 4 && allow_four) {
return ret_four ? arg_value : std::vector<int>{arg_value[2], arg_value[3]};
}
std::ostringstream buffer;
buffer << "For " << prim_name << " attr " << arg_name << " should be a positive vector of size two ";
if (allow_four) {
buffer << "or four ";
}
buffer << " positive int numbers , but got [";
for (auto item : arg_value) {
buffer << item << ",";
}
buffer << "]";
MS_EXCEPTION(ValueError) << buffer.str();
}
std::string CheckAndConvertUtils::CheckString(const std::string &arg_name, const std::string &arg_value,
const std::set<std::string> &check_list, const std::string &prim_name) {
if (check_list.find(arg_value) != check_list.end()) {
return arg_value;
}
std::ostringstream buffer;
buffer << "For " << prim_name << " the " << arg_name << " should be str and must be ";
if (check_list.size() == 1) {
buffer << (*check_list.begin()) << "but got " << arg_value;
MS_EXCEPTION(ValueError) << buffer.str();
}
buffer << "one of {";
for (const auto &item : check_list) {
buffer << item << " ,";
}
buffer << " }"
<< " but got " << arg_value;
MS_EXCEPTION(ValueError) << buffer.str();
}
int CheckAndConvertUtils::CheckInteger(const std::string &arg_name, int arg_value, CompareEnum compare_operator,
int match_value, const std::string &prim_name) {
auto iter = kCompareMap.find(compare_operator);
if (iter == kCompareMap.end()) {
MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map";
}
if (iter->second(arg_value, match_value)) {
return arg_value;
}
std::ostringstream buffer;
if (prim_name.empty()) {
buffer << "The ";
} else {
buffer << "For " << prim_name << " the ";
}
buffer << arg_name << " must ";
auto iter_to_string = kCompareToString.find(compare_operator);
if (iter_to_string == kCompareToString.end()) {
MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare string map";
}
buffer << iter_to_string->second << match_value << " , but got " << arg_value;
MS_EXCEPTION(ValueError) << buffer.str();
}
void CheckAndConvertUtils::CheckInRange(const std::string &arg_name, int arg_value, CompareRange compare_operator,
const std::pair<int, int> &range, const std::string &prim_name) {
auto iter = kCompareRangeMap.find(compare_operator);
if (iter == kCompareRangeMap.end()) {
MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map";
}
if (iter->second(arg_value, range)) {
return;
}
std::ostringstream buffer;
if (prim_name.empty()) {
buffer << "The ";
} else {
buffer << "For " << prim_name << " the ";
}
buffer << arg_name << " must ";
auto iter_to_string = kCompareRangeToString.find(compare_operator);
if (iter_to_string == kCompareRangeToString.end()) {
MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare string map";
}
auto range_strng = iter_to_string->second;
buffer << range_strng.first << range.first << "," << range_strng.second << " , but got " << arg_value;
MS_EXCEPTION(ValueError) << buffer.str();
}
std::vector<int> CheckAndConvertUtils::ConvertShapePtrToShape(const std::string &arg_name, const BaseShapePtr &shape,
const std::string &prim_name) {
MS_EXCEPTION_IF_NULL(shape);
if (!shape->isa<abstract::Shape>()) {
MS_EXCEPTION(ValueError) << "The " << arg_name << "'s shape is " << shape->ToString()
<< "should be a common shape!";
}
auto shape_element = shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
return shape_element->shape();
}
TypeId CheckAndConvertUtils::ConvertTypePtrToTypeId(const string &arg_name, const TypePtr &type_ptr,
const string &prim_name) {
MS_EXCEPTION_IF_NULL(type_ptr);
if (!type_ptr->isa<TensorType>() || !type_ptr->isa<Number>()) {
MS_EXCEPTION(ValueError) << "The " << arg_name << "'s shape is " << type_ptr->ToString()
<< "should be a common type!(tensor_type && numbertype)";
}
return type_ptr->type_id();
}
void CheckAndConvertUtils::Check(const string &arg_name, int arg_value, CompareEnum compare_type,
const string &value_name, int value, const string &prim_name,
ExceptionType exception_type) {
auto iter = kCompareMap.find(compare_type);
if (iter == kCompareMap.end()) {
MS_EXCEPTION(NotExistsError) << "the compare type :" << compare_type << " is not in the compare map";
}
if (iter->second(arg_value, value)) {
return;
}
std::ostringstream buffer;
if (prim_name.empty()) {
buffer << "The ";
} else {
buffer << "For " << prim_name << " the ";
}
auto iter_to_string = kCompareToString.find(compare_type);
if (iter_to_string == kCompareToString.end()) {
MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_type << " cannot find in the compare string map";
}
MS_EXCEPTION(exception_type) << buffer.str() << arg_name << " should be " << iter_to_string->second << value
<< " but got " << arg_value;
}
void CheckAndConvertUtils::Check(const string &arg_name, const std::vector<int> &arg_value, CompareEnum compare_type,
const string &value_name, const std::vector<int> &value, const string &prim_name,
ExceptionType exception_type) {
if (compare_type != kEqual) {
auto iter = kCompareToString.find(compare_type);
if (iter != kCompareToString.end()) {
MS_EXCEPTION(NotSupportError) << "Only supported equal to compare two vectors but got " << iter->second;
}
MS_EXCEPTION(UnknownError) << "Cannot find the operator " << compare_type << "in the compare map!";
}
if (arg_value == value) {
return;
}
std::ostringstream buffer;
if (prim_name.empty()) {
buffer << "The ";
} else {
buffer << "For " << prim_name << " the ";
}
auto iter_to_string = kCompareToString.find(compare_type);
if (iter_to_string == kCompareToString.end()) {
MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_type << " cannot find in the compare string map";
}
buffer << arg_name << "should be " << iter_to_string->second << " [";
for (auto item : value) {
buffer << item << ",";
}
buffer << "] "
<< "but got [";
for (auto item : arg_value) {
buffer << item << " ,";
}
buffer << "]";
MS_EXCEPTION(exception_type) << buffer.str();
}
void CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, TypePtr> &types,
const std::set<TypeId> &check_list, const std::string &prim_name) {
if (types.empty()) {
MS_LOG(WARNING) << "Tryinh to use the function to check a empty types map!";
return;
}
std::set<TypeId> types_id;
std::ostringstream buffer;
buffer << "For " << prim_name;
for (const auto &type : types) {
MS_EXCEPTION_IF_NULL(type.second);
if (!type.second->isa<TensorType>()) {
MS_EXCEPTION(TypeError) << "The " << prim_name << "'s" << type.first << " input must be tensor type but got "
<< type.second->ToString();
}
types_id.emplace(type.second->type_id());
}
if (types_id.size() > 1) {
buffer << "'s input type is not same : ";
for (const auto &item : types) {
buffer << "[ name : " << item.first << " ,type : " << item.second->ToString() << "]";
}
MS_EXCEPTION(TypeError) << buffer.str();
}
if (check_list.find(*(types_id.begin())) != check_list.end()) {
buffer << " type of ";
for (const auto &elem : types) {
buffer << elem.first << " should be in [";
for (auto type_elem : check_list) {
buffer << type_elem << " ,";
}
buffer << "] , but got " << types.begin()->second->ToString();
}
}
MS_EXCEPTION(TypeError) << buffer.str();
}
} // namespace mindspore
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H
#define MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H
#include <vector>
#include <string>
#include <map>
#include <set>
#include <utility>
#include "base/base.h"
#include "ir/anf.h"
#include "ir/dtype/type_id.h"
#include "utils/log_adapter.h"
namespace mindspore {
enum CompareEnum : int {
kEqual = 1, // ==
kNotEqual = 2, // !=
kLessThan = 3, // <
kLessEqual = 4, // <=
kGreaterThan = 5, // >
kGreaterEqual = 6, // >=
};
enum CompareRange {
kIncludeNeither = 1, // (a,b)
kIncludeLeft = 2, // [a,b)
kIncludeRight = 3, // (a,b]
kIncludeBoth = 4, // [a,b]
};
class CheckAndConvertUtils {
public:
static std::vector<int> CheckPositiveVector(const std::string &arg_name, const std::vector<int> &arg_value,
const std::string &prim_name, bool allow_four = false,
bool ret_four = false);
static std::string CheckString(const std::string &arg_name, const std::string &arg_value,
const std::set<std::string> &check_list, const std::string &prim_name);
static int CheckInteger(const std::string &arg_name, int arg_value, CompareEnum compare_operator, int match_value,
const std::string &prim_name);
static void CheckInRange(const std::string &arg_name, int arg_value, CompareRange compare_operator,
const std::pair<int, int> &range, const std::string &prim_name);
static std::vector<int> ConvertShapePtrToShape(const std::string &arg_name, const BaseShapePtr &shape,
const std::string &prim_name);
static TypeId ConvertTypePtrToTypeId(const std::string &arg_name, const TypePtr &type_ptr,
const std::string &prim_name);
static void Check(const std::string &arg_name, int arg_value, CompareEnum compare_type, const std::string &value_name,
int value, const std::string &prim_name = "", ExceptionType exception_type = ValueError);
static void Check(const std::string &arg_name, const std::vector<int> &arg_value, CompareEnum compare_type,
const std::string &value_name, const std::vector<int> &value, const std::string &prim_name = "",
ExceptionType exception_type = ValueError);
static void CheckTensorTypeSame(const std::map<std::string, TypePtr> &types, const std::set<TypeId> &check_list,
const std::string &prim_name);
private:
static bool IsEqualVector(const std::vector<int> &vec_1, const std::vector<int> &vec_2);
};
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册