提交 6760d997 编写于 作者: W WilliamLian

add reshape type to tensor

上级 0ff1000b
......@@ -30,14 +30,6 @@ namespace mindspore {
enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AKG_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL };
namespace kernel {
enum Axis : int {
N = 0,
C,
H,
W,
};
// Supported fusion type
enum FusionType {
CONVLUTION = 0,
......
......@@ -22,6 +22,7 @@
#include <string>
#include <utility>
#include "ir/dtype.h"
#include "ir/kernel_info_dev.h"
#include "backend/kernel_compiler/kernel.h"
namespace mindspore {
......
......@@ -406,16 +406,16 @@ void TbeKernelSelect::StringToAxisVector(const std::string &reshape_type_str, st
for (const auto &c : reshape_type_str) {
switch (c) {
case 'N':
reshape_type_vec->push_back(kernel::N);
reshape_type_vec->push_back(N);
break;
case 'C':
reshape_type_vec->push_back(kernel::C);
reshape_type_vec->push_back(C);
break;
case 'H':
reshape_type_vec->push_back(kernel::H);
reshape_type_vec->push_back(H);
break;
case 'W':
reshape_type_vec->push_back(kernel::W);
reshape_type_vec->push_back(W);
break;
default:
MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";
......
......@@ -55,7 +55,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
CNodePtr trans_data = nullptr;
std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0);
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT;
std::vector<kernel::Axis> padding_axis;
std::vector<Axis> padding_axis;
MS_EXCEPTION_IF_NULL(node);
// if insert transdata for input we need to change the input
if (is_insert_input) {
......@@ -170,7 +170,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
}
} // namespace
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type,
const AnfNodePtr &trans_data, const std::vector<Axis> &reshape_type,
const TypeId &type_id) {
MS_EXCEPTION_IF_NULL(trans_data);
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
......
......@@ -86,7 +86,7 @@ class OpFinder {
using OpFinderPtr = std::shared_ptr<OpFinder>;
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {},
const AnfNodePtr &trans_data, const std::vector<Axis> &reshape_type = {},
const TypeId &type_id = kTypeUnknown);
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
......
......@@ -418,7 +418,7 @@ std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_n
return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
}
std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) {
std::vector<Axis> AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) {
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second);
}
......@@ -483,7 +483,7 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n
return trans::TransShapeToDevice(infer_shape, format);
}
std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
std::vector<Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(node);
if (input_idx > GetInputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index:" << input_idx
......@@ -503,7 +503,7 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNode
return build_info->GetInputReshapeType(input_idx);
}
std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
std::vector<Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (output_idx > GetOutputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
......
......@@ -27,6 +27,7 @@
#include "ir/dtype.h"
#include "base/base.h"
#include "ir/primitive.h"
#include "ir/kernel_info_dev.h"
#include "runtime/device/device_address.h"
#include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/kernel_build_info.h"
......@@ -109,7 +110,7 @@ class AnfRuntimeAlgorithm {
// get output format from prev node,input_index is the input index of current node related to prev node
static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
// get reshape_type of from the output of input node.
static std::vector<kernel::Axis> GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx);
static std::vector<Axis> GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx);
// get output shapes inferred by ME from input nodes.
static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx);
// get input shapes inferred by ME from input nodes.
......@@ -119,9 +120,9 @@ class AnfRuntimeAlgorithm {
// get input shapes which will built and run in device
static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx);
// Get Input Padding Axis
static std::vector<kernel::Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
static std::vector<Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
// Get Output Padding Axis
static std::vector<kernel::Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
static std::vector<Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
// get output data type inferred by ME of anf node
static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx);
// get output original data type from prev node,input_index is the input index of current node related to prev node
......
......@@ -66,12 +66,13 @@ tensor::TensorPtr CreateOutputTensor(const AnfNodePtr &node, size_t output_index
if (type_id == kTypeUnknown) {
type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
}
tensor::TensorPtr tensor;
tensor::TensorPtr tensor = nullptr;
std::vector<int> temp_shape;
if (graph->IsUniqueTargetInternalOutput(node, output_index)) {
temp_shape.emplace_back(1);
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
tensor->set_device_address(address);
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
tensor->set_dirty(false);
return tensor;
}
......@@ -86,6 +87,7 @@ tensor::TensorPtr CreateOutputTensor(const AnfNodePtr &node, size_t output_index
graph->AddInternalOutputTensor(node, output_index, tensor);
}
}
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
// if in paynative mode,data only copyed to host when user want to print data
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
......@@ -240,6 +242,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
} else {
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{device_address->format()});
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{device_address->type_id()});
kernel_build_info_builder->SetOutputsReshapeType({input_tensor->padding_type()});
}
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
// construct abstract of parameter
......
......@@ -399,7 +399,7 @@ std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
return shape;
}
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<kernel::Axis> &padding_axis) {
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis) {
if (padding_axis.empty() || shape.size() != padding_axis.size()) {
return PaddingShapeTo4dByDefault(shape);
}
......
......@@ -51,8 +51,7 @@ size_t TypeIdSize(const TypeId data_type);
size_t ShapeSize(const std::vector<size_t> &shape);
size_t CubeSizeByType(const TypeId data_type);
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape,
const std::vector<kernel::Axis> &padding_axis = {});
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis = {});
std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index);
bool IsNeedPadding(const std::string &format, const size_t shape_size);
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format);
......
......@@ -20,6 +20,12 @@
#include <memory>
namespace mindspore {
enum Axis : int {
N = 0,
C,
H,
W,
};
// Interface for device kernel program information.
class KernelInfoDevice {
public:
......
......@@ -384,7 +384,8 @@ Tensor::Tensor(const Tensor &tensor)
data_(tensor.data_),
dirty_(tensor.dirty_),
id_(tensor.id_),
device_sync_(tensor.device_sync_) {}
device_sync_(tensor.device_sync_),
padding_type_(tensor.padding_type()) {}
Tensor::Tensor(const Tensor &tensor, TypeId data_type)
: MetaTensor(data_type, tensor.shape_),
......@@ -392,7 +393,8 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)),
dirty_(tensor.dirty_),
id_(tensor.id_),
device_sync_(tensor.device_sync_) {}
device_sync_(tensor.device_sync_),
padding_type_(tensor.padding_type()) {}
Tensor::Tensor(TypeId data_type, const std::vector<int> &shape, TensorDataPtr data)
: MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {}
......@@ -441,6 +443,7 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) {
device_sync_ = tensor.device_sync_;
data_ = tensor.data_;
id_ = tensor.id_;
padding_type_ = tensor.padding_type_;
}
return *this;
}
......
......@@ -221,6 +221,8 @@ class Tensor : public MetaTensor {
DeviceSyncPtr device_address() const { return device_sync_; }
void set_device_address(const DeviceSyncPtr &device_sync) { device_sync_ = device_sync; }
void set_padding_type(std::vector<Axis> padding_type) { padding_type_ = padding_type; }
std::vector<Axis> padding_type() const { return padding_type_; }
std::string id() const { return id_; }
......@@ -230,6 +232,7 @@ class Tensor : public MetaTensor {
bool dirty_{true};
std::string id_{""};
DeviceSyncPtr device_sync_{nullptr};
std::vector<Axis> padding_type_;
};
using TensorPtr = std::shared_ptr<Tensor>;
using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;
......
......@@ -54,10 +54,10 @@ class A {
std::shared_ptr<int> i;
};
class C : public A {
class Ca : public A {
public:
C() {}
explicit C(signals *sigs) : A(sigs) { printf("conn C:%p\n", this); }
Ca() {}
explicit Ca(signals *sigs) : A(sigs) { printf("conn C:%p\n", this); }
void FuncA(int v1, float v2, std::string str) { printf("C: --%d--%f--%s--\n", v1, v2, str.c_str()); }
};
......@@ -71,13 +71,13 @@ class B : public A {
TEST_F(TestSignal, test_common) {
A objA;
B objB;
C objC;
Ca objC;
Signal<void(int, float, std::string)> signal;
signal.connect(&objA, &A::FuncA);
signal.connect(&objB, &B::FuncA);
signal.connect(&objC, &C::FuncA);
signal.connect(&objC, &Ca::FuncA);
signal(20, 20, "Signal-Slot test");
}
......@@ -85,11 +85,11 @@ TEST_F(TestSignal, test_sigs) {
signals sigs;
A objA(&sigs);
B objB(&sigs);
C objC(&sigs);
Ca objC(&sigs);
sigs.signal.connect(&objA, &A::FuncA);
sigs.signal.connect(&objB, &B::FuncA);
sigs.signal.connect(&objC, &C::FuncA);
sigs.signal.connect(&objC, &Ca::FuncA);
sigs.signal(20, 20, "sigs Signal-Slot test");
}
......@@ -97,7 +97,7 @@ TEST_F(TestSignal, test_sigs_Named) {
signals sigs;
A objA(&sigs);
B objB(&sigs);
C objC(&sigs);
Ca objC(&sigs);
sigs.signal(10, 20, "Signal-Slot test");
std::shared_ptr<Named> a;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册