From caaa5f86b91beda67daf8ae295cf99fa4dce12ba Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Fri, 11 Aug 2017 15:09:04 -0700 Subject: [PATCH] gather op added --- paddle/framework/CMakeLists.txt | 2 ++ paddle/framework/empty_test.cc | 56 +++++++++++++++++++++++++++++++++ paddle/operators/gather_op.cc | 2 ++ 3 files changed, 60 insertions(+) create mode 100644 paddle/framework/empty_test.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 0398526024..9e306c8650 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -10,6 +10,8 @@ cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor) +cc_test(empty_test SRCS empty_test.cc DEPS tensor) + cc_test(variable_test SRCS variable_test.cc) cc_library(scope SRCS scope.cc) diff --git a/paddle/framework/empty_test.cc b/paddle/framework/empty_test.cc new file mode 100644 index 0000000000..2237f8ce0e --- /dev/null +++ b/paddle/framework/empty_test.cc @@ -0,0 +1,56 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#include +#include +#include "paddle/framework/tensor.h" + +TEST(Empty, Dims) { + using namespace paddle::framework; + using namespace paddle::platform; + Tensor tt; + tt.Resize(make_ddim({0, 3, 4})); + DDim dims = tt.dims(); + ASSERT_EQ(arity(dims), 3); + EXPECT_EQ(0, dims[0]); + EXPECT_EQ(3, dims[1]); + EXPECT_EQ(4, dims[2]); +} + +TEST(Empty, MutableData) { + using namespace paddle::framework; + using namespace paddle::platform; + { + Tensor src_tensor; + float* p1 = nullptr; + // initialization + p1 = src_tensor.mutable_data(make_ddim({0, 2, 3}), CPUPlace()); + EXPECT_NE(p1, nullptr); + } + +#ifndef PADDLE_ONLY_CPU + { + Tensor src_tensor; + float* p1 = nullptr; + float* p2 = nullptr; + // initialization + p1 = src_tensor.mutable_data(make_ddim({0, 2, 3}), GPUPlace()); + EXPECT_NE(p1, nullptr); + // set src_tensor a new dim with large size + // momery is supposed to be re-allocated + p2 = src_tensor.mutable_data(make_ddim({0, 4}), GPUPlace()); + EXPECT_NE(p2, nullptr); + // EXPECT_NE(p1, p2); + } +#endif +} diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index 1008a57a87..3414a3c263 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -23,6 +23,8 @@ class GatherOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE(ctx.InputSize() == 2, ""); PADDLE_ENFORCE(ctx.OutputSize() == 1, ""); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), + "Inputs of GatherOp must all be set"); int batch_size = ctx.Input(1)->dims()[0]; PADDLE_ENFORCE(batch_size > 0); } -- GitLab