提交 5d225f93 编写于 作者: L lianliguang

change the padding strategy & refactor insert transdata

上级 60958d6b
......@@ -20,6 +20,8 @@
#include <utility>
#include "./securec.h"
#include "common/utils.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/kernel.h"
#include "device/convert_tensor_utils.h"
#include "utils/convert_utils.h"
#include "utils/log_adapter.h"
......@@ -27,6 +29,33 @@
namespace mindspore {
namespace trans {
namespace {
std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape) {
std::vector<size_t> shape_4d(4, 1);
switch (shape.size()) {
case 0:
return shape_4d;
case 1:
shape_4d[1] = shape[0];
break;
case 2:
shape_4d[1] = shape[0];
shape_4d[2] = shape[1];
break;
case 3:
shape_4d[1] = shape[0];
shape_4d[2] = shape[1];
shape_4d[3] = shape[2];
break;
case 4:
std::copy(shape.begin(), shape.end(), shape_4d.begin());
break;
default:
MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size();
}
return shape_4d;
}
} // namespace
const size_t kNchwDims = 4;
const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1},
{kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8},
......@@ -154,38 +183,64 @@ size_t TypeIdSize(const TypeId data_type) {
return unsupported_type_error;
}
std::vector<size_t> TransShapeTo4d(const std::vector<size_t> &shape) {
bool IsNeedPadding(const std::string &format, const size_t shape_size) {
if (shape_size == 0) {
return false;
}
if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) {
return false;
} else if (shape_size < 4) {
return true;
}
return false;
}
std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
std::vector<int> shape;
std::vector<size_t> host_shape;
if (node->isa<ValueNode>()) {
auto value_node = node->cast<ValueNodePtr>();
auto node_value = value_node->value();
auto tensor = node_value->cast<tensor::TensorPtr>();
if (tensor == nullptr) {
MS_LOG(EXCEPTION) << " the node[ " << node->DebugString() << "]'s cannot convert ";
}
shape = tensor->shape();
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize);
if (host_shape.empty()) {
host_shape.push_back(1);
}
} else {
host_shape = AnfAlgo::GetOutputInferShape(node, index);
}
if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, 0), host_shape.size())) {
host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, 0));
}
std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToInt);
return shape;
}
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<kernel::Axis> &padding_axis) {
if (padding_axis.empty() || shape.size() != padding_axis.size()) {
return PaddingShapeTo4dByDefault(shape);
}
std::vector<size_t> shape_4d(4, 1);
switch (shape.size()) {
case 0:
break;
case 1:
shape_4d[1] = shape[0];
break;
case 2:
shape_4d[0] = shape[0];
shape_4d[1] = shape[1];
break;
case 3:
MS_LOG(EXCEPTION) << "Unexpected shape size = 3,it should has a default format";
case 4:
for (size_t i = 0; i < 4; ++i) {
shape_4d[i] = shape[i];
}
break;
default:
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
for (size_t index = 0; index < padding_axis.size(); index++) {
shape_4d[padding_axis[index]] = shape[index];
}
return shape_4d;
}
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
return shape;
}
auto temp_shape = shape;
std::vector<size_t> device_shape;
if (format == kOpFormat_FRAC_NZ) {
if (shape.size() < 2) {
MS_EXCEPTION(NotSupportError) << "Format " << format << " is not support shape " << shape.size();
}
if (shape.size() > 2) {
MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size();
} else {
(void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape));
}
auto h1 = (shape[shape.size() - 2] - 1) / kCubeSize + 1;
......@@ -197,35 +252,36 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
return device_shape;
}
if (shape.size() != 4) {
MS_LOG(EXCEPTION) << "shape_4d size should be 4";
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
temp_shape = PaddingShapeTo4dByDefault(shape);
}
if (format == kOpFormat_NC1HWC0) {
size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
size_t C1 = (temp_shape[1] + kCubeSize - 1) / kCubeSize;
size_t C0 = kCubeSize;
device_shape.push_back(shape[0]);
device_shape.push_back(temp_shape[0]);
device_shape.push_back(C1);
device_shape.push_back(shape[2]);
device_shape.push_back(shape[3]);
device_shape.push_back(temp_shape[2]);
device_shape.push_back(temp_shape[3]);
device_shape.push_back(C0);
return device_shape;
} else if (format == kOpFormat_FRAC_Z) {
size_t cout16 = ((shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize;
size_t cin16 = ((shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize;
device_shape.push_back(shape[2] * shape[3] * cin16 / kCubeSize);
size_t cout16 = ((temp_shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize;
size_t cin16 = ((temp_shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize;
device_shape.push_back(temp_shape[2] * temp_shape[3] * cin16 / kCubeSize);
device_shape.push_back(cout16 / kCubeSize);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
} else if (format == kOpFormat_NHWC) {
device_shape.push_back(shape[0]);
device_shape.push_back(shape[2]);
device_shape.push_back(shape[3]);
device_shape.push_back(shape[1]);
device_shape.push_back(temp_shape[0]);
device_shape.push_back(temp_shape[2]);
device_shape.push_back(temp_shape[3]);
device_shape.push_back(temp_shape[1]);
return device_shape;
} else if (format == kOpFormat_NCHW) {
return shape;
} else if (format == kOpFormat_HWCN) {
return {shape[2], shape[3], shape[1], shape[0]};
return {temp_shape[2], temp_shape[3], temp_shape[1], temp_shape[0]};
} else if (format == kOpFormat_NCHW) {
return temp_shape;
}
MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
}
......
......@@ -24,6 +24,7 @@
#include <utility>
#include <vector>
#include "ir/dtype.h"
#include "kernel/kernel.h"
#include "ir/dtype/type.h"
namespace mindspore {
......@@ -49,7 +50,10 @@ size_t TypeIdSize(const TypeId data_type);
size_t ShapeSize(const std::vector<size_t> &shape);
size_t CubeSizeByType(const TypeId data_type);
std::vector<size_t> TransShapeTo4d(const std::vector<size_t> &shape);
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape,
const std::vector<kernel::Axis> &padding_axis = {});
std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index);
bool IsNeedPadding(const std::string &format, const size_t shape_size);
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format);
bool TransDataType(const TypeIdArgs &args, void *result);
bool TransFormat(const FormatArgs &args, void *result);
......
......@@ -141,7 +141,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int
if (format_ == kOpFormat_FRAC_NZ) {
device_shape = trans::TransShapeToDevice(host_shape, format_);
} else {
host_shape = trans::TransShapeTo4d(host_shape);
host_shape = trans::PaddingShapeTo4d(host_shape);
device_shape = trans::TransShapeToDevice(host_shape, format_);
}
if (type_id_ != type) {
......@@ -224,7 +224,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int
if (format_ == kOpFormat_FRAC_NZ) {
device_shape = trans::TransShapeToDevice(host_shape, format_);
} else {
host_shape = trans::TransShapeTo4d(host_shape);
host_shape = trans::PaddingShapeTo4d(host_shape);
device_shape = trans::TransShapeToDevice(host_shape, format_);
}
if (type_id_ != type) {
......
......@@ -27,6 +27,7 @@
#include "utils/context/ms_context.h"
#include "device/ascend/profiling/profiling_manager.h"
#include "hccl/hcom.h"
#include "common/trans.h"
#include "runtime/context.h"
#include "device/ascend/ascend_stream_assign.h"
#include "device/ascend/ascend_memory_pool.h"
......@@ -150,7 +151,7 @@ void DumpOutput(mindspore::session::KernelGraph *graph, const string &dump_path,
auto output_size = AnfAlgo::GetOutputTensorNum(node);
for (size_t j = 0; j < output_size; ++j) {
auto addr = AnfAlgo::GetOutputAddr(node, j);
auto shape = AnfAlgo::GetOutputInferShape(node, j);
auto shape = trans::GetRuntimePaddingShape(node, j);
auto type = AnfAlgo::GetOutputInferDataType(node, j);
auto format = kOpFormat_DEFAULT;
string filepath = dump_path + '/' + kernel_name + '_' + "output_" + std::to_string(j);
......@@ -181,7 +182,7 @@ void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_p
continue;
}
auto addr = AnfAlgo::GetOutputAddr(item, PRAMATER_OUTPUT_INDEX);
auto shape = AnfAlgo::GetOutputInferShape(item, PRAMATER_OUTPUT_INDEX);
auto shape = trans::GetRuntimePaddingShape(item, PRAMATER_OUTPUT_INDEX);
auto type = AnfAlgo::GetOutputInferDataType(item, PRAMATER_OUTPUT_INDEX);
auto format = kOpFormat_DEFAULT;
string filepath = dump_path + '/' + parameter_name + '_' + "output_0";
......
......@@ -184,7 +184,7 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
}
if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
if (AnfAlgo::IsFeatureMapInput(kernel_node, input_index) &&
kSpecialFormatSet.find(kernel_build_info.GetInputFormat(input_index)) != kSpecialFormatSet.end()) {
kNeedTransFormatSet.find(kernel_build_info.GetInputFormat(input_index)) != kNeedTransFormatSet.end()) {
(*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT]++;
}
(*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++;
......@@ -210,19 +210,22 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
(*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++;
}
}
} // namespace
}
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
MS_EXCEPTION_IF_NULL(input_kernel_node);
if (AnfAlgo::IsFeatureMapInput(kernel_node, input_index)) {
continue;
}
auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
MS_EXCEPTION_IF_NULL(input_with_index.first);
auto real_input_node = input_with_index.first;
if (real_input_node->isa<CNode>()) {
continue;
}
if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
continue;
}
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// we set special device info of a input tensor.
......
......@@ -25,6 +25,7 @@
#include "session/anf_runtime_algorithm.h"
#include "utils/context/ms_context.h"
#include "common/trans.h"
#include "utils/config_manager.h"
#include "common/utils.h"
#include "kernel/kernel_build_info.h"
......@@ -391,7 +392,8 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::Context> &c
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
MS_EXCEPTION_IF_NULL(device_address);
tensor->set_device_address(device_address);
if (!device_address->SyncHostToDevice(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(),
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c(false))) {
MS_LOG(INFO) << "SyncHostToDevice failed.";
return false;
......
......@@ -31,6 +31,7 @@ class KernelInfo {
public:
KernelInfo() {
kernel_mod_ = nullptr;
is_feature_map_ = false;
select_kernel_build_info_ = nullptr;
output_address_list_ = {};
workspace_address_list_ = {};
......@@ -45,6 +46,7 @@ class KernelInfo {
void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
select_kernel_build_info_ = select_kernel_build_info;
}
void SetFeatureMapFlag(bool flag) { is_feature_map_ = flag; }
const DeviceAddress *GetOutputAddr(size_t index) const;
DeviceAddressPtr GetMutableOutputAddr(size_t index) const;
bool OutputAddrExist(size_t index) const;
......@@ -63,8 +65,10 @@ class KernelInfo {
void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; }
uint32_t graph_id() const { return graph_id_; }
bool operator==(const KernelInfo &other) const;
bool is_feature_map() const { return is_feature_map_; }
private:
bool is_feature_map_;
kernel::KernelBuildInfoPtr select_kernel_build_info_;
std::vector<std::shared_ptr<DeviceAddress>> output_address_list_;
std::vector<std::shared_ptr<DeviceAddress>> workspace_address_list_;
......
......@@ -105,7 +105,7 @@ size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &nod
std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index);
auto format = AnfAlgo::GetOutputFormat(node, output_index);
if (shape.empty() && format != kOpFormat_DEFAULT) {
shape = trans::TransShapeTo4d(shape);
shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index));
shape = trans::TransShapeToDevice(shape, format);
}
// scalar's output shape is a empty vector
......@@ -401,8 +401,9 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
auto address = CreateDeviceAddress(ptr, node_size, AnfAlgo::GetOutputFormat(value_node, output_idx), output_type_id);
MS_EXCEPTION_IF_NULL(address);
AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
if (!address->SyncHostToDevice(tensor->shape(), tensor_size, tensor->data_type(), tensor->data_c(false))) {
MS_EXCEPTION(NotExistsError) << "kValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is"
if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(),
tensor->data_c(false))) {
MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is"
<< AnfAlgo::GetOutputFormat(value_node, output_idx) << "node dtype is "
<< AnfAlgo::GetOutputInferDataType(value_node, output_idx);
}
......@@ -421,19 +422,6 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(node_value);
if (node_value->isa<Tensor>()) {
AssignValueNodeTensor(value_node, node_value, 0);
} else if (node_value->isa<ValueTuple>()) {
auto value_tuple = node_value->cast<ValueTuplePtr>();
if (value_tuple == nullptr) {
MS_LOG(WARNING) << "value_tuple is null";
continue;
}
size_t i = 0;
auto value_list = value_tuple->value();
for (auto value_ptr : value_list) {
if (value_ptr->isa<Tensor>()) {
AssignValueNodeTensor(value_node, value_ptr, i++);
}
}
} else if (node_value->isa<StringImm>()) {
auto value = GetValue<std::string>(node_value);
size_t tensor_size = value.size();
......
......@@ -59,30 +59,20 @@ size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); }
size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); }
bool KernelBuildInfo::GetInputReshapeType(size_t input_index, std::vector<Axis> *reshape_type) const {
MS_EXCEPTION_IF_NULL(reshape_type);
reshape_type->clear();
std::vector<Axis> KernelBuildInfo::GetInputReshapeType(size_t input_index) const {
if (input_index >= input_reshape_type_.size()) {
MS_LOG(WARNING) << "The index [" << input_index << "] is exceed the number of input node size "
<< input_reshape_type_.size();
return false;
MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size "
<< input_reshape_type_.size();
}
(void)std::copy(input_reshape_type_[input_index].begin(), input_reshape_type_[input_index].end(),
std::inserter(*reshape_type, (*reshape_type).begin()));
return true;
return input_reshape_type_[input_index];
}
bool KernelBuildInfo::GetOutputReshapeType(size_t output_index, std::vector<Axis> *reshape_type) const {
MS_EXCEPTION_IF_NULL(reshape_type);
reshape_type->clear();
std::vector<Axis> KernelBuildInfo::GetOutputReshapeType(size_t output_index) const {
if (output_index >= output_reshape_type_.size()) {
MS_LOG(WARNING) << "The index [" << output_index << "] is exceed the number of output node dixr"
<< output_reshape_type_.size();
return false;
MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of output node size "
<< output_reshape_type_.size();
}
(void)std::copy(output_reshape_type_[output_index].begin(), output_reshape_type_[output_index].end(),
std::inserter(*reshape_type, (*reshape_type).begin()));
return true;
return output_reshape_type_[output_index];
}
std::string KernelBuildInfo::ToString() const {
......@@ -115,6 +105,10 @@ bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const {
return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_);
}
bool KernelBuildInfo::IsInputDefaultPadding() const { return output_reshape_type_.empty(); }
bool KernelBuildInfo::IsOutputDefaultPadding() const { return input_reshape_type_.empty(); }
void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) {
MS_EXCEPTION_IF_NULL(kernel_build_info_);
kernel_build_info_->kernel_type_ = kernel_type;
......
......@@ -54,9 +54,13 @@ class KernelBuildInfo {
TypeId GetOutputDeviceType(size_t output_index) const;
bool GetInputReshapeType(size_t input_index, std::vector<Axis> *reshape_type) const;
std::vector<Axis> GetInputReshapeType(size_t input_index) const;
bool GetOutputReshapeType(size_t input_index, std::vector<Axis> *reshape_type) const;
bool IsInputDefaultPadding() const;
bool IsOutputDefaultPadding() const;
std::vector<Axis> GetOutputReshapeType(size_t input_index) const;
std::vector<std::string> GetAllInputFormats() const;
......
......@@ -18,20 +18,21 @@
#include <set>
#include "common/trans.h"
#include "common/utils.h"
#include "utils/utils.h"
#include "device/kernel_info.h"
#include "kernel/oplib/oplib.h"
#include "operator/ops.h"
#include "session/anf_runtime_algorithm.h"
#include "session/kernel_graph.h"
#include "utils/context/ms_context.h"
#include "utils/utils.h"
namespace mindspore {
namespace opt {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
namespace {
kernel::KernelBuildInfoPtr CreateKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &node, const kernel::KernelBuildInfo ori_build_info) {
kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &node,
const kernel::KernelBuildInfo ori_build_info) {
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({input_format});
builder.SetOutputsFormat({output_format});
......@@ -54,9 +55,11 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
MS_EXCEPTION_IF_NULL(trans_node);
if (need_padding) {
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
{trans::TransShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0))},
trans_node.get());
// if need padding we should set the transdata node's shape to the padding shape
AnfAlgo::SetOutputInferTypeAndShape(
{AnfAlgo::GetOutputInferDataType(input, 0)},
{trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), AnfAlgo::GetOutputReshapeType(input, 0))},
trans_node.get());
} else {
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
{AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get());
......@@ -92,9 +95,11 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i
AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(node);
bool padding_flag = false;
auto input_node = AnfAlgo::GetInputNode(node, index);
if (input_node->isa<ValueNode>() || input_node->isa<Parameter>()) {
auto node_with_index = AnfAlgo::VisitKernel(input_node, 0);
MS_EXCEPTION_IF_NULL(node_with_index.first);
auto real_input = node_with_index.first;
if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) {
input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select);
MS_EXCEPTION_IF_NULL(input_node);
AnfAlgo::SetNodeInput(node, input_node, index);
......@@ -106,33 +111,11 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
std::string origin_format = kOpFormat_DEFAULT;
std::string dest_format = AnfAlgo::GetInputFormat(node, index);
if (dest_format == kOpFormat_C1HWNCoC0) {
padding_flag = (origin_shape.size() != kShape4dDims);
AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag,
origin_format, dest_format, kTransDataOpName, true);
MS_EXCEPTION_IF_NULL(replace_input);
return replace_input;
}
if (dest_format == kOpFormat_NC1HWC0 && origin_shape.size() > 1) {
padding_flag = (origin_shape.size() != kShape4dDims);
AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag,
origin_format, dest_format, kTransDataOpName, true);
MS_EXCEPTION_IF_NULL(replace_input);
MS_LOG(DEBUG) << "Inserted Translate45, index: " << index;
return replace_input;
} else if (dest_format == kOpFormat_FRAC_NZ) {
AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag,
origin_format, dest_format, kTransDataOpName, true);
MS_EXCEPTION_IF_NULL(replace_input);
MS_LOG(DEBUG) << "inserted translate " << AnfAlgo::GetInputFormat(node, index) << " To default, index: " << index;
return replace_input;
} else if (dest_format == kOpFormat_FRAC_Z && !origin_shape.empty()) {
padding_flag = (origin_shape.size() != kShape4dDims);
AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag,
origin_format, dest_format, kTransDataOpName, true);
MS_EXCEPTION_IF_NULL(replace_input);
MS_LOG(DEBUG) << "Inserted Translate45, index: " << index;
return replace_input;
if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
<< " To DefaultFormat , index: " << index;
return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, origin_format, dest_format, kTransDataOpName,
true);
}
return input_node;
}
......@@ -140,7 +123,6 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(node);
bool padding_flag = false;
std::string output_format;
std::vector<size_t> origin_shape;
if (!AnfAlgo::IsRealKernel(node)) {
......@@ -156,46 +138,14 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
}
std::string origin_format = output_format;
std::string dest_format = kOpFormat_DEFAULT;
if (output_format == kOpFormat_C1HWNCoC0) {
padding_flag = (origin_shape.size() != kShape4dDims);
AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format,
dest_format, kTransDataOpName, false);
MS_EXCEPTION_IF_NULL(replace_input);
return replace_input;
}
if (output_format == kOpFormat_NC1HWC0 && origin_shape.size() > 1) {
padding_flag = (origin_shape.size() != kShape4dDims);
AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format,
dest_format, kTransDataOpName, false);
MS_EXCEPTION_IF_NULL(replace_output);
MS_LOG(DEBUG) << "Inserted Trans54";
return replace_output;
} else if (output_format == kOpFormat_FRAC_NZ) {
AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format,
dest_format, kTransDataOpName, false);
MS_EXCEPTION_IF_NULL(replace_output);
MS_LOG(DEBUG) << "Inserted Translate " << output_format << " To default, index: 0";
return replace_output;
} else if (output_format == kOpFormat_FRAC_Z && !origin_shape.empty()) {
padding_flag = (origin_shape.size() != kShape4dDims);
AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format,
dest_format, kTransDataOpName, false);
MS_EXCEPTION_IF_NULL(replace_output);
MS_LOG(DEBUG) << "Inserted Trans54";
return replace_output;
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0";
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, origin_format, dest_format, kTransDataOpName,
false);
}
return node;
}
void GetTransDataInputFormat(const AnfNodePtr &node, size_t idx, std::string *input_format) {
MS_EXCEPTION_IF_NULL(input_format);
if (AnfAlgo::IsRealKernel(node)) {
*input_format = AnfAlgo::GetOutputFormat(node, idx);
} else {
*input_format = AnfAlgo::GetPrevNodeOutputFormat(node, 0);
}
}
AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(func_graph);
......@@ -203,46 +153,17 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
std::vector<AnfNodePtr> make_tuple_inputs;
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) {
bool padding_flag = false;
std::string output_format;
GetTransDataInputFormat(node, output_idx, &output_format);
std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx);
if (output_format == kOpFormat_NC1KHKWHWC0) {
MS_LOG(EXCEPTION) << "got the hw format" << output_format << " when insert the transdata node "
MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node "
<< node->DebugString();
}
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
std::string origin_format = output_format;
std::string dest_format = kOpFormat_DEFAULT;
if (output_format == kOpFormat_C1HWNCoC0) {
padding_flag = (origin_shape.size() != kShape4dDims);
AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag,
origin_format, dest_format, kTransDataOpName, false);
MS_EXCEPTION_IF_NULL(replace_input);
return replace_input;
}
if (output_format == kOpFormat_NC1HWC0 && origin_shape.size() > 1) {
padding_flag = (origin_shape.size() != kShape4dDims);
// Insert a 5to4 trans op.
AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag,
origin_format, dest_format, kTransDataOpName, false);
MS_EXCEPTION_IF_NULL(replace_output);
MS_LOG(DEBUG) << "Inserted Translate54";
make_tuple_inputs.push_back(replace_output);
} else if (output_format == kOpFormat_FRAC_NZ) {
AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag,
origin_format, dest_format, kTransDataOpName, false);
MS_EXCEPTION_IF_NULL(replace_output);
MS_LOG(DEBUG) << "Inserted Translate " << output_format << " To default, index: " << output_idx;
make_tuple_inputs.push_back(replace_output);
} else if (output_format == kOpFormat_FRAC_Z && !origin_shape.empty()) {
padding_flag = (origin_shape.size() != kShape4dDims);
AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag,
origin_format, dest_format, kTransDataOpName, false);
MS_EXCEPTION_IF_NULL(replace_output);
MS_LOG(DEBUG) << "Inserted Translate54";
make_tuple_inputs.push_back(replace_output);
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, output_format,
dest_format, kTransDataOpName, false));
} else {
// No need insert trans op.
make_tuple_inputs.push_back(tuple_getitem);
......@@ -253,16 +174,17 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
}
} // namespace
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select, size_t insert_index, const bool padding_flag,
const KernelSelectPtr &kernel_select, size_t insert_index,
const std::string &origin_format, const std::string &dest_format,
const std::string &op_name, bool is_insert_input) {
AnfNodePtr trans_node = nullptr;
AnfNodePtr input_node = nullptr;
AnfNodePtr input_node = node;
AnfNodePtr trans_data = nullptr;
MS_EXCEPTION_IF_NULL(node);
if (origin_format.empty() || dest_format.empty()) {
MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format;
}
// if insert transdata for input we need to change the input
if (is_insert_input) {
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode";
......@@ -270,29 +192,34 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
input_node = AnfAlgo::GetInputNode(cnode, insert_index);
if (padding_flag) {
auto padd_shape = trans::TransShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0));
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padd_shape);
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, padding_flag, op_name);
} else {
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, padding_flag, op_name);
}
}
bool need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) &&
op_name == kTransDataOpName);
if (!need_padding) {
// don't need padding insert transdata only
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name);
trans_node = trans_data;
} else if (is_insert_input) {
// if need padding & is input need insert a transdata
// reshape[padding shape] -> transdata[padding shape] -> node
auto padding_shape =
trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0));
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, op_name);
trans_node = trans_data;
} else {
input_node = node;
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, padding_flag, op_name);
if (padding_flag) {
auto reshape_node =
CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0));
trans_node = reshape_node;
} else {
trans_node = trans_data;
}
// if need padding & is output need insert a transdata
// node -> transdata[padding shape] -> reshape[ori_shape]
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name);
auto reshape_node =
CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0));
trans_node = reshape_node;
}
// refresh the transdata's format to ori format & dst format
MS_EXCEPTION_IF_NULL(trans_data);
MS_EXCEPTION_IF_NULL(trans_data->kernel_info());
auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info();
auto kernel_build_info = CreateKernelBuildInfo(origin_format, dest_format, input_node, *trans_ori_build_info);
auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, *trans_ori_build_info);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get());
return trans_node;
}
......@@ -376,7 +303,17 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
TypeId origin_type;
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
if (!AnfAlgo::IsFeatureMapInput(cnode, input_index)) {
auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0);
auto is_weight_boundary = [](const AnfNodePtr &node) -> bool {
if (node->isa<ValueNode>()) {
return true;
} else if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
return true;
}
return false;
};
auto real_input_node = kernel_with_index.first;
if (is_weight_boundary(real_input_node)) {
// weight
origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index);
} else {
......
......@@ -48,7 +48,7 @@ class KernelQuery {
using KernelQueryPtr = std::shared_ptr<KernelQuery>;
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select, size_t insert_index, bool padding_flag,
const KernelSelectPtr &kernel_select, size_t insert_index,
const std::string &origin_format, const std::string &dest_format,
const std::string &op_name, bool is_insert_input);
......
......@@ -105,10 +105,8 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
// insert trans
if (origin_format != cur_format) {
auto kernel_select = std::make_shared<KernelSelect>();
bool need_padding =
(cur_format == kOpFormat_NC1HWC0 && AnfAlgo::GetOutputInferShape(final_node, 0).size() != kShape4dDims);
final_node = AddTransOpNodeToGraph(func_graph, final_node, kernel_select, 0, need_padding, cur_format,
origin_format, kTransDataOpName, false);
final_node = AddTransOpNodeToGraph(func_graph, final_node, kernel_select, 0, cur_format, origin_format,
kTransDataOpName, false);
final_index = 0;
MS_EXCEPTION_IF_NULL(final_node);
MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/transdata_split.h"
#include <set>
#include "pre_activate/ascend/ascend_helper.h"
#include "session/anf_runtime_algorithm.h"
#include "debug/anf_ir_dump.h"
namespace mindspore {
namespace opt {
const std::set<std::pair<string, string>> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW},
{kOpFormat_NCHW, kOpFormat_C1HWNCoC0},
{kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT},
{kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}};
bool TransDataSplit::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
bool changed = false;
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) {
CheckCNodeInputSize(node->cast<CNodePtr>(), kBackendTransDataInputNum);
if (IsFormatInvaild(node)) {
changed = DoSplit(func_graph, node);
}
}
}
return changed;
}
bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input_format = AnfAlgo::GetInputFormat(node, 0);
auto output_format = AnfAlgo::GetOutputFormat(node, 0);
auto format_pair = std::make_pair(input_format, output_format);
return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end();
}
// transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW)
bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input_node = node->cast<CNodePtr>()->input(1);
MS_EXCEPTION_IF_NULL(input_node);
auto input_format = AnfAlgo::GetInputFormat(node, 0);
auto output_format = AnfAlgo::GetOutputFormat(node, 0);
AnfNodePtr new_transdata_node = nullptr;
AnfNodePtr new_transpose_node = nullptr;
AnfNodePtr new_replace_node = nullptr;
// if output_format=default transdata need split transdata->transpose else transpose->transdata
if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) {
// trans input_format to hwcn
new_transdata_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, false, input_format, kOpFormat_HWCN,
kTransDataOpName, true);
// trans hwcn to default_format
new_transpose_node = AddTransOpNodeToGraph(func_graph, new_transdata_node, kernel_select_, 0, false, kOpFormat_HWCN,
output_format, prim::kPrimTranspose->name(), false);
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{3, 2, 0, 1}), new_transpose_node);
new_replace_node = new_transpose_node;
} else {
// trans default to hwcn
new_transpose_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, false, input_format, kOpFormat_HWCN,
prim::kPrimTranspose->name(), true);
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node);
// trans hwcn to output_format
new_transdata_node = AddTransOpNodeToGraph(func_graph, new_transpose_node, kernel_select_, 0, false, kOpFormat_HWCN,
output_format, kTransDataOpName, false);
new_replace_node = new_transdata_node;
}
FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(func_graph);
if (!manager->Replace(node, new_replace_node)) {
MS_LOG(EXCEPTION) << "manager replace node failed";
}
MS_LOG(INFO) << "transdata node:" << cnode->DebugString() << "split success.";
return true;
}
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/transdata_split.h"
#include <set>
#include "pre_activate/ascend/ascend_helper.h"
#include "session/anf_runtime_algorithm.h"
#include "debug/anf_ir_dump.h"
namespace mindspore {
namespace opt {
const std::set<std::pair<string, string>> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW},
{kOpFormat_NCHW, kOpFormat_C1HWNCoC0},
{kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT},
{kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}};
bool TransDataSplit::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
bool changed = false;
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) {
CheckCNodeInputSize(node->cast<CNodePtr>(), kBackendTransDataInputNum);
if (IsFormatInvaild(node)) {
changed = DoSplit(func_graph, node);
}
}
}
return changed;
}
bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input_format = AnfAlgo::GetInputFormat(node, 0);
auto output_format = AnfAlgo::GetOutputFormat(node, 0);
auto format_pair = std::make_pair(input_format, output_format);
return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end();
}
// transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW)
bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input_node = node->cast<CNodePtr>()->input(1);
MS_EXCEPTION_IF_NULL(input_node);
auto input_format = AnfAlgo::GetInputFormat(node, 0);
auto output_format = AnfAlgo::GetOutputFormat(node, 0);
AnfNodePtr new_transdata_node = nullptr;
AnfNodePtr new_transpose_node = nullptr;
AnfNodePtr new_replace_node = nullptr;
// if output_format=default transdata need split transdata->transpose else transpose->transdata
if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) {
// trans input_format to hwcn
new_transdata_node =
AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN, kTransDataOpName, true);
// trans hwcn to default_format
new_transpose_node = AddTransOpNodeToGraph(func_graph, new_transdata_node, kernel_select_, 0, kOpFormat_HWCN,
output_format, prim::kPrimTranspose->name(), false);
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{3, 2, 0, 1}), new_transpose_node);
new_replace_node = new_transpose_node;
} else {
// trans default to hwcn
new_transpose_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN,
prim::kPrimTranspose->name(), true);
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node);
// trans hwcn to output_format
new_transdata_node = AddTransOpNodeToGraph(func_graph, new_transpose_node, kernel_select_, 0, kOpFormat_HWCN,
output_format, kTransDataOpName, false);
new_replace_node = new_transdata_node;
}
FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(func_graph);
if (!manager->Replace(node, new_replace_node)) {
MS_LOG(EXCEPTION) << "Manager replace node failed";
}
MS_LOG(INFO) << "Transdata node:" << cnode->DebugString() << "split success.";
return true;
}
} // namespace opt
} // namespace mindspore
......@@ -289,6 +289,11 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (output_idx > GetOutputTensorNum(node)) {
MS_LOG(EXCEPTION) << "Output index:" << output_idx
<< " is out of the node output range :" << GetOutputTensorNum(node) << " #node ["
<< node->DebugString() << "]";
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
......@@ -298,6 +303,11 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t
std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(node);
if (input_idx > GetInputTensorNum(node)) {
MS_LOG(EXCEPTION) << "Input index :" << input_idx
<< " is out of the number node Input range :" << GetInputTensorNum(node) << "#node ["
<< node->DebugString() << "]";
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
......@@ -362,62 +372,60 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNo
std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) {
auto format = GetOutputFormat(node, output_idx);
auto infer_shape = GetOutputInferShape(node, output_idx);
// if format is default_format or NC1KHKWHWC0,device shape = original shape
if (format == kOpFormat_DEFAULT || format == kOpFormat_NC1KHKWHWC0) {
return infer_shape;
}
// scalar shape
if (infer_shape.empty()) {
return infer_shape;
}
if (format == kOpFormat_FRAC_NZ) {
return trans::TransShapeToDevice(infer_shape, format);
// if format is default_format or NC1KHKWHWC0,device shape = original shape
if (trans::IsNeedPadding(format, infer_shape.size())) {
infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx));
}
// else trans infer shape to 4d and then calculate device shape
return trans::TransShapeToDevice(trans::TransShapeTo4d(infer_shape), format);
return trans::TransShapeToDevice(infer_shape, format);
}
std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) {
auto format = GetInputFormat(node, input_idx);
auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx);
// if format is default_format or NC1KHKWHWC0,device shape = original shape
if (format == kOpFormat_DEFAULT || format == kOpFormat_NC1KHKWHWC0) {
return infer_shape;
}
if (infer_shape.empty()) {
return infer_shape;
}
if (format == kOpFormat_FRAC_NZ) {
return trans::TransShapeToDevice(infer_shape, format);
// if format is default_format or NC1KHKWHWC0,device shape = original shape
if (trans::IsNeedPadding(format, infer_shape.size())) {
infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx));
}
// else trans infer shape to 4d and then calculate device shape
return trans::TransShapeToDevice(trans::TransShapeTo4d(infer_shape), format);
return trans::TransShapeToDevice(infer_shape, format);
}
std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(node);
if (input_idx > GetInputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index:" << input_idx
<< " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node["
<< node->DebugString() << "]";
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
std::vector<kernel::Axis> result;
if (!build_info->GetInputReshapeType(input_idx, &result)) {
MS_LOG(EXCEPTION) << "Failed to get the node's[ " << node->DebugString() << "] reshape type !";
if (build_info->IsInputDefaultPadding()) {
return {};
}
return result;
return build_info->GetInputReshapeType(input_idx);
}
std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (output_idx > GetOutputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
<< GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]";
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
std::vector<kernel::Axis> result;
if (!build_info->GetOutputReshapeType(output_idx, &result)) {
MS_LOG(EXCEPTION) << "Failed to get the node's[ " << node->DebugString() << "] reshape type !";
if (build_info->IsOutputDefaultPadding()) {
return {};
}
return result;
return build_info->GetOutputReshapeType(output_idx);
}
TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
......@@ -463,6 +471,10 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &nod
TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (output_idx > GetOutputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
<< GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]";
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
......@@ -472,6 +484,10 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size
TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(node);
if (input_idx > GetInputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ "
<< GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]";
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
......@@ -496,11 +512,15 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node,
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node";
}
}
if (output_idx > GetOutputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
<< GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]";
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto addr = kernel_info->GetOutputAddr(output_idx);
if (addr == nullptr) {
MS_LOG(EXCEPTION) << "output_idx " << output_idx << " of node " << node->DebugString()
MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
<< " output addr is not exist";
}
return addr;
......@@ -517,11 +537,15 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node.";
}
}
if (output_idx > GetOutputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
<< GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]";
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto addr = kernel_info->GetMutableOutputAddr(output_idx);
if (addr == nullptr) {
MS_LOG(EXCEPTION) << "output_idx" << output_idx << " of node " << node->DebugString()
MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString()
<< " output addr is not exist";
}
return addr;
......@@ -530,6 +554,10 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
// get output device addr of anf_node
bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (output_idx > GetOutputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
<< GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]";
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->OutputAddrExist(output_idx);
......@@ -769,22 +797,24 @@ AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index)
return node->input(get_input_index);
}
bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<ValueNode>()) {
return false;
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->is_feature_map();
}
bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) {
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature";
MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map";
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input_node = cnode->input(input_index + 1);
auto node_with_index = VisitKernel(input_node, 0);
MS_EXCEPTION_IF_NULL(node_with_index.first);
if (node_with_index.first->isa<ValueNode>()) {
return false;
}
if (node_with_index.first->isa<Parameter>()) {
return !AnfAlgo::IsParameterWeight(node_with_index.first->cast<ParameterPtr>());
}
return true;
return IsFeatureMapOutput(input_node);
}
size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) {
......
......@@ -101,7 +101,9 @@ class AnfRuntimeAlgorithm {
static std::vector<size_t> GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx);
// get input shapes which will built and run in device
static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx);
// Get Input Padding Axis
static std::vector<kernel::Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
// Get Output Padding Axis
static std::vector<kernel::Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
// get output data type inferred by ME of anf node
static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx);
......@@ -165,6 +167,9 @@ class AnfRuntimeAlgorithm {
// get graph id
static uint32_t GetGraphId(const AnfNode *node);
static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index);
// charge if the node's output is a feature map output
static bool IsFeatureMapOutput(const AnfNodePtr &node);
// charge if the node's input is from a feature map output
static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index);
// get real input index for some tbe ops which input order is different between me and tbe impl
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
......
......@@ -18,6 +18,7 @@
#include "operator/ops.h"
#include "ir/meta_tensor.h"
#include "ir/anf.h"
#include "common/trans.h"
#include "device/kernel_runtime.h"
#include "device/ascend/kernel_select_ascend.h"
#include "device/ascend/kernel_build_ascend.h"
......@@ -730,8 +731,8 @@ void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor
size_t tensor_size = front_tensor->data().nbytes();
auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0);
MS_EXCEPTION_IF_NULL(addr);
if (!addr->SyncHostToDevice(front_tensor->shape(), tensor_size, front_tensor->data_type(),
front_tensor->data_c(false))) {
if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size,
front_tensor->data_type(), front_tensor->data_c(false))) {
MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!";
}
MS_LOG(INFO) << "Finish!";
......
......@@ -143,6 +143,12 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
// create kernel_info from new parameter
auto kernel_info = std::make_shared<device::KernelInfo>();
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
// then the node's output is a feature map output
if (inputs.size() == 1 || std::any_of(inputs.begin() + 1, inputs.end(),
[&](const AnfNodePtr &node) { return AnfAlgo::IsFeatureMapOutput(node); })) {
kernel_info->SetFeatureMapFlag(true);
}
cnode->set_kernel_info(kernel_info);
AnfAlgo::SetGraphId(graph_id_, cnode.get());
return cnode;
......@@ -162,22 +168,26 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) {
ParameterPtr new_parameter = add_parameter();
MS_EXCEPTION_IF_NULL(new_parameter);
// create kernel_info form new parameter
auto kernel_info = std::make_shared<device::KernelInfo>();
size_t output_tensor_num = 1;
// if use default parameter = nullptr,it remarks create a new parameter from no parameter
if (parameter == nullptr) {
new_parameter->set_abstract(std::make_shared<abstract::AbstractNone>());
kernel_info->SetFeatureMapFlag(true);
} else {
// if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
new_parameter->set_abstract(parameter->abstract());
new_parameter->set_name(parameter->name());
if (parameter->has_default()) {
if (AnfAlgo::IsParameterWeight(parameter)) {
new_parameter->set_default_param(parameter->default_param());
kernel_info->SetFeatureMapFlag(false);
} else {
kernel_info->SetFeatureMapFlag(true);
}
// if output is a tuple tensor,now can use for loop to handle tuple tensor
output_tensor_num = AnfAlgo::GetOutputTensorNum(parameter);
}
// create kernel_info form new parameter
auto kernel_info = std::make_shared<device::KernelInfo>();
new_parameter->set_kernel_info(kernel_info);
// create kernel_build_info for new parameter
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
......@@ -217,6 +227,7 @@ std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNo
AddValueNodeToGraph(new_value_node);
auto kernel_info = std::make_shared<device::KernelInfo>();
new_value_node->set_kernel_info(kernel_info);
kernel_info->SetFeatureMapFlag(false);
// create kernel_build_info for new value node
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// set the format of value_node to DEFAULT_FORMAT
......@@ -240,6 +251,7 @@ ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
new_value_node->set_abstract(value_node->abstract());
// create kernel_info fo new value node
auto kernel_info = std::make_shared<device::KernelInfo>();
kernel_info->SetFeatureMapFlag(false);
new_value_node->set_kernel_info(kernel_info);
// create kernel_build_info for new value node
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
......
......@@ -20,6 +20,7 @@
#include "pipeline/parse/data_converter.h"
#include "ir/manager.h"
#include "operator/ops.h"
#include "common/trans.h"
#include "utils/context/ms_context.h"
#include "utils/config_manager.h"
#include "session/anf_runtime_algorithm.h"
......@@ -124,7 +125,8 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->enable_pynative_infer()) {
tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index));
} else if (!address->SyncDeviceToHost(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(),
} else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c(true))) {
MS_LOG(INFO) << "output sync device to host error!!!";
tensor->set_dirty(false);
......@@ -369,7 +371,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{input_tensor->device_address()->type_id()});
}
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
// construct abstract of parameter
// ftruct abstract of parameter
auto abstract = std::make_shared<abstract::AbstractTensor>(input_tensor);
param->set_abstract(abstract);
return param;
......@@ -548,7 +550,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
if (need_sync) {
tensor->set_device_address(device_address);
MS_EXCEPTION_IF_NULL(device_address);
if (!device_address->SyncHostToDevice(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(),
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c(false))) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
......@@ -620,8 +623,8 @@ void SessionBasic::Summary(KernelGraph *graph) {
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
MS_EXCEPTION_IF_NULL(address);
if (!address->SyncDeviceToHost(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c(true))) {
if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()),
tensor->data_type(), tensor->data_c(true))) {
MS_LOG(ERROR) << "Failed to sync output from device to host.";
}
tensor->set_dirty(false);
......
......@@ -197,8 +197,8 @@ const std::set<std::string> kOptOperatorSet = {
kApplyRMSPropOpName,
};
const std::set<std::string> kSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0};
const std::set<std::string> kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0};
static inline void ChangeFileMode(const std::string& file_name, mode_t mode) {
if (access(file_name.c_str(), F_OK) != 0) {
......
......@@ -80,6 +80,8 @@ TEST_F(TestHWLayerNormBetaGammaBackpropFusion, layernorm_beta_gamma_backprop_fus
builder1.SetOutputsDeviceType({kNumberTypeFloat32});
cast0->set_kernel_info(std::make_shared<device::KernelInfo>());
cast1->set_kernel_info(std::make_shared<device::KernelInfo>());
cast0->set_abstract(x_abstract);
cast1->set_abstract(x_abstract);
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast0.get());
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast1.get());
......
......@@ -211,8 +211,8 @@ TEST_F(AnfRuntimeAlgorithmTest, EraseNodeAttr) {
TEST_F(AnfRuntimeAlgorithmTest, GetInputTensorNum) {
auto kernel_graph = std::make_shared<KernelGraph>();
// test cnode node
auto parameter_one = kernel_graph->add_parameter();
auto parameter_two = kernel_graph->add_parameter();
auto parameter_one = kernel_graph->NewParameter();
auto parameter_two = kernel_graph->NewParameter();
std::vector<AnfNodePtr> add_inputs{NewValueNode(prim::kPrimTensorAdd), parameter_one, parameter_two};
auto add = kernel_graph->NewCNode(add_inputs);
EXPECT_EQ(AnfAlgo::GetInputTensorNum(add), 2);
......@@ -247,9 +247,11 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputTensorNum) {
TEST_F(AnfRuntimeAlgorithmTest, GetOutputFormat) {
auto kernel_graph = std::make_shared<KernelGraph>();
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimTensorAdd));
std::vector<AnfNodePtr> inputs = {NewValueNode(prim::kPrimTensorAdd), kernel_graph->NewParameter(),
kernel_graph->NewParameter()};
auto add = kernel_graph->NewCNode(inputs);
std::vector<size_t> shape = {1, 2, 3, 4};
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, {shape, shape}, add.get());
MS_EXCEPTION_IF_NULL(add);
add->set_kernel_info(std::make_shared<KernelInfo>());
auto d_kernel_info = add->kernel_info();
......@@ -266,8 +268,8 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputFormat) {
TEST_F(AnfRuntimeAlgorithmTest, GetInputFormat) {
auto kernel_graph = std::make_shared<KernelGraph>();
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimTensorAdd));
std::vector<AnfNodePtr> inputs = {NewValueNode(prim::kPrimTensorAdd), kernel_graph->NewParameter(),
kernel_graph->NewParameter()};
auto add = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(add);
add->set_kernel_info(std::make_shared<KernelInfo>());
......@@ -345,7 +347,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputInferShape) {
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
// test parameter node as input
auto parameter_node = kernel_graph->add_parameter();
auto parameter_node = kernel_graph->NewParameter();
MS_EXCEPTION_IF_NULL(parameter_node);
parameter_node->set_abstract(x_abstract);
EXPECT_THROW(AnfAlgo::GetPrevNodeOutputInferShape(parameter_node, 0), std::runtime_error);
......@@ -387,13 +389,13 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceShape) {
auto kernel_graph = std::make_shared<KernelGraph>();
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
auto parameter_one = kernel_graph->add_parameter();
auto parameter_one = kernel_graph->NewParameter();
MS_EXCEPTION_IF_NULL(parameter_one);
parameter_one->set_abstract(x_abstract);
auto parameter_two = kernel_graph->add_parameter();
auto parameter_two = kernel_graph->NewParameter();
MS_EXCEPTION_IF_NULL(parameter_two);
parameter_two->set_abstract(x_abstract);
auto parameter_third = kernel_graph->add_parameter();
auto parameter_third = kernel_graph->NewParameter();
MS_EXCEPTION_IF_NULL(parameter_third);
parameter_third->set_abstract(x_abstract);
// test cnode as input
......@@ -466,8 +468,8 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceDataTypeTest) {
TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceDataTypeTest) {
auto kernel_graph = std::make_shared<KernelGraph>();
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimTensorAdd));
std::vector<AnfNodePtr> inputs = {NewValueNode(prim::kPrimTensorAdd), kernel_graph->NewParameter(),
kernel_graph->NewParameter()};
auto add = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(add);
add->set_kernel_info(std::make_shared<KernelInfo>());
......
......@@ -140,11 +140,11 @@ TEST_F(KernelGraphTest, SetExecOrderByDefault) {
std::vector<int> shape = {2, 32, 224, 224};
auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape);
auto x_parameter = kernel_graph->add_parameter();
auto x_parameter = kernel_graph->NewParameter();
MS_EXCEPTION_IF_NULL(x_parameter);
x_parameter->set_name("x_parameter");
x_parameter->set_abstract(abstract);
auto y_parameter = kernel_graph->add_parameter();
auto y_parameter = kernel_graph->NewParameter();
MS_EXCEPTION_IF_NULL(y_parameter);
y_parameter->set_name("y_parameter");
y_parameter->set_abstract(abstract);
......@@ -153,7 +153,7 @@ TEST_F(KernelGraphTest, SetExecOrderByDefault) {
MS_EXCEPTION_IF_NULL(add);
add->set_abstract(abstract);
auto z_parameter = kernel_graph->add_parameter();
auto z_parameter = kernel_graph->NewParameter();
MS_EXCEPTION_IF_NULL(z_parameter);
z_parameter->set_name("z_parameter");
z_parameter->set_abstract(abstract);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册