From c9b57dcc8314720ad01aa0e9c2d3f0711657749a Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 6 Nov 2017 10:31:32 -0800 Subject: [PATCH] ReadFromArray/WriteToArray op (#5407) * Use stable_sort in lod_rank_table It is easy to debug and test when use `stable_sort`and the time complexity is not changed. * Add LoDTensorArray * Stash * Better debug message for IsInitialized * Stash * Better debug message for IsInitialized * Complete array read/write op unittests --- paddle/framework/op_desc.cc | 3 + paddle/operators/CMakeLists.txt | 9 +- paddle/operators/fill_constant_op.cc | 7 +- paddle/operators/fill_constant_op.cu | 3 +- paddle/operators/increment_op.cc | 18 +- paddle/operators/increment_op.cu | 5 +- paddle/operators/increment_op.h | 4 +- .../operators/tensor_array_read_write_op.cc | 219 ++++++++++++++++++ paddle/pybind/protobuf.cc | 55 ++--- python/paddle/v2/framework/executor.py | 7 +- python/paddle/v2/framework/framework.py | 16 +- python/paddle/v2/framework/layers.py | 70 +++++- .../tests/test_array_read_write_op.py | 66 ++++++ .../tests/test_framework_debug_str.py | 13 ++ 14 files changed, 430 insertions(+), 65 deletions(-) create mode 100644 paddle/operators/tensor_array_read_write_op.cc create mode 100644 python/paddle/v2/framework/tests/test_array_read_write_op.py create mode 100644 python/paddle/v2/framework/tests/test_framework_debug_str.py diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index c96166f35d1..495acf4c0a6 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -349,6 +349,9 @@ void OpDescBind::InferVarType(BlockDescBind *block) const { info.infer_var_type_(*this, block); } else { // all output type is LoDTensor by default + VLOG(10) << this->Type() + << " has not registered InferVarType. Set output variables to " + "LOD_TENSOR"; for (auto &out_pair : this->outputs_) { for (auto &out_var_name : out_pair.second) { block->Var(out_var_name)->SetType(VarDesc::LOD_TENSOR); diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 384f004e0ec..f22f86468db 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -110,7 +110,7 @@ function(op_library TARGET) # It's enough to just adding one operator to pybind file(APPEND ${pybind_file} "USE_GPU_ONLY_OP(ncclAllReduce);\n") endif() - + # reduce_op contains several operators if ("${TARGET}" STREQUAL "reduce_op") set(pybind_flag 1) @@ -118,6 +118,11 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n") endif() + if ("${TARGET}" STREQUAL "tensor_array_read_write_op") + set(pybind_flag 1) + file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(write_to_array);\n") + endif() + # pybind USE_NO_KERNEL_OP # HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel file(READ ${TARGET}.cc TARGET_CONTENT) @@ -161,6 +166,7 @@ set(DEPS_OPS sequence_pool_op lod_rank_table_op lstm_op + tensor_array_read_write_op gru_op) op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) @@ -171,6 +177,7 @@ op_library(sum_op DEPS net_op selected_rows_functor) op_library(pool_op DEPS pooling) op_library(pool_with_index_op DEPS pooling) op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table) +op_library(tensor_array_read_write_op SRCS tensor_array_read_write_op.cc) if(WITH_GPU) op_library(nccl_op DEPS nccl_common) endif() diff --git a/paddle/operators/fill_constant_op.cc b/paddle/operators/fill_constant_op.cc index ee2219cd033..f60425051cc 100644 --- a/paddle/operators/fill_constant_op.cc +++ b/paddle/operators/fill_constant_op.cc @@ -35,7 +35,9 @@ class FillConstantOp : public framework::OperatorWithKernel { protected: framework::DataType IndicateDataType( const framework::ExecutionContext &ctx) const override { - return static_cast(ctx.Attr("data_type")); + int data_type = ctx.Attr("data_type"); + VLOG(10) << " FillConstant data_type = " << data_type; + return static_cast(data_type); } }; @@ -71,4 +73,5 @@ REGISTER_OP_WITHOUT_GRADIENT(fill_constant, ops::FillConstantOp, REGISTER_OP_CPU_KERNEL( fill_constant, ops::FillConstantOpKernel, ops::FillConstantOpKernel, - ops::FillConstantOpKernel); + ops::FillConstantOpKernel, + ops::FillConstantOpKernel); diff --git a/paddle/operators/fill_constant_op.cu b/paddle/operators/fill_constant_op.cu index a57b11c6cba..bca402a8b98 100644 --- a/paddle/operators/fill_constant_op.cu +++ b/paddle/operators/fill_constant_op.cu @@ -20,4 +20,5 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( fill_constant, ops::FillConstantOpKernel, ops::FillConstantOpKernel, - ops::FillConstantOpKernel); + ops::FillConstantOpKernel, + ops::FillConstantOpKernel); diff --git a/paddle/operators/increment_op.cc b/paddle/operators/increment_op.cc index c3e9308fe0a..deb02bf2bf8 100644 --- a/paddle/operators/increment_op.cc +++ b/paddle/operators/increment_op.cc @@ -31,7 +31,6 @@ class IncrementOp : public framework::OperatorWithKernel { } }; -template class IncrementOpMaker : public framework::OpProtoAndCheckerMaker { public: IncrementOpMaker(framework::OpProto *proto, @@ -39,10 +38,10 @@ class IncrementOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "(Tensor) The input tensor of increment operator"); AddOutput("Out", "(Tensor) The output tensor of increment operator."); - AddAttr("step", - "(float, default 1.0) " - "The step size by which the " - "input tensor will be incremented.") + AddAttr("step", + "(float, default 1.0) " + "The step size by which the " + "input tensor will be incremented.") .SetDefault(1.0); AddComment(R"DOC( Increment Operator. @@ -73,7 +72,10 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker { namespace ops = paddle::operators; -REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementOpMaker, +REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementOpMaker, ops::IncrementGradOpMaker); -REGISTER_OP_CPU_KERNEL(increment, - ops::IncrementKernel); +REGISTER_OP_CPU_KERNEL( + increment, ops::IncrementKernel, + ops::IncrementKernel, + ops::IncrementKernel, + ops::IncrementKernel); diff --git a/paddle/operators/increment_op.cu b/paddle/operators/increment_op.cu index 659c380d147..f97a6c46852 100644 --- a/paddle/operators/increment_op.cu +++ b/paddle/operators/increment_op.cu @@ -16,4 +16,7 @@ REGISTER_OP_GPU_KERNEL( increment, - paddle::operators::IncrementKernel); + paddle::operators::IncrementKernel, + paddle::operators::IncrementKernel, + paddle::operators::IncrementKernel, + paddle::operators::IncrementKernel); diff --git a/paddle/operators/increment_op.h b/paddle/operators/increment_op.h index 342e254fc45..3d53256dd1a 100644 --- a/paddle/operators/increment_op.h +++ b/paddle/operators/increment_op.h @@ -19,7 +19,7 @@ namespace paddle { namespace operators { -template +template class IncrementKernel : public framework::OpKernel { public: virtual void Compute(const framework::ExecutionContext& context) const { @@ -27,7 +27,7 @@ class IncrementKernel : public framework::OpKernel { auto* in = context.Input("X"); tensor->mutable_data(in->place()); - auto step = static_cast(context.Attr("step")); + auto step = static_cast(context.Attr("step")); auto eigen_out = framework::EigenVector::Flatten(*tensor); auto eigen_in = framework::EigenVector::Flatten(*in); diff --git a/paddle/operators/tensor_array_read_write_op.cc b/paddle/operators/tensor_array_read_write_op.cc new file mode 100644 index 00000000000..11eebfe9e6f --- /dev/null +++ b/paddle/operators/tensor_array_read_write_op.cc @@ -0,0 +1,219 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#include "paddle/framework/lod_tensor_array.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { +class ArrayOpBase : public framework::OperatorBase { + public: + ArrayOpBase(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override {} + + protected: + size_t GetOffset(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const { + auto *i = scope.FindVar(Input("I")); + PADDLE_ENFORCE(i != nullptr, "I must be set"); + auto &i_tensor = i->Get(); + PADDLE_ENFORCE_EQ(i_tensor.numel(), 1); + size_t offset; + if (platform::is_gpu_place(i_tensor.place())) { + // FIXME: Avoid copy from GPU to CPU + framework::Tensor t; + t.CopyFrom(i_tensor, platform::CPUPlace(), dev_ctx); + dev_ctx.Wait(); + offset = static_cast(*t.data()); + } else { + offset = static_cast(*i_tensor.data()); + } + return offset; + } +}; + +class WriteToArrayOp : public ArrayOpBase { + public: + WriteToArrayOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : ArrayOpBase(type, inputs, outputs, attrs) {} + + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto *x = scope.FindVar(Input("X")); + PADDLE_ENFORCE(x != nullptr, "X must be set"); + auto &x_tensor = x->Get(); + size_t offset = GetOffset(scope, dev_ctx); + auto *out = + scope.FindVar(Output("Out"))->GetMutable(); + if (offset >= out->size()) { + out->resize(offset + 1); + } + auto *out_tensor = &out->at(offset); + out_tensor->CopyFrom(x_tensor, dev_ctx.GetPlace(), dev_ctx); + out_tensor->set_lod(x_tensor.lod()); + } +}; + +class WriteToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + WriteToArrayOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "(LoDTensor) the tensor will be written to tensor array"); + AddInput( + "I", + "(Tensor) the subscript index in tensor array. The number of element " + "should be 1"); + AddOutput("Out", "(TensorArray) the tensor array will be written"); + AddComment(R"DOC(Write a LoDTensor to a LoDTensor array. + +Assume T is LoDTensor, i is the subscript of the array, and A is the array. The +equation is + +A[i] = T +)DOC"); + } +}; + +class WriteToArrayInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE(context->HasInput("I"), "Must set the subscript index"); + PADDLE_ENFORCE_EQ(framework::product(context->GetInputDim("I")), 1, + "The number of element of subscript index must be 1"); + PADDLE_ENFORCE(context->HasInput("X"), NotHasXError()); + PADDLE_ENFORCE(context->HasOutput("Out"), NotHasOutError()); + context->SetOutputDim("Out", context->GetInputDim("X")); + } + + protected: + virtual const char *NotHasXError() const { return "Must set the lod tensor"; } + + virtual const char *NotHasOutError() const { + return "Must set the lod tensor array"; + } +}; + +class WriteToArrayInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDescBind &op_desc, + framework::BlockDescBind *block) const override { + VLOG(10) << "I am here?"; + for (auto &out_var : op_desc.OutputArgumentNames()) { + VLOG(10) << "Set Variable " << out_var << " as LOD_TENSOR_ARRAY"; + block->Var(out_var)->SetType(framework::VarDesc::LOD_TENSOR_ARRAY); + } + } +}; + +class ReadFromArrayOp : public ArrayOpBase { + public: + ReadFromArrayOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : ArrayOpBase(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto *x = scope.FindVar(Input("X")); + PADDLE_ENFORCE(x != nullptr, "X must be set"); + auto &x_array = x->Get(); + auto *out = scope.FindVar(Output("Out")); + PADDLE_ENFORCE(out != nullptr, "Out must be set"); + auto *out_tesnor = out->GetMutable(); + size_t offset = GetOffset(scope, dev_ctx); + PADDLE_ENFORCE_LT(offset, x_array.size()); + out_tesnor->CopyFrom(x_array[offset], dev_ctx.GetPlace(), dev_ctx); + out_tesnor->set_lod(x_array[offset].lod()); + } +}; + +class ReadFromArrayProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + ReadFromArrayProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "(TensorArray) the array will be read from."); + AddInput("I", + "(Tensor) the subscript index in tensor array. The number of " + "element should be 1"); + AddOutput("Out", "(LoDTensor) the tensor will be read from."); + AddComment(R"DOC(Read a LoDTensor from a LoDTensor Array + +Assume T is LoDTensor, i is th e subscript of the array, and A is the array. The +equation is + +T = A[i] +)DOC"); + } +}; + +class ReadFromArrayInferShape : public WriteToArrayInferShape { + protected: + const char *NotHasXError() const override { + return "The input array X must be set"; + } + const char *NotHasOutError() const override { + return "The output tensor out must be set"; + } +}; + +class WriteToArrayGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *grad_op = new framework::OpDescBind(); + grad_op->SetType("read_from_array"); + grad_op->SetInput("I", Input("I")); + grad_op->SetInput("X", OutputGrad("Out")); + grad_op->SetOutput("Out", InputGrad("X")); + grad_op->SetAttrMap(Attrs()); + return std::unique_ptr(grad_op); + } +}; + +class ReadFromArrayGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *grad_op = new framework::OpDescBind(); + grad_op->SetType("write_to_array"); + grad_op->SetInput("I", Input("I")); + grad_op->SetInput("X", OutputGrad("Out")); + grad_op->SetOutput("Out", InputGrad("X")); + grad_op->SetAttrMap(Attrs()); + return std::unique_ptr(grad_op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(write_to_array, ops::WriteToArrayOp, + ops::WriteToArrayInferShape, ops::WriteToArrayOpProtoMaker, + ops::WriteToArrayGradMaker, ops::WriteToArrayInferVarType); +REGISTER_OPERATOR(read_from_array, ops::ReadFromArrayOp, + ops::ReadFromArrayInferShape, ops::ReadFromArrayProtoMaker, + ops::ReadFromArrayGradMaker); diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 5462e6c6c7a..5a1ff9b7976 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -97,6 +97,15 @@ namespace pybind { using namespace paddle::framework; // NOLINT +template +static py::bytes SerializeMessage(T &self) { + // Check IsInitialized in Python + std::string retv; + PADDLE_ENFORCE(self.Proto()->SerializePartialToString(&retv), + "Cannot serialize message"); + return retv; +} + // Bind Methods void BindProgramDesc(py::module &m) { py::class_(m, "ProgramDesc", "") @@ -132,17 +141,7 @@ void BindProgramDesc(py::module &m) { .def("block", &ProgramDescBind::MutableBlock, py::return_value_policy::reference) .def("num_blocks", &ProgramDescBind::Size) - .def("serialize_to_string", - [](ProgramDescBind &program_desc) -> py::bytes { - const ProgramDesc *desc = program_desc.Proto(); - PADDLE_ENFORCE(desc->IsInitialized(), - "ProgramDesc has not been initialized."); - std::string res; - PADDLE_ENFORCE( - desc->SerializeToString(&res), - "Serialize ProgramDesc Error. This could be a bug of Paddle."); - return res; - }) + .def("serialize_to_string", SerializeMessage) .def("parse_from_string", [](ProgramDescBind &program_desc, const std::string &data) { ProgramDesc *desc = program_desc.Proto(); @@ -181,16 +180,7 @@ void BindBlockDesc(py::module &m) { py::return_value_policy::reference) .def("op_size", &BlockDescBind::OpSize) .def("op", &BlockDescBind::Op, py::return_value_policy::reference) - .def("serialize_to_string", [](BlockDescBind &block_desc) -> py::bytes { - const BlockDesc *desc = block_desc.Proto(); - PADDLE_ENFORCE(desc->IsInitialized(), - "BlockDesc has not been initialized."); - std::string res; - PADDLE_ENFORCE( - desc->SerializeToString(&res), - "Serialize BlockDesc Error. This could be a bug of Paddle."); - return res; - }); + .def("serialize_to_string", SerializeMessage); } void BindVarDsec(py::module &m) { @@ -219,17 +209,7 @@ void BindVarDsec(py::module &m) { .def("set_lod_level", &VarDescBind::SetLoDLevel) .def("type", &VarDescBind::GetType) .def("set_type", &VarDescBind::SetType) - .def("serialize_to_string", - [](VarDescBind &var_desc) -> py::bytes { - const VarDesc *desc = var_desc.Proto(); - PADDLE_ENFORCE(desc->IsInitialized(), - "VarDesc has not been initialized."); - std::string res; - PADDLE_ENFORCE( - desc->SerializeToString(&res), - "Serialize VarDesc Error. This could be a bug of Paddle."); - return res; - }) + .def("serialize_to_string", SerializeMessage) .def("persistable", &VarDescBind::Persistable) .def("set_persistable", &VarDescBind::SetPersistable); @@ -274,16 +254,7 @@ void BindOpDesc(py::module &m) { .def("check_attrs", &OpDescBind::CheckAttrs) .def("infer_shape", &OpDescBind::InferShape) .def("infer_var_type", &OpDescBind::InferVarType) - .def("serialize_to_string", [](OpDescBind &op_desc) -> py::bytes { - const OpDesc *desc = op_desc.Proto(); - PADDLE_ENFORCE(desc->IsInitialized(), - "OpDesc has not been initialized."); - std::string res; - PADDLE_ENFORCE( - desc->SerializeToString(&res), - "Serialize OpDesc Error. This could be a bug of Paddle."); - return res; - }); + .def("serialize_to_string", SerializeMessage); } } // namespace pybind diff --git a/python/paddle/v2/framework/executor.py b/python/paddle/v2/framework/executor.py index 8268d0d8f51..f5c833190e7 100644 --- a/python/paddle/v2/framework/executor.py +++ b/python/paddle/v2/framework/executor.py @@ -1,5 +1,5 @@ import paddle.v2.framework.core as core -from paddle.v2.framework.framework import Block, Program +from paddle.v2.framework.framework import Block, Program, g_main_program g_scope = core.Scope() @@ -18,7 +18,7 @@ class Executor(object): self.executor = core.Executor(act_places) def run(self, - program, + program=None, feed=None, fetch_list=None, feed_var_name='feed', @@ -29,6 +29,9 @@ class Executor(object): if fetch_list is None: fetch_list = [] + if program is None: + program = g_main_program + if not isinstance(program, Program): raise TypeError() diff --git a/python/paddle/v2/framework/framework.py b/python/paddle/v2/framework/framework.py index dd23c47961b..8fb3cca91e5 100644 --- a/python/paddle/v2/framework/framework.py +++ b/python/paddle/v2/framework/framework.py @@ -12,6 +12,14 @@ def unique_name(prefix): return "_".join([prefix, str(uid)]) +def _debug_string_(proto): + error_fields = list() + if not proto.IsInitialized(error_fields): + raise ValueError("{0} are not initialized\nThe message is {1}".format( + error_fields, proto)) + return proto.__str__() + + class Variable(object): def __init__(self, block, @@ -95,7 +103,7 @@ class Variable(object): def __str__(self): protostr = self.desc.serialize_to_string() proto = framework_pb2.VarDesc.FromString(str(protostr)) - return proto.__str__() + return _debug_string_(proto) __repr__ = __str__ @@ -286,7 +294,7 @@ class Operator(object): def __str__(self): protostr = self.desc.serialize_to_string() proto = framework_pb2.OpDesc.FromString(str(protostr)) - return proto.__str__() + return _debug_string_(proto) __repr__ = __str__ @@ -343,7 +351,7 @@ class Block(object): def __str__(self): protostr = self.desc.serialize_to_string() proto = framework_pb2.BlockDesc.FromString(str(protostr)) - return proto.__str__() + return _debug_string_(proto) __repr__ = __str__ @@ -448,7 +456,7 @@ class Program(object): def __str__(self): protostr = self.desc.serialize_to_string() proto = framework_pb2.ProgramDesc.FromString(str(protostr)) - return proto.__str__() + return _debug_string_(proto) def clone(self): p = Program() diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py index b7e468fb51c..70b6c567200 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/framework/layers.py @@ -1,6 +1,8 @@ import paddle.v2.framework.core as core -from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, Operator -from paddle.v2.framework.initializer import ConstantInitializer, NormalInitializer +from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, \ + Operator +from paddle.v2.framework.initializer import ConstantInitializer, \ + NormalInitializer from paddle.v2.framework.layer_helper import LayerHelper, unique_name import re @@ -751,3 +753,67 @@ def lod_rank_table(x, level=0, main_program=None): outputs={'Out': table}, attrs={'level': level}) return table + + +def fill_constant(shape, dtype, value, main_program=None): + helper = LayerHelper("ones", **locals()) + out = helper.create_tmp_variable(dtype=dtype) + helper.append_op( + type='fill_constant', + inputs={}, + outputs={'Out': [out]}, + attrs={ + 'shape': shape, + 'data_type': out.data_type, + 'value': float(value) + }) + out.stop_gradient = True + return out + + +def ones(shape, dtype, main_program=None): + return fill_constant(value=1.0, **locals()) + + +def zeros(shape, dtype, main_program=None): + return fill_constant(value=0.0, **locals()) + + +def increment(x, value=1.0, main_program=None): + helper = LayerHelper("increment", **locals()) + helper.append_op( + type='increment', + inputs={'X': [x]}, + outputs={'Out': [x]}, + attrs={'step': value}) + return x + + +def array_write(x, i, array=None, main_program=None): + helper = LayerHelper('array_write', **locals()) + if array is None: + array = helper.create_variable( + name="{0}.out".format(helper.name), + type=core.VarDesc.VarType.LOD_TENSOR_ARRAY, + dtype=x.data_type) + helper.append_op( + type='write_to_array', + inputs={'X': [x], + 'I': [i]}, + outputs={'Out': [array]}) + return array + + +def array_read(array, i, main_program=None): + helper = LayerHelper('array_read', **locals()) + if not isinstance( + array, + Variable) or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY: + raise TypeError("array should be tensor array vairable") + out = helper.create_tmp_variable(dtype=array.data_type) + helper.append_op( + type='read_from_array', + inputs={'X': [array], + 'I': [i]}, + outputs={'Out': [out]}) + return out diff --git a/python/paddle/v2/framework/tests/test_array_read_write_op.py b/python/paddle/v2/framework/tests/test_array_read_write_op.py new file mode 100644 index 00000000000..d0bf3d62f9c --- /dev/null +++ b/python/paddle/v2/framework/tests/test_array_read_write_op.py @@ -0,0 +1,66 @@ +import unittest + +import numpy +import paddle.v2.framework.core as core + +import paddle.v2.framework.layers as layers +from paddle.v2.framework.executor import Executor + + +class TestArrayReadWrite(unittest.TestCase): + def test_read_write(self): + x = [ + layers.data( + name='x0', shape=[100]), layers.data( + name='x1', shape=[100]), layers.data( + name='x2', shape=[100]) + ] + + for each_x in x: + each_x.stop_gradient = False + + i = layers.zeros(shape=[1], dtype='int64') + arr = layers.array_write(x=x[0], i=i) + layers.increment(x=i) + arr = layers.array_write(x=x[1], i=i, array=arr) + layers.increment(x=i) + arr = layers.array_write(x=x[2], i=i, array=arr) + + i = layers.zeros(shape=[1], dtype='int64') + a0 = layers.array_read(array=arr, i=i) + layers.increment(x=i) + a1 = layers.array_read(array=arr, i=i) + layers.increment(x=i) + a2 = layers.array_read(array=arr, i=i) + + mean_a0 = layers.mean(x=a0) + mean_a1 = layers.mean(x=a1) + mean_a2 = layers.mean(x=a2) + + a_sum = layers.sums(input=[mean_a0, mean_a1, mean_a2]) + + mean_x0 = layers.mean(x=x[0]) + mean_x1 = layers.mean(x=x[1]) + mean_x2 = layers.mean(x=x[2]) + + x_sum = layers.sums(input=[mean_x0, mean_x1, mean_x2]) + + scope = core.Scope() + cpu = core.CPUPlace() + + exe = Executor(cpu) + + tensor = core.LoDTensor() + tensor.set(numpy.random.random(size=(100, 100)).astype('float32'), cpu) + + outs = map(numpy.array, + exe.run(feed={'x0': tensor, + 'x1': tensor, + 'x2': tensor}, + fetch_list=[a_sum, x_sum], + scope=scope)) + self.assertEqual(outs[0], outs[1]) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_framework_debug_str.py b/python/paddle/v2/framework/tests/test_framework_debug_str.py new file mode 100644 index 00000000000..8fdf8f91171 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_framework_debug_str.py @@ -0,0 +1,13 @@ +import unittest +from paddle.v2.framework.framework import Program + + +class TestDebugStringFramework(unittest.TestCase): + def test_debug_str(self): + p = Program() + p.current_block().create_var(name='t', shape=[0, 1]) + self.assertRaises(ValueError, callableObj=p.__str__) + + +if __name__ == '__main__': + unittest.main() -- GitLab