未验证 提交 01c45bfa 编写于 作者: J Juncheng 提交者: GitHub

Make class Tensor abstract (#3757)

* Make class Tensor abstract

* revert
上级 7b128f36
......@@ -20,51 +20,11 @@ namespace oneflow {
namespace user_op {
Tensor::Tensor(Blob* blob) {
dptr_ = blob->ForceMutDptr();
shape_ = blob->shape();
blob_access_checker_ = blob->blob_access_checker();
if (blob->ForceMutShapeView()) {
mut_shape_.reset(new MutShapeView(*blob->ForceMutShapeView()));
} else {
mut_shape_.reset();
}
data_type_ = blob->data_type();
mem_case_ = &(blob->mem_case());
}
void Tensor::header_access_check() { this->blob_access_checker_->CheckHeaderMutable(); }
void Tensor::body_access_check() { this->blob_access_checker_->CheckBodyMutable(); }
void Tensor::CopyWithoutData(const Tensor& rhs) {
dptr_ = rhs.dptr_;
shape_ = rhs.shape_;
if (rhs.mut_shape_) {
mut_shape_.reset(new MutShapeView(*rhs.mut_shape_));
} else {
mut_shape_.reset();
}
data_type_ = rhs.data_type_;
mem_case_ = rhs.mem_case_;
blob_access_checker_ = rhs.blob_access_checker_;
}
Tensor& Tensor::operator=(Tensor&& rhs) {
dptr_ = rhs.dptr_;
shape_ = rhs.shape_;
mut_shape_ = std::move(rhs.mut_shape_);
data_type_ = rhs.data_type_;
mem_case_ = rhs.mem_case_;
blob_access_checker_ = rhs.blob_access_checker_;
return *this;
}
#ifdef WITH_CUDA
template<>
void Tensor::CheckDataType<half>() const {
LOG_IF(FATAL, data_type_ != DataType::kFloat16)
LOG_IF(FATAL, data_type() != DataType::kFloat16)
<< "tensor data_type mismatched. value: kFloat16, template T: half";
}
......
......@@ -19,63 +19,44 @@ limitations under the License.
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/shape_view.h"
#include "oneflow/core/memory/memory_case.pb.h"
namespace oneflow {
class Blob;
class BlobAccessChecker;
namespace user_op {
class Tensor final {
class Tensor {
public:
Tensor(Blob*);
~Tensor() = default;
Tensor(const Tensor& rhs) { this->CopyWithoutData(rhs); }
Tensor(Tensor&& rhs) { *this = std::move(rhs); }
void CopyWithoutData(const Tensor& rhs);
Tensor& operator=(Tensor&& rhs);
const ShapeView& shape() const { return shape_; }
MutShapeView* mut_shape() {
this->header_access_check();
return mut_shape_.get();
}
DataType data_type() const { return data_type_; }
const MemoryCase& mem_case() const { return *mem_case_; }
virtual const ShapeView& shape() const = 0;
virtual MutShapeView* mut_shape() = 0;
virtual DataType data_type() const = 0;
virtual const MemoryCase& mem_case() const = 0;
virtual const void* raw_dptr() const = 0;
virtual void* mut_raw_dptr() = 0;
template<typename T = void>
const T* dptr() const {
CheckDataType<T>();
return static_cast<const T*>(dptr_);
return reinterpret_cast<const T*>(raw_dptr());
}
template<typename T = void>
T* mut_dptr() {
this->body_access_check();
CheckDataType<T>();
return static_cast<T*>(dptr_);
return reinterpret_cast<T*>(mut_raw_dptr());
}
private:
protected:
template<typename T>
void CheckDataType() const {
LOG_IF(FATAL, (std::is_same<T, void>::value == false && std::is_same<T, char>::value == false
&& data_type_ != DataType::kChar && data_type_ != GetDataType<T>::value))
<< "tensor data_type mismatched. value: " << DataType_Name(data_type_)
&& data_type() != DataType::kChar && data_type() != GetDataType<T>::value))
<< "tensor data_type mismatched. value: " << DataType_Name(data_type())
<< ", template T:" << DataType_Name(GetDataType<T>::value);
}
void header_access_check();
void body_access_check();
void* dptr_;
ShapeView shape_;
std::unique_ptr<MutShapeView> mut_shape_;
DataType data_type_;
const MemoryCase* mem_case_;
const BlobAccessChecker* blob_access_checker_;
};
} // namespace user_op
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/kernel/blob_tensor_view.h"
#include "oneflow/core/register/blob.h"
namespace oneflow {
namespace user_op {
BlobTensorView::BlobTensorView(Blob* blob) : blob_(blob) {}
const ShapeView& BlobTensorView::shape() const { return blob_->shape(); }
MutShapeView* BlobTensorView::mut_shape() { return blob_->mut_shape_view(); }
DataType BlobTensorView::data_type() const { return blob_->data_type(); }
const MemoryCase& BlobTensorView::mem_case() const { return blob_->mem_case(); }
const void* BlobTensorView::raw_dptr() const { return blob_->dptr(); }
void* BlobTensorView::mut_raw_dptr() { return blob_->mut_dptr(); }
void BlobTensorView::Reset(Blob* blob) { blob_ = blob; }
} // namespace user_op
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_KERNEL_BLOB_TENSOR_VIEW_H_
#define ONEFLOW_CORE_KERNEL_BLOB_TENSOR_VIEW_H_
#include "oneflow/core/framework/tensor.h"
namespace oneflow {
class Blob;
namespace user_op {
class BlobTensorView final : public Tensor {
public:
explicit BlobTensorView(Blob* blob);
~BlobTensorView() = default;
const ShapeView& shape() const override;
MutShapeView* mut_shape() override;
DataType data_type() const override;
const MemoryCase& mem_case() const override;
const void* raw_dptr() const override;
void* mut_raw_dptr() override;
void Reset(Blob* blob);
private:
Blob* blob_;
};
} // namespace user_op
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_BLOB_TENSOR_VIEW_H_
......@@ -18,6 +18,7 @@ limitations under the License.
#include "oneflow/core/framework/op_kernel.h"
#include "oneflow/core/framework/op_kernel_infer_cache.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/kernel/blob_tensor_view.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/framework/user_op_conf.h"
#include "oneflow/core/framework/user_op_registry_manager.h"
......@@ -26,7 +27,8 @@ limitations under the License.
namespace oneflow {
using Arg2Tensor = HashMap<std::pair<std::string, int32_t>, std::unique_ptr<user_op::Tensor>>;
using Arg2Tensor =
HashMap<std::pair<std::string, int32_t>, std::unique_ptr<user_op::BlobTensorView>>;
using ArgVec = std::vector<std::pair<std::string, int32_t>>;
namespace {
......@@ -100,7 +102,7 @@ class KernelCreateContext final : public user_op::KernelCreateContext {
explicit KernelCreateContext(const KernelConf& kernel_conf)
: user_op_conf_(kernel_conf.op_attribute().op_conf()) {}
const user_op::UserOpConfWrapper& user_op_conf() const { return user_op_conf_; }
const user_op::UserOpConfWrapper& user_op_conf() const override { return user_op_conf_; }
private:
user_op::UserOpConfWrapper user_op_conf_;
......@@ -304,13 +306,13 @@ class UserKernelInferContext final : public user_op::KernelInferContext {
void UpdateArg2Tensor(const std::function<Blob*(const std::string&)>& BnInOp2Blob) {
for (auto& pair : arg2tensor_) {
const auto& arg_pair = pair.first;
std::unique_ptr<user_op::Tensor>* arg_tensor_ptr = &pair.second;
std::unique_ptr<user_op::BlobTensorView>* arg_tensor_ptr = &pair.second;
Blob* blob = BnInOp2Blob(GenRepeatedBn(arg_pair.first, arg_pair.second));
if (blob == nullptr) { continue; }
if (*arg_tensor_ptr) {
*(arg_tensor_ptr->get()) = std::move(user_op::Tensor(blob));
arg_tensor_ptr->get()->Reset(blob);
} else {
arg_tensor_ptr->reset(new user_op::Tensor(blob));
arg_tensor_ptr->reset(new user_op::BlobTensorView(blob));
}
}
}
......@@ -320,9 +322,32 @@ class UserKernelInferContext final : public user_op::KernelInferContext {
UserKernelBaseContext base_ctx_;
UserKernelOpInferContext op_infer_ctx_;
user_op::TensorDescInferFn tensor_desc_infer_fn_;
HashMap<std::pair<std::string, int32_t>, std::unique_ptr<user_op::Tensor>> arg2tensor_;
HashMap<std::pair<std::string, int32_t>, std::unique_ptr<user_op::BlobTensorView>> arg2tensor_;
};
namespace {
struct BnTensorPair {
std::string bn;
std::unique_ptr<user_op::BlobTensorView> tensor;
};
BnTensorPair MakeBnTensorPair(const std::string& bn) {
BnTensorPair pair;
pair.bn = bn;
return pair;
}
BnTensorPair MakeBnTensorPair(const std::string& bn,
std::unique_ptr<user_op::BlobTensorView>&& tensor) {
BnTensorPair pair;
pair.bn = bn;
pair.tensor = std::move(tensor);
return pair;
}
} // namespace
class UserKernelComputeContext final : public user_op::KernelComputeContext {
public:
explicit UserKernelComputeContext(DeviceCtx* device_ctx, const KernelConf& kernel_conf,
......@@ -332,16 +357,18 @@ class UserKernelComputeContext final : public user_op::KernelComputeContext {
device_ctx_(device_ctx),
base_ctx_(std::move(UserKernelBaseContext(kernel_conf, job_desc))) {
auto InitInOrOut = [&](const PbMap<std::string, UserOpConf::ListString>& arg_map) {
for (auto it = arg_map.begin(); it != arg_map.end(); ++it) {
const std::string& arg_name = it->first;
for (int32_t i = 0; i < it->second.s_size(); ++i) {
arg2tensor_.emplace(std::make_pair(arg_name, i), std::unique_ptr<user_op::Tensor>());
for (const auto& it : arg_map) {
const std::string& arg_name = it.first;
for (int32_t i = 0; i < it.second.s_size(); ++i) {
arg2bn_tensor_pair_.emplace(std::make_pair(arg_name, i),
MakeBnTensorPair(GenRepeatedBn(arg_name, i)));
}
}
};
InitInOrOut(kernel_conf.op_attribute().op_conf().user_conf().input());
InitInOrOut(kernel_conf.op_attribute().op_conf().user_conf().output());
arg2tensor_.emplace(std::make_pair("tmp_buffer", 0), std::unique_ptr<user_op::Tensor>());
arg2bn_tensor_pair_.emplace(std::make_pair("tmp_buffer", 0),
MakeBnTensorPair(GenRepeatedBn("tmp_buffer", 0)));
}
~UserKernelComputeContext() = default;
......@@ -351,22 +378,21 @@ class UserKernelComputeContext final : public user_op::KernelComputeContext {
}
user_op::Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
auto it = arg2tensor_.find(std::make_pair(arg_name, index));
if (it == arg2tensor_.end()) { return nullptr; }
return it->second.get();
auto it = arg2bn_tensor_pair_.find(std::make_pair(arg_name, index));
if (it == arg2bn_tensor_pair_.end()) { return nullptr; }
return it->second.tensor.get();
}
DeviceCtx* device_ctx() override { return device_ctx_; }
void UpdateTensorWithCorrBlob(std::function<Blob*(const std::string&)> BnInOp2Blob) {
for (auto& pair : arg2tensor_) {
const auto& arg_pair = pair.first;
std::unique_ptr<user_op::Tensor>* arg_tensor_ptr = &pair.second;
Blob* blob = BnInOp2Blob(GenRepeatedBn(arg_pair.first, arg_pair.second));
void UpdateTensorWithCorrBlob(const std::function<Blob*(const std::string&)>& BnInOp2Blob) {
for (auto& pair : arg2bn_tensor_pair_) {
std::unique_ptr<user_op::BlobTensorView>* arg_tensor_ptr = &pair.second.tensor;
Blob* blob = BnInOp2Blob(pair.second.bn);
if (blob == nullptr) { continue; }
if (*arg_tensor_ptr) {
*(arg_tensor_ptr->get()) = std::move(user_op::Tensor(blob));
arg_tensor_ptr->get()->Reset(blob);
} else {
arg_tensor_ptr->reset(new user_op::Tensor(blob));
arg_tensor_ptr->reset(new user_op::BlobTensorView(blob));
}
}
}
......@@ -380,7 +406,7 @@ class UserKernelComputeContext final : public user_op::KernelComputeContext {
private:
DeviceCtx* device_ctx_;
Arg2Tensor arg2tensor_;
HashMap<std::pair<std::string, int32_t>, BnTensorPair> arg2bn_tensor_pair_;
UserKernelBaseContext base_ctx_;
};
......
......@@ -132,7 +132,7 @@ Maybe<void> FillImageInSummary(const user_op::Tensor& tensor, const std::string&
const int64_t hw = h * w;
const int64_t depth = static_cast<int64_t>(tensor.shape().At(3));
if (tensor.data_type() == DataType::kUInt8) {
auto ith_image = [tensor, hw, depth](int i) {
auto ith_image = [&tensor, hw, depth](int i) {
auto images = tensor.dptr<uint8_t>();
uint8_t* image_i = (uint8_t*)malloc(sizeof(uint8_t) * hw * depth);
memcpy(image_i, images + i * hw * depth, hw * depth);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册