提交 adfef243 编写于 作者: Z Zhuoyuan

tensor element size support

上级 c3d684ef
......@@ -75,6 +75,9 @@ class Tensor {
template <typename T>
inline T* mutable_data(DDim dims, platform::Place place);
/*! Size of a single element in data() */
inline size_t element_size() { return holder_->element_size(); }
/*! Return the dimensions of the memory block. */
inline const DDim& dims() const;
......@@ -123,6 +126,7 @@ class Tensor {
virtual ~Placeholder() {}
virtual void* ptr() const = 0;
virtual size_t size() const = 0;
virtual size_t element_size() const = 0;
virtual std::type_index type() const = 0;
virtual platform::Place place() const = 0;
};
......@@ -133,7 +137,8 @@ class Tensor {
: ptr_(static_cast<T*>(memory::Alloc(place, size)),
memory::PODDeleter<T, Place>(place)),
place_(place),
size_(size) {
size_(size),
element_size_(sizeof(T)) {
PADDLE_ENFORCE_NOT_NULL(ptr_, "Insufficient %s memory to allocation.",
(is_cpu_place(place_) ? "CPU" : "GPU"));
}
......@@ -142,6 +147,7 @@ class Tensor {
virtual platform::Place place() const { return place_; }
virtual void* ptr() const { return static_cast<void*>(ptr_.get()); }
virtual std::type_index type() const { return std::type_index(typeid(T)); }
virtual size_t element_size() const { return element_size_; }
/*! the pointer of memory block. */
std::unique_ptr<T, memory::PODDeleter<T, Place>> ptr_;
......@@ -151,6 +157,9 @@ class Tensor {
/*! the size of memory block. */
size_t size_;
/*! the size of a single element */
size_t element_size_;
};
/*! holds the memory block if allocated. */
......
......@@ -22,7 +22,7 @@ namespace framework {
template <typename T>
inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tenosr holds no memory. Call Tensor::mutable_data first.");
holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE_GE(
holder_->size(), product(dims_) * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
......
......@@ -59,6 +59,8 @@ TEST(Tensor, MutableData) {
// initialization
p1 = src_tensor.mutable_data<float>(make_ddim({1, 2, 3}), CPUPlace());
EXPECT_NE(p1, nullptr);
// check tensor type
EXPECT_EQ(src_tensor.element_size(), sizeof(float));
// set src_tensor a new dim with large size
// momery is supposed to be re-allocated
p2 = src_tensor.mutable_data<float>(make_ddim({3, 4}), CPUPlace());
......
#include "paddle/operators/switch_op.h"
namespace paddle {
namespace operators {
void CondOp::InferShape(const std::shared_ptr<Scope>& scope) const {
// Create two Nets
// Create two scopes
for (int i = 0; i < 2; ++i)
sub_scope.push_back(scope.NewScope());
for (int i = 0; i < 2; ++i)
sub_net_op_[i].InferShape(sub_scope[i]);
for (int i = 0; i < 2; ++i)
tensor_index = new Tensor();
for (int i = 0; i < 2; ++i)
_index.push_back(vector<int>());
for (int i = 0; i < 2; ++i)
{
// for (auto& input : net_op_[i]->Inputs()) {
for (auto& input : GetAttr<std::vector<std::string>>("True_inputs")) {
auto var_name = input.second;
// Create a new tensor in sub-scope for input-type tensor
sub_scope[i]->NewVar(var_name)->GetMutable<Tensor>();
}
}
}
class CondOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
CondOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Cond", "The condition, which is a bool vector");
AddInput("Xs", "Inputs of Subnets");
AddAttr<std::vector<std::string>>("sub_inputs", "Inputs of the Whole Op, net op and so forth");
AddAttr<std::vector<std::string>>("sub_outputs", "True Outputs needs merge");
AddOutput("Outs", "The output of cond op");
AddComment(R"DOC(
Sample dependent Cond Operator:
The equation is: Out[i] = subnet_t[i], if Cond[i] == true
Out[i] = subnet_t[i], if Cond[i] == false
)DOC");
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(cond_op,
paddle::operators::CondOp,
paddle::operators::CondOpProtoAndCheckerMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/ddim.h"
#include "paddle/operators/gather.h"
#include <vector>
namespace paddle {
namespace operators {
using namespace paddle::framework;
template <typename Place, typename T>
class CondOp final : public OperatorBase {
public:
/**
* InferShape must be called before Run.
*/
void InferShape(const std::shared_ptr<Scope>& scope) const override;
// Set True Block
void set_truenet(std::unique_ptr<OperatorBase> net) {
sub_net_op_[0] = std::move(net);
}
// Set False Block
void set_falsenet(std::unique_ptr<OperatorBase> net) {
sub_net_op_[1] = std::move(net);
}
virtual void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
auto* cond = context.Input<Tensor>("Cond");
// Step 1: get the true/false index at runtime
// _index[0]: vector<int>, contains all index for cond[i] == true
// _index[1]: vector<int>, contains all index for cond[i] == false
for(int i = 0; i < 2; ++i)
_index[i].clear();
for(int i = 0; i < cond->dims()[0]; ++i) {
if (cond->data<bool>()[i])
_index[0].push_back(i);
else
_index[1].push_back(i);
}
// put _index[0] and _index[1] into two tensors
// tensor_index[0] and tensor_index[1]
framework::DDim dim_ = paddle::framework::make_ddim({0});
for(int i = 0; i < 2; ++i) {
dim_[0] = _index[i].size();
int* tmp_ = _index[i]->mutable_data<int>(dim_, CPUPlace());
tensor_index[i]->Resize(dim_);
memcpy(tmp_, index_[i], dim_[0] * sizeof(int));
}
// Step 2: collect data by calling gather
for (int i = 0; i < 2; ++i) {
// i= 0/i for True and False branches respectively
for (auto& input : GetAttr<std::vector<std::string>>("sub_inputs")) {
auto var_name = input.second;
// find Tensor
Tensor* Tensor_parent = scope.FindVar(var_name)->GetMutable<Tensor>();
Tensor* Tensor_child = sub_scope_[i].FindVar(var_name)->GetMutable<Tensor>();
Gather<T>(dev_ctx.GetPlace(), tensor_parent, tensor_index[i], tensor_child);
}
}
// Step 3: run
for (int i = 0; i < 2; ++i)
sub_net_op_[i]->Run(sub_scope_[i], dev_ctx);
// Step 4: merge output results
for (int i = 0; i < 2; ++i) {
// i= 0/i for True and False branches respectively
for (auto& output : GetAttr<std::vector<std::string>>("sub_outputs")) {
auto var_name = output.second;
// find Tensor
Tensor* Tensor_parent = scope.FindVar(var_name)->GetMutable<Tensor>();
Tensor* Tensor_child = sub_scope_[i].FindVar(var_name)->GetMutable<Tensor>();
ScatterUpdate<T>(dev_ctx.GetPlace(), tensor_child, tensor_index[i], tensor_parent);
}
}
}
private:
// sub_scope_[0]: true scope
// sub_scope_[1]: false scope
std::vector<Scope*> sub_scope_;
// sub_net_op_[0]: subnet_t
// sub_net_op_[1]: subnet_f
std::vector<std::unique_ptr<framework::OperatorBase>> sub_net_op_;
// tensor_index[0]: True_index tensor
// tensor_index[1]: False_index;
std::vector<Tensor*> tensor_index;
// _index[0]: True_index;
// _index[1]: False_index;
vector<vector<int> > _index;
};
/*
class CondGradientOp final : public OperatorBase {
public:
void Init() override;
virtual void InferShape(const std::shared_ptr<Scope>& scope) const override;
virtual void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override;
};*/
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册