diff --git a/oneflow/core/framework/tensor.cpp b/oneflow/core/framework/tensor.cpp index 851fbd13681946a0d116462fe3f9ecfab317bbc2..a830fb306960a29a619b2962457fe1b6205171dd 100644 --- a/oneflow/core/framework/tensor.cpp +++ b/oneflow/core/framework/tensor.cpp @@ -20,51 +20,11 @@ namespace oneflow { namespace user_op { -Tensor::Tensor(Blob* blob) { - dptr_ = blob->ForceMutDptr(); - shape_ = blob->shape(); - blob_access_checker_ = blob->blob_access_checker(); - if (blob->ForceMutShapeView()) { - mut_shape_.reset(new MutShapeView(*blob->ForceMutShapeView())); - } else { - mut_shape_.reset(); - } - data_type_ = blob->data_type(); - mem_case_ = &(blob->mem_case()); -} - -void Tensor::header_access_check() { this->blob_access_checker_->CheckHeaderMutable(); } - -void Tensor::body_access_check() { this->blob_access_checker_->CheckBodyMutable(); } - -void Tensor::CopyWithoutData(const Tensor& rhs) { - dptr_ = rhs.dptr_; - shape_ = rhs.shape_; - if (rhs.mut_shape_) { - mut_shape_.reset(new MutShapeView(*rhs.mut_shape_)); - } else { - mut_shape_.reset(); - } - data_type_ = rhs.data_type_; - mem_case_ = rhs.mem_case_; - blob_access_checker_ = rhs.blob_access_checker_; -} - -Tensor& Tensor::operator=(Tensor&& rhs) { - dptr_ = rhs.dptr_; - shape_ = rhs.shape_; - mut_shape_ = std::move(rhs.mut_shape_); - data_type_ = rhs.data_type_; - mem_case_ = rhs.mem_case_; - blob_access_checker_ = rhs.blob_access_checker_; - return *this; -} - #ifdef WITH_CUDA template<> void Tensor::CheckDataType() const { - LOG_IF(FATAL, data_type_ != DataType::kFloat16) + LOG_IF(FATAL, data_type() != DataType::kFloat16) << "tensor data_type mismatched. value: kFloat16, template T: half"; } diff --git a/oneflow/core/framework/tensor.h b/oneflow/core/framework/tensor.h index 0c57276eb260698cea196a92456c652e125bbbbd..e444961ee2bc3eeabfa5810f4a81464f5a2bb632 100644 --- a/oneflow/core/framework/tensor.h +++ b/oneflow/core/framework/tensor.h @@ -19,63 +19,44 @@ limitations under the License. #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/memory/memory_case.pb.h" + namespace oneflow { class Blob; -class BlobAccessChecker; namespace user_op { -class Tensor final { +class Tensor { public: - Tensor(Blob*); ~Tensor() = default; - Tensor(const Tensor& rhs) { this->CopyWithoutData(rhs); } - Tensor(Tensor&& rhs) { *this = std::move(rhs); } - void CopyWithoutData(const Tensor& rhs); - Tensor& operator=(Tensor&& rhs); - - const ShapeView& shape() const { return shape_; } - MutShapeView* mut_shape() { - this->header_access_check(); - return mut_shape_.get(); - } - - DataType data_type() const { return data_type_; } - const MemoryCase& mem_case() const { return *mem_case_; } + virtual const ShapeView& shape() const = 0; + virtual MutShapeView* mut_shape() = 0; + virtual DataType data_type() const = 0; + virtual const MemoryCase& mem_case() const = 0; + virtual const void* raw_dptr() const = 0; + virtual void* mut_raw_dptr() = 0; template const T* dptr() const { CheckDataType(); - return static_cast(dptr_); + return reinterpret_cast(raw_dptr()); } template T* mut_dptr() { - this->body_access_check(); CheckDataType(); - return static_cast(dptr_); + return reinterpret_cast(mut_raw_dptr()); } - private: + protected: template void CheckDataType() const { LOG_IF(FATAL, (std::is_same::value == false && std::is_same::value == false - && data_type_ != DataType::kChar && data_type_ != GetDataType::value)) - << "tensor data_type mismatched. value: " << DataType_Name(data_type_) + && data_type() != DataType::kChar && data_type() != GetDataType::value)) + << "tensor data_type mismatched. value: " << DataType_Name(data_type()) << ", template T:" << DataType_Name(GetDataType::value); } - - void header_access_check(); - void body_access_check(); - - void* dptr_; - ShapeView shape_; - std::unique_ptr mut_shape_; - DataType data_type_; - const MemoryCase* mem_case_; - const BlobAccessChecker* blob_access_checker_; }; } // namespace user_op diff --git a/oneflow/core/kernel/blob_tensor_view.cpp b/oneflow/core/kernel/blob_tensor_view.cpp new file mode 100644 index 0000000000000000000000000000000000000000..97c58b78e4a865f7bd08d1eb836192aa3f057d71 --- /dev/null +++ b/oneflow/core/kernel/blob_tensor_view.cpp @@ -0,0 +1,41 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/kernel/blob_tensor_view.h" +#include "oneflow/core/register/blob.h" + +namespace oneflow { + +namespace user_op { + +BlobTensorView::BlobTensorView(Blob* blob) : blob_(blob) {} + +const ShapeView& BlobTensorView::shape() const { return blob_->shape(); } + +MutShapeView* BlobTensorView::mut_shape() { return blob_->mut_shape_view(); } + +DataType BlobTensorView::data_type() const { return blob_->data_type(); } + +const MemoryCase& BlobTensorView::mem_case() const { return blob_->mem_case(); } + +const void* BlobTensorView::raw_dptr() const { return blob_->dptr(); } + +void* BlobTensorView::mut_raw_dptr() { return blob_->mut_dptr(); } + +void BlobTensorView::Reset(Blob* blob) { blob_ = blob; } + +} // namespace user_op + +} // namespace oneflow diff --git a/oneflow/core/kernel/blob_tensor_view.h b/oneflow/core/kernel/blob_tensor_view.h new file mode 100644 index 0000000000000000000000000000000000000000..ed63bafe56009c712bae0c1bf993f182bdc294b6 --- /dev/null +++ b/oneflow/core/kernel/blob_tensor_view.h @@ -0,0 +1,49 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_KERNEL_BLOB_TENSOR_VIEW_H_ +#define ONEFLOW_CORE_KERNEL_BLOB_TENSOR_VIEW_H_ + +#include "oneflow/core/framework/tensor.h" + +namespace oneflow { + +class Blob; + +namespace user_op { + +class BlobTensorView final : public Tensor { + public: + explicit BlobTensorView(Blob* blob); + ~BlobTensorView() = default; + + const ShapeView& shape() const override; + MutShapeView* mut_shape() override; + DataType data_type() const override; + const MemoryCase& mem_case() const override; + const void* raw_dptr() const override; + void* mut_raw_dptr() override; + + void Reset(Blob* blob); + + private: + Blob* blob_; +}; + +} // namespace user_op + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_BLOB_TENSOR_VIEW_H_ diff --git a/oneflow/core/kernel/user_kernel.cpp b/oneflow/core/kernel/user_kernel.cpp index 54a13797f73efd7fb7e28fa35e4ff0350ee55cdc..ae0ea6a291aebf78951627f67ac8e172a9dab6e6 100644 --- a/oneflow/core/kernel/user_kernel.cpp +++ b/oneflow/core/kernel/user_kernel.cpp @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/framework/op_kernel_infer_cache.h" #include "oneflow/core/framework/tensor.h" +#include "oneflow/core/kernel/blob_tensor_view.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/user_op_registry_manager.h" @@ -26,7 +27,8 @@ limitations under the License. namespace oneflow { -using Arg2Tensor = HashMap, std::unique_ptr>; +using Arg2Tensor = + HashMap, std::unique_ptr>; using ArgVec = std::vector>; namespace { @@ -100,7 +102,7 @@ class KernelCreateContext final : public user_op::KernelCreateContext { explicit KernelCreateContext(const KernelConf& kernel_conf) : user_op_conf_(kernel_conf.op_attribute().op_conf()) {} - const user_op::UserOpConfWrapper& user_op_conf() const { return user_op_conf_; } + const user_op::UserOpConfWrapper& user_op_conf() const override { return user_op_conf_; } private: user_op::UserOpConfWrapper user_op_conf_; @@ -304,13 +306,13 @@ class UserKernelInferContext final : public user_op::KernelInferContext { void UpdateArg2Tensor(const std::function& BnInOp2Blob) { for (auto& pair : arg2tensor_) { const auto& arg_pair = pair.first; - std::unique_ptr* arg_tensor_ptr = &pair.second; + std::unique_ptr* arg_tensor_ptr = &pair.second; Blob* blob = BnInOp2Blob(GenRepeatedBn(arg_pair.first, arg_pair.second)); if (blob == nullptr) { continue; } if (*arg_tensor_ptr) { - *(arg_tensor_ptr->get()) = std::move(user_op::Tensor(blob)); + arg_tensor_ptr->get()->Reset(blob); } else { - arg_tensor_ptr->reset(new user_op::Tensor(blob)); + arg_tensor_ptr->reset(new user_op::BlobTensorView(blob)); } } } @@ -320,9 +322,32 @@ class UserKernelInferContext final : public user_op::KernelInferContext { UserKernelBaseContext base_ctx_; UserKernelOpInferContext op_infer_ctx_; user_op::TensorDescInferFn tensor_desc_infer_fn_; - HashMap, std::unique_ptr> arg2tensor_; + HashMap, std::unique_ptr> arg2tensor_; }; +namespace { + +struct BnTensorPair { + std::string bn; + std::unique_ptr tensor; +}; + +BnTensorPair MakeBnTensorPair(const std::string& bn) { + BnTensorPair pair; + pair.bn = bn; + return pair; +} + +BnTensorPair MakeBnTensorPair(const std::string& bn, + std::unique_ptr&& tensor) { + BnTensorPair pair; + pair.bn = bn; + pair.tensor = std::move(tensor); + return pair; +} + +} // namespace + class UserKernelComputeContext final : public user_op::KernelComputeContext { public: explicit UserKernelComputeContext(DeviceCtx* device_ctx, const KernelConf& kernel_conf, @@ -332,16 +357,18 @@ class UserKernelComputeContext final : public user_op::KernelComputeContext { device_ctx_(device_ctx), base_ctx_(std::move(UserKernelBaseContext(kernel_conf, job_desc))) { auto InitInOrOut = [&](const PbMap& arg_map) { - for (auto it = arg_map.begin(); it != arg_map.end(); ++it) { - const std::string& arg_name = it->first; - for (int32_t i = 0; i < it->second.s_size(); ++i) { - arg2tensor_.emplace(std::make_pair(arg_name, i), std::unique_ptr()); + for (const auto& it : arg_map) { + const std::string& arg_name = it.first; + for (int32_t i = 0; i < it.second.s_size(); ++i) { + arg2bn_tensor_pair_.emplace(std::make_pair(arg_name, i), + MakeBnTensorPair(GenRepeatedBn(arg_name, i))); } } }; InitInOrOut(kernel_conf.op_attribute().op_conf().user_conf().input()); InitInOrOut(kernel_conf.op_attribute().op_conf().user_conf().output()); - arg2tensor_.emplace(std::make_pair("tmp_buffer", 0), std::unique_ptr()); + arg2bn_tensor_pair_.emplace(std::make_pair("tmp_buffer", 0), + MakeBnTensorPair(GenRepeatedBn("tmp_buffer", 0))); } ~UserKernelComputeContext() = default; @@ -351,22 +378,21 @@ class UserKernelComputeContext final : public user_op::KernelComputeContext { } user_op::Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { - auto it = arg2tensor_.find(std::make_pair(arg_name, index)); - if (it == arg2tensor_.end()) { return nullptr; } - return it->second.get(); + auto it = arg2bn_tensor_pair_.find(std::make_pair(arg_name, index)); + if (it == arg2bn_tensor_pair_.end()) { return nullptr; } + return it->second.tensor.get(); } DeviceCtx* device_ctx() override { return device_ctx_; } - void UpdateTensorWithCorrBlob(std::function BnInOp2Blob) { - for (auto& pair : arg2tensor_) { - const auto& arg_pair = pair.first; - std::unique_ptr* arg_tensor_ptr = &pair.second; - Blob* blob = BnInOp2Blob(GenRepeatedBn(arg_pair.first, arg_pair.second)); + void UpdateTensorWithCorrBlob(const std::function& BnInOp2Blob) { + for (auto& pair : arg2bn_tensor_pair_) { + std::unique_ptr* arg_tensor_ptr = &pair.second.tensor; + Blob* blob = BnInOp2Blob(pair.second.bn); if (blob == nullptr) { continue; } if (*arg_tensor_ptr) { - *(arg_tensor_ptr->get()) = std::move(user_op::Tensor(blob)); + arg_tensor_ptr->get()->Reset(blob); } else { - arg_tensor_ptr->reset(new user_op::Tensor(blob)); + arg_tensor_ptr->reset(new user_op::BlobTensorView(blob)); } } } @@ -380,7 +406,7 @@ class UserKernelComputeContext final : public user_op::KernelComputeContext { private: DeviceCtx* device_ctx_; - Arg2Tensor arg2tensor_; + HashMap, BnTensorPair> arg2bn_tensor_pair_; UserKernelBaseContext base_ctx_; }; diff --git a/oneflow/user/summary/event_writer_helper.cpp b/oneflow/user/summary/event_writer_helper.cpp index cdaa653799568bfb3d8b2dd332a6e4dd35f885cb..14e4ef62d02d9d795cb5e2c8ebe0fb389efb40d9 100644 --- a/oneflow/user/summary/event_writer_helper.cpp +++ b/oneflow/user/summary/event_writer_helper.cpp @@ -132,7 +132,7 @@ Maybe FillImageInSummary(const user_op::Tensor& tensor, const std::string& const int64_t hw = h * w; const int64_t depth = static_cast(tensor.shape().At(3)); if (tensor.data_type() == DataType::kUInt8) { - auto ith_image = [tensor, hw, depth](int i) { + auto ith_image = [&tensor, hw, depth](int i) { auto images = tensor.dptr(); uint8_t* image_i = (uint8_t*)malloc(sizeof(uint8_t) * hw * depth); memcpy(image_i, images + i * hw * depth, hw * depth);