From 39d79e64196049b6879612305bed604faac8a2dd Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Thu, 14 Sep 2017 14:20:33 -0700 Subject: [PATCH] modified codes --- paddle/framework/tensor.h | 11 +------ paddle/framework/tensor_test.cc | 8 ++--- paddle/operators/CMakeLists.txt | 3 +- paddle/operators/cond_op.cc | 55 ++++++++++++++++++++------------- paddle/operators/cond_op.h | 17 ++++++++-- 5 files changed, 53 insertions(+), 41 deletions(-) diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 20f019892b..4b5a2ae523 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -78,9 +78,6 @@ class Tensor { template inline T* mutable_data(DDim dims, platform::Place place); - /*! Size of a single element in data() */ - inline size_t element_size() { return holder_->element_size(); } - /*! Return the dimensions of the memory block. */ inline const DDim& dims() const; @@ -132,7 +129,6 @@ class Tensor { virtual ~Placeholder() {} virtual void* ptr() const = 0; virtual size_t size() const = 0; - virtual size_t element_size() const = 0; virtual std::type_index type() const = 0; virtual platform::Place place() const = 0; }; @@ -143,8 +139,7 @@ class Tensor { : ptr_(static_cast(memory::Alloc(place, size)), memory::PODDeleter(place)), place_(place), - size_(size), - element_size_(sizeof(T)) { + size_(size) { PADDLE_ENFORCE_NOT_NULL(ptr_, "Insufficient %s memory to allocation.", (is_cpu_place(place_) ? "CPU" : "GPU")); } @@ -153,7 +148,6 @@ class Tensor { virtual platform::Place place() const { return place_; } virtual void* ptr() const { return static_cast(ptr_.get()); } virtual std::type_index type() const { return std::type_index(typeid(T)); } - virtual size_t element_size() const { return element_size_; } /*! the pointer of memory block. */ std::unique_ptr> ptr_; @@ -163,9 +157,6 @@ class Tensor { /*! the size of memory block. */ size_t size_; - - /*! the size of a single element */ - size_t element_size_; }; /*! holds the memory block if allocated. */ diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index 8491536e6f..e2ec738de3 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -36,7 +36,7 @@ TEST(Tensor, DataAssert) { } catch (paddle::platform::EnforceNotMet err) { caught = true; std::string msg = - "holder_ should not be null\nTenosr holds no memory. Call " + "holder_ should not be null\nTensor holds no memory. Call " "Tensor::mutable_data first."; const char* what = err.what(); for (size_t i = 0; i < msg.length(); ++i) { @@ -59,8 +59,6 @@ TEST(Tensor, MutableData) { // initialization p1 = src_tensor.mutable_data(make_ddim({1, 2, 3}), CPUPlace()); EXPECT_NE(p1, nullptr); - // check tensor type - EXPECT_EQ(src_tensor.element_size(), sizeof(float)); // set src_tensor a new dim with large size // momery is supposed to be re-allocated p2 = src_tensor.mutable_data(make_ddim({3, 4}), CPUPlace()); @@ -114,7 +112,7 @@ TEST(Tensor, ShareDataWith) { } catch (paddle::platform::EnforceNotMet err) { caught = true; std::string msg = - "holder_ should not be null\nTenosr holds no memory. Call " + "holder_ should not be null\nTensor holds no memory. Call " "Tensor::mutable_data first."; const char* what = err.what(); for (size_t i = 0; i < msg.length(); ++i) { @@ -276,4 +274,4 @@ TEST(Tensor, ReshapeToMatrix) { Tensor res = ReshapeToMatrix(src, 2); ASSERT_EQ(res.dims()[0], 2 * 3); ASSERT_EQ(res.dims()[1], 4 * 9); -} \ No newline at end of file +} diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 4e83eea4ac..e3e934bccc 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -80,8 +80,7 @@ endfunction() add_subdirectory(math) set(DEPS_OPS - recurrent_op) -set(DEPS_OPS + recurrent_op cond_op) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor net_op) diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc index a3e4a2506f..b2e1ca395d 100644 --- a/paddle/operators/cond_op.cc +++ b/paddle/operators/cond_op.cc @@ -28,6 +28,7 @@ namespace operators { using Scope = framework::Scope; using Variable = framework::Variable; using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; using DDim = framework::DDim; void CondOp::CreateScope(const Scope& scope) const { @@ -41,8 +42,9 @@ void CondOp::CreateScope(const Scope& scope) const { void CondOp::CreateIndexTensor(const Scope& scope) const { auto index_tensors_var = scope.FindVar("IndexTensors"); PADDLE_ENFORCE(index_tensors_var != nullptr, ""); - auto& index_tensors = *index_tensors_var->GetMutable>(); - index_tensors.push_back(Tensor()); + auto& index_tensors = + *index_tensors_var->GetMutable>(); + index_tensors.push_back(LoDTensor()); } void CondOp::InferShape(const Scope& scope) const { @@ -65,8 +67,8 @@ void CondOp::InferShape(const Scope& scope) const { for (auto& input : Inputs("Xs")) { // Create a new tensor in sub-scope for input-type tensor Variable* v = sub_scopes[i]->NewVar(input); - Tensor* sub_input = v->GetMutable(); - sub_input->Resize(scope.FindVar(input)->GetMutable()->dims()); + LoDTensor* sub_input = v->GetMutable(); + sub_input->Resize(scope.FindVar(input)->GetMutable()->dims()); } for (auto& output : (*sub_net_op_[i]).Outputs()) { @@ -80,33 +82,40 @@ void CondOp::InferShape(const Scope& scope) const { } for (auto& output : Outputs("Outs")) { - Tensor* tensor_t_out = sub_scopes[0]->FindVar(output)->GetMutable(); - PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should be NULL"); - Tensor* tensor_f_out = sub_scopes[1]->FindVar(output)->GetMutable(); - PADDLE_ENFORCE_NOT_NULL(tensor_f_out, "True output should be NULL"); + LoDTensor* tensor_t_out = + sub_scopes[0]->FindVar(output)->GetMutable(); + PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should not be NULL"); + LoDTensor* tensor_f_out = + sub_scopes[1]->FindVar(output)->GetMutable(); + PADDLE_ENFORCE_NOT_NULL(tensor_f_out, "False output should not be NULL"); auto* tensor_out_var = scope.FindVar(output); PADDLE_ENFORCE_NOT_NULL(tensor_out_var, "Output not found"); - Tensor* tensor_out = tensor_out_var->GetMutable(); - PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should be NULL"); + LoDTensor* tensor_out = tensor_out_var->GetMutable(); + PADDLE_ENFORCE_NOT_NULL(tensor_t_out, + "True output tensor should not be NULL"); + // check output size should be same PADDLE_ENFORCE_EQ(tensor_t_out->dims(), tensor_f_out->dims(), "Outputs not of the same shape"); tensor_out->Resize(tensor_t_out->dims()); - tensor_out->mutable_data(tensor_out->dims(), platform::CPUPlace()); + // tensor_out->mutable_data(tensor_out->dims(), + // platform::CPUPlace()); + tensor_out->mutable_data(platform::CPUPlace()); } } void CondOp::Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const { - auto sub_scopes = scope.FindVar("SubScopes")->Get>(); - auto index_tensors = - scope.FindVar("IndexTensors")->Get>(); + auto* sub_scopes_var = scope.FindVar("SubScopes"); + auto sub_scopes = sub_scopes_var->Get>(); + auto* index_tensors_var = scope.FindVar("IndexTensors"); + auto index_tensors = index_tensors_var->Get>(); std::string cond_name = Input("Cond"); Variable* cond_var = scope.FindVar(cond_name); PADDLE_ENFORCE_NOT_NULL(cond_var); - const Tensor* cond = cond_var->GetMutable(); + const LoDTensor* cond = cond_var->GetMutable(); // Step 1: get the true/false index at runtime // index_[0]: vector, contains all index for cond[i] == true @@ -139,11 +148,11 @@ void CondOp::Run(const Scope& scope, // find Tensor Variable* v = scope.FindVar(input); PADDLE_ENFORCE_NOT_NULL(v); - Tensor* tensor_parent = v->GetMutable(); + LoDTensor* tensor_parent = v->GetMutable(); v = sub_scopes[i]->FindVar(input); PADDLE_ENFORCE_NOT_NULL(v); - Tensor* tensor_child = v->GetMutable(); + LoDTensor* tensor_child = v->GetMutable(); // Resize child DDim dim = tensor_child->dims(); @@ -157,7 +166,9 @@ void CondOp::Run(const Scope& scope, } // Step 3: run - for (int i = 0; i < 2; ++i) sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx); + for (int i = 0; i < 2; ++i) { + sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx); + } // Step 4: merge output results for (int i = 0; i < 2; ++i) { @@ -166,11 +177,11 @@ void CondOp::Run(const Scope& scope, // find Tensor Variable* v = scope.FindVar(output); PADDLE_ENFORCE_NOT_NULL(v); - Tensor* tensor_parent = v->GetMutable(); + LoDTensor* tensor_parent = v->GetMutable(); v = sub_scopes[i]->FindVar(output); PADDLE_ENFORCE_NOT_NULL(v); - Tensor* tensor_child = v->GetMutable(); + LoDTensor* tensor_child = v->GetMutable(); ScatterUpdate(dev_ctx.GetPlace(), tensor_child, &index_tensors[i], tensor_parent); @@ -192,7 +203,9 @@ class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Sample dependent Cond Operator: -The equation is: Out[i] = subnet_t[i], if Cond[i] == true +Given Cond[i] as a 1/0 vector to indicate true/false +The equation is: +Out[i] = subnet_t[i], if Cond[i] == true Out[i] = subnet_t[i], if Cond[i] == false )DOC"); } diff --git a/paddle/operators/cond_op.h b/paddle/operators/cond_op.h index 27a6e9e3c3..001096d31a 100644 --- a/paddle/operators/cond_op.h +++ b/paddle/operators/cond_op.h @@ -24,6 +24,17 @@ limitations under the License. */ namespace paddle { namespace operators { +/* + * @brief CondOp is a dynamic if-else Operator + * + * It has a input tensor named cond indicating which netop each instance will + * run. + * + * if cond == 1, it will run true_net, which is a NetOp. + * + * if cond == 0, it will run false_net, which is another NetOp. + */ + class CondOp : public framework::OperatorBase { public: CondOp(const std::string& type, const framework::VariableNameMap& inputs, @@ -45,18 +56,18 @@ class CondOp : public framework::OperatorBase { void CreateIndexTensor(const framework::Scope& scope) const; - /** + /* * InferShape must be called before Run. */ void InferShape(const framework::Scope& scope) const override; // Set True Block - void set_truenet(std::unique_ptr net) { + void set_truenet(std::unique_ptr&& net) { sub_net_op_[0] = std::move(net); } // Set False Block - void set_falsenet(std::unique_ptr net) { + void set_falsenet(std::unique_ptr&& net) { sub_net_op_[1] = std::move(net); } -- GitLab